Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
ccc4a1d3
Unverified
Commit
ccc4a1d3
authored
Aug 16, 2021
by
Chao Liu
Committed by
GitHub
Aug 16, 2021
Browse files
Merge pull request #8 from ROCmSoftwarePlatform/miopen_downstream_init_integration
parents
3b866461
16effa76
Changes
144
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1046 additions
and
1159 deletions
+1046
-1159
composable_kernel/include/tensor_description/tensor_adaptor.hpp
...able_kernel/include/tensor_description/tensor_adaptor.hpp
+3
-5
composable_kernel/include/tensor_description/tensor_descriptor.hpp
...e_kernel/include/tensor_description/tensor_descriptor.hpp
+53
-52
composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp
...l/include/tensor_description/tensor_descriptor_helper.hpp
+22
-23
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp
...el/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp
+45
-47
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp
...el/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp
+10
-10
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
...rnel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
+13
-18
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+55
-59
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp
...lude/tensor_operation/blockwise_tensor_slice_transfer.hpp
+36
-37
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp
...e/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp
+34
-36
composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp
...lude/tensor_operation/gridwise_contraction_dlops_v1r2.hpp
+58
-63
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r2.hpp
...nel/include/tensor_operation/gridwise_gemm_dlops_v1r2.hpp
+119
-136
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp
...nel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp
+76
-97
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
...ernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
+90
-95
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
+127
-151
composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp
...include/tensor_operation/threadwise_contraction_dlops.hpp
+12
-13
composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
...nel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
+3
-5
composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp
.../include/tensor_operation/threadwise_tensor_slice_set.hpp
+8
-8
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
...ude/tensor_operation/threadwise_tensor_slice_transfer.hpp
+165
-177
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp
.../tensor_operation/threadwise_tensor_slice_transfer_v2.hpp
+107
-117
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+10
-10
No files found.
composable_kernel/include/tensor_description/tensor_adaptor.hpp
View file @
ccc4a1d3
...
@@ -2,8 +2,8 @@
...
@@ -2,8 +2,8 @@
#define CK_TENSOR_ADAPTOR_HPP
#define CK_TENSOR_ADAPTOR_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -454,9 +454,7 @@ __host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transf
...
@@ -454,9 +454,7 @@ __host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transf
remove_cv_t
<
decltype
(
top_dim_hidden_ids
)
>>
{
transforms
};
remove_cv_t
<
decltype
(
top_dim_hidden_ids
)
>>
{
transforms
};
}
}
template
<
typename
X
,
template
<
typename
X
,
typename
...
Xs
,
typename
enable_if
<
sizeof
...(
Xs
)
>
=
2
,
bool
>::
type
=
false
>
typename
...
Xs
,
typename
std
::
enable_if
<
sizeof
...(
Xs
)
>
=
2
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
chain_tensor_adaptors
(
const
X
&
x
,
const
Xs
&
...
xs
)
__host__
__device__
constexpr
auto
chain_tensor_adaptors
(
const
X
&
x
,
const
Xs
&
...
xs
)
{
{
return
chain_tensor_adaptors
(
x
,
chain_tensor_adaptors
(
xs
...));
return
chain_tensor_adaptors
(
x
,
chain_tensor_adaptors
(
xs
...));
...
...
composable_kernel/include/tensor_description/
dynamic_
tensor_descriptor.hpp
→
composable_kernel/include/tensor_description/tensor_descriptor.hpp
View file @
ccc4a1d3
#ifndef CK_
DYNAMIC_
TENSOR_DESCRIPTOR_HPP
#ifndef CK_TENSOR_DESCRIPTOR_HPP
#define CK_
DYNAMIC_
TENSOR_DESCRIPTOR_HPP
#define CK_TENSOR_DESCRIPTOR_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
multi_index_transform.hpp"
#include "multi_index_transform.hpp"
namespace
ck
{
namespace
ck
{
template
<
index_t
NDimHidden
,
typename
VisibleDimensionIds
>
template
<
index_t
NDimHidden
,
typename
VisibleDimensionIds
>
struct
Dynamic
TensorCoordinate
;
struct
TensorCoordinate
;
template
<
index_t
NTransform
,
index_t
NDimVisible
,
typename
UpdateLowerIndexHack
>
template
<
index_t
NTransform
,
index_t
NDimVisible
,
typename
UpdateLowerIndexHack
>
struct
Dynamic
TensorCoordinate
I
te
rator
;
struct
TensorCoordinate
S
te
p
;
// Transforms: Tuple<transforms...>
// Transforms: Tuple<transforms...>
// LowerDimensionIdss : Tuple<Sequence<...>, ...>
// LowerDimensionIdss : Tuple<Sequence<...>, ...>
...
@@ -21,7 +21,7 @@ template <typename Transforms,
...
@@ -21,7 +21,7 @@ template <typename Transforms,
typename
UpperDimensionIdss
,
typename
UpperDimensionIdss
,
typename
VisibleDimensionIds
,
typename
VisibleDimensionIds
,
typename
ElementSpaceSize
>
typename
ElementSpaceSize
>
struct
Dynamic
TensorDescriptor
struct
TensorDescriptor
{
{
// TODO make these private
// TODO make these private
__host__
__device__
static
constexpr
index_t
GetNumOfTransform
()
{
return
Transforms
::
Size
();
}
__host__
__device__
static
constexpr
index_t
GetNumOfTransform
()
{
return
Transforms
::
Size
();
}
...
@@ -105,16 +105,16 @@ struct DynamicTensorDescriptor
...
@@ -105,16 +105,16 @@ struct DynamicTensorDescriptor
using
VisibleIndex
=
MultiIndex
<
ndim_visible_
>
;
using
VisibleIndex
=
MultiIndex
<
ndim_visible_
>
;
using
HiddenIndex
=
MultiIndex
<
ndim_hidden_
>
;
using
HiddenIndex
=
MultiIndex
<
ndim_hidden_
>
;
using
Coordinate
=
Dynamic
TensorCoordinate
<
ndim_hidden_
,
VisibleDimensionIds
>
;
using
Coordinate
=
TensorCoordinate
<
ndim_hidden_
,
VisibleDimensionIds
>
;
// may be index_t or Number<>
// may be index_t or Number<>
using
ElementSize
=
remove_cv_t
<
decltype
(
InitializeElementSize
(
Transforms
{}))
>
;
using
ElementSize
=
remove_cv_t
<
decltype
(
InitializeElementSize
(
Transforms
{}))
>
;
public:
public:
__host__
__device__
constexpr
Dynamic
TensorDescriptor
()
=
default
;
__host__
__device__
constexpr
TensorDescriptor
()
=
default
;
__host__
__device__
constexpr
Dynamic
TensorDescriptor
(
const
Transforms
&
transforms
,
__host__
__device__
constexpr
TensorDescriptor
(
const
Transforms
&
transforms
,
ElementSpaceSize
element_space_size
)
ElementSpaceSize
element_space_size
)
:
transforms_
{
transforms
},
:
transforms_
{
transforms
},
element_size_
{
InitializeElementSize
(
transforms
)},
element_size_
{
InitializeElementSize
(
transforms
)},
element_space_size_
{
element_space_size
}
element_space_size_
{
element_space_size
}
...
@@ -159,7 +159,7 @@ struct DynamicTensorDescriptor
...
@@ -159,7 +159,7 @@ struct DynamicTensorDescriptor
{
{
static_assert
(
Idx
::
Size
()
==
GetNumOfDimension
(),
"wrong! inconsistent # of dimension"
);
static_assert
(
Idx
::
Size
()
==
GetNumOfDimension
(),
"wrong! inconsistent # of dimension"
);
return
make_
dynamic_
tensor_coordinate
(
*
this
,
idx
).
GetOffset
();
return
make_tensor_coordinate
(
*
this
,
idx
).
GetOffset
();
}
}
// TODO make these private
// TODO make these private
...
@@ -196,7 +196,7 @@ struct DynamicTensorDescriptor
...
@@ -196,7 +196,7 @@ struct DynamicTensorDescriptor
__host__
__device__
void
Print
()
const
__host__
__device__
void
Print
()
const
{
{
printf
(
"{"
);
printf
(
"{"
);
printf
(
"
Dynamic
TensorDescriptor, "
);
printf
(
"TensorDescriptor, "
);
static_for
<
0
,
ntransform_
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
ntransform_
,
1
>
{}([
&
](
auto
i
)
{
printf
(
"transforms: "
);
printf
(
"transforms: "
);
transforms_
[
i
].
Print
();
transforms_
[
i
].
Print
();
...
@@ -217,7 +217,7 @@ struct DynamicTensorDescriptor
...
@@ -217,7 +217,7 @@ struct DynamicTensorDescriptor
};
};
template
<
index_t
NDimHidden
,
typename
VisibleDimensionIds
>
template
<
index_t
NDimHidden
,
typename
VisibleDimensionIds
>
struct
Dynamic
TensorCoordinate
struct
TensorCoordinate
{
{
// TODO make these private
// TODO make these private
static
constexpr
index_t
ndim_visible_
=
VisibleDimensionIds
::
Size
();
static
constexpr
index_t
ndim_visible_
=
VisibleDimensionIds
::
Size
();
...
@@ -226,9 +226,9 @@ struct DynamicTensorCoordinate
...
@@ -226,9 +226,9 @@ struct DynamicTensorCoordinate
using
VisibleIndex
=
MultiIndex
<
ndim_visible_
>
;
using
VisibleIndex
=
MultiIndex
<
ndim_visible_
>
;
public:
public:
__host__
__device__
constexpr
Dynamic
TensorCoordinate
()
=
default
;
__host__
__device__
constexpr
TensorCoordinate
()
=
default
;
__host__
__device__
constexpr
Dynamic
TensorCoordinate
(
const
HiddenIndex
&
idx_hidden
)
__host__
__device__
constexpr
TensorCoordinate
(
const
HiddenIndex
&
idx_hidden
)
:
idx_hidden_
{
idx_hidden
}
:
idx_hidden_
{
idx_hidden
}
{
{
}
}
...
@@ -252,16 +252,16 @@ struct DynamicTensorCoordinate
...
@@ -252,16 +252,16 @@ struct DynamicTensorCoordinate
};
};
template
<
index_t
NTransform
,
index_t
NDimVisible
,
typename
UpdateLowerIndexHack
>
template
<
index_t
NTransform
,
index_t
NDimVisible
,
typename
UpdateLowerIndexHack
>
struct
Dynamic
TensorCoordinate
I
te
rator
struct
TensorCoordinate
S
te
p
{
{
// TODO make these private
// TODO make these private
using
VisibleIndex
=
MultiIndex
<
NDimVisible
>
;
using
VisibleIndex
=
MultiIndex
<
NDimVisible
>
;
public:
public:
__host__
__device__
constexpr
Dynamic
TensorCoordinate
I
te
rator
()
=
default
;
__host__
__device__
constexpr
TensorCoordinate
S
te
p
()
=
default
;
__host__
__device__
constexpr
Dynamic
TensorCoordinate
I
te
rator
(
__host__
__device__
constexpr
TensorCoordinate
S
te
p
(
const
VisibleIndex
&
idx_diff_visible
,
const
VisibleIndex
&
idx_diff_visible
,
const
MultiIndex
<
NTransform
>&
do_transforms
)
const
MultiIndex
<
NTransform
>&
do_transforms
)
:
idx_diff_visible_
{
idx_diff_visible
},
do_transforms_
{
do_transforms
}
:
idx_diff_visible_
{
idx_diff_visible
},
do_transforms_
{
do_transforms
}
{
{
}
}
...
@@ -283,7 +283,7 @@ struct DynamicTensorCoordinateIterator
...
@@ -283,7 +283,7 @@ struct DynamicTensorCoordinateIterator
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor, and to put it outside the scope where it is used
// doesn't have constructor, and to put it outside the scope where it is used
// (transform_
dynamic_
tensor_descriptor) because template cannot be defined inside a function
// (transform_tensor_descriptor) because template cannot be defined inside a function
// template
// template
template
<
typename
NewTransforms
>
template
<
typename
NewTransforms
>
struct
lambda_get_up_dim_num
struct
lambda_get_up_dim_num
...
@@ -301,10 +301,10 @@ template <typename OldTensorDescriptor,
...
@@ -301,10 +301,10 @@ template <typename OldTensorDescriptor,
typename
NewLowerDimensionOldVisibleIdss
,
typename
NewLowerDimensionOldVisibleIdss
,
typename
NewUpperDimensionNewVisibleIdss
>
typename
NewUpperDimensionNewVisibleIdss
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
transform_
dynamic_
tensor_descriptor
(
const
OldTensorDescriptor
&
old_tensor_desc
,
transform_tensor_descriptor
(
const
OldTensorDescriptor
&
old_tensor_desc
,
const
NewTransforms
&
new_transforms
,
const
NewTransforms
&
new_transforms
,
NewLowerDimensionOldVisibleIdss
,
NewLowerDimensionOldVisibleIdss
,
NewUpperDimensionNewVisibleIdss
)
NewUpperDimensionNewVisibleIdss
)
{
{
// sanity check
// sanity check
{
{
...
@@ -376,17 +376,17 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
...
@@ -376,17 +376,17 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
const
auto
element_space_size
=
old_tensor_desc
.
GetElementSpaceSize
();
const
auto
element_space_size
=
old_tensor_desc
.
GetElementSpaceSize
();
return
Dynamic
TensorDescriptor
<
remove_cv_t
<
decltype
(
all_transforms
)
>
,
return
TensorDescriptor
<
remove_cv_t
<
decltype
(
all_transforms
)
>
,
remove_cv_t
<
decltype
(
all_low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
all_low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
all_up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
all_up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
new_visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
new_visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
all_transforms
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
all_transforms
,
element_space_size
};
element_space_size
};
}
}
template
<
typename
TensorDesc
,
typename
VisibleIndex
>
template
<
typename
TensorDesc
,
typename
VisibleIndex
>
__host__
__device__
constexpr
auto
make_
dynamic_
tensor_coordinate
(
const
TensorDesc
&
tensor_desc
,
__host__
__device__
constexpr
auto
make_tensor_coordinate
(
const
TensorDesc
&
tensor_desc
,
const
VisibleIndex
&
idx_visible
)
const
VisibleIndex
&
idx_visible
)
{
{
static_assert
(
TensorDesc
::
GetNumOfDimension
()
==
VisibleIndex
::
Size
(),
static_assert
(
TensorDesc
::
GetNumOfDimension
()
==
VisibleIndex
::
Size
(),
"wrong! # of dimension inconsistent"
);
"wrong! # of dimension inconsistent"
);
...
@@ -416,14 +416,15 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe
...
@@ -416,14 +416,15 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe
set_container_subset
(
idx_hidden
,
dims_low
,
idx_low
);
set_container_subset
(
idx_hidden
,
dims_low
,
idx_low
);
});
});
return
Dynamic
TensorCoordinate
<
ndim_hidden
,
decltype
(
visible_dim_ids
)
>
{
idx_hidden
};
return
TensorCoordinate
<
ndim_hidden
,
decltype
(
visible_dim_ids
)
>
{
idx_hidden
};
}
}
// UpdateLowerIndexHack: Sequence<...>
// UpdateLowerIndexHack: Sequence<...>
// HACK: control UpdateLowerIndex
// HACK: control UpdateLowerIndex
template
<
typename
TensorDesc
,
typename
VisibleIndex
,
typename
UpdateLowerIndexHack
>
template
<
typename
TensorDesc
,
typename
VisibleIndex
,
typename
UpdateLowerIndexHack
>
__host__
__device__
constexpr
auto
make_dynamic_tensor_coordinate_iterator
(
__host__
__device__
constexpr
auto
make_tensor_coordinate_step
(
const
TensorDesc
&
,
const
TensorDesc
&
,
const
VisibleIndex
&
idx_diff_visible
,
UpdateLowerIndexHack
)
const
VisibleIndex
&
idx_diff_visible
,
UpdateLowerIndexHack
)
{
{
static_assert
(
TensorDesc
::
GetNumOfDimension
()
==
VisibleIndex
::
Size
(),
static_assert
(
TensorDesc
::
GetNumOfDimension
()
==
VisibleIndex
::
Size
(),
"wrong! # of dimension inconsistent"
);
"wrong! # of dimension inconsistent"
);
...
@@ -470,23 +471,24 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate_iterator(
...
@@ -470,23 +471,24 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate_iterator(
set_container_subset
(
is_non_zero_diff
,
dims_low
,
non_zero_diff_pick_low
);
set_container_subset
(
is_non_zero_diff
,
dims_low
,
non_zero_diff_pick_low
);
});
});
return
Dynamic
TensorCoordinate
I
te
rator
<
ntransform
,
ndim_visible
,
UpdateLowerIndexHack
>
{
return
TensorCoordinate
S
te
p
<
ntransform
,
ndim_visible
,
UpdateLowerIndexHack
>
{
idx_diff_visible
,
idx_diff_visible
,
do_transforms
};
do_transforms
};
}
}
template
<
typename
TensorDesc
,
typename
VisibleIndex
>
template
<
typename
TensorDesc
,
typename
VisibleIndex
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
make_tensor_coordinate_step
(
const
TensorDesc
&
,
make_dynamic_tensor_coordinate_iterator
(
const
TensorDesc
&
,
const
VisibleIndex
&
idx_diff_visible
)
const
VisibleIndex
&
idx_diff_visible
)
{
{
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
return
make_
dynamic_
tensor_coordinate_
i
te
rator
(
return
make_tensor_coordinate_
s
te
p
(
TensorDesc
{},
idx_diff_visible
,
typename
uniform_sequence_gen
<
ntransform
,
0
>::
type
{});
TensorDesc
{},
idx_diff_visible
,
typename
uniform_sequence_gen
<
ntransform
,
0
>::
type
{});
}
}
template
<
typename
TensorDesc
,
typename
TensorCoord
,
typename
TensorCoordIterator
>
template
<
typename
TensorDesc
,
typename
TensorCoord
,
typename
TensorCoordStep
>
__host__
__device__
constexpr
void
move_dynamic_tensor_coordinate
(
__host__
__device__
constexpr
void
move_tensor_coordinate
(
const
TensorDesc
&
tensor_desc
,
const
TensorDesc
&
tensor_desc
,
TensorCoord
&
coord
,
const
TensorCoordIterator
&
coord_iterator
)
TensorCoord
&
coord
,
const
TensorCoordStep
&
coord_step
)
{
{
constexpr
index_t
ndim_hidden
=
TensorDesc
::
GetNumOfHiddenDimension
();
constexpr
index_t
ndim_hidden
=
TensorDesc
::
GetNumOfHiddenDimension
();
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
...
@@ -495,9 +497,8 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
...
@@ -495,9 +497,8 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
auto
idx_diff_hidden
=
make_zero_multi_index
<
ndim_hidden
>
();
auto
idx_diff_hidden
=
make_zero_multi_index
<
ndim_hidden
>
();
// initialize visible index diff
// initialize visible index diff
set_container_subset
(
idx_diff_hidden
,
set_container_subset
(
TensorDesc
::
GetVisibleDimensionIds
(),
idx_diff_hidden
,
TensorDesc
::
GetVisibleDimensionIds
(),
coord_step
.
GetVisibleIndexDiff
());
coord_iterator
.
GetVisibleIndexDiff
());
// this is what needs to be updated
// this is what needs to be updated
auto
&
idx_hidden
=
coord
.
GetHiddenIndex
();
auto
&
idx_hidden
=
coord
.
GetHiddenIndex
();
...
@@ -506,13 +507,13 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
...
@@ -506,13 +507,13 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
auto
idx_hidden_pick_visible
=
auto
idx_hidden_pick_visible
=
get_container_subset
(
idx_hidden
,
TensorDesc
::
GetVisibleDimensionIds
());
get_container_subset
(
idx_hidden
,
TensorDesc
::
GetVisibleDimensionIds
());
idx_hidden_pick_visible
+=
coord_
i
te
rator
.
GetIndexDiff
();
idx_hidden_pick_visible
+=
coord_
s
te
p
.
GetIndexDiff
();
set_container_subset
(
idx_hidden
,
TensorDesc
::
GetVisibleDimensionIds
(),
idx_hidden_pick_visible
);
set_container_subset
(
idx_hidden
,
TensorDesc
::
GetVisibleDimensionIds
(),
idx_hidden_pick_visible
);
// update rest of hidden index
// update rest of hidden index
static_for
<
ntransform
-
1
,
-
1
,
-
1
>
{}([
&
](
auto
itran
)
{
static_for
<
ntransform
-
1
,
-
1
,
-
1
>
{}([
&
](
auto
itran
)
{
if
(
coord_
i
te
rator
.
do_transforms_
[
itran
])
if
(
coord_
s
te
p
.
do_transforms_
[
itran
])
{
{
const
auto
&
tran
=
tensor_desc
.
GetTransforms
().
At
(
itran
);
const
auto
&
tran
=
tensor_desc
.
GetTransforms
().
At
(
itran
);
constexpr
auto
dims_low
=
TensorDesc
::
GetLowerDimensionIdss
().
At
(
itran
);
constexpr
auto
dims_low
=
TensorDesc
::
GetLowerDimensionIdss
().
At
(
itran
);
...
@@ -524,8 +525,8 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
...
@@ -524,8 +525,8 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
MultiIndex
<
dims_low
.
Size
()
>
idx_diff_low
;
MultiIndex
<
dims_low
.
Size
()
>
idx_diff_low
;
// HACK: control UpdateLowerIndex for
Dynamic
Merge using hack
// HACK: control UpdateLowerIndex for Merge using hack
constexpr
index_t
Hack
=
decltype
(
coord_
i
te
rator
.
update_lower_index_hack_
)
::
At
(
itran
);
constexpr
index_t
Hack
=
decltype
(
coord_
s
te
p
.
update_lower_index_hack_
)
::
At
(
itran
);
tran
.
UpdateLowerIndex
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up_new
,
Number
<
Hack
>
{});
tran
.
UpdateLowerIndex
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up_new
,
Number
<
Hack
>
{});
...
@@ -585,11 +586,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc&
...
@@ -585,11 +586,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc&
}
}
template
<
typename
TensorDesc
>
template
<
typename
TensorDesc
>
using
Dynamic
TensorCoordinate_t
=
decltype
(
make_
dynamic_
tensor_coordinate
(
using
TensorCoordinate_t
=
decltype
(
make_tensor_coordinate
(
TensorDesc
{},
MultiIndex
<
remove_cv_t
<
remove_reference_t
<
TensorDesc
>>::
GetNumOfDimension
()
>
{}));
TensorDesc
{},
MultiIndex
<
remove_cv_t
<
remove_reference_t
<
TensorDesc
>>::
GetNumOfDimension
()
>
{}));
template
<
typename
TensorDesc
>
template
<
typename
TensorDesc
>
using
Dynamic
TensorCoordinate
I
te
rator
_t
=
decltype
(
make_
dynamic_
tensor_coordinate_
i
te
rator
(
using
TensorCoordinate
S
te
p
_t
=
decltype
(
make_tensor_coordinate_
s
te
p
(
TensorDesc
{},
MultiIndex
<
remove_cv_t
<
remove_reference_t
<
TensorDesc
>>::
GetNumOfDimension
()
>
{}));
TensorDesc
{},
MultiIndex
<
remove_cv_t
<
remove_reference_t
<
TensorDesc
>>::
GetNumOfDimension
()
>
{}));
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/tensor_description/
dynamic_
tensor_descriptor_helper.hpp
→
composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp
View file @
ccc4a1d3
#ifndef CK_
DYNAMIC_
TENSOR_DESCRIPTOR_HELPER_HPP
#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP
#define CK_
DYNAMIC_
TENSOR_DESCRIPTOR_HELPER_HPP
#define CK_TENSOR_DESCRIPTOR_HELPER_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
multi_index_transform_helper.hpp"
#include "multi_index_transform_helper.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -37,10 +37,9 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt
...
@@ -37,10 +37,9 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt
template
<
typename
...
Lengths
,
template
<
typename
...
Lengths
,
typename
...
Strides
,
typename
...
Strides
,
typename
std
::
enable_if
<
sizeof
...(
Lengths
)
==
sizeof
...(
Strides
),
bool
>
::
type
=
false
>
typename
enable_if
<
sizeof
...(
Lengths
)
==
sizeof
...(
Strides
),
bool
>
::
type
=
false
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
make_naive_tensor_descriptor_v2
(
const
Tuple
<
Lengths
...
>&
lengths
,
make_dynamic_naive_tensor_descriptor_v2
(
const
Tuple
<
Lengths
...
>&
lengths
,
const
Tuple
<
Strides
...
>&
strides
)
const
Tuple
<
Strides
...
>&
strides
)
{
{
constexpr
index_t
N
=
sizeof
...(
Lengths
);
constexpr
index_t
N
=
sizeof
...(
Lengths
);
...
@@ -75,12 +74,12 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
...
@@ -75,12 +74,12 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
calculate_element_space_size_impl
(
lengths
,
strides
,
Number
<
0
>
{},
Number
<
1
>
{});
calculate_element_space_size_impl
(
lengths
,
strides
,
Number
<
0
>
{},
Number
<
1
>
{});
#endif
#endif
return
Dynamic
TensorDescriptor
<
remove_cv_t
<
decltype
(
transforms
)
>
,
return
TensorDescriptor
<
remove_cv_t
<
decltype
(
transforms
)
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
transforms
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
transforms
,
element_space_size
};
element_space_size
};
}
}
// Lengths... can be:
// Lengths... can be:
...
@@ -88,7 +87,7 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
...
@@ -88,7 +87,7 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
// 2) Number<>, which is known at compile-time
// 2) Number<>, which is known at compile-time
template
<
typename
...
Lengths
>
template
<
typename
...
Lengths
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
const
Tuple
<
Lengths
...
>&
lengths
)
make_naive_tensor_descriptor_packed
(
const
Tuple
<
Lengths
...
>&
lengths
)
{
{
constexpr
index_t
N
=
sizeof
...(
Lengths
);
constexpr
index_t
N
=
sizeof
...(
Lengths
);
...
@@ -103,17 +102,17 @@ make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple<Lengths...>& lengths)
...
@@ -103,17 +102,17 @@ make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple<Lengths...>& lengths)
const
auto
element_space_size
=
container_reduce
(
lengths
,
math
::
multiplies_v2
{},
Number
<
1
>
{});
const
auto
element_space_size
=
container_reduce
(
lengths
,
math
::
multiplies_v2
{},
Number
<
1
>
{});
return
Dynamic
TensorDescriptor
<
remove_cv_t
<
decltype
(
transforms
)
>
,
return
TensorDescriptor
<
remove_cv_t
<
decltype
(
transforms
)
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
transforms
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
transforms
,
element_space_size
};
element_space_size
};
}
}
template
<
typename
...
Lengths
,
typename
Align
>
template
<
typename
...
Lengths
,
typename
Align
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
const
Tuple
<
Lengths
...
>&
lengths
,
Align
align
)
make_naive_tensor_descriptor_aligned_v2
(
const
Tuple
<
Lengths
...
>&
lengths
,
Align
align
)
{
{
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -143,7 +142,7 @@ make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths
...
@@ -143,7 +142,7 @@ make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths
},
},
Number
<
N
>
{});
Number
<
N
>
{});
return
make_
dynamic_
naive_tensor_descriptor_v2
(
lengths
,
strides
);
return
make_naive_tensor_descriptor_v2
(
lengths
,
strides
);
}
}
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp
View file @
ccc4a1d3
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include "common_header.hpp"
#include "common_header.hpp"
#include "tensor_adaptor.hpp"
#include "tensor_adaptor.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_contraction_dlops.hpp"
#include "threadwise_contraction_dlops.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -22,24 +22,24 @@ namespace ck {
...
@@ -22,24 +22,24 @@ namespace ck {
// 2. CThreadBuffer is StaticBuffer
// 2. CThreadBuffer is StaticBuffer
// Also assume:
// Also assume:
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template
<
index_t
BlockSize
,
template
<
typename
FloatA
,
index_t
BlockSize
,
typename
Float
B
,
typename
Float
A
,
typename
Float
C
,
typename
Float
B
,
typename
AKMBlockDesc
,
typename
FloatC
,
typename
BKN
BlockDesc
,
typename
AKM
BlockDesc
,
index_t
M1PerThreadM11
,
typename
BKNBlockDesc
,
index_t
N
1PerThread
N
11
,
index_t
M
1PerThread
M
11
,
index_t
K
PerThread
,
index_t
N1
PerThread
N11
,
index_t
M1N1
Thread
ClusterM100
,
index_t
KPer
Thread
,
index_t
M1N1ThreadCluster
N
100
,
index_t
M1N1ThreadCluster
M
100
,
index_t
M1N1ThreadCluster
M
10
1
,
index_t
M1N1ThreadCluster
N
10
0
,
index_t
M1N1ThreadCluster
N
101
,
index_t
M1N1ThreadCluster
M
101
,
index_t
A
ThreadC
opyScalarPerVector_M1
1
,
index_t
M1N1
ThreadC
lusterN10
1
,
index_t
B
ThreadCopyScalarPerVector_
N
11
,
index_t
A
ThreadCopyScalarPerVector_
M
11
,
typename
std
::
enable_if
<
AKMBlockDesc
::
IsKnownAtCompileTime
()
&&
index_t
BThreadCopyScalarPerVector_N11
,
BKNBlockDesc
::
IsKnownAtCompileTime
(),
typename
enable_if
<
AKMBlockDesc
::
IsKnownAtCompileTime
()
&&
BKNBlockDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
bool
>
::
type
=
false
>
struct
BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
struct
BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
{
{
using
AIndex
=
MultiIndex
<
3
>
;
using
AIndex
=
MultiIndex
<
3
>
;
...
@@ -71,9 +71,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
...
@@ -71,9 +71,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
static
constexpr
index_t
N0
=
N
/
N1
;
static
constexpr
index_t
N0
=
N
/
N1
;
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeAKM0M1BlockDescriptor
(
const
AKMBlockDesc
&
a_k_m_block_desc
)
MakeAKM0M1BlockDescriptor
(
const
AKMBlockDesc
&
/*
a_k_m_block_desc
*/
)
{
{
const
auto
a_k_m0_m1_block_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
a_k_m0_m1_block_desc
=
transform_tensor_descriptor
(
AKMBlockDesc
{},
AKMBlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
K
>
{}),
make_tuple
(
make_pass_through_transform
(
Number
<
K
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
M0
>
{},
Number
<
M1
>
{}))),
make_unmerge_transform
(
make_tuple
(
Number
<
M0
>
{},
Number
<
M1
>
{}))),
...
@@ -84,9 +84,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
...
@@ -84,9 +84,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeBKN0N1BlockDescriptor
(
const
BKNBlockDesc
&
b_k_n_block_desc
)
MakeBKN0N1BlockDescriptor
(
const
BKNBlockDesc
&
/*
b_k_n_block_desc
*/
)
{
{
const
auto
b_k_n0_n1_block_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
b_k_n0_n1_block_desc
=
transform_tensor_descriptor
(
BKNBlockDesc
{},
BKNBlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
K
>
{}),
make_tuple
(
make_pass_through_transform
(
Number
<
K
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
N0
>
{},
Number
<
N1
>
{}))),
make_unmerge_transform
(
make_tuple
(
Number
<
N0
>
{},
Number
<
N1
>
{}))),
...
@@ -194,7 +194,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
...
@@ -194,7 +194,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
typename
ABlockBuffer
,
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
typename
CThreadBuffer
>
__device__
void
Run
(
const
CM0M1N0N1ThreadDesc
&
c_m0_m1_n0_n1_thread_desc
,
__device__
void
Run
(
const
CM0M1N0N1ThreadDesc
&
/*
c_m0_m1_n0_n1_thread_desc
*/
,
const
ABlockBuffer
&
a_block_buf
,
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
CThreadBuffer
&
c_thread_buf
)
const
...
@@ -357,34 +357,32 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
...
@@ -357,34 +357,32 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
private:
private:
// A[K, M0, M1]
// A[K, M0, M1]
static
constexpr
auto
a_k_m0_m1_thread_desc_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
static
constexpr
auto
a_k_m0_m1_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
M0
>
{},
Number
<
M1PerThreadM11
>
{}));
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
M0
>
{},
Number
<
M1PerThreadM11
>
{}));
// B[K, N0, N1]
// B[K, N0, N1]
static
constexpr
auto
b_k_n0_n1_thread_desc_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
static
constexpr
auto
b_k_n0_n1_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
N0
>
{},
Number
<
N1PerThreadN11
>
{}));
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
N0
>
{},
Number
<
N1PerThreadN11
>
{}));
using
AThreadCopy
=
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
FloatA
,
decltype
(
a_k_m0_m1_block_desc_
),
decltype
(
a_k_m0_m1_block_desc_
),
decltype
(
a_k_m0_m1_thread_desc_
),
decltype
(
a_k_m0_m1_thread_desc_
),
Sequence
<
KPerThread
,
1
,
M1PerThreadM11
>
,
Sequence
<
KPerThread
,
1
,
M1PerThreadM11
>
,
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
2
,
AThreadCopyScalarPerVector_M11
,
AThreadCopyScalarPerVector_M11
,
1
>
;
1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
using
BThreadCopy
=
FloatB
,
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatB
,
decltype
(
b_k_n0_n1_block_desc_
),
FloatB
,
decltype
(
b_k_n0_n1_thread_desc_
),
decltype
(
b_k_n0_n1_block_desc_
),
Sequence
<
KPerThread
,
1
,
N1PerThreadN11
>
,
decltype
(
b_k_n0_n1_thread_desc_
),
Sequence
<
0
,
1
,
2
>
,
Sequence
<
KPerThread
,
1
,
N1PerThreadN11
>
,
2
,
Sequence
<
0
,
1
,
2
>
,
BThreadCopyScalarPerVector_N11
,
2
,
1
>
;
BThreadCopyScalarPerVector_N11
,
1
>
;
CIndex
c_thread_origin_data_idx_
;
CIndex
c_thread_origin_data_idx_
;
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp
View file @
ccc4a1d3
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include "common_header.hpp"
#include "common_header.hpp"
#include "tensor_adaptor.hpp"
#include "tensor_adaptor.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer_v2.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_contraction_dlops.hpp"
#include "threadwise_contraction_dlops.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -38,9 +38,9 @@ template <index_t BlockSize,
...
@@ -38,9 +38,9 @@ template <index_t BlockSize,
// BM10BN10ThreadClusterBN101, ...>
// BM10BN10ThreadClusterBN101, ...>
index_t
AThreadCopyScalarPerVector_BM11
,
index_t
AThreadCopyScalarPerVector_BM11
,
index_t
BThreadCopyScalarPerVector_BN11
,
index_t
BThreadCopyScalarPerVector_BN11
,
typename
std
::
enable_if
<
ABlockDesc_BK0_BM_BK1
::
IsKnownAtCompileTime
()
&&
typename
enable_if
<
ABlockDesc_BK0_BM_BK1
::
IsKnownAtCompileTime
()
&&
BBlockDesc_BK0_BN_BK1
::
IsKnownAtCompileTime
(),
BBlockDesc_BK0_BN_BK1
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
bool
>
::
type
=
false
>
struct
BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
struct
BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
{
{
using
AIndex
=
MultiIndex
<
3
>
;
using
AIndex
=
MultiIndex
<
3
>
;
...
@@ -75,7 +75,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
...
@@ -75,7 +75,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_BK0_BM0_BM1_BK1
(
const
ABlockDesc_BK0_BM_BK1
&
a_block_desc_bk0_bm_bk1
)
MakeABlockDescriptor_BK0_BM0_BM1_BK1
(
const
ABlockDesc_BK0_BM_BK1
&
a_block_desc_bk0_bm_bk1
)
{
{
const
auto
a_block_bk0_bm0_bm1_bk1
=
transform_
dynamic_
tensor_descriptor
(
const
auto
a_block_bk0_bm0_bm1_bk1
=
transform_tensor_descriptor
(
a_block_desc_bk0_bm_bk1
,
a_block_desc_bk0_bm_bk1
,
make_tuple
(
make_pass_through_transform
(
Number
<
BK0
>
{}),
make_tuple
(
make_pass_through_transform
(
Number
<
BK0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
BM0
>
{},
Number
<
BM1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
BM0
>
{},
Number
<
BM1
>
{})),
...
@@ -89,7 +89,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
...
@@ -89,7 +89,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_BK0_BN0_BN1_BK1
(
const
BBlockDesc_BK0_BN_BK1
&
b_block_desc_bk0_bn_bk1
)
MakeBBlockDescriptor_BK0_BN0_BN1_BK1
(
const
BBlockDesc_BK0_BN_BK1
&
b_block_desc_bk0_bn_bk1
)
{
{
const
auto
b_block_desc_bk0_bn0_bn1_bk1
=
transform_
dynamic_
tensor_descriptor
(
const
auto
b_block_desc_bk0_bn0_bn1_bk1
=
transform_tensor_descriptor
(
b_block_desc_bk0_bn_bk1
,
b_block_desc_bk0_bn_bk1
,
make_tuple
(
make_pass_through_transform
(
Number
<
BK0
>
{}),
make_tuple
(
make_pass_through_transform
(
Number
<
BK0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
BN0
>
{},
Number
<
BN1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
BN0
>
{},
Number
<
BN1
>
{})),
...
@@ -372,15 +372,15 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
...
@@ -372,15 +372,15 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
private:
private:
// A[BK0, BM0, BM1, BK1]
// A[BK0, BM0, BM1, BK1]
static
constexpr
auto
a_thread_desc_bk0_bm0_bm1_bk1_
=
static
constexpr
auto
a_thread_desc_bk0_bm0_bm1_bk1_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BK0PerThread
>
{},
Number
<
BM0
>
{},
Number
<
BM1PerThreadBM11
>
{},
Number
<
BK1
>
{}));
Number
<
BK0PerThread
>
{},
Number
<
BM0
>
{},
Number
<
BM1PerThreadBM11
>
{},
Number
<
BK1
>
{}));
// B[BK0, BN0, BN1, BK1]
// B[BK0, BN0, BN1, BK1]
static
constexpr
auto
b_thread_desc_bk0_bn0_bn1_bk1_
=
static
constexpr
auto
b_thread_desc_bk0_bn0_bn1_bk1_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BK0PerThread
>
{},
Number
<
BN0
>
{},
Number
<
BN1PerThreadBN11
>
{},
Number
<
BK1
>
{}));
Number
<
BK0PerThread
>
{},
Number
<
BN0
>
{},
Number
<
BN1PerThreadBN11
>
{},
Number
<
BK1
>
{}));
using
AThreadCopy
=
Threadwise
Dynamic
TensorSliceTransfer_v4r1
<
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4r1
<
FloatA
,
FloatA
,
FloatA
,
FloatA
,
decltype
(
a_block_desc_bk0_bm0_bm1_bk1_
),
decltype
(
a_block_desc_bk0_bm0_bm1_bk1_
),
...
@@ -390,7 +390,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
...
@@ -390,7 +390,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
Sequence
<
1
,
1
,
BM1PerThreadBM11
,
BK1
>
,
// SrcVectorTensorLengths
Sequence
<
1
,
1
,
BM1PerThreadBM11
,
BK1
>
,
// SrcVectorTensorLengths
Sequence
<
0
,
1
,
2
,
3
>>
;
// SrcVectorTensorContiguousDimOrder
Sequence
<
0
,
1
,
2
,
3
>>
;
// SrcVectorTensorContiguousDimOrder
using
BThreadCopy
=
Threadwise
Dynamic
TensorSliceTransfer_v4r1
<
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4r1
<
FloatB
,
FloatB
,
FloatB
,
FloatB
,
decltype
(
b_block_desc_bk0_bn0_bn1_bk1_
),
decltype
(
b_block_desc_bk0_bn0_bn1_bk1_
),
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
View file @
ccc4a1d3
...
@@ -31,25 +31,24 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
...
@@ -31,25 +31,24 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
// HACK: fix this @Jing Zhang
// HACK: fix this @Jing Zhang
static
constexpr
index_t
KPerThreadSubC
=
4
;
static
constexpr
index_t
KPerThreadSubC
=
4
;
static
constexpr
auto
a_thread_mtx_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
static
constexpr
auto
a_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
KPerThreadSubC
>
{}));
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
KPerThreadSubC
>
{}));
static
constexpr
auto
b_thread_mtx_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
static
constexpr
auto
b_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
Number
<
EPerThreadLoop
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
static
constexpr
auto
c_thread_mtx_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
static
constexpr
auto
c_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThreadSubC
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
Number
<
KPerThreadSubC
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
using
AThreadCopy
=
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
FloatA
,
BlockMatrixA
,
BlockMatrixA
,
decltype
(
a_thread_mtx_
),
decltype
(
a_thread_mtx_
),
Sequence
<
EPerThreadLoop
,
KPerThreadSubC
>
,
Sequence
<
EPerThreadLoop
,
KPerThreadSubC
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
ThreadGemmADataPerRead_K
,
ThreadGemmADataPerRead_K
,
1
>
;
1
>
;
__device__
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
()
__device__
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
()
:
c_thread_begin_mtx_idx_
{
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
())},
:
c_thread_begin_mtx_idx_
{
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
())},
...
@@ -69,7 +68,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
...
@@ -69,7 +68,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
"wrong! K dimension not consistent
\n
"
);
"wrong! K dimension not consistent
\n
"
);
constexpr
index_t
K
=
BlockMatrixA
{}.
GetLength
(
I1
);
// A is transposed
constexpr
index_t
K
=
BlockMatrixA
{}.
GetLength
(
I1
);
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
{}.
GetLength
(
I1
);
constexpr
index_t
H
=
BlockMatrixB
{}.
GetLength
(
I2
);
constexpr
index_t
H
=
BlockMatrixB
{}.
GetLength
(
I2
);
constexpr
index_t
W
=
BlockMatrixB
{}.
GetLength
(
I3
);
constexpr
index_t
W
=
BlockMatrixB
{}.
GetLength
(
I3
);
...
@@ -121,9 +119,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
...
@@ -121,9 +119,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
"wrong! inconsistent type"
);
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
...
@@ -138,7 +133,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
...
@@ -138,7 +133,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
static_assert
(
WPerThread
%
WoPerThreadSubC
==
0
,
""
);
static_assert
(
WPerThread
%
WoPerThreadSubC
==
0
,
""
);
// thread A buffer for GEMM
// thread A buffer for GEMM
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatA
,
a_thread_mtx_
.
GetElementSpaceSize
()
>
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatA
,
a_thread_mtx_
.
GetElementSpaceSize
()
,
true
>
a_thread_buf
;
a_thread_buf
;
constexpr
auto
threadwise_gemm
=
ThreadwiseGemmDlops_km_kn_mn_v3
<
FloatA
,
constexpr
auto
threadwise_gemm
=
ThreadwiseGemmDlops_km_kn_mn_v3
<
FloatA
,
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
ccc4a1d3
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "xdlops_gemm.hpp"
#include "xdlops_gemm.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -52,7 +52,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
...
@@ -52,7 +52,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
const
index_t
waveId
=
thread_id
/
WaveSize
;
const
index_t
waveId
=
thread_id
/
WaveSize
;
const
index_t
laneId
=
thread_id
%
WaveSize
;
const
index_t
laneId
=
thread_id
%
WaveSize
;
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
{
{
...
@@ -73,7 +72,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
...
@@ -73,7 +72,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
const
index_t
thread_id
=
get_thread_local_1d_id
();
const
index_t
thread_id
=
get_thread_local_1d_id
();
const
index_t
waveId
=
thread_id
/
WaveSize
;
const
index_t
waveId
=
thread_id
/
WaveSize
;
const
index_t
laneId
=
thread_id
%
WaveSize
;
const
index_t
laneId
=
thread_id
%
WaveSize
;
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
...
@@ -193,35 +191,35 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
...
@@ -193,35 +191,35 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
private:
private:
// A[K, M]
// A[K, M]
static
constexpr
auto
a_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
static
constexpr
auto
a_thread_desc_
=
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
I1
,
Number
<
K1
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
I1
,
Number
<
K1
>
{}));
// B[K, N]
// B[K, N]
static
constexpr
auto
b_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
static
constexpr
auto
b_thread_desc_
=
make_tuple
(
I1
,
Number
<
NRepeat
>
{},
I1
,
Number
<
K1
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
NRepeat
>
{},
I1
,
Number
<
K1
>
{}));
static
constexpr
auto
c_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
static
constexpr
auto
c_thread_desc_
=
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
using
AThreadCopy
=
Threadwise
Dynamic
TensorSliceTransfer_v4
<
FloatAB
,
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
FloatAB
,
ABlockDesc
,
ABlockDesc
,
decltype
(
a_thread_desc_
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
MRepeat
,
1
,
K1
>
,
Sequence
<
1
,
MRepeat
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
3
,
K1
,
K1
,
1
>
;
1
>
;
using
BThreadCopy
=
Threadwise
Dynamic
TensorSliceTransfer_v4
<
FloatAB
,
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
FloatAB
,
BBlockDesc
,
BBlockDesc
,
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
NRepeat
,
1
,
K1
>
,
Sequence
<
1
,
NRepeat
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
3
,
K1
,
K1
,
1
>
;
1
>
;
AThreadCopy
a_thread_copy_
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
BThreadCopy
b_thread_copy_
;
...
@@ -272,7 +270,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
...
@@ -272,7 +270,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
const
index_t
waveId
=
thread_id
/
WaveSize
;
const
index_t
waveId
=
thread_id
/
WaveSize
;
const
index_t
laneId
=
thread_id
%
WaveSize
;
const
index_t
laneId
=
thread_id
%
WaveSize
;
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
{
{
...
@@ -293,7 +290,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
...
@@ -293,7 +290,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
const
index_t
thread_id
=
get_thread_local_1d_id
();
const
index_t
thread_id
=
get_thread_local_1d_id
();
const
index_t
waveId
=
thread_id
/
WaveSize
;
const
index_t
waveId
=
thread_id
/
WaveSize
;
const
index_t
laneId
=
thread_id
%
WaveSize
;
const
index_t
laneId
=
thread_id
%
WaveSize
;
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
...
@@ -490,35 +486,35 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
...
@@ -490,35 +486,35 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
private:
private:
// A[K, M]
// A[K, M]
static
constexpr
auto
a_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
static
constexpr
auto
a_thread_desc_
=
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
I1
,
Number
<
K1
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
I1
,
Number
<
K1
>
{}));
// B[K, N]
// B[K, N]
static
constexpr
auto
b_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
static
constexpr
auto
b_thread_desc_
=
make_tuple
(
I1
,
Number
<
NRepeat
>
{},
I1
,
Number
<
K1
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
NRepeat
>
{},
I1
,
Number
<
K1
>
{}));
static
constexpr
auto
c_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
static
constexpr
auto
c_thread_desc_
=
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
using
AThreadCopy
=
Threadwise
Dynamic
TensorSliceTransfer_v4
<
FloatAB
,
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
FloatAB
,
ABlockDesc
,
ABlockDesc
,
decltype
(
a_thread_desc_
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
K1
>
,
Sequence
<
1
,
1
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
3
,
1
,
// K1,
1
,
// K1,
1
>
;
1
>
;
using
BThreadCopy
=
Threadwise
Dynamic
TensorSliceTransfer_v4
<
FloatAB
,
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
FloatAB
,
BBlockDesc
,
BBlockDesc
,
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
K1
>
,
Sequence
<
1
,
1
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
3
,
1
,
// K1,
1
,
// K1,
1
>
;
1
>
;
AThreadCopy
a_thread_copy_
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
BThreadCopy
b_thread_copy_
;
...
...
composable_kernel/include/tensor_operation/blockwise_
dynamic_
tensor_slice_transfer.hpp
→
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp
View file @
ccc4a1d3
#ifndef CK_BLOCKWISE_
DYNAMIC_
TENSOR_SLICE_TRANSFER_HPP
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
#define CK_BLOCKWISE_
DYNAMIC_
TENSOR_SLICE_TRANSFER_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
namespace
ck
{
// this version does following things to avoid scratch memory issue
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. Threadwise
Dynamic
TensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. Threadwise
Dynamic
TensorSliceTransfer_v3::Run() does not construct new tensor coordinate
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
InMemoryDataOperationEnum_t
DstInMemOp
,
InMemoryDataOperationEnum_t
DstInMemOp
,
typename
BlockSliceLengths
,
typename
BlockSliceLengths
,
...
@@ -33,16 +33,16 @@ template <index_t BlockSize,
...
@@ -33,16 +33,16 @@ template <index_t BlockSize,
index_t
DstScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
Blockwise
Dynamic
TensorSliceTransfer_v4
struct
BlockwiseTensorSliceTransfer_v4
{
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
Blockwise
Dynamic
TensorSliceTransfer_v4
(
const
SrcDesc
&
src_desc
,
__device__
constexpr
BlockwiseTensorSliceTransfer_v4
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
)
const
Index
&
dst_block_slice_origin
)
:
threadwise_transfer_
(
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
())
src_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
())
...
@@ -77,15 +77,14 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
...
@@ -77,15 +77,14 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
}
}
}
}
template
<
typename
SrcBuffer
,
typename
SrcIteratorHacks
>
template
<
typename
SrcBuffer
,
typename
SrcStepHacks
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
__device__
void
const
SrcBuffer
&
src_buf
,
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
SrcStepHacks
&
src_step_hacks
)
const
SrcIteratorHacks
&
src_iterator_hacks
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
RunRead
(
src_desc
,
src_buf
,
src_
i
te
rator
_hacks
);
threadwise_transfer_
.
RunRead
(
src_desc
,
src_buf
,
src_
s
te
p
_hacks
);
}
}
}
}
...
@@ -118,18 +117,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
...
@@ -118,18 +117,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
}
}
}
}
// SrcMoveSliceWindow
I
te
rator
Hack to control index calculation move slice window
// SrcMoveSliceWindow
S
te
p
Hack to control index calculation move slice window
template
<
typename
SrcMoveSliceWindow
I
te
rator
Hack
>
template
<
typename
SrcMoveSliceWindow
S
te
p
Hack
>
__device__
void
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
,
const
Index
&
step
,
const
SrcMoveSliceWindow
I
te
rator
Hack
&
src_move_slice_window_
i
te
rator
_hack
)
const
SrcMoveSliceWindow
S
te
p
Hack
&
src_move_slice_window_
s
te
p
_hack
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
,
src_move_slice_window_
i
te
rator
_hack
);
src_desc
,
step
,
src_move_slice_window_
s
te
p
_hack
);
}
}
}
}
...
@@ -147,22 +146,22 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
...
@@ -147,22 +146,22 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
make_cluster_descriptor_v2
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
make_cluster_descriptor_v2
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
using
ThreadwiseTransfer
=
Threadwise
Dynamic
TensorSliceTransfer_v3
<
ThreadSliceLengths
,
ThreadwiseTensorSliceTransfer_v3
<
ThreadSliceLengths
,
DstInMemOp
,
DstInMemOp
,
SrcData
,
SrcData
,
DstData
,
DstData
,
SrcDesc
,
SrcDesc
,
DstDesc
,
DstDesc
,
SrcDimAccessOrder
,
SrcDimAccessOrder
,
DstDimAccessOrder
,
DstDimAccessOrder
,
SrcVectorDim
,
SrcVectorDim
,
DstVectorDim
,
DstVectorDim
,
SrcScalarPerVector
,
SrcScalarPerVector
,
DstScalarPerVector
,
DstScalarPerVector
,
SrcScalarStrideInVector
,
SrcScalarStrideInVector
,
DstScalarStrideInVector
,
DstScalarStrideInVector
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadwiseTransfer
threadwise_transfer_
;
ThreadwiseTransfer
threadwise_transfer_
;
};
};
...
...
composable_kernel/include/tensor_operation/blockwise_
dynamic_
tensor_slice_transfer_v2.hpp
→
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp
View file @
ccc4a1d3
#ifndef CK_BLOCKWISE_
DYNAMIC_
TENSOR_SLICE_TRANSFER_V2_HPP
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP
#define CK_BLOCKWISE_
DYNAMIC_
TENSOR_SLICE_TRANSFER_V2_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer_v2.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp"
namespace
ck
{
namespace
ck
{
// this version does following things to avoid scratch memory issue
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. Threadwise
Dynamic
TensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. Threadwise
Dynamic
TensorSliceTransfer_v3::Run() does not construct new tensor coordinate
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
InMemoryDataOperationEnum_t
DstInMemOp
,
InMemoryDataOperationEnum_t
DstInMemOp
,
typename
BlockSliceLengths
,
typename
BlockSliceLengths
,
...
@@ -31,17 +31,16 @@ template <index_t BlockSize,
...
@@ -31,17 +31,16 @@ template <index_t BlockSize,
typename
DstVectorTensorContiguousDimOrder
,
typename
DstVectorTensorContiguousDimOrder
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
Blockwise
Dynamic
TensorSliceTransfer_v4r1
struct
BlockwiseTensorSliceTransfer_v4r1
{
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
BlockwiseDynamicTensorSliceTransfer_v4r1
(
__device__
constexpr
BlockwiseTensorSliceTransfer_v4r1
(
const
SrcDesc
&
src_desc
,
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
)
const
Index
&
dst_block_slice_origin
)
:
threadwise_transfer_
(
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
())
src_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
())
...
@@ -76,15 +75,14 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
...
@@ -76,15 +75,14 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
}
}
}
}
template
<
typename
SrcBuffer
,
typename
SrcIteratorHacks
>
template
<
typename
SrcBuffer
,
typename
SrcStepHacks
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
__device__
void
const
SrcBuffer
&
src_buf
,
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
SrcStepHacks
&
src_step_hacks
)
const
SrcIteratorHacks
&
src_iterator_hacks
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
RunRead
(
src_desc
,
src_buf
,
src_
i
te
rator
_hacks
);
threadwise_transfer_
.
RunRead
(
src_desc
,
src_buf
,
src_
s
te
p
_hacks
);
}
}
}
}
...
@@ -107,18 +105,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
...
@@ -107,18 +105,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
}
}
}
}
// SrcMoveSliceWindow
I
te
rator
Hack to control index calculation move slice window
// SrcMoveSliceWindow
S
te
p
Hack to control index calculation move slice window
template
<
typename
SrcMoveSliceWindow
I
te
rator
Hack
>
template
<
typename
SrcMoveSliceWindow
S
te
p
Hack
>
__device__
void
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
,
const
Index
&
step
,
const
SrcMoveSliceWindow
I
te
rator
Hack
&
src_move_slice_window_
i
te
rator
_hack
)
const
SrcMoveSliceWindow
S
te
p
Hack
&
src_move_slice_window_
s
te
p
_hack
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
,
src_move_slice_window_
i
te
rator
_hack
);
src_desc
,
step
,
src_move_slice_window_
s
te
p
_hack
);
}
}
}
}
...
@@ -136,20 +134,20 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
...
@@ -136,20 +134,20 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
make_cluster_descriptor_v2
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
make_cluster_descriptor_v2
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
using
ThreadwiseTransfer
=
Threadwise
Dynamic
TensorSliceTransfer_v3r1
<
ThreadSliceLengths
,
ThreadwiseTensorSliceTransfer_v3r1
<
ThreadSliceLengths
,
DstInMemOp
,
DstInMemOp
,
SrcData
,
SrcData
,
DstData
,
DstData
,
SrcDesc
,
SrcDesc
,
DstDesc
,
DstDesc
,
SrcDimAccessOrder
,
SrcDimAccessOrder
,
DstDimAccessOrder
,
DstDimAccessOrder
,
SrcVectorTensorLengths
,
SrcVectorTensorLengths
,
DstVectorTensorLengths
,
DstVectorTensorLengths
,
SrcVectorTensorContiguousDimOrder
,
SrcVectorTensorContiguousDimOrder
,
DstVectorTensorContiguousDimOrder
,
DstVectorTensorContiguousDimOrder
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadwiseTransfer
threadwise_transfer_
;
ThreadwiseTransfer
threadwise_transfer_
;
};
};
...
...
composable_kernel/include/tensor_operation/gridwise_
dynamic_
contraction_dlops_v1r2.hpp
→
composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp
View file @
ccc4a1d3
#ifndef CK_GRIDWISE_
DYNAMIC_
CONTRACTION_DLOPS_V1R2_HPP
#ifndef CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
#define CK_GRIDWISE_
DYNAMIC_
CONTRACTION_DLOPS_V1R2_HPP
#define CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
multi_index_transform_helper.hpp"
#include "multi_index_transform_helper.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp"
#include "blockwise_
dynamic_
tensor_slice_transfer_v2.hpp"
#include "blockwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_
dynamic_
tensor_slice_set.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -25,7 +25,7 @@ __global__ void
...
@@ -25,7 +25,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_
dynamic_
contraction_dlops_v1r2
(
kernel_contraction_dlops_v1r2
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
...
@@ -84,12 +84,12 @@ template <index_t BlockSize,
...
@@ -84,12 +84,12 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGrid
I
te
rator
Hacks
,
typename
AGrid
S
te
p
Hacks
,
typename
BGrid
I
te
rator
Hacks
,
typename
BGrid
S
te
p
Hacks
,
typename
CGrid
I
te
rator
Hacks
,
typename
CGrid
S
te
p
Hacks
,
typename
AGridMoveSliceWindow
I
te
rator
Hacks
,
typename
AGridMoveSliceWindow
S
te
p
Hacks
,
typename
BGridMoveSliceWindow
I
te
rator
Hacks
>
typename
BGridMoveSliceWindow
S
te
p
Hacks
>
struct
Gridwise
Dynamic
ContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
struct
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -110,17 +110,15 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -110,17 +110,15 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_block_desc_gk0_gm0_gm10_gm11_gk1
=
constexpr
auto
a_block_desc_gk0_gm0_gm10_gm11_gk1
=
make_naive_tensor_descriptor_aligned_v2
(
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
GK0PerBlock
>
{},
GM0
,
I1
,
Number
<
GM1PerBlockGM11
>
{},
GK1
),
make_tuple
(
Number
<
GK0PerBlock
>
{},
GM0
,
I1
,
Number
<
GM1PerBlockGM11
>
{},
GK1
),
max_lds_align
);
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_block_desc_gk0_gn0_gn10_gn11_gk1
=
constexpr
auto
b_block_desc_gk0_gn0_gn10_gn11_gk1
=
make_naive_tensor_descriptor_aligned_v2
(
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
GK0PerBlock
>
{},
GN0
,
I1
,
Number
<
GN1PerBlockGN11
>
{},
GK1
),
make_tuple
(
Number
<
GK0PerBlock
>
{},
GN0
,
I1
,
Number
<
GN1PerBlockGN11
>
{},
GK1
),
max_lds_align
);
max_lds_align
);
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
...
@@ -201,7 +199,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -201,7 +199,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
const
auto
GM11
=
Number
<
GM1PerBlockGM11
>
{};
const
auto
GM11
=
Number
<
GM1PerBlockGM11
>
{};
const
auto
GM10
=
GM1
/
GM11
;
const
auto
GM10
=
GM1
/
GM11
;
const
auto
a_grid_desc_gk0_gm0_gm10_gm11_gk1
=
transform_
dynamic_
tensor_descriptor
(
const
auto
a_grid_desc_gk0_gm0_gm10_gm11_gk1
=
transform_tensor_descriptor
(
a_grid_desc_gk0_gm0_gm1_gk1
,
a_grid_desc_gk0_gm0_gm1_gk1
,
make_tuple
(
make_pass_through_transform
(
GK0
),
make_tuple
(
make_pass_through_transform
(
GK0
),
make_pass_through_transform
(
GM0
),
make_pass_through_transform
(
GM0
),
...
@@ -222,7 +220,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -222,7 +220,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
const
auto
GN11
=
Number
<
GN1PerBlockGN11
>
{};
const
auto
GN11
=
Number
<
GN1PerBlockGN11
>
{};
const
auto
GN10
=
GN1
/
GN11
;
const
auto
GN10
=
GN1
/
GN11
;
const
auto
b_grid_desc_gk0_gn0_gn10_gn11_gk1
=
transform_
dynamic_
tensor_descriptor
(
const
auto
b_grid_desc_gk0_gn0_gn10_gn11_gk1
=
transform_tensor_descriptor
(
b_grid_desc_gk0_gn0_gn1_gk1
,
b_grid_desc_gk0_gn0_gn1_gk1
,
make_tuple
(
make_pass_through_transform
(
GK0
),
make_tuple
(
make_pass_through_transform
(
GK0
),
make_pass_through_transform
(
GN0
),
make_pass_through_transform
(
GN0
),
...
@@ -259,7 +257,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -259,7 +257,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
constexpr
auto
BM0
=
BM
/
BM1
;
constexpr
auto
BM0
=
BM
/
BM1
;
constexpr
auto
BN0
=
BN
/
BN1
;
constexpr
auto
BN0
=
BN
/
BN1
;
const
auto
c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc
=
transform_tensor_descriptor
(
c_grid_desc_gm0_gm1_gn0_gn1
,
c_grid_desc_gm0_gm1_gn0_gn1
,
make_tuple
(
make_pass_through_transform
(
GM0
),
make_tuple
(
make_pass_through_transform
(
GM0
),
make_unmerge_transform
(
make_tuple
(
GM10
,
GM11
)),
make_unmerge_transform
(
make_tuple
(
GM10
,
GM11
)),
...
@@ -268,7 +266,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -268,7 +266,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
,
5
>
{}));
const
auto
c_gm10_bm_gn10_bn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
c_gm10_bm_gn10_bn_grid_desc
=
transform_tensor_descriptor
(
c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc
,
c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GM10
),
make_tuple
(
make_pass_through_transform
(
GM10
),
make_merge_transform
(
make_tuple
(
GM0
,
GM11
)),
make_merge_transform
(
make_tuple
(
GM0
,
GM11
)),
...
@@ -277,7 +275,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -277,7 +275,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1
=
transform_
dynamic_
tensor_descriptor
(
const
auto
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1
=
transform_tensor_descriptor
(
c_gm10_bm_gn10_bn_grid_desc
,
c_gm10_bm_gn10_bn_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GM10
),
make_tuple
(
make_pass_through_transform
(
GM10
),
make_unmerge_transform
(
make_tuple
(
BM0
,
BM1
)),
make_unmerge_transform
(
make_tuple
(
BM0
,
BM1
)),
...
@@ -356,26 +354,24 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -356,26 +354,24 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_block_desc_gk0_gm0_gm10_gm11_gk1
=
constexpr
auto
a_block_desc_gk0_gm0_gm10_gm11_gk1
=
make_naive_tensor_descriptor_aligned_v2
(
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
GK0PerBlock
>
{},
GM0
,
I1
,
Number
<
GM1PerBlockGM11
>
{},
GK1
),
make_tuple
(
Number
<
GK0PerBlock
>
{},
GM0
,
I1
,
Number
<
GM1PerBlockGM11
>
{},
GK1
),
max_lds_align
);
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_block_desc_gk0_gn0_gn10_gn11_gk1
=
constexpr
auto
b_block_desc_gk0_gn0_gn10_gn11_gk1
=
make_naive_tensor_descriptor_aligned_v2
(
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
GK0PerBlock
>
{},
GN0
,
I1
,
Number
<
GN1PerBlockGN11
>
{},
GK1
),
make_tuple
(
Number
<
GK0PerBlock
>
{},
GN0
,
I1
,
Number
<
GN1PerBlockGN11
>
{},
GK1
),
max_lds_align
);
max_lds_align
);
// A matrix in LDS memory for blockwise GEMM
// A matrix in LDS memory for blockwise GEMM
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_block_desc_gk0_bm_gk1
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_block_desc_gk0_bm_gk1
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
GK0PerBlock
>
{},
GM0
*
Number
<
GM1PerBlockGM11
>
{},
GK1
),
max_lds_align
);
make_tuple
(
Number
<
GK0PerBlock
>
{},
GM0
*
Number
<
GM1PerBlockGM11
>
{},
GK1
),
max_lds_align
);
// B matrix in LDS memory for blockwise GEMM
// B matrix in LDS memory for blockwise GEMM
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_block_desc_gk0_bn_gk1
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
b_block_desc_gk0_bn_gk1
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
GK0PerBlock
>
{},
GN0
*
Number
<
GN1PerBlockGN11
>
{},
GK1
),
max_lds_align
);
make_tuple
(
Number
<
GK0PerBlock
>
{},
GN0
*
Number
<
GN1PerBlockGN11
>
{},
GK1
),
max_lds_align
);
static_assert
(
a_block_desc_gk0_gm0_gm10_gm11_gk1
.
GetElementSpaceSize
()
==
static_assert
(
a_block_desc_gk0_gm0_gm10_gm11_gk1
.
GetElementSpaceSize
()
==
...
@@ -385,7 +381,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -385,7 +381,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
"wrong!"
);
"wrong!"
);
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
Blockwise
Dynamic
TensorSliceTransfer_v4r1
<
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4r1
<
BlockSize
,
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
GK0PerBlock
,
GM0
,
1
,
GM1PerBlockGM11
,
GK1
.
value
>
,
Sequence
<
GK0PerBlock
,
GM0
,
1
,
GM1PerBlockGM11
,
GK1
.
value
>
,
...
@@ -409,7 +405,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -409,7 +405,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
make_multi_index
(
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
0
,
0
,
0
,
0
,
0
));
// B matrix blockwise copy
// B matrix blockwise copy
auto
b_blockwise_copy
=
Blockwise
Dynamic
TensorSliceTransfer_v4r1
<
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4r1
<
BlockSize
,
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
GK0PerBlock
,
GN0
,
1
,
GN1PerBlockGN11
,
GK1
.
value
>
,
Sequence
<
GK0PerBlock
,
GN0
,
1
,
GN1PerBlockGN11
,
GK1
.
value
>
,
...
@@ -457,9 +453,8 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -457,9 +453,8 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
constexpr
auto
c_thread_tensor_lengths_bm0_bm1_bn0_bn1
=
constexpr
auto
c_thread_tensor_lengths_bm0_bm1_bn0_bn1
=
decltype
(
blockwise_gemm
)
::
GetCThreadTensorLengths_BM0_BM1_BN0_BN1
();
decltype
(
blockwise_gemm
)
::
GetCThreadTensorLengths_BM0_BM1_BN0_BN1
();
constexpr
auto
c_thread_desc_bm0_bm1_bn0_bn1
=
constexpr
auto
c_thread_desc_bm0_bm1_bn0_bn1
=
make_naive_tensor_descriptor_packed
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
sequence_to_tuple_of_number
(
c_thread_tensor_lengths_bm0_bm1_bn0_bn1
));
sequence_to_tuple_of_number
(
c_thread_tensor_lengths_bm0_bm1_bn0_bn1
));
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
...
@@ -475,9 +470,9 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -475,9 +470,9 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAcc
>
(
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAcc
>
(
c_thread_desc_bm0_bm1_bn0_bn1
.
GetElementSpaceSize
());
c_thread_desc_bm0_bm1_bn0_bn1
.
GetElementSpaceSize
());
Threadwise
Dynamic
TensorSliceSet_v1
<
FloatAcc
,
ThreadwiseTensorSliceSet_v1
<
FloatAcc
,
decltype
(
c_thread_desc_bm0_bm1_bn0_bn1
),
decltype
(
c_thread_desc_bm0_bm1_bn0_bn1
),
decltype
(
c_thread_tensor_lengths_bm0_bm1_bn0_bn1
)
>
{}
decltype
(
c_thread_tensor_lengths_bm0_bm1_bn0_bn1
)
>
{}
.
Run
(
c_thread_desc_bm0_bm1_bn0_bn1
,
.
Run
(
c_thread_desc_bm0_bm1_bn0_bn1
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_thread_buf
,
...
@@ -501,9 +496,9 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -501,9 +496,9 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_global_buf
,
AGrid
I
te
rator
Hacks
{});
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_global_buf
,
AGrid
S
te
p
Hacks
{});
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_global_buf
,
BGrid
I
te
rator
Hacks
{});
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_global_buf
,
BGrid
S
te
p
Hacks
{});
a_blockwise_copy
.
RunWrite
(
a_block_desc_gk0_gm0_gm10_gm11_gk1
,
a_block_even_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc_gk0_gm0_gm10_gm11_gk1
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_gk0_gn0_gn10_gn11_gk1
,
b_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_gk0_gn0_gn10_gn11_gk1
,
b_block_even_buf
);
...
@@ -520,18 +515,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -520,18 +515,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
// even iteration
// even iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
AGridMoveSliceWindow
I
te
rator
Hacks
{});
AGridMoveSliceWindow
S
te
p
Hacks
{});
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
BGridMoveSliceWindow
I
te
rator
Hacks
{});
BGridMoveSliceWindow
S
te
p
Hacks
{});
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_global_buf
,
AGrid
I
te
rator
Hacks
{});
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_global_buf
,
AGrid
S
te
p
Hacks
{});
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_global_buf
,
BGrid
I
te
rator
Hacks
{});
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_global_buf
,
BGrid
S
te
p
Hacks
{});
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
c_thread_desc_bm0_bm1_bn0_bn1
,
blockwise_gemm
.
Run
(
c_thread_desc_bm0_bm1_bn0_bn1
,
...
@@ -546,18 +541,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -546,18 +541,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
// odd iteration
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
AGridMoveSliceWindow
I
te
rator
Hacks
{});
AGridMoveSliceWindow
S
te
p
Hacks
{});
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
BGridMoveSliceWindow
I
te
rator
Hacks
{});
BGridMoveSliceWindow
S
te
p
Hacks
{});
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_global_buf
,
AGrid
I
te
rator
Hacks
{});
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_global_buf
,
AGrid
S
te
p
Hacks
{});
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_global_buf
,
BGrid
I
te
rator
Hacks
{});
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_global_buf
,
BGrid
S
te
p
Hacks
{});
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
blockwise_gemm
.
Run
(
...
@@ -576,18 +571,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -576,18 +571,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
{
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
AGridMoveSliceWindow
I
te
rator
Hacks
{});
AGridMoveSliceWindow
S
te
p
Hacks
{});
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
BGridMoveSliceWindow
I
te
rator
Hacks
{});
BGridMoveSliceWindow
S
te
p
Hacks
{});
__syncthreads
();
__syncthreads
();
// LDS double buffer: load last data from device mem
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_global_buf
,
AGrid
I
te
rator
Hacks
{});
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_global_buf
,
AGrid
S
te
p
Hacks
{});
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_global_buf
,
BGrid
I
te
rator
Hacks
{});
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_global_buf
,
BGrid
S
te
p
Hacks
{});
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
blockwise_gemm
.
Run
(
...
@@ -615,7 +610,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -615,7 +610,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
// output: register to global memory
// output: register to global memory
{
{
constexpr
auto
c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1
=
constexpr
auto
c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
make_tuple
(
I1
,
Number
<
c_thread_tensor_lengths_bm0_bm1_bn0_bn1
[
I0
]
>
{},
Number
<
c_thread_tensor_lengths_bm0_bm1_bn0_bn1
[
I0
]
>
{},
Number
<
c_thread_tensor_lengths_bm0_bm1_bn0_bn1
[
I1
]
>
{},
Number
<
c_thread_tensor_lengths_bm0_bm1_bn0_bn1
[
I1
]
>
{},
...
@@ -627,7 +622,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -627,7 +622,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
blockwise_gemm
.
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
blockwise_gemm
.
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
get_thread_local_1d_id
());
get_thread_local_1d_id
());
Threadwise
Dynamic
TensorSliceTransfer_v1r3
<
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatAcc
,
FloatC
,
FloatC
,
decltype
(
c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1
),
decltype
(
c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1
),
...
@@ -655,7 +650,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -655,7 +650,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
c_thread_buf
,
c_thread_buf
,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1
,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1
,
c_grid_buf
,
c_grid_buf
,
CGrid
I
te
rator
Hacks
{});
CGrid
S
te
p
Hacks
{});
}
}
}
}
};
};
...
...
composable_kernel/include/tensor_operation/gridwise_
dynamic_
gemm_dlops_v1r2.hpp
→
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r2.hpp
View file @
ccc4a1d3
#ifndef CK_GRIDWISE_
DYNAMIC_
GEMM_DLOPS_V1R2_HPP
#ifndef CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
#define CK_GRIDWISE_
DYNAMIC_
GEMM_DLOPS_V1R2_HPP
#define CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
multi_index_transform_helper.hpp"
#include "multi_index_transform_helper.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r2.hpp"
#include "blockwise_gemm_dlops_v2r2.hpp"
#include "blockwise_
dynamic_
tensor_slice_transfer.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_
dynamic_
tensor_slice_set.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -26,7 +26,7 @@ __global__ void
...
@@ -26,7 +26,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_
dynamic_
gemm_dlops_v1r2
(
kernel_gemm_dlops_v1r2
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
...
@@ -68,28 +68,27 @@ __global__ void
...
@@ -68,28 +68,27 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_dynamic_gemm_dlops_v1r2
(
kernel_gemm_dlops_v1r2
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
void
CONSTANT
*
p_a_k_m0_m1_grid_desc
,
const
void
CONSTANT
*
p_a_k_m0_m1_grid_desc
,
const
void
CONSTANT
*
p_b_k_n0_n1_grid_desc
,
const
void
CONSTANT
*
p_b_k_n0_n1_grid_desc
,
const
void
CONSTANT
*
p_c_m0_m10_m11_n0_n10_n11_grid_desc
,
const
void
CONSTANT
*
p_c_m0_m10_m11_n0_n10_n11_grid_desc
,
const
void
CONSTANT
*
p_c_blockid_to_m0_n0_block_cluster_adaptor
)
const
void
CONSTANT
*
p_c_blockid_to_m0_n0_block_cluster_adaptor
)
{
{
// first cast void CONSTANT void* to void*
// first cast void CONSTANT void* to void*
// second cast void* to Desc*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
// the copy constructor of tensor descriptor doesn't take address_space(4)
const
auto
a_k_m0_m1_grid_desc
=
const
auto
a_k_m0_m1_grid_desc
=
*
reinterpret_cast
<
const
AKM0M1GridDesc
*>
(
*
reinterpret_cast
<
const
AKM0M1GridDesc
*>
((
const
void
*
)
p_a_k_m0_m1_grid_desc
);
cast_pointer_to_generic_address_space
(
p_a_k_m0_m1_grid_desc
)
)
;
const
auto
b_k_n0_n1_grid_desc
=
const
auto
b_k_n0_n1_grid_desc
=
*
reinterpret_cast
<
const
BKN0N1GridDesc
*>
(
*
reinterpret_cast
<
const
BKN0N1GridDesc
*>
((
const
void
*
)
p_b_k_n0_n1_grid_desc
);
cast_pointer_to_generic_address_space
(
p_b_k_n0_n1_grid_desc
)
)
;
const
auto
c_m0_m10_m11_n0_n10_n11_grid_desc
=
const
auto
c_m0_m10_m11_n0_n10_n11_grid_desc
=
*
reinterpret_cast
<
const
CM0M10M11N0N10N11GridDesc
*>
(
*
reinterpret_cast
<
const
CM0M10M11N0N10N11GridDesc
*>
(
(
const
void
*
)
p_c_m0_m10_m11_n0_n10_n11_grid_desc
);
cast_pointer_to_generic_address_space
(
p_c_m0_m10_m11_n0_n10_n11_grid_desc
)
)
;
const
auto
c_blockid_to_m0_n0_block_cluster_adaptor
=
const
auto
c_blockid_to_m0_n0_block_cluster_adaptor
=
*
reinterpret_cast
<
const
CBlockIdToM0N0BlockClusterAdaptor
*>
(
*
reinterpret_cast
<
const
CBlockIdToM0N0BlockClusterAdaptor
*>
(
(
const
void
*
)
p_c_blockid_to_m0_n0_block_cluster_adaptor
);
cast_pointer_to_generic_address_space
(
p_c_blockid_to_m0_n0_block_cluster_adaptor
)
)
;
constexpr
index_t
shared_block_size
=
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
...
@@ -146,12 +145,12 @@ template <index_t BlockSize,
...
@@ -146,12 +145,12 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGrid
I
te
rator
Hacks
,
typename
AGrid
S
te
p
Hacks
,
typename
BGrid
I
te
rator
Hacks
,
typename
BGrid
S
te
p
Hacks
,
typename
CGrid
I
te
rator
Hacks
,
typename
CGrid
S
te
p
Hacks
,
typename
AGridMoveSliceWindow
I
te
rator
Hacks
,
typename
AGridMoveSliceWindow
S
te
p
Hacks
,
typename
BGridMoveSliceWindow
I
te
rator
Hacks
>
typename
BGridMoveSliceWindow
S
te
p
Hacks
>
struct
Gridwise
Dynamic
GemmDlops_km_kn_mn_v1r2
struct
GridwiseGemmDlops_km_kn_mn_v1r2
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -167,12 +166,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -167,12 +166,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_k_m_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_k_m_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlockM1
>
{}),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlockM1
>
{}),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_k_n_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
b_k_n_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlockN1
>
{}),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlockN1
>
{}),
max_lds_align
);
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
...
@@ -230,7 +229,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -230,7 +229,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
const
auto
M1
=
Number
<
MPerBlockM1
>
{};
const
auto
M1
=
Number
<
MPerBlockM1
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
M0
=
M
/
M1
;
const
auto
a_k_m0_m1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
a_k_m0_m1_grid_desc
=
transform_tensor_descriptor
(
a_k_m_grid_desc
,
a_k_m_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
))),
make_tuple
(
make_pass_through_transform
(
K
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
...
@@ -248,7 +247,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -248,7 +247,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
const
auto
N1
=
Number
<
NPerBlockN1
>
{};
const
auto
N1
=
Number
<
NPerBlockN1
>
{};
const
auto
N0
=
N
/
N1
;
const
auto
N0
=
N
/
N1
;
const
auto
b_k_n0_n1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
b_k_n0_n1_grid_desc
=
transform_tensor_descriptor
(
b_k_n_grid_desc
,
b_k_n_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
))),
make_tuple
(
make_pass_through_transform
(
K
),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
...
@@ -277,7 +276,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -277,7 +276,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
constexpr
auto
M10
=
M1
/
M11
;
constexpr
auto
M10
=
M1
/
M11
;
constexpr
auto
N10
=
N1
/
N11
;
constexpr
auto
N10
=
N1
/
N11
;
const
auto
c_m0_m10_m11_n0_n10_n11_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
c_m0_m10_m11_n0_n10_n11_grid_desc
=
transform_tensor_descriptor
(
c_m_n_grid_desc
,
c_m_n_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M10
,
M11
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M10
,
M11
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N10
,
N11
))),
make_unmerge_transform
(
make_tuple
(
N0
,
N10
,
N11
))),
...
@@ -352,75 +351,75 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -352,75 +351,75 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_k_m_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_k_m_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlockM1
>
{}),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlockM1
>
{}),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_k_n_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
b_k_n_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlockN1
>
{}),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlockN1
>
{}),
max_lds_align
);
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_k_m0_m1_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_k_m0_m1_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
I1
,
Number
<
MPerBlockM1
>
{}),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
I1
,
Number
<
MPerBlockM1
>
{}),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_k_n0_n1_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
b_k_n0_n1_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
I1
,
Number
<
NPerBlockN1
>
{}),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
I1
,
Number
<
NPerBlockN1
>
{}),
max_lds_align
);
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
Blockwise
Dynamic
TensorSliceTransfer_v4
<
BlockSize
,
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
KPerBlock
,
1
,
MPerBlockM1
>
,
Sequence
<
KPerBlock
,
1
,
MPerBlockM1
>
,
ABlockTransferThreadSliceLengths_K_M0_M1
,
ABlockTransferThreadSliceLengths_K_M0_M1
,
ABlockTransferThreadClusterLengths_K_M0_M1
,
ABlockTransferThreadClusterLengths_K_M0_M1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
a_k_m0_m1_grid_desc
),
decltype
(
a_k_m0_m1_grid_desc
),
decltype
(
a_k_m0_m1_block_desc
),
decltype
(
a_k_m0_m1_block_desc
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
2
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_M1
,
ABlockTransferDstScalarPerVector_M1
,
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_k_m0_m1_grid_desc
,
true
>
(
a_k_m0_m1_grid_desc
,
make_multi_index
(
0
,
im0
,
0
),
make_multi_index
(
0
,
im0
,
0
),
a_k_m0_m1_block_desc
,
a_k_m0_m1_block_desc
,
make_multi_index
(
0
,
0
,
0
));
make_multi_index
(
0
,
0
,
0
));
// B matrix blockwise copy
// B matrix blockwise copy
auto
b_blockwise_copy
=
auto
b_blockwise_copy
=
Blockwise
Dynamic
TensorSliceTransfer_v4
<
BlockSize
,
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
KPerBlock
,
1
,
NPerBlockN1
>
,
Sequence
<
KPerBlock
,
1
,
NPerBlockN1
>
,
BBlockTransferThreadSliceLengths_K_N0_N1
,
BBlockTransferThreadSliceLengths_K_N0_N1
,
BBlockTransferThreadClusterLengths_K_N0_N1
,
BBlockTransferThreadClusterLengths_K_N0_N1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
b_k_n0_n1_grid_desc
),
decltype
(
b_k_n0_n1_grid_desc
),
decltype
(
b_k_n0_n1_block_desc
),
decltype
(
b_k_n0_n1_block_desc
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
2
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_N1
,
BBlockTransferDstScalarPerVector_N1
,
1
,
1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_k_n0_n1_grid_desc
,
true
>
(
b_k_n0_n1_grid_desc
,
make_multi_index
(
0
,
in0
,
0
),
make_multi_index
(
0
,
in0
,
0
),
b_k_n0_n1_block_desc
,
b_k_n0_n1_block_desc
,
make_multi_index
(
0
,
0
,
0
));
make_multi_index
(
0
,
0
,
0
));
// GEMM definition
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// c_mtx += transpose(a_mtx) * b_mtx
...
@@ -447,9 +446,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -447,9 +446,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
constexpr
auto
c_m10_m11_n10_n11_thread_tensor_lengths
=
constexpr
auto
c_m10_m11_n10_n11_thread_tensor_lengths
=
decltype
(
blockwise_gemm
)
::
GetCM0M1N0N1ThreadTensorLengths
();
decltype
(
blockwise_gemm
)
::
GetCM0M1N0N1ThreadTensorLengths
();
constexpr
auto
c_m10_m11_n10_n11_thread_desc
=
constexpr
auto
c_m10_m11_n10_n11_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
sequence_to_tuple_of_number
(
c_m10_m11_n10_n11_thread_tensor_lengths
));
sequence_to_tuple_of_number
(
c_m10_m11_n10_n11_thread_tensor_lengths
));
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_aligned_space_size
=
constexpr
auto
a_block_aligned_space_size
=
...
@@ -465,9 +463,9 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -465,9 +463,9 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAcc
>
(
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAcc
>
(
c_m10_m11_n10_n11_thread_desc
.
GetElementSpaceSize
());
c_m10_m11_n10_n11_thread_desc
.
GetElementSpaceSize
());
Threadwise
Dynamic
TensorSliceSet_v1
<
FloatAcc
,
ThreadwiseTensorSliceSet_v1
<
FloatAcc
,
decltype
(
c_m10_m11_n10_n11_thread_desc
),
decltype
(
c_m10_m11_n10_n11_thread_desc
),
decltype
(
c_m10_m11_n10_n11_thread_tensor_lengths
)
>
{}
decltype
(
c_m10_m11_n10_n11_thread_tensor_lengths
)
>
{}
.
Run
(
c_m10_m11_n10_n11_thread_desc
,
.
Run
(
c_m10_m11_n10_n11_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_thread_buf
,
...
@@ -477,15 +475,15 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -477,15 +475,15 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_k_m0_m1_global_
i
te
rator
_hacks
=
AGrid
I
te
rator
Hacks
{};
constexpr
auto
a_k_m0_m1_global_
s
te
p
_hacks
=
AGrid
S
te
p
Hacks
{};
constexpr
auto
b_k_n0_n1_global_
i
te
rator
_hacks
=
BGrid
I
te
rator
Hacks
{};
constexpr
auto
b_k_n0_n1_global_
s
te
p
_hacks
=
BGrid
S
te
p
Hacks
{};
// hack to control index calculation when move slice window for A and B matrix for
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
// threadwise copy
constexpr
auto
a_k_m0_m1_global_move_slice_window_
i
te
rator
_hack
=
constexpr
auto
a_k_m0_m1_global_move_slice_window_
s
te
p
_hack
=
AGridMoveSliceWindow
I
te
rator
Hacks
{};
AGridMoveSliceWindow
S
te
p
Hacks
{};
constexpr
auto
b_k_n0_n1_global_move_slice_window_
i
te
rator
_hack
=
constexpr
auto
b_k_n0_n1_global_move_slice_window_
s
te
p
_hack
=
BGridMoveSliceWindow
I
te
rator
Hacks
{};
BGridMoveSliceWindow
S
te
p
Hacks
{};
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_a_block_double
,
a_k_m0_m1_block_desc
.
GetElementSpaceSize
());
p_a_block_double
,
a_k_m0_m1_block_desc
.
GetElementSpaceSize
());
...
@@ -502,9 +500,9 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -502,9 +500,9 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_
i
te
rator
_hacks
);
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_
s
te
p
_hacks
);
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_
i
te
rator
_hacks
);
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_
s
te
p
_hacks
);
a_blockwise_copy
.
RunWrite
(
a_k_m0_m1_block_desc
,
a_block_even_buf
);
a_blockwise_copy
.
RunWrite
(
a_k_m0_m1_block_desc
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n0_n1_block_desc
,
b_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n0_n1_block_desc
,
b_block_even_buf
);
...
@@ -519,22 +517,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -519,22 +517,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
do
do
{
{
// even iteration
// even iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m0_m1_grid_desc
,
a_k_m0_m1_grid_desc
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
a_k_m0_m1_global_move_slice_window_step_hack
);
a_k_m0_m1_global_move_slice_window_iterator_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n0_n1_grid_desc
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_block_slice_copy_step
,
b_k_n0_n1_grid_desc
,
b_k_n0_n1_global_move_slice_window_step_hack
);
b_block_slice_copy_step
,
b_k_n0_n1_global_move_slice_window_iterator_hack
);
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_
i
te
rator
_hacks
);
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_
s
te
p
_hacks
);
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_
i
te
rator
_hacks
);
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_
s
te
p
_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
c_m10_m11_n10_n11_thread_desc
,
blockwise_gemm
.
Run
(
c_m10_m11_n10_n11_thread_desc
,
...
@@ -547,22 +543,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -547,22 +543,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
b_blockwise_copy
.
RunWrite
(
b_k_n0_n1_block_desc
,
b_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n0_n1_block_desc
,
b_block_odd_buf
);
// odd iteration
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m0_m1_grid_desc
,
a_k_m0_m1_grid_desc
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
a_k_m0_m1_global_move_slice_window_step_hack
);
a_k_m0_m1_global_move_slice_window_iterator_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n0_n1_grid_desc
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_block_slice_copy_step
,
b_k_n0_n1_grid_desc
,
b_k_n0_n1_global_move_slice_window_step_hack
);
b_block_slice_copy_step
,
b_k_n0_n1_global_move_slice_window_iterator_hack
);
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_
i
te
rator
_hacks
);
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_
s
te
p
_hacks
);
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_
i
te
rator
_hacks
);
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_
s
te
p
_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
blockwise_gemm
.
Run
(
...
@@ -581,18 +575,18 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -581,18 +575,18 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
{
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m0_m1_grid_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m0_m1_grid_desc
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
a_k_m0_m1_global_move_slice_window_
i
te
rator
_hack
);
a_k_m0_m1_global_move_slice_window_
s
te
p
_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n0_n1_grid_desc
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n0_n1_grid_desc
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
b_k_n0_n1_global_move_slice_window_
i
te
rator
_hack
);
b_k_n0_n1_global_move_slice_window_
s
te
p
_hack
);
__syncthreads
();
__syncthreads
();
// LDS double buffer: load last data from device mem
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_
i
te
rator
_hacks
);
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_
s
te
p
_hacks
);
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_
i
te
rator
_hacks
);
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_
s
te
p
_hacks
);
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
blockwise_gemm
.
Run
(
...
@@ -619,19 +613,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -619,19 +613,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
// output: register to global memory
// output: register to global memory
{
{
constexpr
index_t
M11
=
M1PerThreadM111
*
M11N11ThreadClusterM1100
*
M11N11ThreadClusterM1101
;
constexpr
index_t
N11
=
N1PerThreadN111
*
M11N11ThreadClusterN1100
*
M11N11ThreadClusterN1101
;
constexpr
index_t
M10
=
MPerBlockM1
/
M11
;
constexpr
index_t
N10
=
NPerBlockN1
/
N11
;
constexpr
index_t
M111
=
M1PerThreadM111
;
constexpr
index_t
N111
=
N1PerThreadN111
;
constexpr
auto
c_m0_m10_m11_n0_n10_n11_thread_desc
=
constexpr
auto
c_m0_m10_m11_n0_n10_n11_thread_desc
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
make_tuple
(
I1
,
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
]
>
{},
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
]
>
{},
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
]
>
{},
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
]
>
{},
...
@@ -642,7 +625,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -642,7 +625,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
const
auto
c_m10_m11_n10_n11_thread_origin_idx_on_block
=
const
auto
c_m10_m11_n10_n11_thread_origin_idx_on_block
=
blockwise_gemm
.
CalculateCM0M1N0N1ThreadOriginOnBlock
(
get_thread_local_1d_id
());
blockwise_gemm
.
CalculateCM0M1N0N1ThreadOriginOnBlock
(
get_thread_local_1d_id
());
Threadwise
Dynamic
TensorSliceTransfer_v1r3
<
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatAcc
,
FloatC
,
FloatC
,
decltype
(
c_m0_m10_m11_n0_n10_n11_thread_desc
),
decltype
(
c_m0_m10_m11_n0_n10_n11_thread_desc
),
...
@@ -670,7 +653,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -670,7 +653,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
c_thread_buf
,
c_thread_buf
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_grid_buf
,
c_grid_buf
,
CGrid
I
te
rator
Hacks
{});
CGrid
S
te
p
Hacks
{});
}
}
}
}
};
};
...
...
composable_kernel/include/tensor_operation/gridwise_
dynamic_
gemm_dlops_v1r3.hpp
→
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp
View file @
ccc4a1d3
#ifndef CK_GRIDWISE_
DYNAMIC_
GEMM_V1R3_HPP
#ifndef CK_GRIDWISE_GEMM_V1R3_HPP
#define CK_GRIDWISE_
DYNAMIC_
GEMM_V1R3_HPP
#define CK_GRIDWISE_GEMM_V1R3_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
multi_index_transform_helper.hpp"
#include "multi_index_transform_helper.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp"
#include "blockwise_
dynamic_
tensor_slice_transfer_v2.hpp"
#include "blockwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer_v2.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_
dynamic_
tensor_slice_set.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -26,7 +26,7 @@ __global__ void
...
@@ -26,7 +26,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_
dynamic_
gemm_dlops_v1r3
(
kernel_gemm_dlops_v1r3
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
...
@@ -68,28 +68,27 @@ __global__ void
...
@@ -68,28 +68,27 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_dynamic_gemm_dlops_v1r3
(
kernel_gemm_dlops_v1r3
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
void
CONSTANT
*
p_a_k0_m0_m1_k1_grid_desc
,
const
void
CONSTANT
*
p_a_k0_m0_m1_k1_grid_desc
,
const
void
CONSTANT
*
p_b_k0_n0_n1_k1_grid_desc
,
const
void
CONSTANT
*
p_b_k0_n0_n1_k1_grid_desc
,
const
void
CONSTANT
*
p_c_m0_m10_m11_n0_n10_n11_grid_desc
,
const
void
CONSTANT
*
p_c_m0_m10_m11_n0_n10_n11_grid_desc
,
const
void
CONSTANT
*
p_c_blockid_to_m0_n0_block_cluster_adaptor
)
const
void
CONSTANT
*
p_c_blockid_to_m0_n0_block_cluster_adaptor
)
{
{
// first cast void CONSTANT void* to void*
// first cast void CONSTANT void* to void*
// second cast void* to Desc*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
// the copy constructor of tensor descriptor doesn't take address_space(4)
const
auto
a_k0_m0_m1_k1_grid_desc
=
const
auto
a_k0_m0_m1_k1_grid_desc
=
*
reinterpret_cast
<
const
AK0M0M1K1GridDesc
*>
(
*
reinterpret_cast
<
const
AK0M0M1K1GridDesc
*>
((
const
void
*
)
p_a_k0_m0_m1_k1_grid_desc
);
cast_pointer_to_generic_address_space
(
p_a_k0_m0_m1_k1_grid_desc
)
)
;
const
auto
b_k0_n0_n1_k1_grid_desc
=
const
auto
b_k0_n0_n1_k1_grid_desc
=
*
reinterpret_cast
<
const
BK0N0N1K1GridDesc
*>
(
*
reinterpret_cast
<
const
BK0N0N1K1GridDesc
*>
((
const
void
*
)
p_b_k0_n0_n1_k1_grid_desc
);
cast_pointer_to_generic_address_space
(
p_b_k0_n0_n1_k1_grid_desc
)
)
;
const
auto
c_m0_m10_m11_n0_n10_n11_grid_desc
=
const
auto
c_m0_m10_m11_n0_n10_n11_grid_desc
=
*
reinterpret_cast
<
const
CM0M10M11N0N10N11GridDesc
*>
(
*
reinterpret_cast
<
const
CM0M10M11N0N10N11GridDesc
*>
(
(
const
void
*
)
p_c_m0_m10_m11_n0_n10_n11_grid_desc
);
cast_pointer_to_generic_address_space
(
p_c_m0_m10_m11_n0_n10_n11_grid_desc
)
)
;
const
auto
c_blockid_to_m0_n0_block_cluster_adaptor
=
const
auto
c_blockid_to_m0_n0_block_cluster_adaptor
=
*
reinterpret_cast
<
const
CBlockIdToM0N0BlockClusterAdaptor
*>
(
*
reinterpret_cast
<
const
CBlockIdToM0N0BlockClusterAdaptor
*>
(
(
const
void
*
)
p_c_blockid_to_m0_n0_block_cluster_adaptor
);
cast_pointer_to_generic_address_space
(
p_c_blockid_to_m0_n0_block_cluster_adaptor
)
)
;
constexpr
index_t
shared_block_size
=
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
...
@@ -142,12 +141,12 @@ template <index_t BlockSize,
...
@@ -142,12 +141,12 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGrid
I
te
rator
Hacks
,
typename
AGrid
S
te
p
Hacks
,
typename
BGrid
I
te
rator
Hacks
,
typename
BGrid
S
te
p
Hacks
,
typename
CGrid
I
te
rator
Hacks
,
typename
CGrid
S
te
p
Hacks
,
typename
AGridMoveSliceWindow
I
te
rator
Hacks
,
typename
AGridMoveSliceWindow
S
te
p
Hacks
,
typename
BGridMoveSliceWindow
I
te
rator
Hacks
>
typename
BGridMoveSliceWindow
S
te
p
Hacks
>
struct
Gridwise
Dynamic
GemmDlops_km_kn_mn_v1r3
struct
GridwiseGemmDlops_km_kn_mn_v1r3
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -164,12 +163,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -164,12 +163,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// TODO: check alignment
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_k_m_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_k_m_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlockM1
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlockM1
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_k_n_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
b_k_n_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlockN1
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlockN1
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// TODO: check alignment
...
@@ -191,12 +190,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -191,12 +190,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
const
auto
M
=
a_k0_m_k1_grid_desc
.
GetLength
(
I1
);
const
auto
M
=
a_k0_m_k1_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k0_n_k1_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k0_n_k1_grid_desc
.
GetLength
(
I1
);
const
auto
K0
=
a_k0_m_k1_grid_desc
.
GetLength
(
I0
);
const
auto
K0
=
a_k0_m_k1_grid_desc
.
GetLength
(
I0
);
const
auto
K1
=
a_k0_m_k1_grid_desc
.
GetLength
(
I2
);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
(
M
==
c_m_n_grid_desc
.
GetLength
(
I0
)
&&
N
==
c_m_n_grid_desc
.
GetLength
(
I1
)
&&
return
(
M
==
c_m_n_grid_desc
.
GetLength
(
I0
)
&&
N
==
c_m_n_grid_desc
.
GetLength
(
I1
)
&&
K0
==
b_k0_n_k1_grid_desc
.
GetLength
(
I0
)
&&
K0
==
b_k0_n_k1_grid_desc
.
GetLength
(
I0
)
&&
K1
==
a_k0_m_k1_grid_desc
.
GetLength
(
I2
)
&&
K1
==
b_k0_n_k1_grid_desc
.
GetLength
(
I2
))
&&
K1
==
b_k0_n_k1_grid_desc
.
GetLength
(
I2
))
&&
(
M
%
MPerBlockM1
==
0
&&
N
%
NPerBlockN1
==
0
&&
K0
%
KPerBlock
==
0
);
(
M
%
MPerBlockM1
==
0
&&
N
%
NPerBlockN1
==
0
&&
K0
%
KPerBlock
==
0
);
}
}
...
@@ -231,13 +230,13 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -231,13 +230,13 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
const
auto
M1
=
Number
<
MPerBlockM1
>
{};
const
auto
M1
=
Number
<
MPerBlockM1
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
M0
=
M
/
M1
;
const
auto
a_k0_m0_m1_k1_grid_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
a_k0_m0_m1_k1_grid_desc
=
a_k0_m_k1_grid_desc
,
transform_tensor_descriptor
(
a_k0_m_k1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K0
),
make_tuple
(
make_pass_through_transform
(
K0
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
)),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
)),
make_pass_through_transform
(
K1
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
return
a_k0_m0_m1_k1_grid_desc
;
return
a_k0_m0_m1_k1_grid_desc
;
}
}
...
@@ -251,13 +250,13 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -251,13 +250,13 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
const
auto
N1
=
Number
<
NPerBlockN1
>
{};
const
auto
N1
=
Number
<
NPerBlockN1
>
{};
const
auto
N0
=
N
/
N1
;
const
auto
N0
=
N
/
N1
;
const
auto
b_k0_n0_n1_k1_grid_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
b_k0_n0_n1_k1_grid_desc
=
b_k0_n_k1_grid_desc
,
transform_tensor_descriptor
(
b_k0_n_k1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K0
),
make_tuple
(
make_pass_through_transform
(
K0
),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
)),
make_pass_through_transform
(
K1
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
return
b_k0_n0_n1_k1_grid_desc
;
return
b_k0_n0_n1_k1_grid_desc
;
}
}
...
@@ -284,7 +283,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -284,7 +283,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
constexpr
auto
M10
=
M1
/
M11
;
constexpr
auto
M10
=
M1
/
M11
;
constexpr
auto
N10
=
N1
/
N11
;
constexpr
auto
N10
=
N1
/
N11
;
const
auto
c_m0_m10_m11_n0_n10_n11_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
c_m0_m10_m11_n0_n10_n11_grid_desc
=
transform_tensor_descriptor
(
c_m_n_grid_desc
,
c_m_n_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M10
,
M11
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M10
,
M11
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N10
,
N11
))),
make_unmerge_transform
(
make_tuple
(
N0
,
N10
,
N11
))),
...
@@ -355,23 +354,23 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -355,23 +354,23 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// TODO: check alignment
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_k0_m0_m1_k1_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_k0_m0_m1_k1_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
I1
,
Number
<
MPerBlockM1
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
I1
,
Number
<
MPerBlockM1
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_k0_n0_n1_k1_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
b_k0_n0_n1_k1_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
I1
,
Number
<
NPerBlockN1
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
I1
,
Number
<
NPerBlockN1
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM
// A matrix in LDS memory, for blockwise GEMM
constexpr
auto
a_k0_m_k1_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlockM1
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlockM1
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM
// B matrix in LDS memory, for blockwise GEMM
constexpr
auto
b_k0_n_k1_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
b_k0_n_k1_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlockN1
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlockN1
>
{},
K1
),
max_lds_align
);
static_assert
(
a_k0_m0_m1_k1_block_desc
.
GetElementSpaceSize
()
==
static_assert
(
a_k0_m0_m1_k1_block_desc
.
GetElementSpaceSize
()
==
...
@@ -381,7 +380,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -381,7 +380,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
"wrong!"
);
"wrong!"
);
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
Blockwise
Dynamic
TensorSliceTransfer_v4r1
<
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4r1
<
BlockSize
,
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
KPerBlock
,
1
,
MPerBlockM1
,
K1
.
value
>
,
Sequence
<
KPerBlock
,
1
,
MPerBlockM1
,
K1
.
value
>
,
...
@@ -405,7 +404,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -405,7 +404,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
make_multi_index
(
0
,
0
,
0
,
0
));
make_multi_index
(
0
,
0
,
0
,
0
));
// B matrix blockwise copy
// B matrix blockwise copy
auto
b_blockwise_copy
=
Blockwise
Dynamic
TensorSliceTransfer_v4r1
<
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4r1
<
BlockSize
,
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
KPerBlock
,
1
,
NPerBlockN1
,
K1
.
value
>
,
Sequence
<
KPerBlock
,
1
,
NPerBlockN1
,
K1
.
value
>
,
...
@@ -453,9 +452,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -453,9 +452,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
constexpr
auto
c_m10_m11_n10_n11_thread_tensor_lengths
=
constexpr
auto
c_m10_m11_n10_n11_thread_tensor_lengths
=
decltype
(
blockwise_gemm
)
::
GetCThreadTensorLengths_BM0_BM1_BN0_BN1
();
decltype
(
blockwise_gemm
)
::
GetCThreadTensorLengths_BM0_BM1_BN0_BN1
();
constexpr
auto
c_m10_m11_n10_n11_thread_desc
=
constexpr
auto
c_m10_m11_n10_n11_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
sequence_to_tuple_of_number
(
c_m10_m11_n10_n11_thread_tensor_lengths
));
sequence_to_tuple_of_number
(
c_m10_m11_n10_n11_thread_tensor_lengths
));
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
...
@@ -471,9 +469,9 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -471,9 +469,9 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAcc
>
(
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAcc
>
(
c_m10_m11_n10_n11_thread_desc
.
GetElementSpaceSize
());
c_m10_m11_n10_n11_thread_desc
.
GetElementSpaceSize
());
Threadwise
Dynamic
TensorSliceSet_v1
<
FloatAcc
,
ThreadwiseTensorSliceSet_v1
<
FloatAcc
,
decltype
(
c_m10_m11_n10_n11_thread_desc
),
decltype
(
c_m10_m11_n10_n11_thread_desc
),
decltype
(
c_m10_m11_n10_n11_thread_tensor_lengths
)
>
{}
decltype
(
c_m10_m11_n10_n11_thread_tensor_lengths
)
>
{}
.
Run
(
c_m10_m11_n10_n11_thread_desc
,
.
Run
(
c_m10_m11_n10_n11_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_thread_buf
,
...
@@ -496,8 +494,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -496,8 +494,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGrid
I
te
rator
Hacks
{});
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGrid
S
te
p
Hacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGrid
I
te
rator
Hacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGrid
S
te
p
Hacks
{});
a_blockwise_copy
.
RunWrite
(
a_k0_m0_m1_k1_block_desc
,
a_block_even_buf
);
a_blockwise_copy
.
RunWrite
(
a_k0_m0_m1_k1_block_desc
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k0_n0_n1_k1_block_desc
,
b_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k0_n0_n1_k1_block_desc
,
b_block_even_buf
);
...
@@ -516,18 +514,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -516,18 +514,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// even iteration
// even iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m0_m1_k1_grid_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m0_m1_k1_grid_desc
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
AGridMoveSliceWindow
I
te
rator
Hacks
{});
AGridMoveSliceWindow
S
te
p
Hacks
{});
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n0_n1_k1_grid_desc
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n0_n1_k1_grid_desc
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
BGridMoveSliceWindow
I
te
rator
Hacks
{});
BGridMoveSliceWindow
S
te
p
Hacks
{});
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGridStepHacks
{});
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGridIteratorHacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGridStepHacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGridIteratorHacks
{});
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
c_m10_m11_n10_n11_thread_desc
,
blockwise_gemm
.
Run
(
c_m10_m11_n10_n11_thread_desc
,
...
@@ -542,18 +538,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -542,18 +538,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// odd iteration
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m0_m1_k1_grid_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m0_m1_k1_grid_desc
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
AGridMoveSliceWindow
I
te
rator
Hacks
{});
AGridMoveSliceWindow
S
te
p
Hacks
{});
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n0_n1_k1_grid_desc
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n0_n1_k1_grid_desc
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
BGridMoveSliceWindow
I
te
rator
Hacks
{});
BGridMoveSliceWindow
S
te
p
Hacks
{});
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGridStepHacks
{});
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGridIteratorHacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGridStepHacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGridIteratorHacks
{});
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
blockwise_gemm
.
Run
(
...
@@ -570,18 +564,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -570,18 +564,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// LDS double buffer: tail
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
{
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m0_m1_k1_grid_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_block_slice_copy_step
,
a_k0_m0_m1_k1_grid_desc
,
a_block_slice_copy_step
,
AGridMoveSliceWindowStepHacks
{});
AGridMoveSliceWindowIteratorHacks
{});
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n0_n1_k1_grid_desc
,
b_k0_n0_n1_k1_grid_desc
,
b_block_slice_copy_step
,
BGridMoveSliceWindowStepHacks
{});
b_block_slice_copy_step
,
BGridMoveSliceWindowIteratorHacks
{});
__syncthreads
();
__syncthreads
();
// LDS double buffer: load last data from device mem
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGrid
I
te
rator
Hacks
{});
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGrid
S
te
p
Hacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGrid
I
te
rator
Hacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGrid
S
te
p
Hacks
{});
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
blockwise_gemm
.
Run
(
...
@@ -608,21 +600,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -608,21 +600,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// output: register to global memory
// output: register to global memory
{
{
constexpr
auto
M11
=
Number
<
container_reduce
(
M11N11ThreadClusterM110Xs
{},
math
::
multiplies_v2
{},
I1
)
*
M1PerThreadM111
>
{};
constexpr
auto
N11
=
Number
<
container_reduce
(
M11N11ThreadClusterN110Xs
{},
math
::
multiplies_v2
{},
I1
)
*
N1PerThreadN111
>
{};
constexpr
index_t
M10
=
MPerBlockM1
/
M11
;
constexpr
index_t
N10
=
NPerBlockN1
/
N11
;
constexpr
index_t
M111
=
M1PerThreadM111
;
constexpr
index_t
N111
=
N1PerThreadN111
;
constexpr
auto
c_m0_m10_m11_n0_n10_n11_thread_desc
=
constexpr
auto
c_m0_m10_m11_n0_n10_n11_thread_desc
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
make_tuple
(
I1
,
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
]
>
{},
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
]
>
{},
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
]
>
{},
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
]
>
{},
...
@@ -634,7 +613,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -634,7 +613,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
blockwise_gemm
.
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
blockwise_gemm
.
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
get_thread_local_1d_id
());
get_thread_local_1d_id
());
Threadwise
Dynamic
TensorSliceTransfer_v1r3
<
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatAcc
,
FloatC
,
FloatC
,
decltype
(
c_m0_m10_m11_n0_n10_n11_thread_desc
),
decltype
(
c_m0_m10_m11_n0_n10_n11_thread_desc
),
...
@@ -662,7 +641,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -662,7 +641,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
c_thread_buf
,
c_thread_buf
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_grid_buf
,
c_grid_buf
,
CGrid
I
te
rator
Hacks
{});
CGrid
S
te
p
Hacks
{});
}
}
}
}
};
};
...
...
composable_kernel/include/tensor_operation/gridwise_
dynamic_
gemm_dlops_v2.hpp
→
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
View file @
ccc4a1d3
#ifndef CK_GRIDWISE_
DYNAMIC_
GEMM_V2_HPP
#ifndef CK_GRIDWISE_GEMM_V2_HPP
#define CK_GRIDWISE_
DYNAMIC_
GEMM_V2_HPP
#define CK_GRIDWISE_GEMM_V2_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
multi_index_transform_helper.hpp"
#include "multi_index_transform_helper.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_
dynamic_
tensor_slice_transfer.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "blockwise_gemm_dlops_v3.hpp"
#include "blockwise_gemm_dlops_v3.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -42,12 +42,12 @@ template <index_t BlockSize,
...
@@ -42,12 +42,12 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGlobal
I
te
rator
Hacks
,
typename
AGlobal
S
te
p
Hacks
,
typename
BGlobal
I
te
rator
Hacks
,
typename
BGlobal
S
te
p
Hacks
,
typename
CGlobal
I
te
rator
Hacks
,
typename
CGlobal
S
te
p
Hacks
,
typename
AGlobalMoveSliceWindow
I
te
rator
Hacks
,
typename
AGlobalMoveSliceWindow
S
te
p
Hacks
,
typename
BGlobalMoveSliceWindow
I
te
rator
Hacks
>
typename
BGlobalMoveSliceWindow
S
te
p
Hacks
>
struct
Gridwise
Dynamic
GemmDlops_km_kn_mn_v3
struct
GridwiseGemmDlops_km_kn_mn_v3
{
{
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
{
...
@@ -58,7 +58,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
...
@@ -58,7 +58,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_e_k_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_e_k_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
E
>
{},
Number
<
KPerBlock
>
{}),
max_lds_align
);
make_tuple
(
Number
<
E
>
{},
Number
<
KPerBlock
>
{}),
max_lds_align
);
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
...
@@ -102,7 +102,6 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
...
@@ -102,7 +102,6 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
// divide block work by [M, N]
// divide block work by [M, N]
#if 0
#if 0
const auto k_block_work_num = K / Number<KPerBlock>{};
const auto ho_block_work_num = Ho / Number<HoPerBlock>{};
const auto ho_block_work_num = Ho / Number<HoPerBlock>{};
const auto wo_block_work_num = Wo / Number<WoPerBlock>{};
const auto wo_block_work_num = Wo / Number<WoPerBlock>{};
const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num;
const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num;
...
@@ -114,7 +113,6 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
...
@@ -114,7 +113,6 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
#else
#else
// Hack: this force result into SGPR
// Hack: this force result into SGPR
const
index_t
k_block_work_num
=
__builtin_amdgcn_readfirstlane
(
K
/
KPerBlock
);
const
index_t
ho_block_work_num
=
__builtin_amdgcn_readfirstlane
(
Ho
/
HoPerBlock
);
const
index_t
ho_block_work_num
=
__builtin_amdgcn_readfirstlane
(
Ho
/
HoPerBlock
);
const
index_t
wo_block_work_num
=
__builtin_amdgcn_readfirstlane
(
Wo
/
WoPerBlock
);
const
index_t
wo_block_work_num
=
__builtin_amdgcn_readfirstlane
(
Wo
/
WoPerBlock
);
const
index_t
hwo_block_work_num
=
ho_block_work_num
*
wo_block_work_num
;
const
index_t
hwo_block_work_num
=
ho_block_work_num
*
wo_block_work_num
;
...
@@ -134,23 +132,21 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
...
@@ -134,23 +132,21 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_e_k_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_e_k_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
EPerBlock
>
{},
Number
<
KPerBlock
>
{}),
max_lds_align
);
make_tuple
(
Number
<
EPerBlock
>
{},
Number
<
KPerBlock
>
{}),
max_lds_align
);
constexpr
auto
a_e_k_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_e_k_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
E
>
{},
Number
<
KPerBlock
>
{}),
max_lds_align
);
make_tuple
(
Number
<
E
>
{},
Number
<
KPerBlock
>
{}),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_e_n_ho_wo_block_desc
=
constexpr
auto
b_e_n_ho_wo_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
EPerBlock
>
{},
Number
<
1
>
{},
Number
<
HoPerBlock
>
{},
Number
<
WoPerBlock
>
{}));
Number
<
EPerBlock
>
{},
Number
<
1
>
{},
Number
<
HoPerBlock
>
{},
Number
<
WoPerBlock
>
{}));
// c_thread_mtx definition: this is a mess
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k_n_ho_wo_thread_desc
=
constexpr
auto
c_k_n_ho_wo_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
Number
<
KPerThread
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
auto
blockwise_gemm
=
auto
blockwise_gemm
=
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
<
BlockSize
,
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
<
BlockSize
,
...
@@ -184,47 +180,46 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
...
@@ -184,47 +180,46 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
BlockwiseDynamicTensorSliceTransfer_v4
<
BlockSize
,
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
E
,
KPerBlock
>
,
Sequence
<
E
,
KPerBlock
>
,
ABlockTransferThreadSliceLengths_E_K
,
ABlockTransferThreadSliceLengths_E_K
,
ABlockTransferThreadClusterLengths_E_K
,
ABlockTransferThreadClusterLengths_E_K
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
a_e_k_global_desc
),
decltype
(
a_e_k_global_desc
),
decltype
(
a_e_k_desc
),
decltype
(
a_e_k_desc
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
1
,
1
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K
,
ABlockTransferDstScalarPerVector_K
,
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
>
(
a_e_k_global_desc
,
a_e_k_global_desc
,
make_multi_index
(
0
,
k_block_data_on_global
),
make_multi_index
(
0
,
k_block_data_on_global
),
a_e_k_desc
,
a_e_k_desc
,
make_multi_index
(
0
,
0
));
make_multi_index
(
0
,
0
));
constexpr
auto
b_e_n_ho_wo_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
constexpr
auto
b_e_n_ho_wo_thread_desc
=
Number
<
EPerBlock
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
EPerBlock
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
auto
b_threadwise_transfer
=
ThreadwiseTensorSliceTransfer_v2
<
FloatAB
,
auto
b_threadwise_transfer
=
ThreadwiseDynamicTensorSliceTransfer_v2
<
FloatAB
,
FloatAB
,
decltype
(
b_e_n_ho_wo_global_desc
),
FloatAB
,
decltype
(
b_e_n_ho_wo_thread_desc
),
decltype
(
b_e_n_ho_wo_global_desc
),
Sequence
<
EPerBlock
,
1
,
HoPerThread
,
WoPerThread
>
,
decltype
(
b_e_n_ho_wo_thread_desc
),
BBlockTransferSrcAccessOrder
,
Sequence
<
EPerBlock
,
1
,
HoPerThread
,
WoPerThread
>
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcVectorDim
,
1
,
BBlockTransferSrcScalarPerVector
,
true
>
(
1
,
b_e_n_ho_wo_global_desc
,
true
>
(
b_e_n_ho_wo_global_desc
,
make_multi_index
(
0
,
0
,
ho_thread_data_on_global
,
wo_thread_data_on_global
));
make_multi_index
(
0
,
0
,
ho_thread_data_on_global
,
wo_thread_data_on_global
));
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_shared_block
,
a_e_k_desc
.
GetElementSpaceSize
());
p_shared_block
,
a_e_k_desc
.
GetElementSpaceSize
());
...
@@ -232,44 +227,45 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
...
@@ -232,44 +227,45 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
// register allocation for output
// register allocation for output
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAcc
,
FloatAcc
,
c_k_n_ho_wo_thread_desc
.
GetElementSpaceSize
()
>
c_k_n_ho_wo_thread_desc
.
GetElementSpaceSize
(),
true
>
c_thread_buf
;
c_thread_buf
;
// initialize output thread tensor
// initialize output thread tensor
Threadwise
Dynamic
TensorSliceSet_v1
<
FloatAcc
,
ThreadwiseTensorSliceSet_v1
<
FloatAcc
,
decltype
(
c_k_n_ho_wo_thread_desc
),
decltype
(
c_k_n_ho_wo_thread_desc
),
Sequence
<
KPerThread
,
1
,
HoPerThread
,
WoPerThread
>>
{}
Sequence
<
KPerThread
,
1
,
HoPerThread
,
WoPerThread
>>
{}
.
Run
(
c_k_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
FloatAcc
{
0
});
.
Run
(
c_k_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
FloatAcc
{
0
});
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
EPerBlock
,
0
,
0
,
0
);
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
EPerBlock
,
0
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_e_k_global_
i
te
rator
_hacks
=
AGlobal
I
te
rator
Hacks
{};
constexpr
auto
a_e_k_global_
s
te
p
_hacks
=
AGlobal
S
te
p
Hacks
{};
constexpr
auto
b_e_n_ho_wo_global_
i
te
rator
_hacks
=
BGlobal
I
te
rator
Hacks
{};
constexpr
auto
b_e_n_ho_wo_global_
s
te
p
_hacks
=
BGlobal
S
te
p
Hacks
{};
// hack to control index calculation when move slice window for A and B matrix for
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
// threadwise copy
constexpr
auto
a_e_k_global_move_slice_window_iterator_hack
=
constexpr
auto
a_e_k_global_move_slice_window_step_hack
=
AGlobalMoveSliceWindowStepHacks
{};
AGlobalMoveSliceWindowIteratorHacks
{};
constexpr
auto
b_e_n_ho_wo_global_move_slice_window_step_hack
=
constexpr
auto
b_e_n_ho_wo_global_move_slice_window_iterator_hack
=
BGlobalMoveSliceWindowStepHacks
{};
BGlobalMoveSliceWindowIteratorHacks
{};
// double regsiter buffer for b
// double regsiter buffer for b
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAB
,
FloatAB
,
b_e_n_ho_wo_thread_desc
.
GetElementSpaceSize
()
>
b_e_n_ho_wo_thread_desc
.
GetElementSpaceSize
(),
true
>
b_thread_even_buf
,
b_thread_odd_buf
;
b_thread_even_buf
,
b_thread_odd_buf
;
// LDS double buffer: preload data
// LDS double buffer: preload data
{
{
a_blockwise_copy
.
RunRead
(
a_e_k_global_desc
,
a_global_buf
,
a_e_k_global_
i
te
rator
_hacks
);
a_blockwise_copy
.
RunRead
(
a_e_k_global_desc
,
a_global_buf
,
a_e_k_global_
s
te
p
_hacks
);
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
b_global_buf
,
b_global_buf
,
b_e_n_ho_wo_thread_desc
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_thread_even_buf
,
b_e_n_ho_wo_global_
i
te
rator
_hacks
);
b_e_n_ho_wo_global_
s
te
p
_hacks
);
a_blockwise_copy
.
RunWrite
(
a_e_k_desc
,
a_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_e_k_desc
,
a_block_buf
);
}
}
...
@@ -293,7 +289,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
...
@@ -293,7 +289,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
b_e_n_ho_wo_thread_desc
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_thread_odd_buf
,
b_e_n_ho_wo_global_
i
te
rator
_hacks
);
b_e_n_ho_wo_global_
s
te
p
_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
...
@@ -309,7 +305,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
...
@@ -309,7 +305,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
b_e_n_ho_wo_thread_desc
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_thread_even_buf
,
b_e_n_ho_wo_global_
i
te
rator
_hacks
);
b_e_n_ho_wo_global_
s
te
p
_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
...
@@ -332,7 +328,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
...
@@ -332,7 +328,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
b_e_n_ho_wo_thread_desc
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_thread_odd_buf
,
b_e_n_ho_wo_global_
i
te
rator
_hacks
);
b_e_n_ho_wo_global_
s
te
p
_hacks
);
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
...
@@ -351,23 +347,22 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
...
@@ -351,23 +347,22 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
// output: register to global memory
// output: register to global memory
{
{
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
constexpr
auto
c_k_n_ho_wo_global_tensor_
i
te
rator
_hacks
=
CGlobal
I
te
rator
Hacks
{};
constexpr
auto
c_k_n_ho_wo_global_tensor_
s
te
p
_hacks
=
CGlobal
S
te
p
Hacks
{};
const
index_t
k_thread_data_on_global
=
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
k_thread_id
*
KPerThread
;
k_block_data_on_global
+
k_thread_id
*
KPerThread
;
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatAcc
,
FloatC
,
FloatC
,
decltype
(
c_k_n_ho_wo_thread_desc
),
decltype
(
c_k_n_ho_wo_thread_desc
),
decltype
(
c_k_n_ho_wo_global_desc
),
decltype
(
c_k_n_ho_wo_global_desc
),
Sequence
<
KPerThread
,
1
,
HoPerThread
,
WoPerThread
>
,
Sequence
<
KPerThread
,
1
,
HoPerThread
,
WoPerThread
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
CGlobalMemoryDataOperation
,
1
,
1
,
true
>
(
true
>
(
c_k_n_ho_wo_global_desc
,
c_k_n_ho_wo_global_desc
,
make_multi_index
(
make_multi_index
(
k_thread_data_on_global
,
0
,
ho_thread_data_on_global
,
wo_thread_data_on_global
))
k_thread_data_on_global
,
0
,
ho_thread_data_on_global
,
wo_thread_data_on_global
))
...
@@ -376,7 +371,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
...
@@ -376,7 +371,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
c_thread_buf
,
c_thread_buf
,
c_k_n_ho_wo_global_desc
,
c_k_n_ho_wo_global_desc
,
c_global_buf
,
c_global_buf
,
c_k_n_ho_wo_global_tensor_
i
te
rator
_hacks
);
c_k_n_ho_wo_global_tensor_
s
te
p
_hacks
);
}
}
}
}
...
...
composable_kernel/include/tensor_operation/gridwise_
dynamic_
gemm_xdlops_v2r3.hpp
→
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
View file @
ccc4a1d3
#ifndef CK_GRIDWISE_
DYNAMIC_
GEMM_XDLOPS_V2R3_HPP
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP
#define CK_GRIDWISE_
DYNAMIC_
GEMM_XDLOPS_V2R3_HPP
#define CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
multi_index_transform_helper.hpp"
#include "multi_index_transform_helper.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_
dynamic_
tensor_slice_transfer.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_
dynamic_
tensor_slice_set.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -24,13 +24,13 @@ __global__ void
...
@@ -24,13 +24,13 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_
dynamic_
gemm_xdlops_v2r3
(
const
FloatAB
*
__restrict__
p_a_grid
,
kernel_gemm_xdlops_v2r3
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AK0MK1GridDesc
a_k0_m_k1_grid_desc
,
const
AK0MK1GridDesc
a_k0_m_k1_grid_desc
,
const
BK0NK1GridDesc
b_k0_n_k1_grid_desc
,
const
BK0NK1GridDesc
b_k0_n_k1_grid_desc
,
const
CM0M1M2NGridDesc
c_m0_m1_m2_n_grid_desc
,
const
CM0M1M2NGridDesc
c_m0_m1_m2_n_grid_desc
,
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
{
{
constexpr
index_t
shared_block_size
=
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
...
@@ -58,25 +58,25 @@ __global__ void
...
@@ -58,25 +58,25 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_
dynamic_
gemm_xdlops_v2r3
(
const
FloatAB
*
__restrict__
p_a_grid
,
kernel_gemm_xdlops_v2r3
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
void
CONSTANT
*
p_a_k0_m_k1_grid_desc
,
const
void
CONSTANT
*
p_a_k0_m_k1_grid_desc
,
const
void
CONSTANT
*
p_b_k0_n_k1_grid_desc
,
const
void
CONSTANT
*
p_b_k0_n_k1_grid_desc
,
const
void
CONSTANT
*
p_c_m0_m1_m2_n_grid_desc
,
const
void
CONSTANT
*
p_c_m0_m1_m2_n_grid_desc
,
const
void
CONSTANT
*
p_c_block_cluster_adaptor
)
const
void
CONSTANT
*
p_c_block_cluster_adaptor
)
{
{
constexpr
index_t
shared_block_size
=
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
const
auto
a_k0_m_k1_grid_desc
=
const
auto
a_k0_m_k1_grid_desc
=
*
reinterpret_cast
<
const
AK0MK1GridDesc
*>
(
*
reinterpret_cast
<
const
AK0MK1GridDesc
*>
((
const
void
*
)
p_a_k0_m_k1_grid_desc
);
cast_pointer_to_generic_address_space
(
p_a_k0_m_k1_grid_desc
)
)
;
const
auto
b_k0_n_k1_grid_desc
=
const
auto
b_k0_n_k1_grid_desc
=
*
reinterpret_cast
<
const
BK0NK1GridDesc
*>
(
*
reinterpret_cast
<
const
BK0NK1GridDesc
*>
((
const
void
*
)
p_b_k0_n_k1_grid_desc
);
cast_pointer_to_generic_address_space
(
p_b_k0_n_k1_grid_desc
)
)
;
const
auto
c_m0_m1_m2_n_grid_desc
=
const
auto
c_m0_m1_m2_n_grid_desc
=
*
reinterpret_cast
<
const
CM0M1M2NGridDesc
*>
(
*
reinterpret_cast
<
const
CM0M1M2NGridDesc
*>
((
const
void
*
)
p_c_m0_m1_m2_n_grid_desc
);
cast_pointer_to_generic_address_space
(
p_c_m0_m1_m2_n_grid_desc
)
)
;
const
auto
c_block_cluster_adaptor
=
const
auto
c_block_cluster_adaptor
=
*
reinterpret_cast
<
const
CBlockClusterAdaptor
*>
(
*
reinterpret_cast
<
const
CBlockClusterAdaptor
*>
((
const
void
*
)
p_c_block_cluster_adaptor
);
cast_pointer_to_generic_address_space
(
p_c_block_cluster_adaptor
)
)
;
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
...
@@ -126,13 +126,13 @@ template <index_t BlockSize,
...
@@ -126,13 +126,13 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGrid
I
te
rator
Hacks
,
typename
AGrid
S
te
p
Hacks
,
typename
BGrid
I
te
rator
Hacks
,
typename
BGrid
S
te
p
Hacks
,
typename
CGrid
I
te
rator
Hacks
,
typename
CGrid
S
te
p
Hacks
,
typename
AGridMoveSliceWindow
I
te
rator
Hacks
,
typename
AGridMoveSliceWindow
S
te
p
Hacks
,
typename
BGridMoveSliceWindow
I
te
rator
Hacks
,
typename
BGridMoveSliceWindow
S
te
p
Hacks
,
bool
CAccessOrderMRepeatNRepeat
>
bool
CAccessOrderMRepeatNRepeat
>
struct
Gridwise
Dynamic
Gemm_k0mk1_k0nk1_mn_xdlops_v2r3
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -148,12 +148,12 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -148,12 +148,12 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_k0_m_k1_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_k0_n_k1_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
b_k0_n_k1_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
...
@@ -203,9 +203,6 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -203,9 +203,6 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeCM0M1M2NGridDescriptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
MakeCM0M1M2NGridDescriptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
{
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerWave
,
NPerWave
,
K1
>
{};
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerWave
,
NPerWave
,
K1
>
{};
constexpr
auto
CLayout
=
xdlops_gemm
.
GetCLayout
();
constexpr
auto
CLayout
=
xdlops_gemm
.
GetCLayout
();
...
@@ -217,10 +214,9 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -217,10 +214,9 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr
index_t
MWaves
=
MPerBlock
/
(
MPerWave
*
MRepeat
);
constexpr
index_t
MWaves
=
MPerBlock
/
(
MPerWave
*
MRepeat
);
constexpr
index_t
NWaves
=
NPerBlock
/
(
NPerWave
*
NRepeat
);
constexpr
index_t
NWaves
=
NPerBlock
/
(
NPerWave
*
NRepeat
);
constexpr
auto
N0
=
Number
<
CLayout
.
N1
()
>
{};
constexpr
auto
N1
=
Number
<
CLayout
.
N0
()
>
{};
constexpr
auto
N1
=
Number
<
CLayout
.
N0
()
>
{};
const
auto
c_m0_m1_m2_n_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
c_m0_m1_m2_n_grid_desc
=
transform_tensor_descriptor
(
c_m_n_grid_desc
,
c_m_n_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
M0
,
M1
,
M2
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
M0
,
M1
,
M2
)),
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
N1
))),
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
N1
))),
...
@@ -269,11 +265,6 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -269,11 +265,6 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const
CM0M1M2NGridDesc
&
c_m0_m1_m2_n_grid_desc
,
const
CM0M1M2NGridDesc
&
c_m0_m1_m2_n_grid_desc
,
const
CBlockClusterAdaptor
&
c_block_cluster_adaptor
)
const
CBlockClusterAdaptor
&
c_block_cluster_adaptor
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_grid
,
a_k0_m_k1_grid_desc
.
GetElementSpaceSize
());
p_a_grid
,
a_k0_m_k1_grid_desc
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
...
@@ -282,8 +273,6 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -282,8 +273,6 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
p_c_grid
,
c_m0_m1_m2_n_grid_desc
.
GetElementSpaceSize
());
p_c_grid
,
c_m0_m1_m2_n_grid_desc
.
GetElementSpaceSize
());
const
auto
K0
=
a_k0_m_k1_grid_desc
.
GetLength
(
I0
);
const
auto
K0
=
a_k0_m_k1_grid_desc
.
GetLength
(
I0
);
const
auto
M
=
a_k0_m_k1_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k0_n_k1_grid_desc
.
GetLength
(
I1
);
// divide block work by [M, N]
// divide block work by [M, N]
const
auto
block_work_idx
=
const
auto
block_work_idx
=
...
@@ -301,67 +290,65 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -301,67 +290,65 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_k0_m_k1_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_k0_n_k1_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
b_k0_n_k1_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
BlockwiseDynamicTensorSliceTransfer_v4
<
BlockSize
,
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
KPerBlock
,
MPerBlock
,
K1
>
,
Sequence
<
KPerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
a_k0_m_k1_block_desc
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
2
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
>
(
a_k0_m_k1_grid_desc
,
a_k0_m_k1_grid_desc
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_k0_m_k1_block_desc
,
a_k0_m_k1_block_desc
,
make_multi_index
(
0
,
0
,
0
));
make_multi_index
(
0
,
0
,
0
));
// B matrix blockwise copy
// B matrix blockwise copy
auto
b_blockwise_copy
=
auto
b_blockwise_copy
=
BlockwiseDynamicTensorSliceTransfer_v4
<
BlockSize
,
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
KPerBlock
,
NPerBlock
,
K1
>
,
Sequence
<
KPerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
b_k0_n_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
2
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
>
(
b_k0_n_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_k0_n_k1_block_desc
,
b_k0_n_k1_block_desc
,
make_multi_index
(
0
,
0
,
0
));
make_multi_index
(
0
,
0
,
0
));
// GEMM definition
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// c_mtx += transpose(a_mtx) * b_mtx
...
@@ -375,7 +362,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -375,7 +362,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
NPerBlock
%
(
NPerWave
*
NRepeat
)
==
0
,
NPerBlock
%
(
NPerWave
*
NRepeat
)
==
0
,
"wrong!"
);
"wrong!"
);
constexpr
auto
a_k0_m0_m1_k1_block_desc
=
transform_
dynamic_
tensor_descriptor
(
constexpr
auto
a_k0_m0_m1_k1_block_desc
=
transform_tensor_descriptor
(
a_k0_m_k1_block_desc
,
a_k0_m_k1_block_desc
,
make_tuple
(
make_pass_through_transform
(
Number
<
KPerBlock
>
{}),
make_tuple
(
make_pass_through_transform
(
Number
<
KPerBlock
>
{}),
make_unmerge_transform
(
make_unmerge_transform
(
...
@@ -384,7 +371,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -384,7 +371,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
constexpr
auto
b_k0_n0_n1_k1_block_desc
=
transform_
dynamic_
tensor_descriptor
(
constexpr
auto
b_k0_n0_n1_k1_block_desc
=
transform_tensor_descriptor
(
b_k0_n_k1_block_desc
,
b_k0_n_k1_block_desc
,
make_tuple
(
make_pass_through_transform
(
Number
<
KPerBlock
>
{}),
make_tuple
(
make_pass_through_transform
(
Number
<
KPerBlock
>
{}),
make_unmerge_transform
(
make_unmerge_transform
(
...
@@ -410,21 +397,19 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -410,21 +397,19 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
static_assert
(
NumBlks
==
1
&&
NumXdlops
==
1
,
"K Reduction Mfma only"
);
static_assert
(
NumBlks
==
1
&&
NumXdlops
==
1
,
"K Reduction Mfma only"
);
constexpr
auto
c_mr_nr_blk_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
constexpr
auto
c_mr_nr_blk_desc
=
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
vector_type
<
FloatAcc
,
BlkSize
>
,
vector_type
<
FloatAcc
,
BlkSize
>
,
c_mr_nr_blk_desc
.
GetElementSpaceSize
()
>
c_mr_nr_blk_desc
.
GetElementSpaceSize
(),
true
>
c_thread_buf
;
c_thread_buf
;
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size
=
math
::
integer_least_multiple
(
b_k0_n_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block
=
p_shared_block
;
FloatAB
*
p_a_block
=
p_shared_block
;
FloatAB
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
FloatAB
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
...
@@ -432,15 +417,13 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -432,15 +417,13 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_k0_m_k1_grid_
i
te
rator
_hacks
=
AGrid
I
te
rator
Hacks
{};
constexpr
auto
a_k0_m_k1_grid_
s
te
p
_hacks
=
AGrid
S
te
p
Hacks
{};
constexpr
auto
b_k0_n_k1_grid_
i
te
rator
_hacks
=
BGrid
I
te
rator
Hacks
{};
constexpr
auto
b_k0_n_k1_grid_
s
te
p
_hacks
=
BGrid
S
te
p
Hacks
{};
// hack to control index calculation when move slice window for A and B matrix for
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
// threadwise copy
constexpr
auto
a_k0_m_k1_grid_move_slice_window_iterator_hack
=
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hack
=
AGridMoveSliceWindowStepHacks
{};
AGridMoveSliceWindowIteratorHacks
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hack
=
BGridMoveSliceWindowStepHacks
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_iterator_hack
=
BGridMoveSliceWindowIteratorHacks
{};
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_a_block
,
a_k0_m_k1_block_desc
.
GetElementSpaceSize
());
p_a_block
,
a_k0_m_k1_block_desc
.
GetElementSpaceSize
());
...
@@ -449,10 +432,8 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -449,10 +432,8 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// preload data into LDS
// preload data into LDS
{
{
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k0_m_k1_grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
a_k0_m_k1_grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k0_n_k1_grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
b_blockwise_copy
.
RunRead
(
b_k0_n_k1_grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_iterator_hacks
);
a_blockwise_copy
.
RunWrite
(
a_k0_m_k1_block_desc
,
a_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_k0_m_k1_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_k0_n_k1_block_desc
,
b_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_k0_n_k1_block_desc
,
b_block_buf
);
...
@@ -465,18 +446,16 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -465,18 +446,16 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m_k1_grid_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m_k1_grid_desc
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
a_k0_m_k1_grid_move_slice_window_
i
te
rator
_hack
);
a_k0_m_k1_grid_move_slice_window_
s
te
p
_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n_k1_grid_desc
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n_k1_grid_desc
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
b_k0_n_k1_grid_move_slice_window_
i
te
rator
_hack
);
b_k0_n_k1_grid_move_slice_window_
s
te
p
_hack
);
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k0_m_k1_grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
a_k0_m_k1_grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_iterator_hacks
);
block_sync_lds
();
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_k0_n_k1_grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
b_k0_n_k1_grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_iterator_hacks
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
...
@@ -506,7 +485,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -506,7 +485,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr index_t N1 = CLayout.N0();
constexpr index_t N1 = CLayout.N0();
constexpr auto c_m0_m1_m2_n_thread_desc =
constexpr auto c_m0_m1_m2_n_thread_desc =
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(make_tuple(Number<MRepeat>{},
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<NRepeat>{},
Number<1>{},
Number<1>{},
Number<1>{},
Number<1>{},
...
@@ -515,7 +494,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -515,7 +494,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
Number<M2>{},
Number<M2>{},
Number<1>{}));
Number<1>{}));
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize()>
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize()
, true
>
c_blk_buf_;
c_blk_buf_;
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
...
@@ -542,12 +521,12 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -542,12 +521,12 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const index_t n_thread_data_on_grid =
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_grid_tensor_
i
te
rator
_hacks = CGrid
I
te
rator
Hacks{};
constexpr auto c_m0_m1_m2_n_grid_tensor_
s
te
p
_hacks = CGrid
S
te
p
Hacks{};
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
Threadwise
Dynamic
TensorSliceTransfer_v1r3<
ThreadwiseTensorSliceTransfer_v1r3<
FloatC,
FloatC,
FloatC,
FloatC,
decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_thread_desc),
...
@@ -573,7 +552,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -573,7 +552,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_blk_buf_,
c_blk_buf_,
c_m0_m1_m2_n_grid_desc,
c_m0_m1_m2_n_grid_desc,
c_grid_buf,
c_grid_buf,
c_m0_m1_m2_n_grid_tensor_
i
te
rator
_hacks);
c_m0_m1_m2_n_grid_tensor_
s
te
p
_hacks);
}
}
#else
#else
{
{
...
@@ -581,11 +560,8 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -581,11 +560,8 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr
index_t
M1
=
CLayout
.
N1
();
constexpr
index_t
M1
=
CLayout
.
N1
();
constexpr
index_t
M2
=
CLayout
.
M0
();
constexpr
index_t
M2
=
CLayout
.
M0
();
constexpr
auto
c_m0_m1_m2_n_thread_desc
=
constexpr
auto
c_m0_m1_m2_n_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
make_tuple
(
I1
,
I1
,
I1
,
I1
,
Number
<
M0
>
{},
Number
<
1
>
{},
Number
<
M2
>
{},
Number
<
1
>
{}));
I1
,
I1
,
I1
,
I1
,
Number
<
M0
>
{},
Number
<
1
>
{},
Number
<
M2
>
{},
Number
<
1
>
{}));
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatC
,
BlkSize
>
c_blk_buf_
;
// calculate origin of thread output tensor on global memory
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
// blockwise GEMM c matrix starting index
...
@@ -598,20 +574,20 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -598,20 +574,20 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const
index_t
n_thread_data_on_grid
=
const
index_t
n_thread_data_on_grid
=
n_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I1
];
n_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I1
];
constexpr
auto
c_m0_m1_m2_n_grid_tensor_
i
te
rator
_hacks
=
CGrid
I
te
rator
Hacks
{};
constexpr
auto
c_m0_m1_m2_n_grid_tensor_
s
te
p
_hacks
=
CGrid
S
te
p
Hacks
{};
auto
c_thread_copy
=
auto
c_thread_copy
=
Threadwise
Dynamic
TensorSliceTransfer_v1r3
<
FloatC
,
ThreadwiseTensorSliceTransfer_v1r3
<
FloatC
,
FloatC
,
FloatC
,
decltype
(
c_m0_m1_m2_n_thread_desc
),
decltype
(
c_m0_m1_m2_n_thread_desc
),
decltype
(
c_m0_m1_m2_n_grid_desc
),
decltype
(
c_m0_m1_m2_n_grid_desc
),
Sequence
<
1
,
1
,
1
,
1
,
M0
,
1
,
M2
,
1
>
,
Sequence
<
1
,
1
,
1
,
1
,
M0
,
1
,
M2
,
1
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
CGlobalMemoryDataOperation
,
1
,
1
,
true
>
{
true
>
{
c_m0_m1_m2_n_grid_desc
,
c_m0_m1_m2_n_grid_desc
,
make_multi_index
(
0
,
make_multi_index
(
0
,
0
,
0
,
...
@@ -629,7 +605,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -629,7 +605,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_m0_m1_m2_n_grid_desc
,
c_m0_m1_m2_n_grid_desc
,
c_grid_buf
,
c_grid_buf
,
c_m0_m1_m2_n_grid_tensor_
i
te
rator
_hacks
);
c_m0_m1_m2_n_grid_tensor_
s
te
p
_hacks
);
return
c_thread_idx_
;
return
c_thread_idx_
;
};
};
...
@@ -644,7 +620,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -644,7 +620,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_m0_m1_m2_n_grid_desc
,
c_m0_m1_m2_n_grid_desc
,
c_grid_buf
,
c_grid_buf
,
c_m0_m1_m2_n_grid_tensor_
i
te
rator
_hacks
);
c_m0_m1_m2_n_grid_tensor_
s
te
p
_hacks
);
};
};
auto
nrepeat_plus_copy
=
[
&
](
auto
c_thread_idx_
)
{
auto
nrepeat_plus_copy
=
[
&
](
auto
c_thread_idx_
)
{
...
@@ -657,7 +633,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -657,7 +633,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_m0_m1_m2_n_grid_desc
,
c_m0_m1_m2_n_grid_desc
,
c_grid_buf
,
c_grid_buf
,
c_m0_m1_m2_n_grid_tensor_
i
te
rator
_hacks
);
c_m0_m1_m2_n_grid_tensor_
s
te
p
_hacks
);
};
};
auto
mrepeat_minus_copy
=
[
&
](
auto
c_thread_idx_
)
{
auto
mrepeat_minus_copy
=
[
&
](
auto
c_thread_idx_
)
{
...
@@ -670,7 +646,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -670,7 +646,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_m0_m1_m2_n_grid_desc
,
c_m0_m1_m2_n_grid_desc
,
c_grid_buf
,
c_grid_buf
,
c_m0_m1_m2_n_grid_tensor_
i
te
rator
_hacks
);
c_m0_m1_m2_n_grid_tensor_
s
te
p
_hacks
);
};
};
auto
nrepeat_minus_copy
=
[
&
](
auto
c_thread_idx_
)
{
auto
nrepeat_minus_copy
=
[
&
](
auto
c_thread_idx_
)
{
...
@@ -683,7 +659,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -683,7 +659,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_m0_m1_m2_n_grid_desc
,
c_m0_m1_m2_n_grid_desc
,
c_grid_buf
,
c_grid_buf
,
c_m0_m1_m2_n_grid_tensor_
i
te
rator
_hacks
);
c_m0_m1_m2_n_grid_tensor_
s
te
p
_hacks
);
};
};
static_assert
((
MRepeat
==
4
&&
NRepeat
==
4
)
or
(
MRepeat
==
4
&&
NRepeat
==
2
)
or
static_assert
((
MRepeat
==
4
&&
NRepeat
==
4
)
or
(
MRepeat
==
4
&&
NRepeat
==
2
)
or
...
...
composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp
View file @
ccc4a1d3
...
@@ -21,10 +21,10 @@ template <typename FloatA,
...
@@ -21,10 +21,10 @@ template <typename FloatA,
typename
TKLengths
,
typename
TKLengths
,
typename
TMLengths
,
typename
TMLengths
,
typename
TNLengths
,
typename
TNLengths
,
typename
std
::
enable_if
<
AThreadDesc_TK0_TM0_TM1_TK1
::
IsKnownAtCompileTime
()
&&
typename
enable_if
<
AThreadDesc_TK0_TM0_TM1_TK1
::
IsKnownAtCompileTime
()
&&
BThreadDesc_TK0_TN0_TN1_TK1
::
IsKnownAtCompileTime
()
&&
BThreadDesc_TK0_TN0_TN1_TK1
::
IsKnownAtCompileTime
()
&&
CThreadDesc_TM0_TM1_TN0_TN1
::
IsKnownAtCompileTime
(),
CThreadDesc_TM0_TM1_TN0_TN1
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
bool
>
::
type
=
false
>
struct
ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
struct
ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
{
{
__device__
constexpr
ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
()
__device__
constexpr
ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
()
...
@@ -97,10 +97,9 @@ struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
...
@@ -97,10 +97,9 @@ struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
CThreadDesc_TM0_TM1_TN0_TN1
{}.
CalculateOffset
(
CThreadDesc_TM0_TM1_TN0_TN1
{}.
CalculateOffset
(
c_origin_idx
+
make_multi_index
(
tm0
,
tm1
,
tn0
,
tn1
));
c_origin_idx
+
make_multi_index
(
tm0
,
tm1
,
tn0
,
tn1
));
amd_inner_product_dlop
<
FloatA
,
FloatB
,
FloatC
>
(
inner_product
<
FloatA
,
FloatB
,
FloatC
>
(
a_buf
[
Number
<
a_offset
>
{}],
a_buf
[
Number
<
a_offset
>
{}],
b_buf
[
Number
<
b_offset
>
{}],
b_buf
[
Number
<
b_offset
>
{}],
c_buf
(
Number
<
c_offset
>
{}));
c_buf
(
Number
<
c_offset
>
{}));
});
});
});
});
});
});
...
@@ -124,10 +123,10 @@ template <typename FloatA,
...
@@ -124,10 +123,10 @@ template <typename FloatA,
typename
TKLengths
,
typename
TKLengths
,
typename
TMLengths
,
typename
TMLengths
,
typename
TNLengths
,
typename
TNLengths
,
typename
std
::
enable_if
<
AThreadDesc_TK0_TM0_TM1_TK1
::
IsKnownAtCompileTime
()
&&
typename
enable_if
<
AThreadDesc_TK0_TM0_TM1_TK1
::
IsKnownAtCompileTime
()
&&
BThreadDesc_TK0_TN0_TN1_TK1
::
IsKnownAtCompileTime
()
&&
BThreadDesc_TK0_TN0_TN1_TK1
::
IsKnownAtCompileTime
()
&&
CThreadDesc_TM0_TM1_TN0_TN1
::
IsKnownAtCompileTime
(),
CThreadDesc_TM0_TM1_TN0_TN1
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
bool
>
::
type
=
false
>
struct
ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
struct
ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
{
{
__device__
constexpr
ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
()
__device__
constexpr
ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
()
...
@@ -214,7 +213,7 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_
...
@@ -214,7 +213,7 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_
CThreadDesc_TM0_TM1_TN0_TN1
{}.
CalculateOffset
(
CThreadDesc_TM0_TM1_TN0_TN1
{}.
CalculateOffset
(
c_origin_idx
+
make_multi_index
(
tm0
,
tm1
,
tn0
,
tn1
));
c_origin_idx
+
make_multi_index
(
tm0
,
tm1
,
tn0
,
tn1
));
amd_
inner_product
_dlop
<
a_vector_t
,
b_vector_t
,
FloatC
>
(
inner_product
<
a_vector_t
,
b_vector_t
,
FloatC
>
(
a_vec
.
template
AsType
<
a_vector_t
>()[
I0
],
a_vec
.
template
AsType
<
a_vector_t
>()[
I0
],
b_vec
.
template
AsType
<
b_vector_t
>()[
I0
],
b_vec
.
template
AsType
<
b_vector_t
>()[
I0
],
c_buf
(
Number
<
c_offset
>
{}));
c_buf
(
Number
<
c_offset
>
{}));
...
...
composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
View file @
ccc4a1d3
...
@@ -19,9 +19,9 @@ template <typename FloatA,
...
@@ -19,9 +19,9 @@ template <typename FloatA,
typename
CDesc
,
typename
CDesc
,
index_t
H
,
index_t
H
,
index_t
W
,
index_t
W
,
typename
std
::
enable_if
<
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
typename
enable_if
<
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
CDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
bool
>
::
type
=
false
>
struct
ThreadwiseGemmDlops_km_kn_mn_v3
struct
ThreadwiseGemmDlops_km_kn_mn_v3
{
{
template
<
typename
ABuffer
,
template
<
typename
ABuffer
,
...
@@ -57,8 +57,6 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
...
@@ -57,8 +57,6 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
E
=
ADesc
{}.
GetLength
(
I0
);
constexpr
auto
E
=
ADesc
{}.
GetLength
(
I0
);
constexpr
auto
K
=
ADesc
{}.
GetLength
(
I1
);
constexpr
auto
K
=
ADesc
{}.
GetLength
(
I1
);
...
...
composable_kernel/include/tensor_operation/threadwise_
dynamic_
tensor_slice_set.hpp
→
composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp
View file @
ccc4a1d3
#ifndef CK_THREADWISE_
DYNAMIC_
TENSOR_SET_HPP
#ifndef CK_THREADWISE_TENSOR_SET_HPP
#define CK_THREADWISE_
DYNAMIC_
TENSOR_SET_HPP
#define CK_THREADWISE_TENSOR_SET_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -11,12 +11,12 @@ namespace ck {
...
@@ -11,12 +11,12 @@ namespace ck {
// 1. Desc is known at compile-time
// 1. Desc is known at compile-time
// 2. Buffer is StaticBuffer
// 2. Buffer is StaticBuffer
// 3. OriginIdx is known at compile-time
// 3. OriginIdx is known at compile-time
// 4. use #-
i
te
rator
// 4. use #-
s
te
p
template
<
typename
Data
,
template
<
typename
Data
,
typename
Desc
,
typename
Desc
,
typename
SliceLengths
,
typename
SliceLengths
,
typename
std
::
enable_if
<
Desc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
typename
enable_if
<
Desc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
Threadwise
Dynamic
TensorSliceSet_v1
struct
ThreadwiseTensorSliceSet_v1
{
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
...
@@ -40,7 +40,7 @@ struct ThreadwiseDynamicTensorSliceSet_v1
...
@@ -40,7 +40,7 @@ struct ThreadwiseDynamicTensorSliceSet_v1
constexpr
auto
origin_idx
=
to_multi_index
(
OriginIdx
{});
constexpr
auto
origin_idx
=
to_multi_index
(
OriginIdx
{});
static_ford
<
SliceLengths
>
{}([
&
](
auto
access_idx
)
{
static_ford
<
SliceLengths
>
{}([
&
](
auto
access_idx
)
{
constexpr
auto
coord
=
make_
dynamic_
tensor_coordinate
(
desc
,
origin_idx
+
access_idx
);
constexpr
auto
coord
=
make_tensor_coordinate
(
desc
,
origin_idx
+
access_idx
);
constexpr
bool
is_valid
=
constexpr
bool
is_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
desc
,
coord
);
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
desc
,
coord
);
...
...
composable_kernel/include/tensor_operation/threadwise_
dynamic_
tensor_slice_transfer.hpp
→
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
View file @
ccc4a1d3
#ifndef CK_THREADWISE_
DYNAMIC_
TENSOR_SLICE_TRANSFER_HPP
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_HPP
#define CK_THREADWISE_
DYNAMIC_
TENSOR_SLICE_TRANSFER_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -57,20 +57,20 @@ template <typename SrcData,
...
@@ -57,20 +57,20 @@ template <typename SrcData,
InMemoryDataOperationEnum_t
DstInMemOp
,
InMemoryDataOperationEnum_t
DstInMemOp
,
index_t
DstScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
bool
DstResetCoordinateAfterRun
,
bool
DstResetCoordinateAfterRun
,
typename
std
::
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
typename
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
Threadwise
Dynamic
TensorSliceTransfer_v1r3
struct
ThreadwiseTensorSliceTransfer_v1r3
{
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
using
DstCoord
=
decltype
(
make_
dynamic_
tensor_coordinate
(
DstDesc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_tensor_coordinate
(
DstDesc
{},
Index
{}));
using
DstCoord
I
te
rator
=
decltype
(
make_
dynamic_
tensor_coordinate_
i
te
rator
(
DstDesc
{},
Index
{}));
using
DstCoord
S
te
p
=
decltype
(
make_tensor_coordinate_
s
te
p
(
DstDesc
{},
Index
{}));
__device__
constexpr
Threadwise
Dynamic
TensorSliceTransfer_v1r3
(
__device__
constexpr
ThreadwiseTensorSliceTransfer_v1r3
(
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
const
Index
&
dst_slice_origin_idx
)
:
dst_coord_
(
make_
dynamic_
tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
))
:
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
))
{
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc need to known at compile-time"
);
"wrong! SrcDesc need to known at compile-time"
);
...
@@ -78,19 +78,19 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -78,19 +78,19 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
{
{
dst_coord_
=
make_
dynamic_
tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
dst_coord_
=
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
}
}
template
<
typename
SrcSliceOriginIdx
,
template
<
typename
SrcSliceOriginIdx
,
typename
SrcBuffer
,
typename
SrcBuffer
,
typename
DstBuffer
,
typename
DstBuffer
,
typename
Dst
I
te
rator
Hacks
>
typename
Dst
S
te
p
Hacks
>
__device__
void
Run
(
const
SrcDesc
&
,
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcSliceOriginIdx
&
,
const
SrcSliceOriginIdx
&
,
const
SrcBuffer
&
src_buf
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
DstBuffer
&
dst_buf
,
const
Dst
I
te
rator
Hacks
&
dst_
i
te
rator
_hacks
)
const
Dst
S
te
p
Hacks
&
dst_
s
te
p
_hacks
)
{
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc need to known at compile-time"
);
"wrong! SrcDesc need to known at compile-time"
);
...
@@ -127,31 +127,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -127,31 +127,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
constexpr
auto
ordered_access_lengths
=
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
// make forward
i
te
rator
s
// make forward
s
te
p
s
const
auto
dst_forward_
i
te
rator
s
=
generate_tuple
(
const
auto
dst_forward_
s
te
p
s
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
Index
forward_step
;
Index
forward_step
_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step
(
j
)
=
(
i
.
value
==
j
.
value
)
?
dst_scalar_per_access
[
i
]
:
0
;
forward_step
_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
dst_scalar_per_access
[
i
]
:
0
;
});
});
return
make_
dynamic_
tensor_coordinate_
i
te
rator
(
return
make_tensor_coordinate_
s
te
p
(
dst_desc
,
forward_step
,
dst_
i
te
rator
_hacks
[
I0
][
i
]);
dst_desc
,
forward_step
_idx
,
dst_
s
te
p
_hacks
[
I0
][
i
]);
},
},
Number
<
nDim
>
{});
Number
<
nDim
>
{});
// make backward
i
te
rator
s
// make backward
s
te
p
s
const
auto
dst_backward_
i
te
rator
s
=
generate_tuple
(
const
auto
dst_backward_
s
te
p
s
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
Index
backward_step
;
Index
backward_step
_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
dst_scalar_per_access
[
i
]
:
0
;
backward_step
_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
dst_scalar_per_access
[
i
]
:
0
;
});
});
return
make_
dynamic_
tensor_coordinate_
i
te
rator
(
return
make_tensor_coordinate_
s
te
p
(
dst_desc
,
backward_step
,
dst_
i
te
rator
_hacks
[
I1
][
i
]);
dst_desc
,
backward_step
_idx
,
dst_
s
te
p
_hacks
[
I1
][
i
]);
},
},
Number
<
nDim
>
{});
Number
<
nDim
>
{});
...
@@ -235,13 +235,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -235,13 +235,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
{
{
if
constexpr
(
forward_sweep
[
i
])
if
constexpr
(
forward_sweep
[
i
])
{
{
move_
dynamic_
tensor_coordinate
(
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_forward_
i
te
rator
s
[
dim_access_order
[
i
]]);
dst_desc
,
dst_coord_
,
dst_forward_
s
te
p
s
[
dim_access_order
[
i
]]);
}
}
else
else
{
{
move_
dynamic_
tensor_coordinate
(
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_backward_
i
te
rator
s
[
dim_access_order
[
i
]]);
dst_desc
,
dst_coord_
,
dst_backward_
s
te
p
s
[
dim_access_order
[
i
]]);
}
}
}
}
});
});
...
@@ -250,10 +250,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -250,10 +250,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
// move dst coordinate back to slice origin (or not)
// move dst coordinate back to slice origin (or not)
if
constexpr
(
DstResetCoordinateAfterRun
)
if
constexpr
(
DstResetCoordinateAfterRun
)
{
{
const
auto
dst_reset_
i
te
rator
=
const
auto
dst_reset_
s
te
p
=
make_
dynamic_
tensor_coordinate_
i
te
rator
(
dst_desc
,
GetDstCoordinateResetStep
());
make_tensor_coordinate_
s
te
p
(
dst_desc
,
GetDstCoordinateResetStep
());
move_
dynamic_
tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_reset_
i
te
rator
);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_reset_
s
te
p
);
}
}
}
}
...
@@ -268,11 +268,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -268,11 +268,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_dst
,
0
>::
type
{};
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_dst
,
0
>::
type
{};
constexpr
auto
dst_
i
te
rator
_hacks
=
constexpr
auto
dst_
s
te
p
_hacks
=
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
Run
(
SrcDesc
{},
SrcSliceOriginIdx
{},
src_buf
,
dst_desc
,
dst_buf
,
dst_
i
te
rator
_hacks
);
Run
(
SrcDesc
{},
SrcSliceOriginIdx
{},
src_buf
,
dst_desc
,
dst_buf
,
dst_
s
te
p
_hacks
);
}
}
__device__
static
constexpr
auto
GetDstCoordinateResetStep
()
__device__
static
constexpr
auto
GetDstCoordinateResetStep
()
...
@@ -345,10 +345,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -345,10 +345,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
:
dst_slice_origin_step_idx
+
GetDstCoordinateResetStep
();
:
dst_slice_origin_step_idx
+
GetDstCoordinateResetStep
();
// is it OK to construct a new step every time?
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
dst_desc
,
adjusted_step_idx
);
make_dynamic_tensor_coordinate_iterator
(
dst_desc
,
adjusted_step_idx
);
move_
dynamic_
tensor_coordinate
(
dst_desc
,
dst_coord_
,
adjusted_step
);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
adjusted_step
);
}
}
private:
private:
...
@@ -374,20 +373,20 @@ template <typename SrcData,
...
@@ -374,20 +373,20 @@ template <typename SrcData,
index_t
SrcScalarPerVector
,
index_t
SrcScalarPerVector
,
index_t
SrcScalarStrideInVector
,
index_t
SrcScalarStrideInVector
,
bool
SrcResetCoordinateAfterRun
,
bool
SrcResetCoordinateAfterRun
,
typename
std
::
enable_if
<
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
typename
enable_if
<
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
Threadwise
Dynamic
TensorSliceTransfer_v2
struct
ThreadwiseTensorSliceTransfer_v2
{
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
using
SrcCoord
=
decltype
(
make_
dynamic_
tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
SrcCoord
=
decltype
(
make_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
SrcCoord
I
te
rator
=
decltype
(
make_
dynamic_
tensor_coordinate_
i
te
rator
(
SrcDesc
{},
Index
{}));
using
SrcCoord
S
te
p
=
decltype
(
make_tensor_coordinate_
s
te
p
(
SrcDesc
{},
Index
{}));
__device__
constexpr
Threadwise
Dynamic
TensorSliceTransfer_v2
(
const
SrcDesc
&
src_desc
,
__device__
constexpr
ThreadwiseTensorSliceTransfer_v2
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
const
Index
&
src_slice_origin_idx
)
:
src_coord_
(
make_
dynamic_
tensor_coordinate
(
src_desc
,
src_slice_origin_idx
))
:
src_coord_
(
make_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
))
{
{
static_assert
(
DstDesc
::
IsKnownAtCompileTime
(),
static_assert
(
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc need to known at compile-time"
);
"wrong! SrcDesc need to known at compile-time"
);
...
@@ -395,19 +394,19 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
...
@@ -395,19 +394,19 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
__device__
void
SetDstSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
__device__
void
SetDstSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
{
{
src_coord_
=
make_
dynamic_
tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
src_coord_
=
make_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
}
}
template
<
typename
SrcBuffer
,
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
typename
DstBuffer
,
typename
DstSliceOriginIdx
,
typename
DstSliceOriginIdx
,
typename
Src
I
te
rator
Hacks
>
typename
Src
S
te
p
Hacks
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
,
const
DstDesc
&
,
const
DstSliceOriginIdx
&
,
const
DstSliceOriginIdx
&
,
DstBuffer
&
dst_buf
,
DstBuffer
&
dst_buf
,
const
Src
I
te
rator
Hacks
&
src_
i
te
rator
_hacks
)
const
Src
S
te
p
Hacks
&
src_
s
te
p
_hacks
)
{
{
static_assert
(
DstDesc
::
IsKnownAtCompileTime
(),
static_assert
(
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! DstDesc need to known at compile-time"
);
"wrong! DstDesc need to known at compile-time"
);
...
@@ -442,31 +441,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
...
@@ -442,31 +441,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
constexpr
auto
ordered_access_lengths
=
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
// make forward
i
te
rator
s
// make forward
s
te
p
s
const
auto
src_forward_
i
te
rator
s
=
generate_tuple
(
const
auto
src_forward_
s
te
p
s
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
Index
forward_step
;
Index
forward_step
_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step
(
j
)
=
(
i
.
value
==
j
.
value
)
?
src_scalar_per_access
[
i
]
:
0
;
forward_step
_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
src_scalar_per_access
[
i
]
:
0
;
});
});
return
make_
dynamic_
tensor_coordinate_
i
te
rator
(
return
make_tensor_coordinate_
s
te
p
(
src_desc
,
forward_step
,
src_
i
te
rator
_hacks
[
I0
][
i
]);
src_desc
,
forward_step
_idx
,
src_
s
te
p
_hacks
[
I0
][
i
]);
},
},
Number
<
nDim
>
{});
Number
<
nDim
>
{});
// make backward
i
te
rator
s
// make backward
s
te
p
s
const
auto
src_backward_
i
te
rator
s
=
generate_tuple
(
const
auto
src_backward_
s
te
p
s
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
Index
backward_step
;
Index
backward_step
_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
src_scalar_per_access
[
i
]
:
0
;
backward_step
_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
src_scalar_per_access
[
i
]
:
0
;
});
});
return
make_
dynamic_
tensor_coordinate_
i
te
rator
(
return
make_tensor_coordinate_
s
te
p
(
src_desc
,
backward_step
,
src_
i
te
rator
_hacks
[
I1
][
i
]);
src_desc
,
backward_step
_idx
,
src_
s
te
p
_hacks
[
I1
][
i
]);
},
},
Number
<
nDim
>
{});
Number
<
nDim
>
{});
...
@@ -548,13 +547,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
...
@@ -548,13 +547,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
{
{
if
constexpr
(
forward_sweep
[
i
])
if
constexpr
(
forward_sweep
[
i
])
{
{
move_
dynamic_
tensor_coordinate
(
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_forward_
i
te
rator
s
[
dim_access_order
[
i
]]);
src_desc
,
src_coord_
,
src_forward_
s
te
p
s
[
dim_access_order
[
i
]]);
}
}
else
else
{
{
move_
dynamic_
tensor_coordinate
(
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_backward_
i
te
rator
s
[
dim_access_order
[
i
]]);
src_desc
,
src_coord_
,
src_backward_
s
te
p
s
[
dim_access_order
[
i
]]);
}
}
}
}
});
});
...
@@ -563,10 +562,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
...
@@ -563,10 +562,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
// move src coordinate back to slice origin (or not)
// move src coordinate back to slice origin (or not)
if
constexpr
(
SrcResetCoordinateAfterRun
)
if
constexpr
(
SrcResetCoordinateAfterRun
)
{
{
const
auto
src_reset_
i
te
rator
=
const
auto
src_reset_
s
te
p
=
make_
dynamic_
tensor_coordinate_
i
te
rator
(
src_desc
,
GetSrcCoordinateResetStep
());
make_tensor_coordinate_
s
te
p
(
src_desc
,
GetSrcCoordinateResetStep
());
move_
dynamic_
tensor_coordinate
(
src_desc
,
src_coord_
,
src_reset_
i
te
rator
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_reset_
s
te
p
);
}
}
}
}
...
@@ -581,11 +580,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
...
@@ -581,11 +580,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_src
,
0
>::
type
{};
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_src
,
0
>::
type
{};
constexpr
auto
src_
i
te
rator
_hacks
=
constexpr
auto
src_
s
te
p
_hacks
=
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
Run
(
src_desc
,
src_buf
,
DstDesc
{},
DstSliceOriginIdx
{},
dst_buf
,
src_
i
te
rator
_hacks
);
Run
(
src_desc
,
src_buf
,
DstDesc
{},
DstSliceOriginIdx
{},
dst_buf
,
src_
s
te
p
_hacks
);
}
}
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
...
@@ -658,10 +657,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
...
@@ -658,10 +657,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
:
src_slice_origin_step_idx
+
GetSrcCoordinateResetStep
();
:
src_slice_origin_step_idx
+
GetSrcCoordinateResetStep
();
// is it OK to construct a new step every time?
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
src_desc
,
adjusted_step_idx
);
make_dynamic_tensor_coordinate_iterator
(
src_desc
,
adjusted_step_idx
);
move_
dynamic_
tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
}
}
private:
private:
...
@@ -693,23 +691,23 @@ template <typename SliceLengths,
...
@@ -693,23 +691,23 @@ template <typename SliceLengths,
bool
DstResetCoordinateAfterRun
>
// control whether to move back dst coordinate after each
bool
DstResetCoordinateAfterRun
>
// control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to
// RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation
// save addr computation
struct
Threadwise
Dynamic
TensorSliceTransfer_v3
struct
ThreadwiseTensorSliceTransfer_v3
{
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
using
SrcCoord
=
decltype
(
make_
dynamic_
tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
SrcCoord
=
decltype
(
make_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_
dynamic_
tensor_coordinate
(
DstDesc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_tensor_coordinate
(
DstDesc
{},
Index
{}));
using
SrcCoord
I
te
rator
=
decltype
(
make_
dynamic_
tensor_coordinate_
i
te
rator
(
SrcDesc
{},
Index
{}));
using
SrcCoord
S
te
p
=
decltype
(
make_tensor_coordinate_
s
te
p
(
SrcDesc
{},
Index
{}));
using
DstCoord
I
te
rator
=
decltype
(
make_
dynamic_
tensor_coordinate_
i
te
rator
(
DstDesc
{},
Index
{}));
using
DstCoord
S
te
p
=
decltype
(
make_tensor_coordinate_
s
te
p
(
DstDesc
{},
Index
{}));
__device__
constexpr
Threadwise
Dynamic
TensorSliceTransfer_v3
(
const
SrcDesc
&
src_desc
,
__device__
constexpr
ThreadwiseTensorSliceTransfer_v3
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin
,
const
Index
&
src_slice_origin
,
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin
)
const
Index
&
dst_slice_origin
)
:
src_coord_
(
make_
dynamic_
tensor_coordinate
(
src_desc
,
src_slice_origin
)),
:
src_coord_
(
make_tensor_coordinate
(
src_desc
,
src_slice_origin
)),
dst_coord_
(
make_
dynamic_
tensor_coordinate
(
dst_desc
,
dst_slice_origin
))
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin
))
{
{
// TODO: fix this
// TODO: fix this
static_assert
(
is_same
<
SrcData
,
DstData
>::
value
,
static_assert
(
is_same
<
SrcData
,
DstData
>::
value
,
...
@@ -718,18 +716,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -718,18 +716,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
{
{
src_coord_
=
make_
dynamic_
tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
src_coord_
=
make_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
}
}
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
{
{
dst_coord_
=
make_
dynamic_
tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
dst_coord_
=
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
}
}
template
<
typename
SrcBuffer
,
typename
SrcIteratorHacks
>
template
<
typename
SrcBuffer
,
typename
SrcStepHacks
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
__device__
void
const
SrcBuffer
&
src_buf
,
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
SrcStepHacks
&
src_step_hacks
)
const
SrcIteratorHacks
&
src_iterator_hacks
)
{
{
static_assert
(
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Global
or
static_assert
(
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Global
or
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
...
@@ -757,31 +754,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -757,31 +754,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
constexpr
auto
ordered_src_access_lengths
=
constexpr
auto
ordered_src_access_lengths
=
container_reorder_given_new2old
(
src_access_lengths
,
src_dim_access_order
);
container_reorder_given_new2old
(
src_access_lengths
,
src_dim_access_order
);
// make forward
i
te
rator
s
// make forward
s
te
p
s
const
auto
src_forward_
i
te
rator
s
=
generate_tuple
(
const
auto
src_forward_
s
te
p
s
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
Index
forward_step
;
Index
forward_step
_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step
(
j
)
=
(
i
.
value
==
j
.
value
)
?
src_scalar_per_access
[
i
]
:
0
;
forward_step
_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
src_scalar_per_access
[
i
]
:
0
;
});
});
return
make_
dynamic_
tensor_coordinate_
i
te
rator
(
return
make_tensor_coordinate_
s
te
p
(
src_desc
,
forward_step
,
src_
i
te
rator
_hacks
[
I0
][
i
]);
src_desc
,
forward_step
_idx
,
src_
s
te
p
_hacks
[
I0
][
i
]);
},
},
Number
<
nDim
>
{});
Number
<
nDim
>
{});
// make backward
i
te
rator
s
// make backward
s
te
p
s
const
auto
src_backward_
i
te
rator
s
=
generate_tuple
(
const
auto
src_backward_
s
te
p
s
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
Index
backward_step
;
Index
backward_step
_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
src_scalar_per_access
[
i
]
:
0
;
backward_step
_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
src_scalar_per_access
[
i
]
:
0
;
});
});
return
make_
dynamic_
tensor_coordinate_
i
te
rator
(
return
make_tensor_coordinate_
s
te
p
(
src_desc
,
backward_step
,
src_
i
te
rator
_hacks
[
I1
][
i
]);
src_desc
,
backward_step
_idx
,
src_
s
te
p
_hacks
[
I1
][
i
]);
},
},
Number
<
nDim
>
{});
Number
<
nDim
>
{});
...
@@ -862,13 +859,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -862,13 +859,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
{
{
if
constexpr
(
forward_sweep
[
i
])
if
constexpr
(
forward_sweep
[
i
])
{
{
move_
dynamic_
tensor_coordinate
(
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_forward_
i
te
rator
s
[
src_dim_access_order
[
i
]]);
src_desc
,
src_coord_
,
src_forward_
s
te
p
s
[
src_dim_access_order
[
i
]]);
}
}
else
else
{
{
move_
dynamic_
tensor_coordinate
(
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_backward_
i
te
rator
s
[
src_dim_access_order
[
i
]]);
src_desc
,
src_coord_
,
src_backward_
s
te
p
s
[
src_dim_access_order
[
i
]]);
}
}
}
}
});
});
...
@@ -877,17 +874,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -877,17 +874,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// move src coordinate back to slice origin (or not)
// move src coordinate back to slice origin (or not)
if
constexpr
(
SrcResetCoordinateAfterRun
)
if
constexpr
(
SrcResetCoordinateAfterRun
)
{
{
const
auto
src_reset_
i
te
rator
=
const
auto
src_reset_
s
te
p
=
make_
dynamic_
tensor_coordinate_
i
te
rator
(
src_desc
,
GetSrcCoordinateResetStep
());
make_tensor_coordinate_
s
te
p
(
src_desc
,
GetSrcCoordinateResetStep
());
move_
dynamic_
tensor_coordinate
(
src_desc
,
src_coord_
,
src_reset_
i
te
rator
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_reset_
s
te
p
);
}
}
}
}
template
<
typename
DstBuffer
,
typename
DstIteratorHacks
>
template
<
typename
DstBuffer
,
typename
DstStepHacks
>
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
__device__
void
DstBuffer
&
dst_buf
,
RunWrite
(
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
const
DstStepHacks
&
dst_step_hacks
)
const
DstIteratorHacks
&
dst_iterator_hacks
)
{
{
static_assert
(
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Global
or
static_assert
(
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Global
or
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
...
@@ -915,35 +911,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -915,35 +911,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
constexpr
auto
ordered_dst_access_lengths
=
constexpr
auto
ordered_dst_access_lengths
=
container_reorder_given_new2old
(
dst_access_lengths
,
dst_dim_access_order
);
container_reorder_given_new2old
(
dst_access_lengths
,
dst_dim_access_order
);
// make forward
i
te
rator
s
// make forward
s
te
p
s
const
auto
dst_forward_
i
te
rator
s
=
generate_tuple
(
const
auto
dst_forward_
s
te
p
s
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
Index
forward_step
;
Index
forward_step
_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step
(
j
)
=
(
i
.
value
==
j
.
value
)
?
dst_scalar_per_access
[
i
]
:
0
;
forward_step
_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
dst_scalar_per_access
[
i
]
:
0
;
});
});
const
auto
forward_iterator
=
make_dynamic_tensor_coordinate_iterator
(
return
make_tensor_coordinate_step
(
dst_desc
,
forward_step
,
dst_iterator_hacks
[
I0
][
i
]);
dst_desc
,
forward_step_idx
,
dst_step_hacks
[
I0
][
i
]);
return
forward_iterator
;
},
},
Number
<
nDim
>
{});
Number
<
nDim
>
{});
// make backward
i
te
rator
s
// make backward
s
te
p
s
const
auto
dst_backward_
i
te
rator
s
=
generate_tuple
(
const
auto
dst_backward_
s
te
p
s
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
Index
backward_step
;
Index
backward_step
_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
dst_scalar_per_access
[
i
]
:
0
;
backward_step
_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
dst_scalar_per_access
[
i
]
:
0
;
});
});
const
auto
backward_iterator
=
make_dynamic_tensor_coordinate_iterator
(
return
make_tensor_coordinate_step
(
dst_desc
,
backward_step
,
dst_iterator_hacks
[
I1
][
i
]);
dst_desc
,
backward_step_idx
,
dst_step_hacks
[
I1
][
i
]);
return
backward_iterator
;
},
},
Number
<
nDim
>
{});
Number
<
nDim
>
{});
...
@@ -1026,13 +1018,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -1026,13 +1018,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
{
{
if
constexpr
(
forward_sweep
[
i
])
if
constexpr
(
forward_sweep
[
i
])
{
{
move_
dynamic_
tensor_coordinate
(
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_forward_
i
te
rator
s
[
dst_dim_access_order
[
i
]]);
dst_desc
,
dst_coord_
,
dst_forward_
s
te
p
s
[
dst_dim_access_order
[
i
]]);
}
}
else
else
{
{
move_
dynamic_
tensor_coordinate
(
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_backward_
i
te
rator
s
[
dst_dim_access_order
[
i
]]);
dst_desc
,
dst_coord_
,
dst_backward_
s
te
p
s
[
dst_dim_access_order
[
i
]]);
}
}
}
}
});
});
...
@@ -1041,10 +1033,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -1041,10 +1033,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// move dst coordinate back to slice origin (or not)
// move dst coordinate back to slice origin (or not)
if
constexpr
(
DstResetCoordinateAfterRun
)
if
constexpr
(
DstResetCoordinateAfterRun
)
{
{
const
auto
dst_reset_
i
te
rator
=
const
auto
dst_reset_
s
te
p
=
make_
dynamic_
tensor_coordinate_
i
te
rator
(
dst_desc
,
GetDstCoordinateResetStep
());
make_tensor_coordinate_
s
te
p
(
dst_desc
,
GetDstCoordinateResetStep
());
move_
dynamic_
tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_reset_
i
te
rator
);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_reset_
s
te
p
);
}
}
}
}
...
@@ -1055,11 +1047,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -1055,11 +1047,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_src
,
0
>::
type
{};
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_src
,
0
>::
type
{};
constexpr
auto
src_
i
te
rator
_hacks
=
constexpr
auto
src_
s
te
p
_hacks
=
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
RunRead
(
src_desc
,
src_buf
,
src_
i
te
rator
_hacks
);
RunRead
(
src_desc
,
src_buf
,
src_
s
te
p
_hacks
);
}
}
template
<
typename
DstBuffer
>
template
<
typename
DstBuffer
>
...
@@ -1069,11 +1061,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -1069,11 +1061,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_dst
,
0
>::
type
{};
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_dst
,
0
>::
type
{};
constexpr
auto
dst_
i
te
rator
_hacks
=
constexpr
auto
dst_
s
te
p
_hacks
=
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
RunWrite
(
dst_desc
,
dst_buf
,
dst_
i
te
rator
_hacks
);
RunWrite
(
dst_desc
,
dst_buf
,
dst_
s
te
p
_hacks
);
}
}
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
...
@@ -1206,18 +1198,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -1206,18 +1198,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
:
src_slice_origin_step_idx
+
GetSrcCoordinateResetStep
();
:
src_slice_origin_step_idx
+
GetSrcCoordinateResetStep
();
// is it OK to construct a new step every time?
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
src_desc
,
adjusted_step_idx
);
make_dynamic_tensor_coordinate_iterator
(
src_desc
,
adjusted_step_idx
);
move_
dynamic_
tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template
<
typename
SrcMoveSliceWindow
I
te
rator
Hack
>
template
<
typename
SrcMoveSliceWindow
S
te
p
Hack
>
__device__
void
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
,
const
Index
&
src_slice_origin_step_idx
,
const
SrcMoveSliceWindow
I
te
rator
Hack
&
src_move_slice_window_
i
te
rator
_hack
)
const
SrcMoveSliceWindow
S
te
p
Hack
&
src_move_slice_window_
s
te
p
_hack
)
{
{
// if src coord was not reset by RunRead(), then need to adjust the step here
// if src coord was not reset by RunRead(), then need to adjust the step here
const
auto
adjusted_step_idx
=
const
auto
adjusted_step_idx
=
...
@@ -1225,10 +1216,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -1225,10 +1216,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
:
src_slice_origin_step_idx
+
GetSrcCoordinateResetStep
();
:
src_slice_origin_step_idx
+
GetSrcCoordinateResetStep
();
// is it OK to construct a new step every time?
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_
dynamic_
tensor_coordinate_
i
te
rator
(
const
auto
adjusted_step
=
make_tensor_coordinate_
s
te
p
(
src_desc
,
adjusted_step_idx
,
src_move_slice_window_
i
te
rator
_hack
);
src_desc
,
adjusted_step_idx
,
src_move_slice_window_
s
te
p
_hack
);
move_
dynamic_
tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
}
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
...
@@ -1240,19 +1231,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -1240,19 +1231,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
:
dst_slice_origin_step_idx
+
GetDstCoordinateResetStep
();
:
dst_slice_origin_step_idx
+
GetDstCoordinateResetStep
();
// is it OK to construct a new step every time?
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
dst_desc
,
adjusted_step_idx
);
make_dynamic_tensor_coordinate_iterator
(
dst_desc
,
adjusted_step_idx
);
move_
dynamic_
tensor_coordinate
(
dst_desc
,
dst_coord_
,
adjusted_step
);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
adjusted_step
);
}
}
private:
private:
static
constexpr
auto
buffer_desc_
=
static
constexpr
auto
buffer_desc_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
sequence_to_tuple_of_number
(
SliceLengths
{}));
make_naive_tensor_descriptor_packed
(
sequence_to_tuple_of_number
(
SliceLengths
{}));
static
constexpr
auto
buffer_size_
=
buffer_desc_
.
GetElementSpaceSize
();
static
constexpr
auto
buffer_size_
=
buffer_desc_
.
GetElementSpaceSize
();
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
SrcData
,
buffer_size_
>
buffer_
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
SrcData
,
buffer_size_
,
true
>
buffer_
;
SrcCoord
src_coord_
;
SrcCoord
src_coord_
;
DstCoord
dst_coord_
;
DstCoord
dst_coord_
;
...
@@ -1264,37 +1254,36 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -1264,37 +1254,36 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// 2. SrcBuffer is DynamicBuffer
// 2. SrcBuffer is DynamicBuffer
// 3. src_ref_idx is known at run-time
// 3. src_ref_idx is known at run-time
// 4. SrcRefToOriginDisplacement is known at compile-time
// 4. SrcRefToOriginDisplacement is known at compile-time
// 5. use #-
i
te
rator
// 5. use #-
s
te
p
// 2. dst:
// 2. dst:
// 1. DstDesc is known at compile-time
// 1. DstDesc is known at compile-time
// 2. DstBuffer is StaticBuffer
// 2. DstBuffer is StaticBuffer
// 3. DstOriginIdx is known at compile-time
// 3. DstOriginIdx is known at compile-time
// 4. use direct address calculation
// 4. use direct address calculation
// 3. vector access on src
// 3. vector access on src
template
<
template
<
typename
SrcData
,
typename
SrcData
,
typename
DstData
,
typename
DstData
,
typename
SrcDesc
,
typename
SrcDesc
,
typename
DstDesc
,
typename
DstDesc
,
typename
SliceLengths
,
typename
SliceLengths
,
typename
DimAccessOrder
,
typename
DimAccessOrder
,
index_t
SrcVectorDim
,
index_t
SrcVectorDim
,
index_t
SrcScalarPerVector
,
index_t
SrcScalarPerVector
,
index_t
SrcScalarStrideInVector
,
index_t
SrcScalarStrideInVector
,
typename
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
typename
std
::
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
bool
>
::
type
=
false
>
struct
ThreadwiseTensorSliceTransfer_v4
struct
ThreadwiseDynamicTensorSliceTransfer_v4
{
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
using
SrcCoord
=
decltype
(
make_
dynamic_
tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
SrcCoord
=
decltype
(
make_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
SrcCoord
I
te
rator
=
decltype
(
make_
dynamic_
tensor_coordinate_
i
te
rator
(
SrcDesc
{},
Index
{}));
using
SrcCoord
S
te
p
=
decltype
(
make_tensor_coordinate_
s
te
p
(
SrcDesc
{},
Index
{}));
__device__
constexpr
Threadwise
Dynamic
TensorSliceTransfer_v4
(
const
Index
&
src_ref_idx
)
__device__
constexpr
ThreadwiseTensorSliceTransfer_v4
(
const
Index
&
src_ref_idx
)
:
src_ref_coord_
(
make_
dynamic_
tensor_coordinate
(
SrcDesc
{},
src_ref_idx
))
:
src_ref_coord_
(
make_tensor_coordinate
(
SrcDesc
{},
src_ref_idx
))
{
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
...
@@ -1390,13 +1379,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
...
@@ -1390,13 +1379,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
constexpr
auto
src_ref_to_data_disp_idx
=
constexpr
auto
src_ref_to_data_disp_idx
=
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
;
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
;
constexpr
auto
src_ref_to_data_disp_coord_
i
te
rator
=
constexpr
auto
src_ref_to_data_disp_coord_
s
te
p
=
make_
dynamic_
tensor_coordinate_
i
te
rator
(
src_desc
,
src_ref_to_data_disp_idx
);
make_tensor_coordinate_
s
te
p
(
src_desc
,
src_ref_to_data_disp_idx
);
auto
src_data_coord
=
src_ref_coord_
;
auto
src_data_coord
=
src_ref_coord_
;
move_dynamic_tensor_coordinate
(
move_tensor_coordinate
(
src_desc
,
src_data_coord
,
src_ref_to_data_disp_coord_step
);
src_desc
,
src_data_coord
,
src_ref_to_data_disp_coord_iterator
);
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
src_tmp_vector
;
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
src_tmp_vector
;
...
@@ -1435,10 +1423,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
...
@@ -1435,10 +1423,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
{
{
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
src_desc
=
SrcDesc
{};
const
auto
src_slice_move_step_iter
=
make_dynamic_tensor_coordinate_iterator
(
const
auto
src_slice_move_step_iter
=
src_desc
,
to_multi_index
(
src_slice_move_step_idx
));
make_tensor_coordinate_step
(
src_desc
,
to_multi_index
(
src_slice_move_step_idx
));
move_
dynamic_
tensor_coordinate
(
SrcDesc
{},
src_ref_coord_
,
src_slice_move_step_iter
);
move_tensor_coordinate
(
SrcDesc
{},
src_ref_coord_
,
src_slice_move_step_iter
);
}
}
private:
private:
...
...
composable_kernel/include/tensor_operation/threadwise_
dynamic_
tensor_slice_transfer_v2.hpp
→
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp
View file @
ccc4a1d3
#ifndef CK_THREADWISE_
DYNAMIC_
TENSOR_SLICE_TRANSFER_V2_HPP
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V2_HPP
#define CK_THREADWISE_
DYNAMIC_
TENSOR_SLICE_TRANSFER_V2_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V2_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -30,7 +30,7 @@ template <typename SliceLengths,
...
@@ -30,7 +30,7 @@ template <typename SliceLengths,
bool
DstResetCoordinateAfterRun
>
// control whether to move back dst coordinate after each
bool
DstResetCoordinateAfterRun
>
// control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to
// RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation
// save addr computation
struct
Threadwise
Dynamic
TensorSliceTransfer_v3r1
struct
ThreadwiseTensorSliceTransfer_v3r1
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -38,18 +38,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
...
@@ -38,18 +38,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
using
SrcCoord
=
decltype
(
make_
dynamic_
tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
SrcCoord
=
decltype
(
make_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_
dynamic_
tensor_coordinate
(
DstDesc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_tensor_coordinate
(
DstDesc
{},
Index
{}));
using
SrcCoord
I
te
rator
=
decltype
(
make_
dynamic_
tensor_coordinate_
i
te
rator
(
SrcDesc
{},
Index
{}));
using
SrcCoord
S
te
p
=
decltype
(
make_tensor_coordinate_
s
te
p
(
SrcDesc
{},
Index
{}));
using
DstCoord
I
te
rator
=
decltype
(
make_
dynamic_
tensor_coordinate_
i
te
rator
(
DstDesc
{},
Index
{}));
using
DstCoord
S
te
p
=
decltype
(
make_tensor_coordinate_
s
te
p
(
DstDesc
{},
Index
{}));
__device__
constexpr
Threadwise
Dynamic
TensorSliceTransfer_v3r1
(
const
SrcDesc
&
src_desc
,
__device__
constexpr
ThreadwiseTensorSliceTransfer_v3r1
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin
,
const
Index
&
src_slice_origin
,
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin
)
const
Index
&
dst_slice_origin
)
:
src_coord_
(
make_
dynamic_
tensor_coordinate
(
src_desc
,
src_slice_origin
)),
:
src_coord_
(
make_tensor_coordinate
(
src_desc
,
src_slice_origin
)),
dst_coord_
(
make_
dynamic_
tensor_coordinate
(
dst_desc
,
dst_slice_origin
))
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin
))
{
{
// TODO: fix this
// TODO: fix this
static_assert
(
is_same
<
SrcData
,
DstData
>::
value
,
static_assert
(
is_same
<
SrcData
,
DstData
>::
value
,
...
@@ -64,18 +64,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
...
@@ -64,18 +64,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
{
{
src_coord_
=
make_
dynamic_
tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
src_coord_
=
make_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
}
}
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
{
{
dst_coord_
=
make_
dynamic_
tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
dst_coord_
=
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
}
}
template
<
typename
SrcBuffer
,
typename
SrcIteratorHacks
>
template
<
typename
SrcBuffer
,
typename
SrcStepHacks
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
__device__
void
const
SrcBuffer
&
src_buf
,
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
SrcStepHacks
&
src_step_hacks
)
const
SrcIteratorHacks
&
src_iterator_hacks
)
{
{
static_assert
(
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Global
or
static_assert
(
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Global
or
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
...
@@ -96,9 +95,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
...
@@ -96,9 +95,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
I1
),
I1
),
SrcVectorTensorContiguousDimOrder
{});
SrcVectorTensorContiguousDimOrder
{});
constexpr
auto
src_vector_desc
=
make_dynamic_naive_tensor_descriptor_v2
(
constexpr
auto
src_vector_desc
=
sequence_to_tuple_of_number
(
src_vector_tensor_lengths
),
make_naive_tensor_descriptor_v2
(
sequence_to_tuple_of_number
(
src_vector_tensor_lengths
),
sequence_to_tuple_of_number
(
src_vector_tensor_strides
));
sequence_to_tuple_of_number
(
src_vector_tensor_strides
));
// access order and lengths
// access order and lengths
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_vector_tensor_lengths
;
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_vector_tensor_lengths
;
...
@@ -108,31 +107,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
...
@@ -108,31 +107,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
constexpr
auto
ordered_src_access_lengths
=
constexpr
auto
ordered_src_access_lengths
=
container_reorder_given_new2old
(
src_access_lengths
,
src_dim_access_order
);
container_reorder_given_new2old
(
src_access_lengths
,
src_dim_access_order
);
// make forward
i
te
rator
s
// make forward
s
te
p
s
const
auto
src_forward_
i
te
rator
s
=
generate_tuple
(
const
auto
src_forward_
s
te
p
s
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
Index
forward_step
;
Index
forward_step
_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step
(
j
)
=
(
i
.
value
==
j
.
value
)
?
src_vector_tensor_lengths
[
i
]
:
0
;
forward_step
_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
src_vector_tensor_lengths
[
i
]
:
0
;
});
});
return
make_
dynamic_
tensor_coordinate_
i
te
rator
(
return
make_tensor_coordinate_
s
te
p
(
src_desc
,
forward_step
,
src_
i
te
rator
_hacks
[
I0
][
i
]);
src_desc
,
forward_step
_idx
,
src_
s
te
p
_hacks
[
I0
][
i
]);
},
},
Number
<
nDim
>
{});
Number
<
nDim
>
{});
// make backward
i
te
rator
s
// make backward
s
te
p
s
const
auto
src_backward_
i
te
rator
s
=
generate_tuple
(
const
auto
src_backward_
s
te
p
s
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
Index
backward_step
;
Index
backward_step
_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
src_vector_tensor_lengths
[
i
]
:
0
;
backward_step
_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
src_vector_tensor_lengths
[
i
]
:
0
;
});
});
return
make_
dynamic_
tensor_coordinate_
i
te
rator
(
return
make_tensor_coordinate_
s
te
p
(
src_desc
,
backward_step
,
src_
i
te
rator
_hacks
[
I1
][
i
]);
src_desc
,
backward_step
_idx
,
src_
s
te
p
_hacks
[
I1
][
i
]);
},
},
Number
<
nDim
>
{});
Number
<
nDim
>
{});
...
@@ -219,13 +218,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
...
@@ -219,13 +218,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
{
{
if
constexpr
(
forward_sweep
[
i
])
if
constexpr
(
forward_sweep
[
i
])
{
{
move_
dynamic_
tensor_coordinate
(
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_forward_
i
te
rator
s
[
src_dim_access_order
[
i
]]);
src_desc
,
src_coord_
,
src_forward_
s
te
p
s
[
src_dim_access_order
[
i
]]);
}
}
else
else
{
{
move_
dynamic_
tensor_coordinate
(
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_backward_
i
te
rator
s
[
src_dim_access_order
[
i
]]);
src_desc
,
src_coord_
,
src_backward_
s
te
p
s
[
src_dim_access_order
[
i
]]);
}
}
}
}
});
});
...
@@ -234,17 +233,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
...
@@ -234,17 +233,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
// move src coordinate back to slice origin (or not)
// move src coordinate back to slice origin (or not)
if
constexpr
(
SrcResetCoordinateAfterRun
)
if
constexpr
(
SrcResetCoordinateAfterRun
)
{
{
const
auto
src_reset_
i
te
rator
=
const
auto
src_reset_
s
te
p
=
make_
dynamic_
tensor_coordinate_
i
te
rator
(
src_desc
,
GetSrcCoordinateResetStep
());
make_tensor_coordinate_
s
te
p
(
src_desc
,
GetSrcCoordinateResetStep
());
move_
dynamic_
tensor_coordinate
(
src_desc
,
src_coord_
,
src_reset_
i
te
rator
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_reset_
s
te
p
);
}
}
}
}
template
<
typename
DstBuffer
,
typename
DstIteratorHacks
>
template
<
typename
DstBuffer
,
typename
DstStepHacks
>
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
__device__
void
DstBuffer
&
dst_buf
,
RunWrite
(
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
const
DstStepHacks
&
dst_step_hacks
)
const
DstIteratorHacks
&
dst_iterator_hacks
)
{
{
static_assert
(
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Global
or
static_assert
(
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Global
or
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
...
@@ -265,9 +263,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
...
@@ -265,9 +263,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
I1
),
I1
),
DstVectorTensorContiguousDimOrder
{});
DstVectorTensorContiguousDimOrder
{});
constexpr
auto
dst_vector_desc
=
make_dynamic_naive_tensor_descriptor_v2
(
constexpr
auto
dst_vector_desc
=
sequence_to_tuple_of_number
(
dst_vector_tensor_lengths
),
make_naive_tensor_descriptor_v2
(
sequence_to_tuple_of_number
(
dst_vector_tensor_lengths
),
sequence_to_tuple_of_number
(
dst_vector_tensor_strides
));
sequence_to_tuple_of_number
(
dst_vector_tensor_strides
));
// dst access order and lengths
// dst access order and lengths
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_vector_tensor_lengths
;
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_vector_tensor_lengths
;
...
@@ -277,35 +275,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
...
@@ -277,35 +275,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
constexpr
auto
ordered_dst_access_lengths
=
constexpr
auto
ordered_dst_access_lengths
=
container_reorder_given_new2old
(
dst_access_lengths
,
dst_dim_access_order
);
container_reorder_given_new2old
(
dst_access_lengths
,
dst_dim_access_order
);
// make forward
i
te
rator
s
// make forward
s
te
p
s
const
auto
dst_forward_
i
te
rator
s
=
generate_tuple
(
const
auto
dst_forward_
s
te
p
s
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
Index
forward_step
;
Index
forward_step
_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step
(
j
)
=
(
i
.
value
==
j
.
value
)
?
dst_vector_tensor_lengths
[
i
]
:
0
;
forward_step
_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
dst_vector_tensor_lengths
[
i
]
:
0
;
});
});
const
auto
forward_iterator
=
make_dynamic_tensor_coordinate_iterator
(
return
make_tensor_coordinate_step
(
dst_desc
,
forward_step
,
dst_iterator_hacks
[
I0
][
i
]);
dst_desc
,
forward_step_idx
,
dst_step_hacks
[
I0
][
i
]);
return
forward_iterator
;
},
},
Number
<
nDim
>
{});
Number
<
nDim
>
{});
// make backward
i
te
rator
s
// make backward
s
te
p
s
const
auto
dst_backward_
i
te
rator
s
=
generate_tuple
(
const
auto
dst_backward_
s
te
p
s
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
Index
backward_step
;
Index
backward_step
_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
dst_vector_tensor_lengths
[
i
]
:
0
;
backward_step
_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
dst_vector_tensor_lengths
[
i
]
:
0
;
});
});
const
auto
backward_iterator
=
make_dynamic_tensor_coordinate_iterator
(
return
make_tensor_coordinate_step
(
dst_desc
,
backward_step
,
dst_iterator_hacks
[
I1
][
i
]);
dst_desc
,
backward_step_idx
,
dst_step_hacks
[
I1
][
i
]);
return
backward_iterator
;
},
},
Number
<
nDim
>
{});
Number
<
nDim
>
{});
...
@@ -394,13 +388,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
...
@@ -394,13 +388,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
{
{
if
constexpr
(
forward_sweep
[
i
])
if
constexpr
(
forward_sweep
[
i
])
{
{
move_
dynamic_
tensor_coordinate
(
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_forward_
i
te
rator
s
[
dst_dim_access_order
[
i
]]);
dst_desc
,
dst_coord_
,
dst_forward_
s
te
p
s
[
dst_dim_access_order
[
i
]]);
}
}
else
else
{
{
move_
dynamic_
tensor_coordinate
(
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_backward_
i
te
rator
s
[
dst_dim_access_order
[
i
]]);
dst_desc
,
dst_coord_
,
dst_backward_
s
te
p
s
[
dst_dim_access_order
[
i
]]);
}
}
}
}
});
});
...
@@ -409,10 +403,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
...
@@ -409,10 +403,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
// move dst coordinate back to slice origin (or not)
// move dst coordinate back to slice origin (or not)
if
constexpr
(
DstResetCoordinateAfterRun
)
if
constexpr
(
DstResetCoordinateAfterRun
)
{
{
const
auto
dst_reset_
i
te
rator
=
const
auto
dst_reset_
s
te
p
=
make_
dynamic_
tensor_coordinate_
i
te
rator
(
dst_desc
,
GetDstCoordinateResetStep
());
make_tensor_coordinate_
s
te
p
(
dst_desc
,
GetDstCoordinateResetStep
());
move_
dynamic_
tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_reset_
i
te
rator
);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_reset_
s
te
p
);
}
}
}
}
...
@@ -423,11 +417,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
...
@@ -423,11 +417,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_src
,
0
>::
type
{};
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_src
,
0
>::
type
{};
constexpr
auto
src_
i
te
rator
_hacks
=
constexpr
auto
src_
s
te
p
_hacks
=
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
RunRead
(
src_desc
,
src_buf
,
src_
i
te
rator
_hacks
);
RunRead
(
src_desc
,
src_buf
,
src_
s
te
p
_hacks
);
}
}
template
<
typename
DstBuffer
>
template
<
typename
DstBuffer
>
...
@@ -437,11 +431,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
...
@@ -437,11 +431,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_dst
,
0
>::
type
{};
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_dst
,
0
>::
type
{};
constexpr
auto
dst_
i
te
rator
_hacks
=
constexpr
auto
dst_
s
te
p
_hacks
=
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
RunWrite
(
dst_desc
,
dst_buf
,
dst_
i
te
rator
_hacks
);
RunWrite
(
dst_desc
,
dst_buf
,
dst_
s
te
p
_hacks
);
}
}
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
...
@@ -564,18 +558,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
...
@@ -564,18 +558,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
:
src_slice_origin_step_idx
+
GetSrcCoordinateResetStep
();
:
src_slice_origin_step_idx
+
GetSrcCoordinateResetStep
();
// is it OK to construct a new step every time?
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
src_desc
,
adjusted_step_idx
);
make_dynamic_tensor_coordinate_iterator
(
src_desc
,
adjusted_step_idx
);
move_
dynamic_
tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template
<
typename
SrcMoveSliceWindow
I
te
rator
Hack
>
template
<
typename
SrcMoveSliceWindow
S
te
p
Hack
>
__device__
void
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
,
const
Index
&
src_slice_origin_step_idx
,
const
SrcMoveSliceWindow
I
te
rator
Hack
&
src_move_slice_window_
i
te
rator
_hack
)
const
SrcMoveSliceWindow
S
te
p
Hack
&
src_move_slice_window_
s
te
p
_hack
)
{
{
// if src coord was not reset by RunRead(), then need to adjust the step here
// if src coord was not reset by RunRead(), then need to adjust the step here
const
auto
adjusted_step_idx
=
const
auto
adjusted_step_idx
=
...
@@ -583,10 +576,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
...
@@ -583,10 +576,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
:
src_slice_origin_step_idx
+
GetSrcCoordinateResetStep
();
:
src_slice_origin_step_idx
+
GetSrcCoordinateResetStep
();
// is it OK to construct a new step every time?
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_
dynamic_
tensor_coordinate_
i
te
rator
(
const
auto
adjusted_step
=
make_tensor_coordinate_
s
te
p
(
src_desc
,
adjusted_step_idx
,
src_move_slice_window_
i
te
rator
_hack
);
src_desc
,
adjusted_step_idx
,
src_move_slice_window_
s
te
p
_hack
);
move_
dynamic_
tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
}
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
...
@@ -598,19 +591,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
...
@@ -598,19 +591,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
:
dst_slice_origin_step_idx
+
GetDstCoordinateResetStep
();
:
dst_slice_origin_step_idx
+
GetDstCoordinateResetStep
();
// is it OK to construct a new step every time?
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
dst_desc
,
adjusted_step_idx
);
make_dynamic_tensor_coordinate_iterator
(
dst_desc
,
adjusted_step_idx
);
move_
dynamic_
tensor_coordinate
(
dst_desc
,
dst_coord_
,
adjusted_step
);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
adjusted_step
);
}
}
private:
private:
static
constexpr
auto
buffer_desc_
=
static
constexpr
auto
buffer_desc_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
sequence_to_tuple_of_number
(
SliceLengths
{}));
make_naive_tensor_descriptor_packed
(
sequence_to_tuple_of_number
(
SliceLengths
{}));
static
constexpr
auto
buffer_size_
=
buffer_desc_
.
GetElementSpaceSize
();
static
constexpr
auto
buffer_size_
=
buffer_desc_
.
GetElementSpaceSize
();
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
SrcData
,
buffer_size_
>
buffer_
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
SrcData
,
buffer_size_
,
true
>
buffer_
;
SrcCoord
src_coord_
;
SrcCoord
src_coord_
;
DstCoord
dst_coord_
;
DstCoord
dst_coord_
;
...
@@ -622,25 +614,24 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
...
@@ -622,25 +614,24 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
// 2. SrcBuffer is DynamicBuffer
// 2. SrcBuffer is DynamicBuffer
// 3. src_ref_idx is known at run-time
// 3. src_ref_idx is known at run-time
// 4. SrcRefToOriginDisplacement is known at compile-time
// 4. SrcRefToOriginDisplacement is known at compile-time
// 5. use #-
i
te
rator
// 5. use #-
s
te
p
// 2. dst:
// 2. dst:
// 1. DstDesc is known at compile-time
// 1. DstDesc is known at compile-time
// 2. DstBuffer is StaticBuffer
// 2. DstBuffer is StaticBuffer
// 3. DstOriginIdx is known at compile-time
// 3. DstOriginIdx is known at compile-time
// 4. use direct address calculation
// 4. use direct address calculation
// 3. vector access on src
// 3. vector access on src
template
<
template
<
typename
SrcData
,
typename
SrcData
,
typename
DstData
,
typename
DstData
,
typename
SrcDesc
,
typename
SrcDesc
,
typename
DstDesc
,
typename
DstDesc
,
typename
SliceLengths
,
typename
SliceLengths
,
typename
DimAccessOrder
,
typename
DimAccessOrder
,
typename
SrcVectorTensorLengths
,
typename
SrcVectorTensorLengths
,
typename
SrcVectorTensorContiguousDimOrder
,
typename
SrcVectorTensorContiguousDimOrder
,
typename
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
typename
std
::
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
bool
>
::
type
=
false
>
struct
ThreadwiseTensorSliceTransfer_v4r1
struct
ThreadwiseDynamicTensorSliceTransfer_v4r1
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -649,12 +640,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1
...
@@ -649,12 +640,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
using
SrcCoord
=
decltype
(
make_
dynamic_
tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
SrcCoord
=
decltype
(
make_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
SrcCoord
I
te
rator
=
decltype
(
make_
dynamic_
tensor_coordinate_
i
te
rator
(
SrcDesc
{},
Index
{}));
using
SrcCoord
S
te
p
=
decltype
(
make_tensor_coordinate_
s
te
p
(
SrcDesc
{},
Index
{}));
__device__
constexpr
Threadwise
Dynamic
TensorSliceTransfer_v4r1
(
const
Index
&
src_ref_idx
)
__device__
constexpr
ThreadwiseTensorSliceTransfer_v4r1
(
const
Index
&
src_ref_idx
)
:
src_ref_coord_
(
make_
dynamic_
tensor_coordinate
(
SrcDesc
{},
src_ref_idx
))
:
src_ref_coord_
(
make_tensor_coordinate
(
SrcDesc
{},
src_ref_idx
))
{
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
...
@@ -712,9 +703,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1
...
@@ -712,9 +703,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1
I1
),
I1
),
SrcVectorTensorContiguousDimOrder
{});
SrcVectorTensorContiguousDimOrder
{});
constexpr
auto
src_vector_desc
=
make_dynamic_naive_tensor_descriptor_v2
(
constexpr
auto
src_vector_desc
=
sequence_to_tuple_of_number
(
src_vector_tensor_lengths
),
make_naive_tensor_descriptor_v2
(
sequence_to_tuple_of_number
(
src_vector_tensor_lengths
),
sequence_to_tuple_of_number
(
src_vector_tensor_strides
));
sequence_to_tuple_of_number
(
src_vector_tensor_strides
));
// access order and lengths
// access order and lengths
constexpr
auto
access_lengths
=
SliceLengths
{}
/
src_vector_tensor_lengths
;
constexpr
auto
access_lengths
=
SliceLengths
{}
/
src_vector_tensor_lengths
;
...
@@ -734,13 +725,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1
...
@@ -734,13 +725,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1
constexpr
auto
src_ref_to_data_disp_idx
=
constexpr
auto
src_ref_to_data_disp_idx
=
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
;
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
;
constexpr
auto
src_ref_to_data_disp_coord_
i
te
rator
=
constexpr
auto
src_ref_to_data_disp_coord_
s
te
p
=
make_
dynamic_
tensor_coordinate_
i
te
rator
(
src_desc
,
src_ref_to_data_disp_idx
);
make_tensor_coordinate_
s
te
p
(
src_desc
,
src_ref_to_data_disp_idx
);
auto
src_data_coord
=
src_ref_coord_
;
auto
src_data_coord
=
src_ref_coord_
;
move_dynamic_tensor_coordinate
(
move_tensor_coordinate
(
src_desc
,
src_data_coord
,
src_ref_to_data_disp_coord_step
);
src_desc
,
src_data_coord
,
src_ref_to_data_disp_coord_iterator
);
vector_type_maker_t
<
SrcData
,
src_vector_desc
.
GetElementSpaceSize
()
>
src_vector
;
vector_type_maker_t
<
SrcData
,
src_vector_desc
.
GetElementSpaceSize
()
>
src_vector
;
...
@@ -775,10 +765,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1
...
@@ -775,10 +765,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1
{
{
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
src_desc
=
SrcDesc
{};
const
auto
src_slice_move_step_iter
=
make_dynamic_tensor_coordinate_iterator
(
const
auto
src_slice_move_step_iter
=
src_desc
,
to_multi_index
(
src_slice_move_step_idx
));
make_tensor_coordinate_step
(
src_desc
,
to_multi_index
(
src_slice_move_step_idx
));
move_
dynamic_
tensor_coordinate
(
SrcDesc
{},
src_ref_coord_
,
src_slice_move_step_iter
);
move_tensor_coordinate
(
SrcDesc
{},
src_ref_coord_
,
src_slice_move_step_iter
);
}
}
private:
private:
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
ccc4a1d3
...
@@ -350,8 +350,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
...
@@ -350,8 +350,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
class FloatC>
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
{
const auto p_a =
re
inter
pret
_cast<const ushort2_t*>(a);
const auto p_a =
c_style_po
inter_cast<const ushort2_t*>(a);
const auto p_b =
re
inter
pret
_cast<const ushort2_t*>(b);
const auto p_b =
c_style_po
inter_cast<const ushort2_t*>(b);
return intrin_mfma_f32_32x32x2bf16<MPerXdlops, NPerXdlops, AStride, BStride>::run(
return intrin_mfma_f32_32x32x2bf16<MPerXdlops, NPerXdlops, AStride, BStride>::run(
p_a, p_b, reg_c);
p_a, p_b, reg_c);
...
@@ -384,8 +384,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
...
@@ -384,8 +384,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
class FloatC>
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
{
const auto p_a =
re
inter
pret
_cast<const ushort2_t*>(a);
const auto p_a =
c_style_po
inter_cast<const ushort2_t*>(a);
const auto p_b =
re
inter
pret
_cast<const ushort2_t*>(b);
const auto p_b =
c_style_po
inter_cast<const ushort2_t*>(b);
return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c);
return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c);
}
}
...
@@ -417,8 +417,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
...
@@ -417,8 +417,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
class FloatC>
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
{
const auto p_a =
re
inter
pret
_cast<const ushort2_t*>(a);
const auto p_a =
c_style_po
inter_cast<const ushort2_t*>(a);
const auto p_b =
re
inter
pret
_cast<const ushort2_t*>(b);
const auto p_b =
c_style_po
inter_cast<const ushort2_t*>(b);
return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c);
return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c);
}
}
...
@@ -450,8 +450,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
...
@@ -450,8 +450,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
class FloatC>
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
{
const auto p_a =
re
inter
pret
_cast<const ushort2_t*>(a);
const auto p_a =
c_style_po
inter_cast<const ushort2_t*>(a);
const auto p_b =
re
inter
pret
_cast<const ushort2_t*>(b);
const auto p_b =
c_style_po
inter_cast<const ushort2_t*>(b);
return intrin_mfma_f32_16x16x2bf16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
return intrin_mfma_f32_16x16x2bf16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
}
}
...
@@ -483,8 +483,8 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
...
@@ -483,8 +483,8 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
class FloatC>
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
{
const auto p_a =
re
inter
pret
_cast<const ushort2_t*>(a);
const auto p_a =
c_style_po
inter_cast<const ushort2_t*>(a);
const auto p_b =
re
inter
pret
_cast<const ushort2_t*>(b);
const auto p_b =
c_style_po
inter_cast<const ushort2_t*>(b);
return intrin_mfma_f32_4x4x2bf16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
return intrin_mfma_f32_4x4x2bf16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
}
}
...
...
Prev
1
2
3
4
5
6
…
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment