Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2a32ec48
"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "751708509f30711ce8b40d19effd6d22659990f7"
Commit
2a32ec48
authored
Sep 28, 2023
by
Adam Osewski
Browse files
Linear B2C tile map along K dim.
parent
271269a5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
202 additions
and
2 deletions
+202
-2
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+132
-2
test/block_to_ctile_map/test_block_to_ctile_map.cpp
test/block_to_ctile_map/test_block_to_ctile_map.cpp
+70
-0
No files found.
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
2a32ec48
...
...
@@ -697,12 +697,19 @@ struct LocalBlockToCTileMap
{
using
underlying_type
=
UnderlyingBlockToCTileMap
;
__host__
__device__
LocalBlockToCTileMap
(
UnderlyingBlockToCTileMap
block_to_ctile_map
,
__host__
__device__
LocalBlockToCTileMap
(
const
UnderlyingBlockToCTileMap
&
block_to_ctile_map
,
index_t
local_id
)
:
block_to_ctile_map_
{
block_to_ctile_map
},
local_block_id_
{
local_id
}
{
}
__host__
__device__
LocalBlockToCTileMap
(
const
UnderlyingBlockToCTileMap
&
block_to_ctile_map
)
:
LocalBlockToCTileMap
(
block_to_ctile_map
,
0
)
{
}
__host__
__device__
void
SetLocalBlockId
(
index_t
local_id
)
{
local_block_id_
=
local_id
;
}
__host__
__device__
constexpr
auto
CalculateBottomIndex
()
const
{
return
block_to_ctile_map_
.
CalculateBottomIndex
(
make_multi_index
(
local_block_id_
));
...
...
@@ -727,7 +734,7 @@ struct LocalBlockToCTileMap
return
block_to_ctile_map_
.
CalculateGridSize
(
c_grid_desc_m_n
);
}
UnderlyingBlockToCTileMap
block_to_ctile_map_
;
const
UnderlyingBlockToCTileMap
&
block_to_ctile_map_
;
index_t
local_block_id_
;
};
...
...
@@ -1133,4 +1140,127 @@ struct BlockToCTileMap_GemmStreamK
}
};
/**
* @brief Linear workgroup mapping along fastest (reduced K) dim.
*
* @tparam MPerBlock Number of M rows per output data tile.
* @tparam NPerBlock Number of N columns per output data tile.
*/
template
<
index_t
MPerBlock
,
index_t
NPerBlock
>
struct
BlockToCTileMap_LinearKSplit
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
BlockToCTileMap_LinearKSplit
()
=
default
;
__host__
__device__
BlockToCTileMap_LinearKSplit
(
const
BlockToCTileMap_LinearKSplit
&
)
=
default
;
__host__
__device__
BlockToCTileMap_LinearKSplit
(
BlockToCTileMap_LinearKSplit
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_LinearKSplit
&
operator
=
(
const
BlockToCTileMap_LinearKSplit
&
)
=
default
;
__host__
__device__
BlockToCTileMap_LinearKSplit
&
operator
=
(
BlockToCTileMap_LinearKSplit
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_LinearKSplit
(
index_t
M
,
index_t
N
,
index_t
KSplit
)
:
M_
{
M
},
N_
{
N
},
KSplit_
{
KSplit
}
{
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
BlockToCTileMap_LinearKSplit
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
KSplit
=
1
)
:
BlockToCTileMap_LinearKSplit
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
),
KSplit
)
{
}
__host__
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
return
M0
*
N0
*
KSplit_
;
}
template
<
typename
CGridDesc_M_N
>
__host__
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
return
CalculateGridSize
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
}
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
__host__
constexpr
auto
CalculateBottomIndex
(
index_t
block_1d_id
)
{
const
auto
N0
=
math
::
integer_divide_ceil
(
N_
,
NPerBlock
);
M0_idx_
=
block_1d_id
/
(
N0
*
KSplit_
);
block_1d_id
=
block_1d_id
%
(
N0
*
KSplit_
);
N0_idx_
=
block_1d_id
/
KSplit_
;
K0_idx_
=
block_1d_id
%
KSplit_
;
return
make_tuple
(
M0_idx_
,
N0_idx_
,
K0_idx_
);
}
__device__
constexpr
auto
CalculateBottomIndex
(
index_t
block_1d_id
)
{
const
auto
N0
=
math
::
integer_divide_ceil
(
N_
,
NPerBlock
);
M0_idx_
=
__builtin_amdgcn_readfirstlane
(
block_1d_id
/
(
N0
*
KSplit_
));
block_1d_id
=
block_1d_id
%
(
N0
*
KSplit_
);
N0_idx_
=
__builtin_amdgcn_readfirstlane
(
block_1d_id
/
KSplit_
);
K0_idx_
=
__builtin_amdgcn_readfirstlane
(
block_1d_id
%
KSplit_
);
return
make_tuple
(
M0_idx_
,
N0_idx_
,
K0_idx_
);
}
__host__
__device__
auto
GetBottomIndex
()
const
{
return
make_tuple
(
M0_idx_
,
N0_idx_
,
K0_idx_
);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
{
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
__host__
__device__
bool
GetNextKTileIdx
()
{
K0_idx_
++
;
return
K0_idx_
<
KSplit_
;
}
///
/// @brief Determines whether the current workgroup processed first tile in K dimension
///
/// @param[in] tiles_per_block The number of tiles per block to process per workgroup.
///
/// @return True if the current workgroup processed first tile. False otherwise.
///
__host__
__device__
bool
IsFirstKSplitBlock
(
index_t
tiles_per_block
)
const
{
return
(
K0_idx_
-
tiles_per_block
)
<=
0
;
}
__host__
__device__
index_t
GetTileMIdx
()
const
{
return
M0_idx_
;
}
__host__
__device__
index_t
GetTileNIdx
()
const
{
return
N0_idx_
;
}
__host__
__device__
index_t
GetTileKIdx
()
const
{
return
K0_idx_
;
}
private:
index_t
M_
;
index_t
N_
;
index_t
KSplit_
;
index_t
M0_idx_
;
index_t
N0_idx_
;
index_t
K0_idx_
;
};
}
// namespace ck
test/block_to_ctile_map/test_block_to_ctile_map.cpp
View file @
2a32ec48
...
...
@@ -320,3 +320,73 @@ TEST(BlockToCTileMap, TestBlockToCTileMap_KSplit_M00_N0_M01Adapt)
EXPECT_TRUE
(
equal
);
}
}
TEST
(
BlockToCTileMap
,
BlockToCTileMap_LinearKSplit_BottomIndex
)
{
const
index_t
M
=
768
;
const
index_t
N
=
384
;
const
index_t
MPerBlock
=
128
;
const
index_t
NPerBlock
=
64
;
const
index_t
MBlock
=
M
/
MPerBlock
;
const
index_t
NBlock
=
N
/
NPerBlock
;
const
index_t
KSplit
=
3
;
auto
c_grid_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
));
BlockToCTileMap_LinearKSplit
<
MPerBlock
,
NPerBlock
>
tile_map
(
c_grid_desc_m_n
,
KSplit
);
EXPECT_EQ
(
tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
),
MBlock
*
NBlock
*
KSplit
);
std
::
vector
<
std
::
vector
<
int
>>
expected_m0idx_n0idx_k0idx
=
{
{
0
,
0
,
0
},
{
0
,
0
,
1
},
{
0
,
0
,
2
},
{
0
,
1
,
0
},
{
0
,
1
,
1
},
{
0
,
1
,
2
},
{
0
,
2
,
0
},
{
0
,
2
,
1
},
{
0
,
2
,
2
},
{
0
,
3
,
0
},
{
0
,
3
,
1
},
{
0
,
3
,
2
},
{
0
,
4
,
0
},
{
0
,
4
,
1
},
{
0
,
4
,
2
},
{
0
,
5
,
0
},
{
0
,
5
,
1
},
{
0
,
5
,
2
},
{
1
,
0
,
0
},
{
1
,
0
,
1
},
{
1
,
0
,
2
},
{
1
,
1
,
0
},
{
1
,
1
,
1
},
{
1
,
1
,
2
},
{
1
,
2
,
0
},
{
1
,
2
,
1
},
{
1
,
2
,
2
},
{
1
,
3
,
0
},
{
1
,
3
,
1
},
{
1
,
3
,
2
},
{
1
,
4
,
0
},
{
1
,
4
,
1
},
{
1
,
4
,
2
},
{
1
,
5
,
0
},
{
1
,
5
,
1
},
{
1
,
5
,
2
},
{
2
,
0
,
0
},
{
2
,
0
,
1
},
{
2
,
0
,
2
},
{
2
,
1
,
0
},
{
2
,
1
,
1
},
{
2
,
1
,
2
},
{
2
,
2
,
0
},
{
2
,
2
,
1
},
{
2
,
2
,
2
},
{
2
,
3
,
0
},
{
2
,
3
,
1
},
{
2
,
3
,
2
},
{
2
,
4
,
0
},
{
2
,
4
,
1
},
{
2
,
4
,
2
},
{
2
,
5
,
0
},
{
2
,
5
,
1
},
{
2
,
5
,
2
},
{
3
,
0
,
0
},
{
3
,
0
,
1
},
{
3
,
0
,
2
},
{
3
,
1
,
0
},
{
3
,
1
,
1
},
{
3
,
1
,
2
},
{
3
,
2
,
0
},
{
3
,
2
,
1
},
{
3
,
2
,
2
},
{
3
,
3
,
0
},
{
3
,
3
,
1
},
{
3
,
3
,
2
},
{
3
,
4
,
0
},
{
3
,
4
,
1
},
{
3
,
4
,
2
},
{
3
,
5
,
0
},
{
3
,
5
,
1
},
{
3
,
5
,
2
},
{
4
,
0
,
0
},
{
4
,
0
,
1
},
{
4
,
0
,
2
},
{
4
,
1
,
0
},
{
4
,
1
,
1
},
{
4
,
1
,
2
},
{
4
,
2
,
0
},
{
4
,
2
,
1
},
{
4
,
2
,
2
},
{
4
,
3
,
0
},
{
4
,
3
,
1
},
{
4
,
3
,
2
},
{
4
,
4
,
0
},
{
4
,
4
,
1
},
{
4
,
4
,
2
},
{
4
,
5
,
0
},
{
4
,
5
,
1
},
{
4
,
5
,
2
},
{
5
,
0
,
0
},
{
5
,
0
,
1
},
{
5
,
0
,
2
},
{
5
,
1
,
0
},
{
5
,
1
,
1
},
{
5
,
1
,
2
},
{
5
,
2
,
0
},
{
5
,
2
,
1
},
{
5
,
2
,
2
},
{
5
,
3
,
0
},
{
5
,
3
,
1
},
{
5
,
3
,
2
},
{
5
,
4
,
0
},
{
5
,
4
,
1
},
{
5
,
4
,
2
},
{
5
,
5
,
0
},
{
5
,
5
,
1
},
{
5
,
5
,
2
},
};
for
(
index_t
i
=
0
;
i
<
tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
++
i
)
{
auto
m0n0k0_idx
=
tile_map
.
CalculateBottomIndex
(
i
);
EXPECT_EQ
((
std
::
vector
<
int
>
{
m0n0k0_idx
[
I0
],
m0n0k0_idx
[
I1
],
m0n0k0_idx
[
I2
]}),
expected_m0idx_n0idx_k0idx
[
i
]);
}
}
TEST
(
BlockToCTileMap
,
BlockToCTileMap_LinearKSplit_NextKTile
)
{
const
index_t
M
=
768
;
const
index_t
N
=
384
;
const
index_t
MPerBlock
=
128
;
const
index_t
NPerBlock
=
64
;
const
index_t
KSplit
=
3
;
auto
c_grid_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
));
BlockToCTileMap_LinearKSplit
<
MPerBlock
,
NPerBlock
>
tile_map
(
c_grid_desc_m_n
,
KSplit
);
auto
m0n0k0_idx
=
tile_map
.
CalculateBottomIndex
(
3
);
EXPECT_EQ
((
std
::
vector
<
int
>
{
m0n0k0_idx
[
I0
],
m0n0k0_idx
[
I1
],
m0n0k0_idx
[
I2
]}),
(
std
::
vector
<
int
>
{
0
,
1
,
0
}));
for
(
index_t
i
=
0
;
i
<
KSplit
-
1
;
i
++
)
{
EXPECT_TRUE
(
tile_map
.
GetNextKTileIdx
());
m0n0k0_idx
=
tile_map
.
GetBottomIndex
();
EXPECT_EQ
((
std
::
vector
<
int
>
{
m0n0k0_idx
[
I0
],
m0n0k0_idx
[
I1
],
m0n0k0_idx
[
I2
]}),
(
std
::
vector
<
int
>
{
0
,
1
,
i
+
1
}));
}
EXPECT_FALSE
(
tile_map
.
GetNextKTileIdx
());
m0n0k0_idx
=
tile_map
.
GetBottomIndex
();
EXPECT_EQ
((
std
::
vector
<
int
>
{
m0n0k0_idx
[
I0
],
m0n0k0_idx
[
I1
],
m0n0k0_idx
[
I2
]}),
(
std
::
vector
<
int
>
{
0
,
1
,
3
}));
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment