Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ccc4a1d3
Unverified
Commit
ccc4a1d3
authored
Aug 16, 2021
by
Chao Liu
Committed by
GitHub
Aug 16, 2021
Browse files
Merge pull request #8 from ROCmSoftwarePlatform/miopen_downstream_init_integration
parents
3b866461
16effa76
Changes
144
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1046 additions
and
1159 deletions
+1046
-1159
composable_kernel/include/tensor_description/tensor_adaptor.hpp
...able_kernel/include/tensor_description/tensor_adaptor.hpp
+3
-5
composable_kernel/include/tensor_description/tensor_descriptor.hpp
...e_kernel/include/tensor_description/tensor_descriptor.hpp
+53
-52
composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp
...l/include/tensor_description/tensor_descriptor_helper.hpp
+22
-23
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp
...el/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp
+45
-47
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp
...el/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp
+10
-10
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
...rnel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
+13
-18
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+55
-59
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp
...lude/tensor_operation/blockwise_tensor_slice_transfer.hpp
+36
-37
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp
...e/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp
+34
-36
composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp
...lude/tensor_operation/gridwise_contraction_dlops_v1r2.hpp
+58
-63
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r2.hpp
...nel/include/tensor_operation/gridwise_gemm_dlops_v1r2.hpp
+119
-136
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp
...nel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp
+76
-97
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
...ernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
+90
-95
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
+127
-151
composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp
...include/tensor_operation/threadwise_contraction_dlops.hpp
+12
-13
composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
...nel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
+3
-5
composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp
.../include/tensor_operation/threadwise_tensor_slice_set.hpp
+8
-8
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
...ude/tensor_operation/threadwise_tensor_slice_transfer.hpp
+165
-177
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp
.../tensor_operation/threadwise_tensor_slice_transfer_v2.hpp
+107
-117
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+10
-10
No files found.
composable_kernel/include/tensor_description/tensor_adaptor.hpp
View file @
ccc4a1d3
...
@@ -2,8 +2,8 @@
...
@@ -2,8 +2,8 @@
#define CK_TENSOR_ADAPTOR_HPP
#define CK_TENSOR_ADAPTOR_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -454,9 +454,7 @@ __host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transf
...
@@ -454,9 +454,7 @@ __host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transf
remove_cv_t
<
decltype
(
top_dim_hidden_ids
)
>>
{
transforms
};
remove_cv_t
<
decltype
(
top_dim_hidden_ids
)
>>
{
transforms
};
}
}
template
<
typename
X
,
template
<
typename
X
,
typename
...
Xs
,
typename
enable_if
<
sizeof
...(
Xs
)
>
=
2
,
bool
>::
type
=
false
>
typename
...
Xs
,
typename
std
::
enable_if
<
sizeof
...(
Xs
)
>
=
2
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
chain_tensor_adaptors
(
const
X
&
x
,
const
Xs
&
...
xs
)
__host__
__device__
constexpr
auto
chain_tensor_adaptors
(
const
X
&
x
,
const
Xs
&
...
xs
)
{
{
return
chain_tensor_adaptors
(
x
,
chain_tensor_adaptors
(
xs
...));
return
chain_tensor_adaptors
(
x
,
chain_tensor_adaptors
(
xs
...));
...
...
composable_kernel/include/tensor_description/
dynamic_
tensor_descriptor.hpp
→
composable_kernel/include/tensor_description/tensor_descriptor.hpp
View file @
ccc4a1d3
#ifndef CK_
DYNAMIC_
TENSOR_DESCRIPTOR_HPP
#ifndef CK_TENSOR_DESCRIPTOR_HPP
#define CK_
DYNAMIC_
TENSOR_DESCRIPTOR_HPP
#define CK_TENSOR_DESCRIPTOR_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
multi_index_transform.hpp"
#include "multi_index_transform.hpp"
namespace
ck
{
namespace
ck
{
template
<
index_t
NDimHidden
,
typename
VisibleDimensionIds
>
template
<
index_t
NDimHidden
,
typename
VisibleDimensionIds
>
struct
Dynamic
TensorCoordinate
;
struct
TensorCoordinate
;
template
<
index_t
NTransform
,
index_t
NDimVisible
,
typename
UpdateLowerIndexHack
>
template
<
index_t
NTransform
,
index_t
NDimVisible
,
typename
UpdateLowerIndexHack
>
struct
Dynamic
TensorCoordinate
I
te
rator
;
struct
TensorCoordinate
S
te
p
;
// Transforms: Tuple<transforms...>
// Transforms: Tuple<transforms...>
// LowerDimensionIdss : Tuple<Sequence<...>, ...>
// LowerDimensionIdss : Tuple<Sequence<...>, ...>
...
@@ -21,7 +21,7 @@ template <typename Transforms,
...
@@ -21,7 +21,7 @@ template <typename Transforms,
typename
UpperDimensionIdss
,
typename
UpperDimensionIdss
,
typename
VisibleDimensionIds
,
typename
VisibleDimensionIds
,
typename
ElementSpaceSize
>
typename
ElementSpaceSize
>
struct
Dynamic
TensorDescriptor
struct
TensorDescriptor
{
{
// TODO make these private
// TODO make these private
__host__
__device__
static
constexpr
index_t
GetNumOfTransform
()
{
return
Transforms
::
Size
();
}
__host__
__device__
static
constexpr
index_t
GetNumOfTransform
()
{
return
Transforms
::
Size
();
}
...
@@ -105,15 +105,15 @@ struct DynamicTensorDescriptor
...
@@ -105,15 +105,15 @@ struct DynamicTensorDescriptor
using
VisibleIndex
=
MultiIndex
<
ndim_visible_
>
;
using
VisibleIndex
=
MultiIndex
<
ndim_visible_
>
;
using
HiddenIndex
=
MultiIndex
<
ndim_hidden_
>
;
using
HiddenIndex
=
MultiIndex
<
ndim_hidden_
>
;
using
Coordinate
=
Dynamic
TensorCoordinate
<
ndim_hidden_
,
VisibleDimensionIds
>
;
using
Coordinate
=
TensorCoordinate
<
ndim_hidden_
,
VisibleDimensionIds
>
;
// may be index_t or Number<>
// may be index_t or Number<>
using
ElementSize
=
remove_cv_t
<
decltype
(
InitializeElementSize
(
Transforms
{}))
>
;
using
ElementSize
=
remove_cv_t
<
decltype
(
InitializeElementSize
(
Transforms
{}))
>
;
public:
public:
__host__
__device__
constexpr
Dynamic
TensorDescriptor
()
=
default
;
__host__
__device__
constexpr
TensorDescriptor
()
=
default
;
__host__
__device__
constexpr
Dynamic
TensorDescriptor
(
const
Transforms
&
transforms
,
__host__
__device__
constexpr
TensorDescriptor
(
const
Transforms
&
transforms
,
ElementSpaceSize
element_space_size
)
ElementSpaceSize
element_space_size
)
:
transforms_
{
transforms
},
:
transforms_
{
transforms
},
element_size_
{
InitializeElementSize
(
transforms
)},
element_size_
{
InitializeElementSize
(
transforms
)},
...
@@ -159,7 +159,7 @@ struct DynamicTensorDescriptor
...
@@ -159,7 +159,7 @@ struct DynamicTensorDescriptor
{
{
static_assert
(
Idx
::
Size
()
==
GetNumOfDimension
(),
"wrong! inconsistent # of dimension"
);
static_assert
(
Idx
::
Size
()
==
GetNumOfDimension
(),
"wrong! inconsistent # of dimension"
);
return
make_
dynamic_
tensor_coordinate
(
*
this
,
idx
).
GetOffset
();
return
make_tensor_coordinate
(
*
this
,
idx
).
GetOffset
();
}
}
// TODO make these private
// TODO make these private
...
@@ -196,7 +196,7 @@ struct DynamicTensorDescriptor
...
@@ -196,7 +196,7 @@ struct DynamicTensorDescriptor
__host__
__device__
void
Print
()
const
__host__
__device__
void
Print
()
const
{
{
printf
(
"{"
);
printf
(
"{"
);
printf
(
"
Dynamic
TensorDescriptor, "
);
printf
(
"TensorDescriptor, "
);
static_for
<
0
,
ntransform_
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
ntransform_
,
1
>
{}([
&
](
auto
i
)
{
printf
(
"transforms: "
);
printf
(
"transforms: "
);
transforms_
[
i
].
Print
();
transforms_
[
i
].
Print
();
...
@@ -217,7 +217,7 @@ struct DynamicTensorDescriptor
...
@@ -217,7 +217,7 @@ struct DynamicTensorDescriptor
};
};
template
<
index_t
NDimHidden
,
typename
VisibleDimensionIds
>
template
<
index_t
NDimHidden
,
typename
VisibleDimensionIds
>
struct
Dynamic
TensorCoordinate
struct
TensorCoordinate
{
{
// TODO make these private
// TODO make these private
static
constexpr
index_t
ndim_visible_
=
VisibleDimensionIds
::
Size
();
static
constexpr
index_t
ndim_visible_
=
VisibleDimensionIds
::
Size
();
...
@@ -226,9 +226,9 @@ struct DynamicTensorCoordinate
...
@@ -226,9 +226,9 @@ struct DynamicTensorCoordinate
using
VisibleIndex
=
MultiIndex
<
ndim_visible_
>
;
using
VisibleIndex
=
MultiIndex
<
ndim_visible_
>
;
public:
public:
__host__
__device__
constexpr
Dynamic
TensorCoordinate
()
=
default
;
__host__
__device__
constexpr
TensorCoordinate
()
=
default
;
__host__
__device__
constexpr
Dynamic
TensorCoordinate
(
const
HiddenIndex
&
idx_hidden
)
__host__
__device__
constexpr
TensorCoordinate
(
const
HiddenIndex
&
idx_hidden
)
:
idx_hidden_
{
idx_hidden
}
:
idx_hidden_
{
idx_hidden
}
{
{
}
}
...
@@ -252,16 +252,16 @@ struct DynamicTensorCoordinate
...
@@ -252,16 +252,16 @@ struct DynamicTensorCoordinate
};
};
template
<
index_t
NTransform
,
index_t
NDimVisible
,
typename
UpdateLowerIndexHack
>
template
<
index_t
NTransform
,
index_t
NDimVisible
,
typename
UpdateLowerIndexHack
>
struct
Dynamic
TensorCoordinate
I
te
rator
struct
TensorCoordinate
S
te
p
{
{
// TODO make these private
// TODO make these private
using
VisibleIndex
=
MultiIndex
<
NDimVisible
>
;
using
VisibleIndex
=
MultiIndex
<
NDimVisible
>
;
public:
public:
__host__
__device__
constexpr
Dynamic
TensorCoordinate
I
te
rator
()
=
default
;
__host__
__device__
constexpr
TensorCoordinate
S
te
p
()
=
default
;
__host__
__device__
constexpr
Dynamic
TensorCoordinate
I
te
rator
(
__host__
__device__
constexpr
TensorCoordinate
S
te
p
(
const
VisibleIndex
&
idx_diff_visible
,
const
VisibleIndex
&
idx_diff_visible
,
const
MultiIndex
<
NTransform
>&
do_transforms
)
const
MultiIndex
<
NTransform
>&
do_transforms
)
:
idx_diff_visible_
{
idx_diff_visible
},
do_transforms_
{
do_transforms
}
:
idx_diff_visible_
{
idx_diff_visible
},
do_transforms_
{
do_transforms
}
{
{
}
}
...
@@ -283,7 +283,7 @@ struct DynamicTensorCoordinateIterator
...
@@ -283,7 +283,7 @@ struct DynamicTensorCoordinateIterator
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor, and to put it outside the scope where it is used
// doesn't have constructor, and to put it outside the scope where it is used
// (transform_
dynamic_
tensor_descriptor) because template cannot be defined inside a function
// (transform_tensor_descriptor) because template cannot be defined inside a function
// template
// template
template
<
typename
NewTransforms
>
template
<
typename
NewTransforms
>
struct
lambda_get_up_dim_num
struct
lambda_get_up_dim_num
...
@@ -301,7 +301,7 @@ template <typename OldTensorDescriptor,
...
@@ -301,7 +301,7 @@ template <typename OldTensorDescriptor,
typename
NewLowerDimensionOldVisibleIdss
,
typename
NewLowerDimensionOldVisibleIdss
,
typename
NewUpperDimensionNewVisibleIdss
>
typename
NewUpperDimensionNewVisibleIdss
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
transform_
dynamic_
tensor_descriptor
(
const
OldTensorDescriptor
&
old_tensor_desc
,
transform_tensor_descriptor
(
const
OldTensorDescriptor
&
old_tensor_desc
,
const
NewTransforms
&
new_transforms
,
const
NewTransforms
&
new_transforms
,
NewLowerDimensionOldVisibleIdss
,
NewLowerDimensionOldVisibleIdss
,
NewUpperDimensionNewVisibleIdss
)
NewUpperDimensionNewVisibleIdss
)
...
@@ -376,7 +376,7 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
...
@@ -376,7 +376,7 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
const
auto
element_space_size
=
old_tensor_desc
.
GetElementSpaceSize
();
const
auto
element_space_size
=
old_tensor_desc
.
GetElementSpaceSize
();
return
Dynamic
TensorDescriptor
<
remove_cv_t
<
decltype
(
all_transforms
)
>
,
return
TensorDescriptor
<
remove_cv_t
<
decltype
(
all_transforms
)
>
,
remove_cv_t
<
decltype
(
all_low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
all_low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
all_up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
all_up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
new_visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
new_visible_dim_hidden_ids
)
>
,
...
@@ -385,7 +385,7 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
...
@@ -385,7 +385,7 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
}
}
template
<
typename
TensorDesc
,
typename
VisibleIndex
>
template
<
typename
TensorDesc
,
typename
VisibleIndex
>
__host__
__device__
constexpr
auto
make_
dynamic_
tensor_coordinate
(
const
TensorDesc
&
tensor_desc
,
__host__
__device__
constexpr
auto
make_tensor_coordinate
(
const
TensorDesc
&
tensor_desc
,
const
VisibleIndex
&
idx_visible
)
const
VisibleIndex
&
idx_visible
)
{
{
static_assert
(
TensorDesc
::
GetNumOfDimension
()
==
VisibleIndex
::
Size
(),
static_assert
(
TensorDesc
::
GetNumOfDimension
()
==
VisibleIndex
::
Size
(),
...
@@ -416,14 +416,15 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe
...
@@ -416,14 +416,15 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe
set_container_subset
(
idx_hidden
,
dims_low
,
idx_low
);
set_container_subset
(
idx_hidden
,
dims_low
,
idx_low
);
});
});
return
Dynamic
TensorCoordinate
<
ndim_hidden
,
decltype
(
visible_dim_ids
)
>
{
idx_hidden
};
return
TensorCoordinate
<
ndim_hidden
,
decltype
(
visible_dim_ids
)
>
{
idx_hidden
};
}
}
// UpdateLowerIndexHack: Sequence<...>
// UpdateLowerIndexHack: Sequence<...>
// HACK: control UpdateLowerIndex
// HACK: control UpdateLowerIndex
template
<
typename
TensorDesc
,
typename
VisibleIndex
,
typename
UpdateLowerIndexHack
>
template
<
typename
TensorDesc
,
typename
VisibleIndex
,
typename
UpdateLowerIndexHack
>
__host__
__device__
constexpr
auto
make_dynamic_tensor_coordinate_iterator
(
__host__
__device__
constexpr
auto
make_tensor_coordinate_step
(
const
TensorDesc
&
,
const
TensorDesc
&
,
const
VisibleIndex
&
idx_diff_visible
,
UpdateLowerIndexHack
)
const
VisibleIndex
&
idx_diff_visible
,
UpdateLowerIndexHack
)
{
{
static_assert
(
TensorDesc
::
GetNumOfDimension
()
==
VisibleIndex
::
Size
(),
static_assert
(
TensorDesc
::
GetNumOfDimension
()
==
VisibleIndex
::
Size
(),
"wrong! # of dimension inconsistent"
);
"wrong! # of dimension inconsistent"
);
...
@@ -470,23 +471,24 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate_iterator(
...
@@ -470,23 +471,24 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate_iterator(
set_container_subset
(
is_non_zero_diff
,
dims_low
,
non_zero_diff_pick_low
);
set_container_subset
(
is_non_zero_diff
,
dims_low
,
non_zero_diff_pick_low
);
});
});
return
Dynamic
TensorCoordinate
I
te
rator
<
ntransform
,
ndim_visible
,
UpdateLowerIndexHack
>
{
return
TensorCoordinate
S
te
p
<
ntransform
,
ndim_visible
,
UpdateLowerIndexHack
>
{
idx_diff_visible
,
idx_diff_visible
,
do_transforms
};
do_transforms
};
}
}
template
<
typename
TensorDesc
,
typename
VisibleIndex
>
template
<
typename
TensorDesc
,
typename
VisibleIndex
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
make_tensor_coordinate_step
(
const
TensorDesc
&
,
make_dynamic_tensor_coordinate_iterator
(
const
TensorDesc
&
,
const
VisibleIndex
&
idx_diff_visible
)
const
VisibleIndex
&
idx_diff_visible
)
{
{
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
return
make_
dynamic_
tensor_coordinate_
i
te
rator
(
return
make_tensor_coordinate_
s
te
p
(
TensorDesc
{},
idx_diff_visible
,
typename
uniform_sequence_gen
<
ntransform
,
0
>::
type
{});
TensorDesc
{},
idx_diff_visible
,
typename
uniform_sequence_gen
<
ntransform
,
0
>::
type
{});
}
}
template
<
typename
TensorDesc
,
typename
TensorCoord
,
typename
TensorCoordIterator
>
template
<
typename
TensorDesc
,
typename
TensorCoord
,
typename
TensorCoordStep
>
__host__
__device__
constexpr
void
move_dynamic_tensor_coordinate
(
__host__
__device__
constexpr
void
move_tensor_coordinate
(
const
TensorDesc
&
tensor_desc
,
const
TensorDesc
&
tensor_desc
,
TensorCoord
&
coord
,
const
TensorCoordIterator
&
coord_iterator
)
TensorCoord
&
coord
,
const
TensorCoordStep
&
coord_step
)
{
{
constexpr
index_t
ndim_hidden
=
TensorDesc
::
GetNumOfHiddenDimension
();
constexpr
index_t
ndim_hidden
=
TensorDesc
::
GetNumOfHiddenDimension
();
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
...
@@ -495,9 +497,8 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
...
@@ -495,9 +497,8 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
auto
idx_diff_hidden
=
make_zero_multi_index
<
ndim_hidden
>
();
auto
idx_diff_hidden
=
make_zero_multi_index
<
ndim_hidden
>
();
// initialize visible index diff
// initialize visible index diff
set_container_subset
(
idx_diff_hidden
,
set_container_subset
(
TensorDesc
::
GetVisibleDimensionIds
(),
idx_diff_hidden
,
TensorDesc
::
GetVisibleDimensionIds
(),
coord_step
.
GetVisibleIndexDiff
());
coord_iterator
.
GetVisibleIndexDiff
());
// this is what needs to be updated
// this is what needs to be updated
auto
&
idx_hidden
=
coord
.
GetHiddenIndex
();
auto
&
idx_hidden
=
coord
.
GetHiddenIndex
();
...
@@ -506,13 +507,13 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
...
@@ -506,13 +507,13 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
auto
idx_hidden_pick_visible
=
auto
idx_hidden_pick_visible
=
get_container_subset
(
idx_hidden
,
TensorDesc
::
GetVisibleDimensionIds
());
get_container_subset
(
idx_hidden
,
TensorDesc
::
GetVisibleDimensionIds
());
idx_hidden_pick_visible
+=
coord_
i
te
rator
.
GetIndexDiff
();
idx_hidden_pick_visible
+=
coord_
s
te
p
.
GetIndexDiff
();
set_container_subset
(
idx_hidden
,
TensorDesc
::
GetVisibleDimensionIds
(),
idx_hidden_pick_visible
);
set_container_subset
(
idx_hidden
,
TensorDesc
::
GetVisibleDimensionIds
(),
idx_hidden_pick_visible
);
// update rest of hidden index
// update rest of hidden index
static_for
<
ntransform
-
1
,
-
1
,
-
1
>
{}([
&
](
auto
itran
)
{
static_for
<
ntransform
-
1
,
-
1
,
-
1
>
{}([
&
](
auto
itran
)
{
if
(
coord_
i
te
rator
.
do_transforms_
[
itran
])
if
(
coord_
s
te
p
.
do_transforms_
[
itran
])
{
{
const
auto
&
tran
=
tensor_desc
.
GetTransforms
().
At
(
itran
);
const
auto
&
tran
=
tensor_desc
.
GetTransforms
().
At
(
itran
);
constexpr
auto
dims_low
=
TensorDesc
::
GetLowerDimensionIdss
().
At
(
itran
);
constexpr
auto
dims_low
=
TensorDesc
::
GetLowerDimensionIdss
().
At
(
itran
);
...
@@ -524,8 +525,8 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
...
@@ -524,8 +525,8 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
MultiIndex
<
dims_low
.
Size
()
>
idx_diff_low
;
MultiIndex
<
dims_low
.
Size
()
>
idx_diff_low
;
// HACK: control UpdateLowerIndex for
Dynamic
Merge using hack
// HACK: control UpdateLowerIndex for Merge using hack
constexpr
index_t
Hack
=
decltype
(
coord_
i
te
rator
.
update_lower_index_hack_
)
::
At
(
itran
);
constexpr
index_t
Hack
=
decltype
(
coord_
s
te
p
.
update_lower_index_hack_
)
::
At
(
itran
);
tran
.
UpdateLowerIndex
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up_new
,
Number
<
Hack
>
{});
tran
.
UpdateLowerIndex
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up_new
,
Number
<
Hack
>
{});
...
@@ -585,11 +586,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc&
...
@@ -585,11 +586,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc&
}
}
template
<
typename
TensorDesc
>
template
<
typename
TensorDesc
>
using
Dynamic
TensorCoordinate_t
=
decltype
(
make_
dynamic_
tensor_coordinate
(
using
TensorCoordinate_t
=
decltype
(
make_tensor_coordinate
(
TensorDesc
{},
MultiIndex
<
remove_cv_t
<
remove_reference_t
<
TensorDesc
>>::
GetNumOfDimension
()
>
{}));
TensorDesc
{},
MultiIndex
<
remove_cv_t
<
remove_reference_t
<
TensorDesc
>>::
GetNumOfDimension
()
>
{}));
template
<
typename
TensorDesc
>
template
<
typename
TensorDesc
>
using
Dynamic
TensorCoordinate
I
te
rator
_t
=
decltype
(
make_
dynamic_
tensor_coordinate_
i
te
rator
(
using
TensorCoordinate
S
te
p
_t
=
decltype
(
make_tensor_coordinate_
s
te
p
(
TensorDesc
{},
MultiIndex
<
remove_cv_t
<
remove_reference_t
<
TensorDesc
>>::
GetNumOfDimension
()
>
{}));
TensorDesc
{},
MultiIndex
<
remove_cv_t
<
remove_reference_t
<
TensorDesc
>>::
GetNumOfDimension
()
>
{}));
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/tensor_description/
dynamic_
tensor_descriptor_helper.hpp
→
composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp
View file @
ccc4a1d3
#ifndef CK_
DYNAMIC_
TENSOR_DESCRIPTOR_HELPER_HPP
#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP
#define CK_
DYNAMIC_
TENSOR_DESCRIPTOR_HELPER_HPP
#define CK_TENSOR_DESCRIPTOR_HELPER_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
multi_index_transform_helper.hpp"
#include "multi_index_transform_helper.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -37,9 +37,8 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt
...
@@ -37,9 +37,8 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt
template
<
typename
...
Lengths
,
template
<
typename
...
Lengths
,
typename
...
Strides
,
typename
...
Strides
,
typename
std
::
enable_if
<
sizeof
...(
Lengths
)
==
sizeof
...(
Strides
),
bool
>
::
type
=
false
>
typename
enable_if
<
sizeof
...(
Lengths
)
==
sizeof
...(
Strides
),
bool
>
::
type
=
false
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
make_naive_tensor_descriptor_v2
(
const
Tuple
<
Lengths
...
>&
lengths
,
make_dynamic_naive_tensor_descriptor_v2
(
const
Tuple
<
Lengths
...
>&
lengths
,
const
Tuple
<
Strides
...
>&
strides
)
const
Tuple
<
Strides
...
>&
strides
)
{
{
constexpr
index_t
N
=
sizeof
...(
Lengths
);
constexpr
index_t
N
=
sizeof
...(
Lengths
);
...
@@ -75,7 +74,7 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
...
@@ -75,7 +74,7 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
calculate_element_space_size_impl
(
lengths
,
strides
,
Number
<
0
>
{},
Number
<
1
>
{});
calculate_element_space_size_impl
(
lengths
,
strides
,
Number
<
0
>
{},
Number
<
1
>
{});
#endif
#endif
return
Dynamic
TensorDescriptor
<
remove_cv_t
<
decltype
(
transforms
)
>
,
return
TensorDescriptor
<
remove_cv_t
<
decltype
(
transforms
)
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
...
@@ -88,7 +87,7 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
...
@@ -88,7 +87,7 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
// 2) Number<>, which is known at compile-time
// 2) Number<>, which is known at compile-time
template
<
typename
...
Lengths
>
template
<
typename
...
Lengths
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
const
Tuple
<
Lengths
...
>&
lengths
)
make_naive_tensor_descriptor_packed
(
const
Tuple
<
Lengths
...
>&
lengths
)
{
{
constexpr
index_t
N
=
sizeof
...(
Lengths
);
constexpr
index_t
N
=
sizeof
...(
Lengths
);
...
@@ -103,7 +102,7 @@ make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple<Lengths...>& lengths)
...
@@ -103,7 +102,7 @@ make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple<Lengths...>& lengths)
const
auto
element_space_size
=
container_reduce
(
lengths
,
math
::
multiplies_v2
{},
Number
<
1
>
{});
const
auto
element_space_size
=
container_reduce
(
lengths
,
math
::
multiplies_v2
{},
Number
<
1
>
{});
return
Dynamic
TensorDescriptor
<
remove_cv_t
<
decltype
(
transforms
)
>
,
return
TensorDescriptor
<
remove_cv_t
<
decltype
(
transforms
)
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
...
@@ -113,7 +112,7 @@ make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple<Lengths...>& lengths)
...
@@ -113,7 +112,7 @@ make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple<Lengths...>& lengths)
template
<
typename
...
Lengths
,
typename
Align
>
template
<
typename
...
Lengths
,
typename
Align
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
const
Tuple
<
Lengths
...
>&
lengths
,
Align
align
)
make_naive_tensor_descriptor_aligned_v2
(
const
Tuple
<
Lengths
...
>&
lengths
,
Align
align
)
{
{
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -143,7 +142,7 @@ make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths
...
@@ -143,7 +142,7 @@ make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths
},
},
Number
<
N
>
{});
Number
<
N
>
{});
return
make_
dynamic_
naive_tensor_descriptor_v2
(
lengths
,
strides
);
return
make_naive_tensor_descriptor_v2
(
lengths
,
strides
);
}
}
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp
View file @
ccc4a1d3
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include "common_header.hpp"
#include "common_header.hpp"
#include "tensor_adaptor.hpp"
#include "tensor_adaptor.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_contraction_dlops.hpp"
#include "threadwise_contraction_dlops.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -22,7 +22,8 @@ namespace ck {
...
@@ -22,7 +22,8 @@ namespace ck {
// 2. CThreadBuffer is StaticBuffer
// 2. CThreadBuffer is StaticBuffer
// Also assume:
// Also assume:
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatA
,
typename
FloatB
,
typename
FloatB
,
typename
FloatC
,
typename
FloatC
,
...
@@ -37,8 +38,7 @@ template <index_t BlockSize,
...
@@ -37,8 +38,7 @@ template <index_t BlockSize,
index_t
M1N1ThreadClusterN101
,
index_t
M1N1ThreadClusterN101
,
index_t
AThreadCopyScalarPerVector_M11
,
index_t
AThreadCopyScalarPerVector_M11
,
index_t
BThreadCopyScalarPerVector_N11
,
index_t
BThreadCopyScalarPerVector_N11
,
typename
std
::
enable_if
<
AKMBlockDesc
::
IsKnownAtCompileTime
()
&&
typename
enable_if
<
AKMBlockDesc
::
IsKnownAtCompileTime
()
&&
BKNBlockDesc
::
IsKnownAtCompileTime
(),
BKNBlockDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
bool
>
::
type
=
false
>
struct
BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
struct
BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
{
{
...
@@ -71,9 +71,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
...
@@ -71,9 +71,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
static
constexpr
index_t
N0
=
N
/
N1
;
static
constexpr
index_t
N0
=
N
/
N1
;
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeAKM0M1BlockDescriptor
(
const
AKMBlockDesc
&
a_k_m_block_desc
)
MakeAKM0M1BlockDescriptor
(
const
AKMBlockDesc
&
/*
a_k_m_block_desc
*/
)
{
{
const
auto
a_k_m0_m1_block_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
a_k_m0_m1_block_desc
=
transform_tensor_descriptor
(
AKMBlockDesc
{},
AKMBlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
K
>
{}),
make_tuple
(
make_pass_through_transform
(
Number
<
K
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
M0
>
{},
Number
<
M1
>
{}))),
make_unmerge_transform
(
make_tuple
(
Number
<
M0
>
{},
Number
<
M1
>
{}))),
...
@@ -84,9 +84,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
...
@@ -84,9 +84,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeBKN0N1BlockDescriptor
(
const
BKNBlockDesc
&
b_k_n_block_desc
)
MakeBKN0N1BlockDescriptor
(
const
BKNBlockDesc
&
/*
b_k_n_block_desc
*/
)
{
{
const
auto
b_k_n0_n1_block_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
b_k_n0_n1_block_desc
=
transform_tensor_descriptor
(
BKNBlockDesc
{},
BKNBlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
K
>
{}),
make_tuple
(
make_pass_through_transform
(
Number
<
K
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
N0
>
{},
Number
<
N1
>
{}))),
make_unmerge_transform
(
make_tuple
(
Number
<
N0
>
{},
Number
<
N1
>
{}))),
...
@@ -194,7 +194,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
...
@@ -194,7 +194,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
typename
ABlockBuffer
,
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
typename
CThreadBuffer
>
__device__
void
Run
(
const
CM0M1N0N1ThreadDesc
&
c_m0_m1_n0_n1_thread_desc
,
__device__
void
Run
(
const
CM0M1N0N1ThreadDesc
&
/*
c_m0_m1_n0_n1_thread_desc
*/
,
const
ABlockBuffer
&
a_block_buf
,
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
CThreadBuffer
&
c_thread_buf
)
const
...
@@ -357,15 +357,14 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
...
@@ -357,15 +357,14 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
private:
private:
// A[K, M0, M1]
// A[K, M0, M1]
static
constexpr
auto
a_k_m0_m1_thread_desc_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
static
constexpr
auto
a_k_m0_m1_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
M0
>
{},
Number
<
M1PerThreadM11
>
{}));
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
M0
>
{},
Number
<
M1PerThreadM11
>
{}));
// B[K, N0, N1]
// B[K, N0, N1]
static
constexpr
auto
b_k_n0_n1_thread_desc_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
static
constexpr
auto
b_k_n0_n1_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
N0
>
{},
Number
<
N1PerThreadN11
>
{}));
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
N0
>
{},
Number
<
N1PerThreadN11
>
{}));
using
AThreadCopy
=
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
FloatA
,
decltype
(
a_k_m0_m1_block_desc_
),
decltype
(
a_k_m0_m1_block_desc_
),
decltype
(
a_k_m0_m1_thread_desc_
),
decltype
(
a_k_m0_m1_thread_desc_
),
...
@@ -375,8 +374,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
...
@@ -375,8 +374,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
AThreadCopyScalarPerVector_M11
,
AThreadCopyScalarPerVector_M11
,
1
>
;
1
>
;
using
BThreadCopy
=
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
FloatB
,
decltype
(
b_k_n0_n1_block_desc_
),
decltype
(
b_k_n0_n1_block_desc_
),
decltype
(
b_k_n0_n1_thread_desc_
),
decltype
(
b_k_n0_n1_thread_desc_
),
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp
View file @
ccc4a1d3
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include "common_header.hpp"
#include "common_header.hpp"
#include "tensor_adaptor.hpp"
#include "tensor_adaptor.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer_v2.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_contraction_dlops.hpp"
#include "threadwise_contraction_dlops.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -38,7 +38,7 @@ template <index_t BlockSize,
...
@@ -38,7 +38,7 @@ template <index_t BlockSize,
// BM10BN10ThreadClusterBN101, ...>
// BM10BN10ThreadClusterBN101, ...>
index_t
AThreadCopyScalarPerVector_BM11
,
index_t
AThreadCopyScalarPerVector_BM11
,
index_t
BThreadCopyScalarPerVector_BN11
,
index_t
BThreadCopyScalarPerVector_BN11
,
typename
std
::
enable_if
<
ABlockDesc_BK0_BM_BK1
::
IsKnownAtCompileTime
()
&&
typename
enable_if
<
ABlockDesc_BK0_BM_BK1
::
IsKnownAtCompileTime
()
&&
BBlockDesc_BK0_BN_BK1
::
IsKnownAtCompileTime
(),
BBlockDesc_BK0_BN_BK1
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
bool
>
::
type
=
false
>
struct
BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
struct
BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
...
@@ -75,7 +75,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
...
@@ -75,7 +75,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_BK0_BM0_BM1_BK1
(
const
ABlockDesc_BK0_BM_BK1
&
a_block_desc_bk0_bm_bk1
)
MakeABlockDescriptor_BK0_BM0_BM1_BK1
(
const
ABlockDesc_BK0_BM_BK1
&
a_block_desc_bk0_bm_bk1
)
{
{
const
auto
a_block_bk0_bm0_bm1_bk1
=
transform_
dynamic_
tensor_descriptor
(
const
auto
a_block_bk0_bm0_bm1_bk1
=
transform_tensor_descriptor
(
a_block_desc_bk0_bm_bk1
,
a_block_desc_bk0_bm_bk1
,
make_tuple
(
make_pass_through_transform
(
Number
<
BK0
>
{}),
make_tuple
(
make_pass_through_transform
(
Number
<
BK0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
BM0
>
{},
Number
<
BM1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
BM0
>
{},
Number
<
BM1
>
{})),
...
@@ -89,7 +89,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
...
@@ -89,7 +89,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_BK0_BN0_BN1_BK1
(
const
BBlockDesc_BK0_BN_BK1
&
b_block_desc_bk0_bn_bk1
)
MakeBBlockDescriptor_BK0_BN0_BN1_BK1
(
const
BBlockDesc_BK0_BN_BK1
&
b_block_desc_bk0_bn_bk1
)
{
{
const
auto
b_block_desc_bk0_bn0_bn1_bk1
=
transform_
dynamic_
tensor_descriptor
(
const
auto
b_block_desc_bk0_bn0_bn1_bk1
=
transform_tensor_descriptor
(
b_block_desc_bk0_bn_bk1
,
b_block_desc_bk0_bn_bk1
,
make_tuple
(
make_pass_through_transform
(
Number
<
BK0
>
{}),
make_tuple
(
make_pass_through_transform
(
Number
<
BK0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
BN0
>
{},
Number
<
BN1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
BN0
>
{},
Number
<
BN1
>
{})),
...
@@ -372,15 +372,15 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
...
@@ -372,15 +372,15 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
private:
private:
// A[BK0, BM0, BM1, BK1]
// A[BK0, BM0, BM1, BK1]
static
constexpr
auto
a_thread_desc_bk0_bm0_bm1_bk1_
=
static
constexpr
auto
a_thread_desc_bk0_bm0_bm1_bk1_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BK0PerThread
>
{},
Number
<
BM0
>
{},
Number
<
BM1PerThreadBM11
>
{},
Number
<
BK1
>
{}));
Number
<
BK0PerThread
>
{},
Number
<
BM0
>
{},
Number
<
BM1PerThreadBM11
>
{},
Number
<
BK1
>
{}));
// B[BK0, BN0, BN1, BK1]
// B[BK0, BN0, BN1, BK1]
static
constexpr
auto
b_thread_desc_bk0_bn0_bn1_bk1_
=
static
constexpr
auto
b_thread_desc_bk0_bn0_bn1_bk1_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BK0PerThread
>
{},
Number
<
BN0
>
{},
Number
<
BN1PerThreadBN11
>
{},
Number
<
BK1
>
{}));
Number
<
BK0PerThread
>
{},
Number
<
BN0
>
{},
Number
<
BN1PerThreadBN11
>
{},
Number
<
BK1
>
{}));
using
AThreadCopy
=
Threadwise
Dynamic
TensorSliceTransfer_v4r1
<
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4r1
<
FloatA
,
FloatA
,
FloatA
,
FloatA
,
decltype
(
a_block_desc_bk0_bm0_bm1_bk1_
),
decltype
(
a_block_desc_bk0_bm0_bm1_bk1_
),
...
@@ -390,7 +390,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
...
@@ -390,7 +390,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
Sequence
<
1
,
1
,
BM1PerThreadBM11
,
BK1
>
,
// SrcVectorTensorLengths
Sequence
<
1
,
1
,
BM1PerThreadBM11
,
BK1
>
,
// SrcVectorTensorLengths
Sequence
<
0
,
1
,
2
,
3
>>
;
// SrcVectorTensorContiguousDimOrder
Sequence
<
0
,
1
,
2
,
3
>>
;
// SrcVectorTensorContiguousDimOrder
using
BThreadCopy
=
Threadwise
Dynamic
TensorSliceTransfer_v4r1
<
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4r1
<
FloatB
,
FloatB
,
FloatB
,
FloatB
,
decltype
(
b_block_desc_bk0_bn0_bn1_bk1_
),
decltype
(
b_block_desc_bk0_bn0_bn1_bk1_
),
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
View file @
ccc4a1d3
...
@@ -31,17 +31,16 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
...
@@ -31,17 +31,16 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
// HACK: fix this @Jing Zhang
// HACK: fix this @Jing Zhang
static
constexpr
index_t
KPerThreadSubC
=
4
;
static
constexpr
index_t
KPerThreadSubC
=
4
;
static
constexpr
auto
a_thread_mtx_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
static
constexpr
auto
a_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
KPerThreadSubC
>
{}));
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
KPerThreadSubC
>
{}));
static
constexpr
auto
b_thread_mtx_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
static
constexpr
auto
b_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
Number
<
EPerThreadLoop
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
static
constexpr
auto
c_thread_mtx_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
static
constexpr
auto
c_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThreadSubC
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
Number
<
KPerThreadSubC
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
using
AThreadCopy
=
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
FloatA
,
BlockMatrixA
,
BlockMatrixA
,
decltype
(
a_thread_mtx_
),
decltype
(
a_thread_mtx_
),
...
@@ -69,7 +68,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
...
@@ -69,7 +68,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
"wrong! K dimension not consistent
\n
"
);
"wrong! K dimension not consistent
\n
"
);
constexpr
index_t
K
=
BlockMatrixA
{}.
GetLength
(
I1
);
// A is transposed
constexpr
index_t
K
=
BlockMatrixA
{}.
GetLength
(
I1
);
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
{}.
GetLength
(
I1
);
constexpr
index_t
H
=
BlockMatrixB
{}.
GetLength
(
I2
);
constexpr
index_t
H
=
BlockMatrixB
{}.
GetLength
(
I2
);
constexpr
index_t
W
=
BlockMatrixB
{}.
GetLength
(
I3
);
constexpr
index_t
W
=
BlockMatrixB
{}.
GetLength
(
I3
);
...
@@ -121,9 +119,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
...
@@ -121,9 +119,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
"wrong! inconsistent type"
);
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
...
@@ -138,7 +133,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
...
@@ -138,7 +133,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
static_assert
(
WPerThread
%
WoPerThreadSubC
==
0
,
""
);
static_assert
(
WPerThread
%
WoPerThreadSubC
==
0
,
""
);
// thread A buffer for GEMM
// thread A buffer for GEMM
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatA
,
a_thread_mtx_
.
GetElementSpaceSize
()
>
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatA
,
a_thread_mtx_
.
GetElementSpaceSize
()
,
true
>
a_thread_buf
;
a_thread_buf
;
constexpr
auto
threadwise_gemm
=
ThreadwiseGemmDlops_km_kn_mn_v3
<
FloatA
,
constexpr
auto
threadwise_gemm
=
ThreadwiseGemmDlops_km_kn_mn_v3
<
FloatA
,
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
ccc4a1d3
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "xdlops_gemm.hpp"
#include "xdlops_gemm.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -52,7 +52,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
...
@@ -52,7 +52,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
const
index_t
waveId
=
thread_id
/
WaveSize
;
const
index_t
waveId
=
thread_id
/
WaveSize
;
const
index_t
laneId
=
thread_id
%
WaveSize
;
const
index_t
laneId
=
thread_id
%
WaveSize
;
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
{
{
...
@@ -73,7 +72,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
...
@@ -73,7 +72,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
const
index_t
thread_id
=
get_thread_local_1d_id
();
const
index_t
thread_id
=
get_thread_local_1d_id
();
const
index_t
waveId
=
thread_id
/
WaveSize
;
const
index_t
waveId
=
thread_id
/
WaveSize
;
const
index_t
laneId
=
thread_id
%
WaveSize
;
const
index_t
laneId
=
thread_id
%
WaveSize
;
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
...
@@ -193,17 +191,17 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
...
@@ -193,17 +191,17 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
private:
private:
// A[K, M]
// A[K, M]
static
constexpr
auto
a_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
static
constexpr
auto
a_thread_desc_
=
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
I1
,
Number
<
K1
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
I1
,
Number
<
K1
>
{}));
// B[K, N]
// B[K, N]
static
constexpr
auto
b_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
static
constexpr
auto
b_thread_desc_
=
make_tuple
(
I1
,
Number
<
NRepeat
>
{},
I1
,
Number
<
K1
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
NRepeat
>
{},
I1
,
Number
<
K1
>
{}));
static
constexpr
auto
c_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
static
constexpr
auto
c_thread_desc_
=
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
using
AThreadCopy
=
Threadwise
Dynamic
TensorSliceTransfer_v4
<
FloatAB
,
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
FloatAB
,
ABlockDesc
,
ABlockDesc
,
decltype
(
a_thread_desc_
),
decltype
(
a_thread_desc_
),
...
@@ -213,7 +211,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
...
@@ -213,7 +211,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
K1
,
K1
,
1
>
;
1
>
;
using
BThreadCopy
=
Threadwise
Dynamic
TensorSliceTransfer_v4
<
FloatAB
,
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
FloatAB
,
BBlockDesc
,
BBlockDesc
,
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
...
@@ -272,7 +270,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
...
@@ -272,7 +270,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
const
index_t
waveId
=
thread_id
/
WaveSize
;
const
index_t
waveId
=
thread_id
/
WaveSize
;
const
index_t
laneId
=
thread_id
%
WaveSize
;
const
index_t
laneId
=
thread_id
%
WaveSize
;
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
{
{
...
@@ -293,7 +290,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
...
@@ -293,7 +290,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
const
index_t
thread_id
=
get_thread_local_1d_id
();
const
index_t
thread_id
=
get_thread_local_1d_id
();
const
index_t
waveId
=
thread_id
/
WaveSize
;
const
index_t
waveId
=
thread_id
/
WaveSize
;
const
index_t
laneId
=
thread_id
%
WaveSize
;
const
index_t
laneId
=
thread_id
%
WaveSize
;
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
...
@@ -490,17 +486,17 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
...
@@ -490,17 +486,17 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
private:
private:
// A[K, M]
// A[K, M]
static
constexpr
auto
a_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
static
constexpr
auto
a_thread_desc_
=
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
I1
,
Number
<
K1
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
I1
,
Number
<
K1
>
{}));
// B[K, N]
// B[K, N]
static
constexpr
auto
b_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
static
constexpr
auto
b_thread_desc_
=
make_tuple
(
I1
,
Number
<
NRepeat
>
{},
I1
,
Number
<
K1
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
NRepeat
>
{},
I1
,
Number
<
K1
>
{}));
static
constexpr
auto
c_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
static
constexpr
auto
c_thread_desc_
=
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
using
AThreadCopy
=
Threadwise
Dynamic
TensorSliceTransfer_v4
<
FloatAB
,
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
FloatAB
,
ABlockDesc
,
ABlockDesc
,
decltype
(
a_thread_desc_
),
decltype
(
a_thread_desc_
),
...
@@ -510,7 +506,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
...
@@ -510,7 +506,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
1
,
// K1,
1
,
// K1,
1
>
;
1
>
;
using
BThreadCopy
=
Threadwise
Dynamic
TensorSliceTransfer_v4
<
FloatAB
,
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
FloatAB
,
BBlockDesc
,
BBlockDesc
,
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
...
...
composable_kernel/include/tensor_operation/blockwise_
dynamic_
tensor_slice_transfer.hpp
→
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp
View file @
ccc4a1d3
#ifndef CK_BLOCKWISE_
DYNAMIC_
TENSOR_SLICE_TRANSFER_HPP
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
#define CK_BLOCKWISE_
DYNAMIC_
TENSOR_SLICE_TRANSFER_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
namespace
ck
{
// this version does following things to avoid scratch memory issue
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. Threadwise
Dynamic
TensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. Threadwise
Dynamic
TensorSliceTransfer_v3::Run() does not construct new tensor coordinate
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
InMemoryDataOperationEnum_t
DstInMemOp
,
InMemoryDataOperationEnum_t
DstInMemOp
,
typename
BlockSliceLengths
,
typename
BlockSliceLengths
,
...
@@ -33,13 +33,13 @@ template <index_t BlockSize,
...
@@ -33,13 +33,13 @@ template <index_t BlockSize,
index_t
DstScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
Blockwise
Dynamic
TensorSliceTransfer_v4
struct
BlockwiseTensorSliceTransfer_v4
{
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
Blockwise
Dynamic
TensorSliceTransfer_v4
(
const
SrcDesc
&
src_desc
,
__device__
constexpr
BlockwiseTensorSliceTransfer_v4
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
)
const
Index
&
dst_block_slice_origin
)
...
@@ -77,15 +77,14 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
...
@@ -77,15 +77,14 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
}
}
}
}
template
<
typename
SrcBuffer
,
typename
SrcIteratorHacks
>
template
<
typename
SrcBuffer
,
typename
SrcStepHacks
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
__device__
void
const
SrcBuffer
&
src_buf
,
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
SrcStepHacks
&
src_step_hacks
)
const
SrcIteratorHacks
&
src_iterator_hacks
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
RunRead
(
src_desc
,
src_buf
,
src_
i
te
rator
_hacks
);
threadwise_transfer_
.
RunRead
(
src_desc
,
src_buf
,
src_
s
te
p
_hacks
);
}
}
}
}
...
@@ -118,18 +117,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
...
@@ -118,18 +117,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
}
}
}
}
// SrcMoveSliceWindow
I
te
rator
Hack to control index calculation move slice window
// SrcMoveSliceWindow
S
te
p
Hack to control index calculation move slice window
template
<
typename
SrcMoveSliceWindow
I
te
rator
Hack
>
template
<
typename
SrcMoveSliceWindow
S
te
p
Hack
>
__device__
void
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
,
const
Index
&
step
,
const
SrcMoveSliceWindow
I
te
rator
Hack
&
src_move_slice_window_
i
te
rator
_hack
)
const
SrcMoveSliceWindow
S
te
p
Hack
&
src_move_slice_window_
s
te
p
_hack
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
,
src_move_slice_window_
i
te
rator
_hack
);
src_desc
,
step
,
src_move_slice_window_
s
te
p
_hack
);
}
}
}
}
...
@@ -147,7 +146,7 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
...
@@ -147,7 +146,7 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
make_cluster_descriptor_v2
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
make_cluster_descriptor_v2
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
using
ThreadwiseTransfer
=
Threadwise
Dynamic
TensorSliceTransfer_v3
<
ThreadSliceLengths
,
ThreadwiseTensorSliceTransfer_v3
<
ThreadSliceLengths
,
DstInMemOp
,
DstInMemOp
,
SrcData
,
SrcData
,
DstData
,
DstData
,
...
...
composable_kernel/include/tensor_operation/blockwise_
dynamic_
tensor_slice_transfer_v2.hpp
→
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp
View file @
ccc4a1d3
#ifndef CK_BLOCKWISE_
DYNAMIC_
TENSOR_SLICE_TRANSFER_V2_HPP
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP
#define CK_BLOCKWISE_
DYNAMIC_
TENSOR_SLICE_TRANSFER_V2_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer_v2.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp"
namespace
ck
{
namespace
ck
{
// this version does following things to avoid scratch memory issue
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. Threadwise
Dynamic
TensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. Threadwise
Dynamic
TensorSliceTransfer_v3::Run() does not construct new tensor coordinate
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
InMemoryDataOperationEnum_t
DstInMemOp
,
InMemoryDataOperationEnum_t
DstInMemOp
,
typename
BlockSliceLengths
,
typename
BlockSliceLengths
,
...
@@ -31,14 +31,13 @@ template <index_t BlockSize,
...
@@ -31,14 +31,13 @@ template <index_t BlockSize,
typename
DstVectorTensorContiguousDimOrder
,
typename
DstVectorTensorContiguousDimOrder
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
Blockwise
Dynamic
TensorSliceTransfer_v4r1
struct
BlockwiseTensorSliceTransfer_v4r1
{
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
BlockwiseDynamicTensorSliceTransfer_v4r1
(
__device__
constexpr
BlockwiseTensorSliceTransfer_v4r1
(
const
SrcDesc
&
src_desc
,
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
)
const
Index
&
dst_block_slice_origin
)
...
@@ -76,15 +75,14 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
...
@@ -76,15 +75,14 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
}
}
}
}
template
<
typename
SrcBuffer
,
typename
SrcIteratorHacks
>
template
<
typename
SrcBuffer
,
typename
SrcStepHacks
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
__device__
void
const
SrcBuffer
&
src_buf
,
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
SrcStepHacks
&
src_step_hacks
)
const
SrcIteratorHacks
&
src_iterator_hacks
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
RunRead
(
src_desc
,
src_buf
,
src_
i
te
rator
_hacks
);
threadwise_transfer_
.
RunRead
(
src_desc
,
src_buf
,
src_
s
te
p
_hacks
);
}
}
}
}
...
@@ -107,18 +105,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
...
@@ -107,18 +105,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
}
}
}
}
// SrcMoveSliceWindow
I
te
rator
Hack to control index calculation move slice window
// SrcMoveSliceWindow
S
te
p
Hack to control index calculation move slice window
template
<
typename
SrcMoveSliceWindow
I
te
rator
Hack
>
template
<
typename
SrcMoveSliceWindow
S
te
p
Hack
>
__device__
void
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
,
const
Index
&
step
,
const
SrcMoveSliceWindow
I
te
rator
Hack
&
src_move_slice_window_
i
te
rator
_hack
)
const
SrcMoveSliceWindow
S
te
p
Hack
&
src_move_slice_window_
s
te
p
_hack
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
,
src_move_slice_window_
i
te
rator
_hack
);
src_desc
,
step
,
src_move_slice_window_
s
te
p
_hack
);
}
}
}
}
...
@@ -136,7 +134,7 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
...
@@ -136,7 +134,7 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
make_cluster_descriptor_v2
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
make_cluster_descriptor_v2
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
using
ThreadwiseTransfer
=
Threadwise
Dynamic
TensorSliceTransfer_v3r1
<
ThreadSliceLengths
,
ThreadwiseTensorSliceTransfer_v3r1
<
ThreadSliceLengths
,
DstInMemOp
,
DstInMemOp
,
SrcData
,
SrcData
,
DstData
,
DstData
,
...
...
composable_kernel/include/tensor_operation/gridwise_
dynamic_
contraction_dlops_v1r2.hpp
→
composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp
View file @
ccc4a1d3
#ifndef CK_GRIDWISE_
DYNAMIC_
CONTRACTION_DLOPS_V1R2_HPP
#ifndef CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
#define CK_GRIDWISE_
DYNAMIC_
CONTRACTION_DLOPS_V1R2_HPP
#define CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
multi_index_transform_helper.hpp"
#include "multi_index_transform_helper.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp"
#include "blockwise_
dynamic_
tensor_slice_transfer_v2.hpp"
#include "blockwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_
dynamic_
tensor_slice_set.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -25,7 +25,7 @@ __global__ void
...
@@ -25,7 +25,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_
dynamic_
contraction_dlops_v1r2
(
kernel_contraction_dlops_v1r2
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
...
@@ -84,12 +84,12 @@ template <index_t BlockSize,
...
@@ -84,12 +84,12 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGrid
I
te
rator
Hacks
,
typename
AGrid
S
te
p
Hacks
,
typename
BGrid
I
te
rator
Hacks
,
typename
BGrid
S
te
p
Hacks
,
typename
CGrid
I
te
rator
Hacks
,
typename
CGrid
S
te
p
Hacks
,
typename
AGridMoveSliceWindow
I
te
rator
Hacks
,
typename
AGridMoveSliceWindow
S
te
p
Hacks
,
typename
BGridMoveSliceWindow
I
te
rator
Hacks
>
typename
BGridMoveSliceWindow
S
te
p
Hacks
>
struct
Gridwise
Dynamic
ContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
struct
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -110,15 +110,13 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -110,15 +110,13 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_block_desc_gk0_gm0_gm10_gm11_gk1
=
constexpr
auto
a_block_desc_gk0_gm0_gm10_gm11_gk1
=
make_naive_tensor_descriptor_aligned_v2
(
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
GK0PerBlock
>
{},
GM0
,
I1
,
Number
<
GM1PerBlockGM11
>
{},
GK1
),
make_tuple
(
Number
<
GK0PerBlock
>
{},
GM0
,
I1
,
Number
<
GM1PerBlockGM11
>
{},
GK1
),
max_lds_align
);
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_block_desc_gk0_gn0_gn10_gn11_gk1
=
constexpr
auto
b_block_desc_gk0_gn0_gn10_gn11_gk1
=
make_naive_tensor_descriptor_aligned_v2
(
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
GK0PerBlock
>
{},
GN0
,
I1
,
Number
<
GN1PerBlockGN11
>
{},
GK1
),
make_tuple
(
Number
<
GK0PerBlock
>
{},
GN0
,
I1
,
Number
<
GN1PerBlockGN11
>
{},
GK1
),
max_lds_align
);
max_lds_align
);
...
@@ -201,7 +199,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -201,7 +199,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
const
auto
GM11
=
Number
<
GM1PerBlockGM11
>
{};
const
auto
GM11
=
Number
<
GM1PerBlockGM11
>
{};
const
auto
GM10
=
GM1
/
GM11
;
const
auto
GM10
=
GM1
/
GM11
;
const
auto
a_grid_desc_gk0_gm0_gm10_gm11_gk1
=
transform_
dynamic_
tensor_descriptor
(
const
auto
a_grid_desc_gk0_gm0_gm10_gm11_gk1
=
transform_tensor_descriptor
(
a_grid_desc_gk0_gm0_gm1_gk1
,
a_grid_desc_gk0_gm0_gm1_gk1
,
make_tuple
(
make_pass_through_transform
(
GK0
),
make_tuple
(
make_pass_through_transform
(
GK0
),
make_pass_through_transform
(
GM0
),
make_pass_through_transform
(
GM0
),
...
@@ -222,7 +220,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -222,7 +220,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
const
auto
GN11
=
Number
<
GN1PerBlockGN11
>
{};
const
auto
GN11
=
Number
<
GN1PerBlockGN11
>
{};
const
auto
GN10
=
GN1
/
GN11
;
const
auto
GN10
=
GN1
/
GN11
;
const
auto
b_grid_desc_gk0_gn0_gn10_gn11_gk1
=
transform_
dynamic_
tensor_descriptor
(
const
auto
b_grid_desc_gk0_gn0_gn10_gn11_gk1
=
transform_tensor_descriptor
(
b_grid_desc_gk0_gn0_gn1_gk1
,
b_grid_desc_gk0_gn0_gn1_gk1
,
make_tuple
(
make_pass_through_transform
(
GK0
),
make_tuple
(
make_pass_through_transform
(
GK0
),
make_pass_through_transform
(
GN0
),
make_pass_through_transform
(
GN0
),
...
@@ -259,7 +257,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -259,7 +257,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
constexpr
auto
BM0
=
BM
/
BM1
;
constexpr
auto
BM0
=
BM
/
BM1
;
constexpr
auto
BN0
=
BN
/
BN1
;
constexpr
auto
BN0
=
BN
/
BN1
;
const
auto
c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc
=
transform_tensor_descriptor
(
c_grid_desc_gm0_gm1_gn0_gn1
,
c_grid_desc_gm0_gm1_gn0_gn1
,
make_tuple
(
make_pass_through_transform
(
GM0
),
make_tuple
(
make_pass_through_transform
(
GM0
),
make_unmerge_transform
(
make_tuple
(
GM10
,
GM11
)),
make_unmerge_transform
(
make_tuple
(
GM10
,
GM11
)),
...
@@ -268,7 +266,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -268,7 +266,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
,
5
>
{}));
const
auto
c_gm10_bm_gn10_bn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
c_gm10_bm_gn10_bn_grid_desc
=
transform_tensor_descriptor
(
c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc
,
c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GM10
),
make_tuple
(
make_pass_through_transform
(
GM10
),
make_merge_transform
(
make_tuple
(
GM0
,
GM11
)),
make_merge_transform
(
make_tuple
(
GM0
,
GM11
)),
...
@@ -277,7 +275,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -277,7 +275,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1
=
transform_
dynamic_
tensor_descriptor
(
const
auto
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1
=
transform_tensor_descriptor
(
c_gm10_bm_gn10_bn_grid_desc
,
c_gm10_bm_gn10_bn_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GM10
),
make_tuple
(
make_pass_through_transform
(
GM10
),
make_unmerge_transform
(
make_tuple
(
BM0
,
BM1
)),
make_unmerge_transform
(
make_tuple
(
BM0
,
BM1
)),
...
@@ -356,26 +354,24 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -356,26 +354,24 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_block_desc_gk0_gm0_gm10_gm11_gk1
=
constexpr
auto
a_block_desc_gk0_gm0_gm10_gm11_gk1
=
make_naive_tensor_descriptor_aligned_v2
(
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
GK0PerBlock
>
{},
GM0
,
I1
,
Number
<
GM1PerBlockGM11
>
{},
GK1
),
make_tuple
(
Number
<
GK0PerBlock
>
{},
GM0
,
I1
,
Number
<
GM1PerBlockGM11
>
{},
GK1
),
max_lds_align
);
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_block_desc_gk0_gn0_gn10_gn11_gk1
=
constexpr
auto
b_block_desc_gk0_gn0_gn10_gn11_gk1
=
make_naive_tensor_descriptor_aligned_v2
(
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
GK0PerBlock
>
{},
GN0
,
I1
,
Number
<
GN1PerBlockGN11
>
{},
GK1
),
make_tuple
(
Number
<
GK0PerBlock
>
{},
GN0
,
I1
,
Number
<
GN1PerBlockGN11
>
{},
GK1
),
max_lds_align
);
max_lds_align
);
// A matrix in LDS memory for blockwise GEMM
// A matrix in LDS memory for blockwise GEMM
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_block_desc_gk0_bm_gk1
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_block_desc_gk0_bm_gk1
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
GK0PerBlock
>
{},
GM0
*
Number
<
GM1PerBlockGM11
>
{},
GK1
),
max_lds_align
);
make_tuple
(
Number
<
GK0PerBlock
>
{},
GM0
*
Number
<
GM1PerBlockGM11
>
{},
GK1
),
max_lds_align
);
// B matrix in LDS memory for blockwise GEMM
// B matrix in LDS memory for blockwise GEMM
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_block_desc_gk0_bn_gk1
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
b_block_desc_gk0_bn_gk1
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
GK0PerBlock
>
{},
GN0
*
Number
<
GN1PerBlockGN11
>
{},
GK1
),
max_lds_align
);
make_tuple
(
Number
<
GK0PerBlock
>
{},
GN0
*
Number
<
GN1PerBlockGN11
>
{},
GK1
),
max_lds_align
);
static_assert
(
a_block_desc_gk0_gm0_gm10_gm11_gk1
.
GetElementSpaceSize
()
==
static_assert
(
a_block_desc_gk0_gm0_gm10_gm11_gk1
.
GetElementSpaceSize
()
==
...
@@ -385,7 +381,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -385,7 +381,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
"wrong!"
);
"wrong!"
);
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
Blockwise
Dynamic
TensorSliceTransfer_v4r1
<
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4r1
<
BlockSize
,
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
GK0PerBlock
,
GM0
,
1
,
GM1PerBlockGM11
,
GK1
.
value
>
,
Sequence
<
GK0PerBlock
,
GM0
,
1
,
GM1PerBlockGM11
,
GK1
.
value
>
,
...
@@ -409,7 +405,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -409,7 +405,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
make_multi_index
(
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
0
,
0
,
0
,
0
,
0
));
// B matrix blockwise copy
// B matrix blockwise copy
auto
b_blockwise_copy
=
Blockwise
Dynamic
TensorSliceTransfer_v4r1
<
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4r1
<
BlockSize
,
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
GK0PerBlock
,
GN0
,
1
,
GN1PerBlockGN11
,
GK1
.
value
>
,
Sequence
<
GK0PerBlock
,
GN0
,
1
,
GN1PerBlockGN11
,
GK1
.
value
>
,
...
@@ -457,8 +453,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -457,8 +453,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
constexpr
auto
c_thread_tensor_lengths_bm0_bm1_bn0_bn1
=
constexpr
auto
c_thread_tensor_lengths_bm0_bm1_bn0_bn1
=
decltype
(
blockwise_gemm
)
::
GetCThreadTensorLengths_BM0_BM1_BN0_BN1
();
decltype
(
blockwise_gemm
)
::
GetCThreadTensorLengths_BM0_BM1_BN0_BN1
();
constexpr
auto
c_thread_desc_bm0_bm1_bn0_bn1
=
constexpr
auto
c_thread_desc_bm0_bm1_bn0_bn1
=
make_naive_tensor_descriptor_packed
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
sequence_to_tuple_of_number
(
c_thread_tensor_lengths_bm0_bm1_bn0_bn1
));
sequence_to_tuple_of_number
(
c_thread_tensor_lengths_bm0_bm1_bn0_bn1
));
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
...
@@ -475,7 +470,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -475,7 +470,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAcc
>
(
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAcc
>
(
c_thread_desc_bm0_bm1_bn0_bn1
.
GetElementSpaceSize
());
c_thread_desc_bm0_bm1_bn0_bn1
.
GetElementSpaceSize
());
Threadwise
Dynamic
TensorSliceSet_v1
<
FloatAcc
,
ThreadwiseTensorSliceSet_v1
<
FloatAcc
,
decltype
(
c_thread_desc_bm0_bm1_bn0_bn1
),
decltype
(
c_thread_desc_bm0_bm1_bn0_bn1
),
decltype
(
c_thread_tensor_lengths_bm0_bm1_bn0_bn1
)
>
{}
decltype
(
c_thread_tensor_lengths_bm0_bm1_bn0_bn1
)
>
{}
.
Run
(
c_thread_desc_bm0_bm1_bn0_bn1
,
.
Run
(
c_thread_desc_bm0_bm1_bn0_bn1
,
...
@@ -501,9 +496,9 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -501,9 +496,9 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_global_buf
,
AGrid
I
te
rator
Hacks
{});
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_global_buf
,
AGrid
S
te
p
Hacks
{});
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_global_buf
,
BGrid
I
te
rator
Hacks
{});
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_global_buf
,
BGrid
S
te
p
Hacks
{});
a_blockwise_copy
.
RunWrite
(
a_block_desc_gk0_gm0_gm10_gm11_gk1
,
a_block_even_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc_gk0_gm0_gm10_gm11_gk1
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_gk0_gn0_gn10_gn11_gk1
,
b_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_gk0_gn0_gn10_gn11_gk1
,
b_block_even_buf
);
...
@@ -520,18 +515,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -520,18 +515,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
// even iteration
// even iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
AGridMoveSliceWindow
I
te
rator
Hacks
{});
AGridMoveSliceWindow
S
te
p
Hacks
{});
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
BGridMoveSliceWindow
I
te
rator
Hacks
{});
BGridMoveSliceWindow
S
te
p
Hacks
{});
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_global_buf
,
AGrid
I
te
rator
Hacks
{});
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_global_buf
,
AGrid
S
te
p
Hacks
{});
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_global_buf
,
BGrid
I
te
rator
Hacks
{});
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_global_buf
,
BGrid
S
te
p
Hacks
{});
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
c_thread_desc_bm0_bm1_bn0_bn1
,
blockwise_gemm
.
Run
(
c_thread_desc_bm0_bm1_bn0_bn1
,
...
@@ -546,18 +541,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -546,18 +541,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
// odd iteration
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
AGridMoveSliceWindow
I
te
rator
Hacks
{});
AGridMoveSliceWindow
S
te
p
Hacks
{});
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
BGridMoveSliceWindow
I
te
rator
Hacks
{});
BGridMoveSliceWindow
S
te
p
Hacks
{});
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_global_buf
,
AGrid
I
te
rator
Hacks
{});
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_global_buf
,
AGrid
S
te
p
Hacks
{});
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_global_buf
,
BGrid
I
te
rator
Hacks
{});
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_global_buf
,
BGrid
S
te
p
Hacks
{});
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
blockwise_gemm
.
Run
(
...
@@ -576,18 +571,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -576,18 +571,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
{
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
AGridMoveSliceWindow
I
te
rator
Hacks
{});
AGridMoveSliceWindow
S
te
p
Hacks
{});
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
BGridMoveSliceWindow
I
te
rator
Hacks
{});
BGridMoveSliceWindow
S
te
p
Hacks
{});
__syncthreads
();
__syncthreads
();
// LDS double buffer: load last data from device mem
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_global_buf
,
AGrid
I
te
rator
Hacks
{});
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_global_buf
,
AGrid
S
te
p
Hacks
{});
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_global_buf
,
BGrid
I
te
rator
Hacks
{});
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_global_buf
,
BGrid
S
te
p
Hacks
{});
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
blockwise_gemm
.
Run
(
...
@@ -615,7 +610,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -615,7 +610,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
// output: register to global memory
// output: register to global memory
{
{
constexpr
auto
c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1
=
constexpr
auto
c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
make_tuple
(
I1
,
Number
<
c_thread_tensor_lengths_bm0_bm1_bn0_bn1
[
I0
]
>
{},
Number
<
c_thread_tensor_lengths_bm0_bm1_bn0_bn1
[
I0
]
>
{},
Number
<
c_thread_tensor_lengths_bm0_bm1_bn0_bn1
[
I1
]
>
{},
Number
<
c_thread_tensor_lengths_bm0_bm1_bn0_bn1
[
I1
]
>
{},
...
@@ -627,7 +622,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -627,7 +622,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
blockwise_gemm
.
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
blockwise_gemm
.
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
get_thread_local_1d_id
());
get_thread_local_1d_id
());
Threadwise
Dynamic
TensorSliceTransfer_v1r3
<
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatAcc
,
FloatC
,
FloatC
,
decltype
(
c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1
),
decltype
(
c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1
),
...
@@ -655,7 +650,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
...
@@ -655,7 +650,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
c_thread_buf
,
c_thread_buf
,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1
,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1
,
c_grid_buf
,
c_grid_buf
,
CGrid
I
te
rator
Hacks
{});
CGrid
S
te
p
Hacks
{});
}
}
}
}
};
};
...
...
composable_kernel/include/tensor_operation/gridwise_
dynamic_
gemm_dlops_v1r2.hpp
→
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r2.hpp
View file @
ccc4a1d3
#ifndef CK_GRIDWISE_
DYNAMIC_
GEMM_DLOPS_V1R2_HPP
#ifndef CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
#define CK_GRIDWISE_
DYNAMIC_
GEMM_DLOPS_V1R2_HPP
#define CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
multi_index_transform_helper.hpp"
#include "multi_index_transform_helper.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r2.hpp"
#include "blockwise_gemm_dlops_v2r2.hpp"
#include "blockwise_
dynamic_
tensor_slice_transfer.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_
dynamic_
tensor_slice_set.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -26,7 +26,7 @@ __global__ void
...
@@ -26,7 +26,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_
dynamic_
gemm_dlops_v1r2
(
kernel_gemm_dlops_v1r2
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
...
@@ -68,8 +68,7 @@ __global__ void
...
@@ -68,8 +68,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_dynamic_gemm_dlops_v1r2
(
kernel_gemm_dlops_v1r2
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
void
CONSTANT
*
p_a_k_m0_m1_grid_desc
,
const
void
CONSTANT
*
p_a_k_m0_m1_grid_desc
,
...
@@ -80,16 +79,16 @@ __global__ void
...
@@ -80,16 +79,16 @@ __global__ void
// first cast void CONSTANT void* to void*
// first cast void CONSTANT void* to void*
// second cast void* to Desc*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
// the copy constructor of tensor descriptor doesn't take address_space(4)
const
auto
a_k_m0_m1_grid_desc
=
const
auto
a_k_m0_m1_grid_desc
=
*
reinterpret_cast
<
const
AKM0M1GridDesc
*>
(
*
reinterpret_cast
<
const
AKM0M1GridDesc
*>
((
const
void
*
)
p_a_k_m0_m1_grid_desc
);
cast_pointer_to_generic_address_space
(
p_a_k_m0_m1_grid_desc
)
)
;
const
auto
b_k_n0_n1_grid_desc
=
const
auto
b_k_n0_n1_grid_desc
=
*
reinterpret_cast
<
const
BKN0N1GridDesc
*>
(
*
reinterpret_cast
<
const
BKN0N1GridDesc
*>
((
const
void
*
)
p_b_k_n0_n1_grid_desc
);
cast_pointer_to_generic_address_space
(
p_b_k_n0_n1_grid_desc
)
)
;
const
auto
c_m0_m10_m11_n0_n10_n11_grid_desc
=
const
auto
c_m0_m10_m11_n0_n10_n11_grid_desc
=
*
reinterpret_cast
<
const
CM0M10M11N0N10N11GridDesc
*>
(
*
reinterpret_cast
<
const
CM0M10M11N0N10N11GridDesc
*>
(
(
const
void
*
)
p_c_m0_m10_m11_n0_n10_n11_grid_desc
);
cast_pointer_to_generic_address_space
(
p_c_m0_m10_m11_n0_n10_n11_grid_desc
)
)
;
const
auto
c_blockid_to_m0_n0_block_cluster_adaptor
=
const
auto
c_blockid_to_m0_n0_block_cluster_adaptor
=
*
reinterpret_cast
<
const
CBlockIdToM0N0BlockClusterAdaptor
*>
(
*
reinterpret_cast
<
const
CBlockIdToM0N0BlockClusterAdaptor
*>
(
(
const
void
*
)
p_c_blockid_to_m0_n0_block_cluster_adaptor
);
cast_pointer_to_generic_address_space
(
p_c_blockid_to_m0_n0_block_cluster_adaptor
)
)
;
constexpr
index_t
shared_block_size
=
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
...
@@ -146,12 +145,12 @@ template <index_t BlockSize,
...
@@ -146,12 +145,12 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGrid
I
te
rator
Hacks
,
typename
AGrid
S
te
p
Hacks
,
typename
BGrid
I
te
rator
Hacks
,
typename
BGrid
S
te
p
Hacks
,
typename
CGrid
I
te
rator
Hacks
,
typename
CGrid
S
te
p
Hacks
,
typename
AGridMoveSliceWindow
I
te
rator
Hacks
,
typename
AGridMoveSliceWindow
S
te
p
Hacks
,
typename
BGridMoveSliceWindow
I
te
rator
Hacks
>
typename
BGridMoveSliceWindow
S
te
p
Hacks
>
struct
Gridwise
Dynamic
GemmDlops_km_kn_mn_v1r2
struct
GridwiseGemmDlops_km_kn_mn_v1r2
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -167,12 +166,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -167,12 +166,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_k_m_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_k_m_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlockM1
>
{}),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlockM1
>
{}),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_k_n_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
b_k_n_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlockN1
>
{}),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlockN1
>
{}),
max_lds_align
);
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
...
@@ -230,7 +229,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -230,7 +229,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
const
auto
M1
=
Number
<
MPerBlockM1
>
{};
const
auto
M1
=
Number
<
MPerBlockM1
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
M0
=
M
/
M1
;
const
auto
a_k_m0_m1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
a_k_m0_m1_grid_desc
=
transform_tensor_descriptor
(
a_k_m_grid_desc
,
a_k_m_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
))),
make_tuple
(
make_pass_through_transform
(
K
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
...
@@ -248,7 +247,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -248,7 +247,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
const
auto
N1
=
Number
<
NPerBlockN1
>
{};
const
auto
N1
=
Number
<
NPerBlockN1
>
{};
const
auto
N0
=
N
/
N1
;
const
auto
N0
=
N
/
N1
;
const
auto
b_k_n0_n1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
b_k_n0_n1_grid_desc
=
transform_tensor_descriptor
(
b_k_n_grid_desc
,
b_k_n_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
))),
make_tuple
(
make_pass_through_transform
(
K
),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
...
@@ -277,7 +276,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -277,7 +276,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
constexpr
auto
M10
=
M1
/
M11
;
constexpr
auto
M10
=
M1
/
M11
;
constexpr
auto
N10
=
N1
/
N11
;
constexpr
auto
N10
=
N1
/
N11
;
const
auto
c_m0_m10_m11_n0_n10_n11_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
c_m0_m10_m11_n0_n10_n11_grid_desc
=
transform_tensor_descriptor
(
c_m_n_grid_desc
,
c_m_n_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M10
,
M11
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M10
,
M11
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N10
,
N11
))),
make_unmerge_transform
(
make_tuple
(
N0
,
N10
,
N11
))),
...
@@ -352,27 +351,27 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -352,27 +351,27 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_k_m_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_k_m_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlockM1
>
{}),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlockM1
>
{}),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_k_n_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
b_k_n_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlockN1
>
{}),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlockN1
>
{}),
max_lds_align
);
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_k_m0_m1_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_k_m0_m1_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
I1
,
Number
<
MPerBlockM1
>
{}),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
I1
,
Number
<
MPerBlockM1
>
{}),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_k_n0_n1_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
b_k_n0_n1_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
I1
,
Number
<
NPerBlockN1
>
{}),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
I1
,
Number
<
NPerBlockN1
>
{}),
max_lds_align
);
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
Blockwise
Dynamic
TensorSliceTransfer_v4
<
BlockSize
,
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
KPerBlock
,
1
,
MPerBlockM1
>
,
Sequence
<
KPerBlock
,
1
,
MPerBlockM1
>
,
ABlockTransferThreadSliceLengths_K_M0_M1
,
ABlockTransferThreadSliceLengths_K_M0_M1
,
...
@@ -398,7 +397,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -398,7 +397,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
// B matrix blockwise copy
// B matrix blockwise copy
auto
b_blockwise_copy
=
auto
b_blockwise_copy
=
Blockwise
Dynamic
TensorSliceTransfer_v4
<
BlockSize
,
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
KPerBlock
,
1
,
NPerBlockN1
>
,
Sequence
<
KPerBlock
,
1
,
NPerBlockN1
>
,
BBlockTransferThreadSliceLengths_K_N0_N1
,
BBlockTransferThreadSliceLengths_K_N0_N1
,
...
@@ -447,8 +446,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -447,8 +446,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
constexpr
auto
c_m10_m11_n10_n11_thread_tensor_lengths
=
constexpr
auto
c_m10_m11_n10_n11_thread_tensor_lengths
=
decltype
(
blockwise_gemm
)
::
GetCM0M1N0N1ThreadTensorLengths
();
decltype
(
blockwise_gemm
)
::
GetCM0M1N0N1ThreadTensorLengths
();
constexpr
auto
c_m10_m11_n10_n11_thread_desc
=
constexpr
auto
c_m10_m11_n10_n11_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
sequence_to_tuple_of_number
(
c_m10_m11_n10_n11_thread_tensor_lengths
));
sequence_to_tuple_of_number
(
c_m10_m11_n10_n11_thread_tensor_lengths
));
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
...
@@ -465,7 +463,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -465,7 +463,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAcc
>
(
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAcc
>
(
c_m10_m11_n10_n11_thread_desc
.
GetElementSpaceSize
());
c_m10_m11_n10_n11_thread_desc
.
GetElementSpaceSize
());
Threadwise
Dynamic
TensorSliceSet_v1
<
FloatAcc
,
ThreadwiseTensorSliceSet_v1
<
FloatAcc
,
decltype
(
c_m10_m11_n10_n11_thread_desc
),
decltype
(
c_m10_m11_n10_n11_thread_desc
),
decltype
(
c_m10_m11_n10_n11_thread_tensor_lengths
)
>
{}
decltype
(
c_m10_m11_n10_n11_thread_tensor_lengths
)
>
{}
.
Run
(
c_m10_m11_n10_n11_thread_desc
,
.
Run
(
c_m10_m11_n10_n11_thread_desc
,
...
@@ -477,15 +475,15 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -477,15 +475,15 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_k_m0_m1_global_
i
te
rator
_hacks
=
AGrid
I
te
rator
Hacks
{};
constexpr
auto
a_k_m0_m1_global_
s
te
p
_hacks
=
AGrid
S
te
p
Hacks
{};
constexpr
auto
b_k_n0_n1_global_
i
te
rator
_hacks
=
BGrid
I
te
rator
Hacks
{};
constexpr
auto
b_k_n0_n1_global_
s
te
p
_hacks
=
BGrid
S
te
p
Hacks
{};
// hack to control index calculation when move slice window for A and B matrix for
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
// threadwise copy
constexpr
auto
a_k_m0_m1_global_move_slice_window_
i
te
rator
_hack
=
constexpr
auto
a_k_m0_m1_global_move_slice_window_
s
te
p
_hack
=
AGridMoveSliceWindow
I
te
rator
Hacks
{};
AGridMoveSliceWindow
S
te
p
Hacks
{};
constexpr
auto
b_k_n0_n1_global_move_slice_window_
i
te
rator
_hack
=
constexpr
auto
b_k_n0_n1_global_move_slice_window_
s
te
p
_hack
=
BGridMoveSliceWindow
I
te
rator
Hacks
{};
BGridMoveSliceWindow
S
te
p
Hacks
{};
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_a_block_double
,
a_k_m0_m1_block_desc
.
GetElementSpaceSize
());
p_a_block_double
,
a_k_m0_m1_block_desc
.
GetElementSpaceSize
());
...
@@ -502,9 +500,9 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -502,9 +500,9 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_
i
te
rator
_hacks
);
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_
s
te
p
_hacks
);
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_
i
te
rator
_hacks
);
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_
s
te
p
_hacks
);
a_blockwise_copy
.
RunWrite
(
a_k_m0_m1_block_desc
,
a_block_even_buf
);
a_blockwise_copy
.
RunWrite
(
a_k_m0_m1_block_desc
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n0_n1_block_desc
,
b_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n0_n1_block_desc
,
b_block_even_buf
);
...
@@ -519,22 +517,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -519,22 +517,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
do
do
{
{
// even iteration
// even iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m0_m1_grid_desc
,
a_k_m0_m1_grid_desc
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
a_k_m0_m1_global_move_slice_window_iterator_hack
);
a_k_m0_m1_global_move_slice_window_step_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n0_n1_grid_desc
,
b_k_n0_n1_grid_desc
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
b_k_n0_n1_global_move_slice_window_
i
te
rator
_hack
);
b_k_n0_n1_global_move_slice_window_
s
te
p
_hack
);
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_
i
te
rator
_hacks
);
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_
s
te
p
_hacks
);
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_
i
te
rator
_hacks
);
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_
s
te
p
_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
c_m10_m11_n10_n11_thread_desc
,
blockwise_gemm
.
Run
(
c_m10_m11_n10_n11_thread_desc
,
...
@@ -547,22 +543,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -547,22 +543,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
b_blockwise_copy
.
RunWrite
(
b_k_n0_n1_block_desc
,
b_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n0_n1_block_desc
,
b_block_odd_buf
);
// odd iteration
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m0_m1_grid_desc
,
a_k_m0_m1_grid_desc
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
a_k_m0_m1_global_move_slice_window_iterator_hack
);
a_k_m0_m1_global_move_slice_window_step_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n0_n1_grid_desc
,
b_k_n0_n1_grid_desc
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
b_k_n0_n1_global_move_slice_window_
i
te
rator
_hack
);
b_k_n0_n1_global_move_slice_window_
s
te
p
_hack
);
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_
i
te
rator
_hacks
);
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_
s
te
p
_hacks
);
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_
i
te
rator
_hacks
);
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_
s
te
p
_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
blockwise_gemm
.
Run
(
...
@@ -581,18 +575,18 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -581,18 +575,18 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
{
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m0_m1_grid_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m0_m1_grid_desc
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
a_k_m0_m1_global_move_slice_window_
i
te
rator
_hack
);
a_k_m0_m1_global_move_slice_window_
s
te
p
_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n0_n1_grid_desc
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n0_n1_grid_desc
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
b_k_n0_n1_global_move_slice_window_
i
te
rator
_hack
);
b_k_n0_n1_global_move_slice_window_
s
te
p
_hack
);
__syncthreads
();
__syncthreads
();
// LDS double buffer: load last data from device mem
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_
i
te
rator
_hacks
);
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_
s
te
p
_hacks
);
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_
i
te
rator
_hacks
);
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_
s
te
p
_hacks
);
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
blockwise_gemm
.
Run
(
...
@@ -619,19 +613,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -619,19 +613,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
// output: register to global memory
// output: register to global memory
{
{
constexpr
index_t
M11
=
M1PerThreadM111
*
M11N11ThreadClusterM1100
*
M11N11ThreadClusterM1101
;
constexpr
index_t
N11
=
N1PerThreadN111
*
M11N11ThreadClusterN1100
*
M11N11ThreadClusterN1101
;
constexpr
index_t
M10
=
MPerBlockM1
/
M11
;
constexpr
index_t
N10
=
NPerBlockN1
/
N11
;
constexpr
index_t
M111
=
M1PerThreadM111
;
constexpr
index_t
N111
=
N1PerThreadN111
;
constexpr
auto
c_m0_m10_m11_n0_n10_n11_thread_desc
=
constexpr
auto
c_m0_m10_m11_n0_n10_n11_thread_desc
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
make_tuple
(
I1
,
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
]
>
{},
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
]
>
{},
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
]
>
{},
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
]
>
{},
...
@@ -642,7 +625,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -642,7 +625,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
const
auto
c_m10_m11_n10_n11_thread_origin_idx_on_block
=
const
auto
c_m10_m11_n10_n11_thread_origin_idx_on_block
=
blockwise_gemm
.
CalculateCM0M1N0N1ThreadOriginOnBlock
(
get_thread_local_1d_id
());
blockwise_gemm
.
CalculateCM0M1N0N1ThreadOriginOnBlock
(
get_thread_local_1d_id
());
Threadwise
Dynamic
TensorSliceTransfer_v1r3
<
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatAcc
,
FloatC
,
FloatC
,
decltype
(
c_m0_m10_m11_n0_n10_n11_thread_desc
),
decltype
(
c_m0_m10_m11_n0_n10_n11_thread_desc
),
...
@@ -670,7 +653,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
...
@@ -670,7 +653,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
c_thread_buf
,
c_thread_buf
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_grid_buf
,
c_grid_buf
,
CGrid
I
te
rator
Hacks
{});
CGrid
S
te
p
Hacks
{});
}
}
}
}
};
};
...
...
composable_kernel/include/tensor_operation/gridwise_
dynamic_
gemm_dlops_v1r3.hpp
→
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp
View file @
ccc4a1d3
#ifndef CK_GRIDWISE_
DYNAMIC_
GEMM_V1R3_HPP
#ifndef CK_GRIDWISE_GEMM_V1R3_HPP
#define CK_GRIDWISE_
DYNAMIC_
GEMM_V1R3_HPP
#define CK_GRIDWISE_GEMM_V1R3_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
multi_index_transform_helper.hpp"
#include "multi_index_transform_helper.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp"
#include "blockwise_
dynamic_
tensor_slice_transfer_v2.hpp"
#include "blockwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer_v2.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_
dynamic_
tensor_slice_set.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -26,7 +26,7 @@ __global__ void
...
@@ -26,7 +26,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_
dynamic_
gemm_dlops_v1r3
(
kernel_gemm_dlops_v1r3
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
...
@@ -68,8 +68,7 @@ __global__ void
...
@@ -68,8 +68,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_dynamic_gemm_dlops_v1r3
(
kernel_gemm_dlops_v1r3
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
void
CONSTANT
*
p_a_k0_m0_m1_k1_grid_desc
,
const
void
CONSTANT
*
p_a_k0_m0_m1_k1_grid_desc
,
...
@@ -80,16 +79,16 @@ __global__ void
...
@@ -80,16 +79,16 @@ __global__ void
// first cast void CONSTANT void* to void*
// first cast void CONSTANT void* to void*
// second cast void* to Desc*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
// the copy constructor of tensor descriptor doesn't take address_space(4)
const
auto
a_k0_m0_m1_k1_grid_desc
=
const
auto
a_k0_m0_m1_k1_grid_desc
=
*
reinterpret_cast
<
const
AK0M0M1K1GridDesc
*>
(
*
reinterpret_cast
<
const
AK0M0M1K1GridDesc
*>
((
const
void
*
)
p_a_k0_m0_m1_k1_grid_desc
);
cast_pointer_to_generic_address_space
(
p_a_k0_m0_m1_k1_grid_desc
)
)
;
const
auto
b_k0_n0_n1_k1_grid_desc
=
const
auto
b_k0_n0_n1_k1_grid_desc
=
*
reinterpret_cast
<
const
BK0N0N1K1GridDesc
*>
(
*
reinterpret_cast
<
const
BK0N0N1K1GridDesc
*>
((
const
void
*
)
p_b_k0_n0_n1_k1_grid_desc
);
cast_pointer_to_generic_address_space
(
p_b_k0_n0_n1_k1_grid_desc
)
)
;
const
auto
c_m0_m10_m11_n0_n10_n11_grid_desc
=
const
auto
c_m0_m10_m11_n0_n10_n11_grid_desc
=
*
reinterpret_cast
<
const
CM0M10M11N0N10N11GridDesc
*>
(
*
reinterpret_cast
<
const
CM0M10M11N0N10N11GridDesc
*>
(
(
const
void
*
)
p_c_m0_m10_m11_n0_n10_n11_grid_desc
);
cast_pointer_to_generic_address_space
(
p_c_m0_m10_m11_n0_n10_n11_grid_desc
)
)
;
const
auto
c_blockid_to_m0_n0_block_cluster_adaptor
=
const
auto
c_blockid_to_m0_n0_block_cluster_adaptor
=
*
reinterpret_cast
<
const
CBlockIdToM0N0BlockClusterAdaptor
*>
(
*
reinterpret_cast
<
const
CBlockIdToM0N0BlockClusterAdaptor
*>
(
(
const
void
*
)
p_c_blockid_to_m0_n0_block_cluster_adaptor
);
cast_pointer_to_generic_address_space
(
p_c_blockid_to_m0_n0_block_cluster_adaptor
)
)
;
constexpr
index_t
shared_block_size
=
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
...
@@ -142,12 +141,12 @@ template <index_t BlockSize,
...
@@ -142,12 +141,12 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGrid
I
te
rator
Hacks
,
typename
AGrid
S
te
p
Hacks
,
typename
BGrid
I
te
rator
Hacks
,
typename
BGrid
S
te
p
Hacks
,
typename
CGrid
I
te
rator
Hacks
,
typename
CGrid
S
te
p
Hacks
,
typename
AGridMoveSliceWindow
I
te
rator
Hacks
,
typename
AGridMoveSliceWindow
S
te
p
Hacks
,
typename
BGridMoveSliceWindow
I
te
rator
Hacks
>
typename
BGridMoveSliceWindow
S
te
p
Hacks
>
struct
Gridwise
Dynamic
GemmDlops_km_kn_mn_v1r3
struct
GridwiseGemmDlops_km_kn_mn_v1r3
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -164,12 +163,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -164,12 +163,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// TODO: check alignment
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_k_m_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_k_m_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlockM1
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlockM1
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_k_n_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
b_k_n_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlockN1
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlockN1
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// TODO: check alignment
...
@@ -191,12 +190,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -191,12 +190,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
const
auto
M
=
a_k0_m_k1_grid_desc
.
GetLength
(
I1
);
const
auto
M
=
a_k0_m_k1_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k0_n_k1_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k0_n_k1_grid_desc
.
GetLength
(
I1
);
const
auto
K0
=
a_k0_m_k1_grid_desc
.
GetLength
(
I0
);
const
auto
K0
=
a_k0_m_k1_grid_desc
.
GetLength
(
I0
);
const
auto
K1
=
a_k0_m_k1_grid_desc
.
GetLength
(
I2
);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
(
M
==
c_m_n_grid_desc
.
GetLength
(
I0
)
&&
N
==
c_m_n_grid_desc
.
GetLength
(
I1
)
&&
return
(
M
==
c_m_n_grid_desc
.
GetLength
(
I0
)
&&
N
==
c_m_n_grid_desc
.
GetLength
(
I1
)
&&
K0
==
b_k0_n_k1_grid_desc
.
GetLength
(
I0
)
&&
K0
==
b_k0_n_k1_grid_desc
.
GetLength
(
I0
)
&&
K1
==
a_k0_m_k1_grid_desc
.
GetLength
(
I2
)
&&
K1
==
b_k0_n_k1_grid_desc
.
GetLength
(
I2
))
&&
K1
==
b_k0_n_k1_grid_desc
.
GetLength
(
I2
))
&&
(
M
%
MPerBlockM1
==
0
&&
N
%
NPerBlockN1
==
0
&&
K0
%
KPerBlock
==
0
);
(
M
%
MPerBlockM1
==
0
&&
N
%
NPerBlockN1
==
0
&&
K0
%
KPerBlock
==
0
);
}
}
...
@@ -231,8 +230,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -231,8 +230,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
const
auto
M1
=
Number
<
MPerBlockM1
>
{};
const
auto
M1
=
Number
<
MPerBlockM1
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
M0
=
M
/
M1
;
const
auto
a_k0_m0_m1_k1_grid_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
a_k0_m0_m1_k1_grid_desc
=
a_k0_m_k1_grid_desc
,
transform_tensor_descriptor
(
a_k0_m_k1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K0
),
make_tuple
(
make_pass_through_transform
(
K0
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
)),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
)),
make_pass_through_transform
(
K1
)),
make_pass_through_transform
(
K1
)),
...
@@ -251,8 +250,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -251,8 +250,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
const
auto
N1
=
Number
<
NPerBlockN1
>
{};
const
auto
N1
=
Number
<
NPerBlockN1
>
{};
const
auto
N0
=
N
/
N1
;
const
auto
N0
=
N
/
N1
;
const
auto
b_k0_n0_n1_k1_grid_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
b_k0_n0_n1_k1_grid_desc
=
b_k0_n_k1_grid_desc
,
transform_tensor_descriptor
(
b_k0_n_k1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K0
),
make_tuple
(
make_pass_through_transform
(
K0
),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
)),
make_pass_through_transform
(
K1
)),
make_pass_through_transform
(
K1
)),
...
@@ -284,7 +283,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -284,7 +283,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
constexpr
auto
M10
=
M1
/
M11
;
constexpr
auto
M10
=
M1
/
M11
;
constexpr
auto
N10
=
N1
/
N11
;
constexpr
auto
N10
=
N1
/
N11
;
const
auto
c_m0_m10_m11_n0_n10_n11_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
c_m0_m10_m11_n0_n10_n11_grid_desc
=
transform_tensor_descriptor
(
c_m_n_grid_desc
,
c_m_n_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M10
,
M11
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M10
,
M11
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N10
,
N11
))),
make_unmerge_transform
(
make_tuple
(
N0
,
N10
,
N11
))),
...
@@ -355,23 +354,23 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -355,23 +354,23 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// TODO: check alignment
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_k0_m0_m1_k1_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_k0_m0_m1_k1_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
I1
,
Number
<
MPerBlockM1
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
I1
,
Number
<
MPerBlockM1
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_k0_n0_n1_k1_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
b_k0_n0_n1_k1_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
I1
,
Number
<
NPerBlockN1
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
I1
,
Number
<
NPerBlockN1
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM
// A matrix in LDS memory, for blockwise GEMM
constexpr
auto
a_k0_m_k1_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlockM1
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlockM1
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM
// B matrix in LDS memory, for blockwise GEMM
constexpr
auto
b_k0_n_k1_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
b_k0_n_k1_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlockN1
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlockN1
>
{},
K1
),
max_lds_align
);
static_assert
(
a_k0_m0_m1_k1_block_desc
.
GetElementSpaceSize
()
==
static_assert
(
a_k0_m0_m1_k1_block_desc
.
GetElementSpaceSize
()
==
...
@@ -381,7 +380,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -381,7 +380,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
"wrong!"
);
"wrong!"
);
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
Blockwise
Dynamic
TensorSliceTransfer_v4r1
<
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4r1
<
BlockSize
,
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
KPerBlock
,
1
,
MPerBlockM1
,
K1
.
value
>
,
Sequence
<
KPerBlock
,
1
,
MPerBlockM1
,
K1
.
value
>
,
...
@@ -405,7 +404,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -405,7 +404,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
make_multi_index
(
0
,
0
,
0
,
0
));
make_multi_index
(
0
,
0
,
0
,
0
));
// B matrix blockwise copy
// B matrix blockwise copy
auto
b_blockwise_copy
=
Blockwise
Dynamic
TensorSliceTransfer_v4r1
<
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4r1
<
BlockSize
,
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
KPerBlock
,
1
,
NPerBlockN1
,
K1
.
value
>
,
Sequence
<
KPerBlock
,
1
,
NPerBlockN1
,
K1
.
value
>
,
...
@@ -453,8 +452,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -453,8 +452,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
constexpr
auto
c_m10_m11_n10_n11_thread_tensor_lengths
=
constexpr
auto
c_m10_m11_n10_n11_thread_tensor_lengths
=
decltype
(
blockwise_gemm
)
::
GetCThreadTensorLengths_BM0_BM1_BN0_BN1
();
decltype
(
blockwise_gemm
)
::
GetCThreadTensorLengths_BM0_BM1_BN0_BN1
();
constexpr
auto
c_m10_m11_n10_n11_thread_desc
=
constexpr
auto
c_m10_m11_n10_n11_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
sequence_to_tuple_of_number
(
c_m10_m11_n10_n11_thread_tensor_lengths
));
sequence_to_tuple_of_number
(
c_m10_m11_n10_n11_thread_tensor_lengths
));
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
...
@@ -471,7 +469,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -471,7 +469,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAcc
>
(
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAcc
>
(
c_m10_m11_n10_n11_thread_desc
.
GetElementSpaceSize
());
c_m10_m11_n10_n11_thread_desc
.
GetElementSpaceSize
());
Threadwise
Dynamic
TensorSliceSet_v1
<
FloatAcc
,
ThreadwiseTensorSliceSet_v1
<
FloatAcc
,
decltype
(
c_m10_m11_n10_n11_thread_desc
),
decltype
(
c_m10_m11_n10_n11_thread_desc
),
decltype
(
c_m10_m11_n10_n11_thread_tensor_lengths
)
>
{}
decltype
(
c_m10_m11_n10_n11_thread_tensor_lengths
)
>
{}
.
Run
(
c_m10_m11_n10_n11_thread_desc
,
.
Run
(
c_m10_m11_n10_n11_thread_desc
,
...
@@ -496,8 +494,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -496,8 +494,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGrid
I
te
rator
Hacks
{});
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGrid
S
te
p
Hacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGrid
I
te
rator
Hacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGrid
S
te
p
Hacks
{});
a_blockwise_copy
.
RunWrite
(
a_k0_m0_m1_k1_block_desc
,
a_block_even_buf
);
a_blockwise_copy
.
RunWrite
(
a_k0_m0_m1_k1_block_desc
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k0_n0_n1_k1_block_desc
,
b_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k0_n0_n1_k1_block_desc
,
b_block_even_buf
);
...
@@ -516,18 +514,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -516,18 +514,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// even iteration
// even iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m0_m1_k1_grid_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m0_m1_k1_grid_desc
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
AGridMoveSliceWindow
I
te
rator
Hacks
{});
AGridMoveSliceWindow
S
te
p
Hacks
{});
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n0_n1_k1_grid_desc
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n0_n1_k1_grid_desc
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
BGridMoveSliceWindow
I
te
rator
Hacks
{});
BGridMoveSliceWindow
S
te
p
Hacks
{});
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGridStepHacks
{});
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGridIteratorHacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGridStepHacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGridIteratorHacks
{});
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
c_m10_m11_n10_n11_thread_desc
,
blockwise_gemm
.
Run
(
c_m10_m11_n10_n11_thread_desc
,
...
@@ -542,18 +538,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -542,18 +538,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// odd iteration
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m0_m1_k1_grid_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m0_m1_k1_grid_desc
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
AGridMoveSliceWindow
I
te
rator
Hacks
{});
AGridMoveSliceWindow
S
te
p
Hacks
{});
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n0_n1_k1_grid_desc
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n0_n1_k1_grid_desc
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
BGridMoveSliceWindow
I
te
rator
Hacks
{});
BGridMoveSliceWindow
S
te
p
Hacks
{});
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGridStepHacks
{});
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGridIteratorHacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGridStepHacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGridIteratorHacks
{});
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
blockwise_gemm
.
Run
(
...
@@ -570,18 +564,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -570,18 +564,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// LDS double buffer: tail
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
{
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m0_m1_k1_grid_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_block_slice_copy_step
,
a_k0_m0_m1_k1_grid_desc
,
a_block_slice_copy_step
,
AGridMoveSliceWindowStepHacks
{});
AGridMoveSliceWindowIteratorHacks
{});
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n0_n1_k1_grid_desc
,
b_k0_n0_n1_k1_grid_desc
,
b_block_slice_copy_step
,
BGridMoveSliceWindowStepHacks
{});
b_block_slice_copy_step
,
BGridMoveSliceWindowIteratorHacks
{});
__syncthreads
();
__syncthreads
();
// LDS double buffer: load last data from device mem
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGrid
I
te
rator
Hacks
{});
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGrid
S
te
p
Hacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGrid
I
te
rator
Hacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGrid
S
te
p
Hacks
{});
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
blockwise_gemm
.
Run
(
...
@@ -608,21 +600,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -608,21 +600,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
// output: register to global memory
// output: register to global memory
{
{
constexpr
auto
M11
=
Number
<
container_reduce
(
M11N11ThreadClusterM110Xs
{},
math
::
multiplies_v2
{},
I1
)
*
M1PerThreadM111
>
{};
constexpr
auto
N11
=
Number
<
container_reduce
(
M11N11ThreadClusterN110Xs
{},
math
::
multiplies_v2
{},
I1
)
*
N1PerThreadN111
>
{};
constexpr
index_t
M10
=
MPerBlockM1
/
M11
;
constexpr
index_t
N10
=
NPerBlockN1
/
N11
;
constexpr
index_t
M111
=
M1PerThreadM111
;
constexpr
index_t
N111
=
N1PerThreadN111
;
constexpr
auto
c_m0_m10_m11_n0_n10_n11_thread_desc
=
constexpr
auto
c_m0_m10_m11_n0_n10_n11_thread_desc
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
make_tuple
(
I1
,
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
]
>
{},
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
]
>
{},
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
]
>
{},
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
]
>
{},
...
@@ -634,7 +613,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -634,7 +613,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
blockwise_gemm
.
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
blockwise_gemm
.
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
get_thread_local_1d_id
());
get_thread_local_1d_id
());
Threadwise
Dynamic
TensorSliceTransfer_v1r3
<
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatAcc
,
FloatC
,
FloatC
,
decltype
(
c_m0_m10_m11_n0_n10_n11_thread_desc
),
decltype
(
c_m0_m10_m11_n0_n10_n11_thread_desc
),
...
@@ -662,7 +641,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
...
@@ -662,7 +641,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
c_thread_buf
,
c_thread_buf
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_grid_buf
,
c_grid_buf
,
CGrid
I
te
rator
Hacks
{});
CGrid
S
te
p
Hacks
{});
}
}
}
}
};
};
...
...
composable_kernel/include/tensor_operation/gridwise_
dynamic_
gemm_dlops_v2.hpp
→
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
View file @
ccc4a1d3
#ifndef CK_GRIDWISE_
DYNAMIC_
GEMM_V2_HPP
#ifndef CK_GRIDWISE_GEMM_V2_HPP
#define CK_GRIDWISE_
DYNAMIC_
GEMM_V2_HPP
#define CK_GRIDWISE_GEMM_V2_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
multi_index_transform_helper.hpp"
#include "multi_index_transform_helper.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_
dynamic_
tensor_slice_transfer.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "blockwise_gemm_dlops_v3.hpp"
#include "blockwise_gemm_dlops_v3.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -42,12 +42,12 @@ template <index_t BlockSize,
...
@@ -42,12 +42,12 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGlobal
I
te
rator
Hacks
,
typename
AGlobal
S
te
p
Hacks
,
typename
BGlobal
I
te
rator
Hacks
,
typename
BGlobal
S
te
p
Hacks
,
typename
CGlobal
I
te
rator
Hacks
,
typename
CGlobal
S
te
p
Hacks
,
typename
AGlobalMoveSliceWindow
I
te
rator
Hacks
,
typename
AGlobalMoveSliceWindow
S
te
p
Hacks
,
typename
BGlobalMoveSliceWindow
I
te
rator
Hacks
>
typename
BGlobalMoveSliceWindow
S
te
p
Hacks
>
struct
Gridwise
Dynamic
GemmDlops_km_kn_mn_v3
struct
GridwiseGemmDlops_km_kn_mn_v3
{
{
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
{
...
@@ -58,7 +58,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
...
@@ -58,7 +58,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_e_k_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_e_k_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
E
>
{},
Number
<
KPerBlock
>
{}),
max_lds_align
);
make_tuple
(
Number
<
E
>
{},
Number
<
KPerBlock
>
{}),
max_lds_align
);
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
...
@@ -102,7 +102,6 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
...
@@ -102,7 +102,6 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
// divide block work by [M, N]
// divide block work by [M, N]
#if 0
#if 0
const auto k_block_work_num = K / Number<KPerBlock>{};
const auto ho_block_work_num = Ho / Number<HoPerBlock>{};
const auto ho_block_work_num = Ho / Number<HoPerBlock>{};
const auto wo_block_work_num = Wo / Number<WoPerBlock>{};
const auto wo_block_work_num = Wo / Number<WoPerBlock>{};
const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num;
const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num;
...
@@ -114,7 +113,6 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
...
@@ -114,7 +113,6 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
#else
#else
// Hack: this force result into SGPR
// Hack: this force result into SGPR
const
index_t
k_block_work_num
=
__builtin_amdgcn_readfirstlane
(
K
/
KPerBlock
);
const
index_t
ho_block_work_num
=
__builtin_amdgcn_readfirstlane
(
Ho
/
HoPerBlock
);
const
index_t
ho_block_work_num
=
__builtin_amdgcn_readfirstlane
(
Ho
/
HoPerBlock
);
const
index_t
wo_block_work_num
=
__builtin_amdgcn_readfirstlane
(
Wo
/
WoPerBlock
);
const
index_t
wo_block_work_num
=
__builtin_amdgcn_readfirstlane
(
Wo
/
WoPerBlock
);
const
index_t
hwo_block_work_num
=
ho_block_work_num
*
wo_block_work_num
;
const
index_t
hwo_block_work_num
=
ho_block_work_num
*
wo_block_work_num
;
...
@@ -134,22 +132,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
...
@@ -134,22 +132,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_e_k_block_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_e_k_block_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
EPerBlock
>
{},
Number
<
KPerBlock
>
{}),
max_lds_align
);
make_tuple
(
Number
<
EPerBlock
>
{},
Number
<
KPerBlock
>
{}),
max_lds_align
);
constexpr
auto
a_e_k_desc
=
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_e_k_desc
=
make_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
E
>
{},
Number
<
KPerBlock
>
{}),
max_lds_align
);
make_tuple
(
Number
<
E
>
{},
Number
<
KPerBlock
>
{}),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_e_n_ho_wo_block_desc
=
constexpr
auto
b_e_n_ho_wo_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
EPerBlock
>
{},
Number
<
1
>
{},
Number
<
HoPerBlock
>
{},
Number
<
WoPerBlock
>
{}));
Number
<
EPerBlock
>
{},
Number
<
1
>
{},
Number
<
HoPerBlock
>
{},
Number
<
WoPerBlock
>
{}));
// c_thread_mtx definition: this is a mess
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k_n_ho_wo_thread_desc
=
constexpr
auto
c_k_n_ho_wo_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
Number
<
KPerThread
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
auto
blockwise_gemm
=
auto
blockwise_gemm
=
...
@@ -184,7 +180,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
...
@@ -184,7 +180,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
Blockwise
Dynamic
TensorSliceTransfer_v4
<
BlockSize
,
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
E
,
KPerBlock
>
,
Sequence
<
E
,
KPerBlock
>
,
ABlockTransferThreadSliceLengths_E_K
,
ABlockTransferThreadSliceLengths_E_K
,
...
@@ -203,18 +199,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
...
@@ -203,18 +199,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
>
(
a_e_k_global_desc
,
a_e_k_global_desc
,
make_multi_index
(
0
,
k_block_data_on_global
),
make_multi_index
(
0
,
k_block_data_on_global
),
a_e_k_desc
,
a_e_k_desc
,
make_multi_index
(
0
,
0
));
make_multi_index
(
0
,
0
));
constexpr
auto
b_e_n_ho_wo_thread_desc
=
constexpr
auto
b_e_n_ho_wo_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
EPerBlock
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
Number
<
EPerBlock
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
auto
b_threadwise_transfer
=
ThreadwiseDynamicTensorSliceTransfer_v2
<
auto
b_threadwise_transfer
=
FloatAB
,
ThreadwiseTensorSliceTransfer_v2
<
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
b_e_n_ho_wo_global_desc
),
decltype
(
b_e_n_ho_wo_global_desc
),
decltype
(
b_e_n_ho_wo_thread_desc
),
decltype
(
b_e_n_ho_wo_thread_desc
),
...
@@ -223,7 +217,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
...
@@ -223,7 +217,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
1
,
1
,
true
>
(
b_e_n_ho_wo_global_desc
,
true
>
(
b_e_n_ho_wo_global_desc
,
make_multi_index
(
0
,
0
,
ho_thread_data_on_global
,
wo_thread_data_on_global
));
make_multi_index
(
0
,
0
,
ho_thread_data_on_global
,
wo_thread_data_on_global
));
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
...
@@ -232,11 +227,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
...
@@ -232,11 +227,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
// register allocation for output
// register allocation for output
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAcc
,
FloatAcc
,
c_k_n_ho_wo_thread_desc
.
GetElementSpaceSize
()
>
c_k_n_ho_wo_thread_desc
.
GetElementSpaceSize
(),
true
>
c_thread_buf
;
c_thread_buf
;
// initialize output thread tensor
// initialize output thread tensor
Threadwise
Dynamic
TensorSliceSet_v1
<
FloatAcc
,
ThreadwiseTensorSliceSet_v1
<
FloatAcc
,
decltype
(
c_k_n_ho_wo_thread_desc
),
decltype
(
c_k_n_ho_wo_thread_desc
),
Sequence
<
KPerThread
,
1
,
HoPerThread
,
WoPerThread
>>
{}
Sequence
<
KPerThread
,
1
,
HoPerThread
,
WoPerThread
>>
{}
.
Run
(
c_k_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
FloatAcc
{
0
});
.
Run
(
c_k_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
FloatAcc
{
0
});
...
@@ -244,32 +240,32 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
...
@@ -244,32 +240,32 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
EPerBlock
,
0
,
0
,
0
);
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
EPerBlock
,
0
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_e_k_global_
i
te
rator
_hacks
=
AGlobal
I
te
rator
Hacks
{};
constexpr
auto
a_e_k_global_
s
te
p
_hacks
=
AGlobal
S
te
p
Hacks
{};
constexpr
auto
b_e_n_ho_wo_global_
i
te
rator
_hacks
=
BGlobal
I
te
rator
Hacks
{};
constexpr
auto
b_e_n_ho_wo_global_
s
te
p
_hacks
=
BGlobal
S
te
p
Hacks
{};
// hack to control index calculation when move slice window for A and B matrix for
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
// threadwise copy
constexpr
auto
a_e_k_global_move_slice_window_iterator_hack
=
constexpr
auto
a_e_k_global_move_slice_window_step_hack
=
AGlobalMoveSliceWindowStepHacks
{};
AGlobalMoveSliceWindowIteratorHacks
{};
constexpr
auto
b_e_n_ho_wo_global_move_slice_window_step_hack
=
constexpr
auto
b_e_n_ho_wo_global_move_slice_window_iterator_hack
=
BGlobalMoveSliceWindowStepHacks
{};
BGlobalMoveSliceWindowIteratorHacks
{};
// double regsiter buffer for b
// double regsiter buffer for b
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAB
,
FloatAB
,
b_e_n_ho_wo_thread_desc
.
GetElementSpaceSize
()
>
b_e_n_ho_wo_thread_desc
.
GetElementSpaceSize
(),
true
>
b_thread_even_buf
,
b_thread_odd_buf
;
b_thread_even_buf
,
b_thread_odd_buf
;
// LDS double buffer: preload data
// LDS double buffer: preload data
{
{
a_blockwise_copy
.
RunRead
(
a_e_k_global_desc
,
a_global_buf
,
a_e_k_global_
i
te
rator
_hacks
);
a_blockwise_copy
.
RunRead
(
a_e_k_global_desc
,
a_global_buf
,
a_e_k_global_
s
te
p
_hacks
);
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
b_global_buf
,
b_global_buf
,
b_e_n_ho_wo_thread_desc
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_thread_even_buf
,
b_e_n_ho_wo_global_
i
te
rator
_hacks
);
b_e_n_ho_wo_global_
s
te
p
_hacks
);
a_blockwise_copy
.
RunWrite
(
a_e_k_desc
,
a_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_e_k_desc
,
a_block_buf
);
}
}
...
@@ -293,7 +289,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
...
@@ -293,7 +289,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
b_e_n_ho_wo_thread_desc
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_thread_odd_buf
,
b_e_n_ho_wo_global_
i
te
rator
_hacks
);
b_e_n_ho_wo_global_
s
te
p
_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
...
@@ -309,7 +305,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
...
@@ -309,7 +305,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
b_e_n_ho_wo_thread_desc
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_thread_even_buf
,
b_e_n_ho_wo_global_
i
te
rator
_hacks
);
b_e_n_ho_wo_global_
s
te
p
_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
...
@@ -332,7 +328,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
...
@@ -332,7 +328,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
b_e_n_ho_wo_thread_desc
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_thread_odd_buf
,
b_e_n_ho_wo_global_
i
te
rator
_hacks
);
b_e_n_ho_wo_global_
s
te
p
_hacks
);
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
...
@@ -351,13 +347,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
...
@@ -351,13 +347,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
// output: register to global memory
// output: register to global memory
{
{
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
constexpr
auto
c_k_n_ho_wo_global_tensor_
i
te
rator
_hacks
=
CGlobal
I
te
rator
Hacks
{};
constexpr
auto
c_k_n_ho_wo_global_tensor_
s
te
p
_hacks
=
CGlobal
S
te
p
Hacks
{};
const
index_t
k_thread_data_on_global
=
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
k_thread_id
*
KPerThread
;
k_block_data_on_global
+
k_thread_id
*
KPerThread
;
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatAcc
,
FloatC
,
FloatC
,
decltype
(
c_k_n_ho_wo_thread_desc
),
decltype
(
c_k_n_ho_wo_thread_desc
),
decltype
(
c_k_n_ho_wo_global_desc
),
decltype
(
c_k_n_ho_wo_global_desc
),
...
@@ -376,7 +371,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
...
@@ -376,7 +371,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
c_thread_buf
,
c_thread_buf
,
c_k_n_ho_wo_global_desc
,
c_k_n_ho_wo_global_desc
,
c_global_buf
,
c_global_buf
,
c_k_n_ho_wo_global_tensor_
i
te
rator
_hacks
);
c_k_n_ho_wo_global_tensor_
s
te
p
_hacks
);
}
}
}
}
...
...
composable_kernel/include/tensor_operation/gridwise_
dynamic_
gemm_xdlops_v2r3.hpp
→
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
View file @
ccc4a1d3
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp
View file @
ccc4a1d3
...
@@ -21,7 +21,7 @@ template <typename FloatA,
...
@@ -21,7 +21,7 @@ template <typename FloatA,
typename
TKLengths
,
typename
TKLengths
,
typename
TMLengths
,
typename
TMLengths
,
typename
TNLengths
,
typename
TNLengths
,
typename
std
::
enable_if
<
AThreadDesc_TK0_TM0_TM1_TK1
::
IsKnownAtCompileTime
()
&&
typename
enable_if
<
AThreadDesc_TK0_TM0_TM1_TK1
::
IsKnownAtCompileTime
()
&&
BThreadDesc_TK0_TN0_TN1_TK1
::
IsKnownAtCompileTime
()
&&
BThreadDesc_TK0_TN0_TN1_TK1
::
IsKnownAtCompileTime
()
&&
CThreadDesc_TM0_TM1_TN0_TN1
::
IsKnownAtCompileTime
(),
CThreadDesc_TM0_TM1_TN0_TN1
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
bool
>
::
type
=
false
>
...
@@ -97,8 +97,7 @@ struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
...
@@ -97,8 +97,7 @@ struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
CThreadDesc_TM0_TM1_TN0_TN1
{}.
CalculateOffset
(
CThreadDesc_TM0_TM1_TN0_TN1
{}.
CalculateOffset
(
c_origin_idx
+
make_multi_index
(
tm0
,
tm1
,
tn0
,
tn1
));
c_origin_idx
+
make_multi_index
(
tm0
,
tm1
,
tn0
,
tn1
));
amd_inner_product_dlop
<
FloatA
,
FloatB
,
FloatC
>
(
inner_product
<
FloatA
,
FloatB
,
FloatC
>
(
a_buf
[
Number
<
a_offset
>
{}],
a_buf
[
Number
<
a_offset
>
{}],
b_buf
[
Number
<
b_offset
>
{}],
b_buf
[
Number
<
b_offset
>
{}],
c_buf
(
Number
<
c_offset
>
{}));
c_buf
(
Number
<
c_offset
>
{}));
});
});
...
@@ -124,7 +123,7 @@ template <typename FloatA,
...
@@ -124,7 +123,7 @@ template <typename FloatA,
typename
TKLengths
,
typename
TKLengths
,
typename
TMLengths
,
typename
TMLengths
,
typename
TNLengths
,
typename
TNLengths
,
typename
std
::
enable_if
<
AThreadDesc_TK0_TM0_TM1_TK1
::
IsKnownAtCompileTime
()
&&
typename
enable_if
<
AThreadDesc_TK0_TM0_TM1_TK1
::
IsKnownAtCompileTime
()
&&
BThreadDesc_TK0_TN0_TN1_TK1
::
IsKnownAtCompileTime
()
&&
BThreadDesc_TK0_TN0_TN1_TK1
::
IsKnownAtCompileTime
()
&&
CThreadDesc_TM0_TM1_TN0_TN1
::
IsKnownAtCompileTime
(),
CThreadDesc_TM0_TM1_TN0_TN1
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
bool
>
::
type
=
false
>
...
@@ -214,7 +213,7 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_
...
@@ -214,7 +213,7 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_
CThreadDesc_TM0_TM1_TN0_TN1
{}.
CalculateOffset
(
CThreadDesc_TM0_TM1_TN0_TN1
{}.
CalculateOffset
(
c_origin_idx
+
make_multi_index
(
tm0
,
tm1
,
tn0
,
tn1
));
c_origin_idx
+
make_multi_index
(
tm0
,
tm1
,
tn0
,
tn1
));
amd_
inner_product
_dlop
<
a_vector_t
,
b_vector_t
,
FloatC
>
(
inner_product
<
a_vector_t
,
b_vector_t
,
FloatC
>
(
a_vec
.
template
AsType
<
a_vector_t
>()[
I0
],
a_vec
.
template
AsType
<
a_vector_t
>()[
I0
],
b_vec
.
template
AsType
<
b_vector_t
>()[
I0
],
b_vec
.
template
AsType
<
b_vector_t
>()[
I0
],
c_buf
(
Number
<
c_offset
>
{}));
c_buf
(
Number
<
c_offset
>
{}));
...
...
composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
View file @
ccc4a1d3
...
@@ -19,7 +19,7 @@ template <typename FloatA,
...
@@ -19,7 +19,7 @@ template <typename FloatA,
typename
CDesc
,
typename
CDesc
,
index_t
H
,
index_t
H
,
index_t
W
,
index_t
W
,
typename
std
::
enable_if
<
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
typename
enable_if
<
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
CDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
bool
>
::
type
=
false
>
struct
ThreadwiseGemmDlops_km_kn_mn_v3
struct
ThreadwiseGemmDlops_km_kn_mn_v3
...
@@ -57,8 +57,6 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
...
@@ -57,8 +57,6 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
E
=
ADesc
{}.
GetLength
(
I0
);
constexpr
auto
E
=
ADesc
{}.
GetLength
(
I0
);
constexpr
auto
K
=
ADesc
{}.
GetLength
(
I1
);
constexpr
auto
K
=
ADesc
{}.
GetLength
(
I1
);
...
...
composable_kernel/include/tensor_operation/threadwise_
dynamic_
tensor_slice_set.hpp
→
composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp
View file @
ccc4a1d3
#ifndef CK_THREADWISE_
DYNAMIC_
TENSOR_SET_HPP
#ifndef CK_THREADWISE_TENSOR_SET_HPP
#define CK_THREADWISE_
DYNAMIC_
TENSOR_SET_HPP
#define CK_THREADWISE_TENSOR_SET_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -11,12 +11,12 @@ namespace ck {
...
@@ -11,12 +11,12 @@ namespace ck {
// 1. Desc is known at compile-time
// 1. Desc is known at compile-time
// 2. Buffer is StaticBuffer
// 2. Buffer is StaticBuffer
// 3. OriginIdx is known at compile-time
// 3. OriginIdx is known at compile-time
// 4. use #-
i
te
rator
// 4. use #-
s
te
p
template
<
typename
Data
,
template
<
typename
Data
,
typename
Desc
,
typename
Desc
,
typename
SliceLengths
,
typename
SliceLengths
,
typename
std
::
enable_if
<
Desc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
typename
enable_if
<
Desc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
Threadwise
Dynamic
TensorSliceSet_v1
struct
ThreadwiseTensorSliceSet_v1
{
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
...
@@ -40,7 +40,7 @@ struct ThreadwiseDynamicTensorSliceSet_v1
...
@@ -40,7 +40,7 @@ struct ThreadwiseDynamicTensorSliceSet_v1
constexpr
auto
origin_idx
=
to_multi_index
(
OriginIdx
{});
constexpr
auto
origin_idx
=
to_multi_index
(
OriginIdx
{});
static_ford
<
SliceLengths
>
{}([
&
](
auto
access_idx
)
{
static_ford
<
SliceLengths
>
{}([
&
](
auto
access_idx
)
{
constexpr
auto
coord
=
make_
dynamic_
tensor_coordinate
(
desc
,
origin_idx
+
access_idx
);
constexpr
auto
coord
=
make_tensor_coordinate
(
desc
,
origin_idx
+
access_idx
);
constexpr
bool
is_valid
=
constexpr
bool
is_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
desc
,
coord
);
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
desc
,
coord
);
...
...
composable_kernel/include/tensor_operation/threadwise_
dynamic_
tensor_slice_transfer.hpp
→
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
View file @
ccc4a1d3
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/threadwise_
dynamic_
tensor_slice_transfer_v2.hpp
→
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp
View file @
ccc4a1d3
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
ccc4a1d3
...
@@ -350,8 +350,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
...
@@ -350,8 +350,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
class FloatC>
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
{
const auto p_a =
re
inter
pret
_cast<const ushort2_t*>(a);
const auto p_a =
c_style_po
inter_cast<const ushort2_t*>(a);
const auto p_b =
re
inter
pret
_cast<const ushort2_t*>(b);
const auto p_b =
c_style_po
inter_cast<const ushort2_t*>(b);
return intrin_mfma_f32_32x32x2bf16<MPerXdlops, NPerXdlops, AStride, BStride>::run(
return intrin_mfma_f32_32x32x2bf16<MPerXdlops, NPerXdlops, AStride, BStride>::run(
p_a, p_b, reg_c);
p_a, p_b, reg_c);
...
@@ -384,8 +384,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
...
@@ -384,8 +384,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
class FloatC>
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
{
const auto p_a =
re
inter
pret
_cast<const ushort2_t*>(a);
const auto p_a =
c_style_po
inter_cast<const ushort2_t*>(a);
const auto p_b =
re
inter
pret
_cast<const ushort2_t*>(b);
const auto p_b =
c_style_po
inter_cast<const ushort2_t*>(b);
return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c);
return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c);
}
}
...
@@ -417,8 +417,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
...
@@ -417,8 +417,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
class FloatC>
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
{
const auto p_a =
re
inter
pret
_cast<const ushort2_t*>(a);
const auto p_a =
c_style_po
inter_cast<const ushort2_t*>(a);
const auto p_b =
re
inter
pret
_cast<const ushort2_t*>(b);
const auto p_b =
c_style_po
inter_cast<const ushort2_t*>(b);
return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c);
return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c);
}
}
...
@@ -450,8 +450,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
...
@@ -450,8 +450,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
class FloatC>
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
{
const auto p_a =
re
inter
pret
_cast<const ushort2_t*>(a);
const auto p_a =
c_style_po
inter_cast<const ushort2_t*>(a);
const auto p_b =
re
inter
pret
_cast<const ushort2_t*>(b);
const auto p_b =
c_style_po
inter_cast<const ushort2_t*>(b);
return intrin_mfma_f32_16x16x2bf16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
return intrin_mfma_f32_16x16x2bf16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
}
}
...
@@ -483,8 +483,8 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
...
@@ -483,8 +483,8 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
class FloatC>
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
{
const auto p_a =
re
inter
pret
_cast<const ushort2_t*>(a);
const auto p_a =
c_style_po
inter_cast<const ushort2_t*>(a);
const auto p_b =
re
inter
pret
_cast<const ushort2_t*>(b);
const auto p_b =
c_style_po
inter_cast<const ushort2_t*>(b);
return intrin_mfma_f32_4x4x2bf16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
return intrin_mfma_f32_4x4x2bf16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
}
}
...
...
Prev
1
2
3
4
5
6
…
8
Next
gaoqiong
@gaoqiong
mentioned in commit
dfb80c4e
·
Dec 05, 2023
mentioned in commit
dfb80c4e
mentioned in commit dfb80c4e39ec7b304c3ebc88bab2a204bc4906b9
Toggle commit list
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment