Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4809badf
Commit
4809badf
authored
Sep 08, 2022
by
Po-Yen, Chen
Browse files
Seperate template parameters
parent
9a06e83e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
45 deletions
+21
-45
include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp
include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp
+21
-45
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp
View file @
4809badf
...
@@ -15,14 +15,13 @@
...
@@ -15,14 +15,13 @@
namespace
ck
{
namespace
ck
{
namespace
detail
{
namespace
detail
{
template
<
typename
TileDims
,
typename
GridDesc
riptor
>
template
<
index_t
HPerBlock
,
index_t
WPerBlock
,
typename
GridDesc
>
struct
Block2TileMap
struct
Block2TileMap
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
NumDim
=
Number
<
GridDesc
::
GetNumOfDimension
()
>
{};
static_assert
(
2
<=
NumDim
);
static
constexpr
index_t
NumDim
=
TileDims
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static_assert
(
NumDim
==
2
);
static_assert
(
NumDim
<=
GridDescriptor
::
GetNumOfDimension
());
Block2TileMap
()
=
delete
;
Block2TileMap
()
=
delete
;
Block2TileMap
(
const
Block2TileMap
&
)
=
default
;
Block2TileMap
(
const
Block2TileMap
&
)
=
default
;
...
@@ -33,22 +32,16 @@ struct Block2TileMap
...
@@ -33,22 +32,16 @@ struct Block2TileMap
Block2TileMap
&
operator
=
(
const
Block2TileMap
&
)
=
delete
;
Block2TileMap
&
operator
=
(
const
Block2TileMap
&
)
=
delete
;
Block2TileMap
&
operator
=
(
Block2TileMap
&&
)
=
delete
;
Block2TileMap
&
operator
=
(
Block2TileMap
&&
)
=
delete
;
explicit
Block2TileMap
(
const
GridDesc
riptor
&
desc
)
:
desc_
(
desc
)
{}
explicit
Block2TileMap
(
const
GridDesc
&
desc
)
:
desc_
(
desc
)
{}
__host__
constexpr
index_t
CalculateGridSize
(
const
GridDesc
riptor
&
desc
)
const
__host__
constexpr
index_t
CalculateGridSize
(
const
GridDesc
&
desc
)
const
{
{
return
[
&
]()
{
const
auto
H0
=
math
::
integer_divide_ceil
(
desc
.
GetLength
(
NumDim
-
Number
<
2
>
{}),
HPerBlock
);
std
::
array
<
index_t
,
2
>
num_tiles_per_axis
;
const
auto
W0
=
math
::
integer_divide_ceil
(
desc
.
GetLength
(
NumDim
-
Number
<
1
>
{}),
WPerBlock
);
static_for
<
NumDim
-
2
,
NumDim
,
1
>
{}([
&
](
auto
I
)
{
num_tiles_per_axis
[
I
-
(
NumDim
-
2
)]
=
const
index_t
grid_size
=
H0
*
W0
;
math
::
integer_divide_ceil
(
desc
.
GetLength
(
I
),
TileDims
::
At
(
I
-
(
NumDim
-
2
)));
});
return
grid_size
;
return
std
::
accumulate
(
begin
(
num_tiles_per_axis
),
end
(
num_tiles_per_axis
),
index_t
{
1
},
std
::
multiplies
<
index_t
>
{});
}();
}
}
template
<
typename
TopIdx
>
template
<
typename
TopIdx
>
...
@@ -58,34 +51,17 @@ struct Block2TileMap
...
@@ -58,34 +51,17 @@ struct Block2TileMap
auto
block_1d_id
=
idx_top
[
I0
];
auto
block_1d_id
=
idx_top
[
I0
];
std
::
array
<
index_t
,
2
>
num_tiles_per_axis
;
const
auto
H0
=
math
::
integer_divide_ceil
(
desc_
.
GetLength
(
NumDim
-
Number
<
2
>
{}),
HPerBlock
);
static_for
<
NumDim
-
2
,
NumDim
,
1
>
{}([
&
](
auto
I
)
{
const
auto
W0
=
math
::
integer_divide_ceil
(
desc_
.
GetLength
(
NumDim
-
Number
<
1
>
{}),
WPerBlock
);
num_tiles_per_axis
[
I
-
(
NumDim
-
2
)]
=
math
::
integer_divide_ceil
(
desc_
.
GetLength
(
I
),
TileDims
::
At
(
I
-
(
NumDim
-
2
)));
index_t
idx_H0
=
block_1d_id
/
W0
;
});
index_t
idx_W0
=
block_1d_id
%
W0
;
std
::
array
<
index_t
,
2
>
divisors
;
return
make_tuple
(
idx_H0
,
idx_W0
);
index_t
product
=
1
;
auto
divisor
=
rbegin
(
divisors
);
for
(
auto
num_tiles
=
rbegin
(
num_tiles_per_axis
);
num_tiles
!=
rend
(
num_tiles_per_axis
);
++
num_tiles
)
{
product
*=
(
*
num_tiles
);
*
(
divisor
++
)
=
product
;
}
const
index_t
grid_size
=
divisors
.
front
();
block_1d_id
=
block_1d_id
%
grid_size
;
// swallow batch index
return
generate_tuple
(
[
&
](
auto
I
)
{
return
(
block_1d_id
%
divisors
[
I
])
/
(
divisors
[
I
]
/
num_tiles_per_axis
[
I
]);
},
Number
<
2
>
{});
}
}
private:
private:
const
GridDesc
riptor
desc_
;
const
GridDesc
desc_
;
};
};
}
// namespace detail
}
// namespace detail
...
@@ -137,7 +113,7 @@ struct GridwisePermute
...
@@ -137,7 +113,7 @@ struct GridwisePermute
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
DefaultBlock2TileMap
=
detail
::
Block2TileMap
<
Sequence
<
HPerBlock
,
WPerBlock
>
,
InGridDesc
>
;
using
DefaultBlock2TileMap
=
detail
::
Block2TileMap
<
HPerBlock
,
WPerBlock
,
InGridDesc
>
;
__host__
__device__
static
constexpr
auto
GetInBlockDescriptor
()
__host__
__device__
static
constexpr
auto
GetInBlockDescriptor
()
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment