Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ed794598
Commit
ed794598
authored
Sep 07, 2022
by
Po-Yen, Chen
Browse files
Add 'BlockToTileMap' for 'GridwiseCopy'
parent
52c99c1a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
57 additions
and
8 deletions
+57
-8
include/ck/tensor_operation/gpu/grid/gridwise_copy.hpp
include/ck/tensor_operation/gpu/grid/gridwise_copy.hpp
+57
-8
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_copy.hpp
View file @
ed794598
...
@@ -3,6 +3,10 @@
...
@@ -3,6 +3,10 @@
#pragma once
#pragma once
#include <functional>
#include <numeric>
#include <iterator>
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
...
@@ -10,6 +14,48 @@
...
@@ -10,6 +14,48 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
namespace
ck
{
namespace
detail
{
template
<
typename
TileDims
,
typename
GridDescriptor
>
struct
BlockToTileMap
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
index_t
NumDim
=
TileDims
::
Size
();
static_assert
(
NumDim
==
GridDescriptor
::
GetNumOfDimension
());
BlockToTileMap
()
=
delete
;
~
BlockToTileMap
()
=
delete
;
template
<
typename
TopIdx
>
__host__
__device__
static
constexpr
auto
CalculateBottomIndex
(
const
GridDescriptor
&
desc
,
const
TopIdx
&
idx_top
)
{
static_assert
(
TopIdx
::
Size
()
==
1
);
auto
block_1d_id
=
idx_top
[
I0
];
std
::
array
<
index_t
,
NumDim
>
num_tiles_per_axis
;
static_for
<
0
,
NumDim
,
1
>
{}([
&
](
auto
I
)
{
num_tiles_per_axis
[
I
]
=
math
::
integer_divide_ceil
(
desc
.
GetLength
(
I
),
TileDims
::
At
(
I
));
});
std
::
array
<
index_t
,
NumDim
>
divisors
;
std
::
partial_sum
(
rbegin
(
num_tiles_per_axis
),
rend
(
num_tiles_per_axis
),
rbegin
(
divisors
),
std
::
multiplies
<
index_t
>
{});
const
index_t
grid_size
=
divisors
.
front
();
block_1d_id
=
block_1d_id
%
grid_size
;
// swallow batch index
return
generate_tuple
(
[
&
](
auto
I
)
{
return
(
block_1d_id
%
divisors
[
I
])
/
(
divisors
[
I
]
/
num_tiles_per_axis
[
I
]);
},
Number
<
NumDim
>
{});
}
};
}
// namespace detail
template
<
typename
GridwiseCopyFunctor
,
template
<
typename
GridwiseCopyFunctor
,
typename
InGrid1dDesc
,
typename
InGrid1dDesc
,
...
@@ -54,6 +100,9 @@ struct GridwiseCopy
...
@@ -54,6 +100,9 @@ struct GridwiseCopy
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
BlockToTileMap
=
detail
::
BlockToTileMap
<
Sequence
<
NPerBlock
,
HPerBlock
,
WPerBlock
>
,
InGrid1dDesc
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
{
constexpr
index_t
ABlockLdsExtraM
=
0
;
constexpr
index_t
ABlockLdsExtraM
=
0
;
...
@@ -94,7 +143,7 @@ struct GridwiseCopy
...
@@ -94,7 +143,7 @@ struct GridwiseCopy
const
index_t
loop_step
=
blockPerGrid
*
blockSize
*
MPerThread
;
const
index_t
loop_step
=
blockPerGrid
*
blockSize
*
MPerThread
;
const
auto
loop_step_index
=
make_multi_index
(
loop_step
);
const
auto
loop_step_index
=
make_multi_index
(
loop_step
);
#if
1
#if
0
auto in_global_load =
auto in_global_load =
ThreadwiseTensorSliceTransfer_v2<InDataType,
ThreadwiseTensorSliceTransfer_v2<InDataType,
InDataType,
InDataType,
...
@@ -107,8 +156,8 @@ struct GridwiseCopy
...
@@ -107,8 +156,8 @@ struct GridwiseCopy
1, // SrcScalarStrideInVector
1, // SrcScalarStrideInVector
false>{in_grid_1d_desc, thread_global_offset};
false>{in_grid_1d_desc, thread_global_offset};
#else
#else
//
const auto block_work_idx =
const
auto
block_work_idx
=
BlockToTileMap
::
CalculateBottomIndex
(
//
block_2_etile_map.CalculateBottomIndex(
make_multi_index(get_block_1d_id()));
in_grid_1d_desc
,
make_multi_index
(
get_block_1d_id
()));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
const
index_t
m_block_data_idx_on_grid
=
...
@@ -126,7 +175,7 @@ struct GridwiseCopy
...
@@ -126,7 +175,7 @@ struct GridwiseCopy
using
SliceLengths
=
Sequence
<
NPerBlock
,
HPerBlock
,
WPerBlock
>
;
using
SliceLengths
=
Sequence
<
NPerBlock
,
HPerBlock
,
WPerBlock
>
;
using
ABlockTransferThreadClusterLengths_AK0_M_AK1
=
Sequence
<
4
,
64
,
1
>
;
using
ABlockTransferThreadClusterLengths_AK0_M_AK1
=
Sequence
<
4
,
64
,
1
>
;
using
ABlockTransferThreadClusterArrangeOrder
=
Sequence
<
1
,
0
,
2
>
;
using
ABlockTransferThreadClusterArrangeOrder
=
Sequence
<
1
,
0
,
2
>
;
using
ABlockTransferSrcAccessOrder
=
int
;
using
ABlockTransferSrcAccessOrder
=
Sequence
<
1
,
0
,
2
>
;
constexpr
index_t
ABlockTransferSrcVectorDim
=
2
;
constexpr
index_t
ABlockTransferSrcVectorDim
=
2
;
constexpr
index_t
ABlockTransferSrcScalarPerVector
=
1
;
constexpr
index_t
ABlockTransferSrcScalarPerVector
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector
=
1
;
...
@@ -155,7 +204,7 @@ struct GridwiseCopy
...
@@ -155,7 +204,7 @@ struct GridwiseCopy
true
>
(
true
>
(
in_grid_1d_desc
,
in_grid_1d_desc
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
element_op
,
element
wise
_op
,
a_block_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
...
@@ -166,8 +215,8 @@ struct GridwiseCopy
...
@@ -166,8 +215,8 @@ struct GridwiseCopy
decltype
(
thread_buffer_desc_m
),
decltype
(
thread_buffer_desc_m
),
decltype
(
out_grid_1d_desc
),
decltype
(
out_grid_1d_desc
),
PassThroughOp
,
PassThroughOp
,
S
equence
<
MPerThread
>
,
// SliceLengths
S
liceLengths
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
Sequence
<
1
,
0
,
2
>
,
// DimAccessOrder
0
,
// SrcVectorDim
0
,
// SrcVectorDim
OutScalarPerVector
,
OutScalarPerVector
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment