Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c400e5b3
Commit
c400e5b3
authored
Jan 09, 2025
by
Adam Osewski
Browse files
Introduce static encoding pattern
parent
6fe9e964
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
265 additions
and
180 deletions
+265
-180
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+2
-1
include/ck_tile/core/algorithm/static_encoding_pattern.hpp
include/ck_tile/core/algorithm/static_encoding_pattern.hpp
+178
-0
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
...gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+85
-178
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
+0
-1
No files found.
include/ck_tile/core.hpp
View file @
c400e5b3
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/algorithm/static_encoding_pattern.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
...
...
include/ck_tile/core/algorithm/static_encoding_pattern.hpp
0 → 100644
View file @
c400e5b3
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
namespace
ck_tile
{
/**
* @brief Enumeration describing tile distribution patterns.
*
*/
enum
struct
tile_distribution_pattern
{
/**
* @brief Thread raked pattern.
*
*/
thread_raked
,
/**
* @brief Warp raked pattern.
*
*/
warp_raked
,
/**
* @brief Block raked pattern - aka linear.
*
*/
block_raked
,
// TODO pattern taking into account MFMA attributes:
// block_fmha_pipeline_qx_ks_vs_custom_policy.hpp::51 MakeQDramTileDistribution()
};
struct
TileDistributionEcodingPattern
{
};
/**
* @brief Class creating 2D static tile distribution with different load/store patterns.
*
* @note We always assume that Tile is YPerTile x XPerTile where X dim (rightmost)
* is contiguous and we can do vector load on this dimension.
*
* @tparam BlockSize Number of threads in a workgroup.
* @tparam YPerTile The tile size of outer/leftmost dimension.
* @tparam XPerTile The tile size of inner/rightmost dimension (contiguous).
* @tparam VecSize The vector access size.
* @tparam DistributionPattern The enumeration describing used access pattern.
*/
template
<
index_t
BlockSize
,
index_t
YPerTile
,
index_t
XPerTile
,
index_t
VecSize
,
tile_distribution_pattern
DistributionPattern
>
struct
TileDistributionEncodingPattern2D
:
public
TileDistributionEcodingPattern
{
};
// Thread raked
template
<
index_t
BlockSize
,
index_t
YPerTile
,
index_t
XPerTile
,
index_t
VecSize
>
struct
TileDistributionEncodingPattern2D
<
BlockSize
,
YPerTile
,
XPerTile
,
VecSize
,
tile_distribution_pattern
::
thread_raked
>
:
public
TileDistributionEcodingPattern
{
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make2DStaticTileDistribution
()
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert
(
XPerTile
%
VecSize
==
0
,
"XPerTile must be a multiple of VecSize!"
);
constexpr
index_t
warp_size
=
get_warp_size
();
constexpr
index_t
num_warps
=
BlockSize
/
get_warp_size
();
constexpr
index_t
X1
=
VecSize
;
constexpr
index_t
X0
=
XPerTile
/
X1
;
// # of threads in X dim
// # of rows in Y dim accessed by single wavefront in one iteration
constexpr
index_t
Y1
=
warp_size
/
X0
;
static_assert
(
X0
*
Y1
==
warp_size
,
"X0 * Y1 must cover whole wavefront!"
);
constexpr
index_t
Y0
=
num_warps
;
// YPerWarp = YPerTile / Y0;
// Y2 = YPerWarp / Y1;
constexpr
index_t
Y2
=
YPerTile
/
(
Y1
*
Y0
);
// # of iters within wavefront
static_assert
(
X0
*
Y1
*
Y0
==
BlockSize
,
"X0 * warp_ys * Y0 must cover whole workgroup!"
);
static_assert
(
Y0
*
Y1
*
Y2
==
YPerTile
,
"Y0, Y1, Y2 must cover whole YPerTile"
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
Y0
,
Y1
,
Y2
>
,
sequence
<
X0
,
X1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
1
>>
{});
}
};
// Warp raked
template
<
index_t
BlockSize
,
index_t
YPerTile
,
index_t
XPerTile
,
index_t
VecSize
>
struct
TileDistributionEncodingPattern2D
<
BlockSize
,
YPerTile
,
XPerTile
,
VecSize
,
tile_distribution_pattern
::
warp_raked
>
:
public
TileDistributionEcodingPattern
{
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make2DStaticTileDistribution
()
{
static_assert
(
XPerTile
%
VecSize
==
0
,
"XPerTile must be a multiple of VecSize!"
);
constexpr
index_t
warp_size
=
get_warp_size
();
constexpr
index_t
num_warps
=
BlockSize
/
get_warp_size
();
constexpr
index_t
X1
=
VecSize
;
constexpr
index_t
X0
=
XPerTile
/
X1
;
// # of threads in X dim
constexpr
index_t
Y2
=
warp_size
/
X0
;
// # of rows in Y dim to cover whole wavefront
static_assert
(
X0
*
Y2
==
warp_size
,
"X0 * Y2 must cover whole wavefront!"
);
constexpr
index_t
Y0
=
num_warps
;
static_assert
(
X0
*
Y2
*
Y0
==
BlockSize
,
"X0 * Y2 * Y1 must cover whole workgroup!"
);
constexpr
index_t
Y1
=
YPerTile
/
(
Y2
*
Y0
);
// # of iters within wavefront
static_assert
(
Y0
*
Y1
*
Y2
==
YPerTile
,
"Y0, Y1, Y2 must cover whole YPerTile"
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
Y0
,
Y1
,
Y2
>
,
sequence
<
X0
,
X1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
};
// Block raked
template
<
index_t
BlockSize
,
index_t
YPerTile
,
index_t
XPerTile
,
index_t
VecSize
>
struct
TileDistributionEncodingPattern2D
<
BlockSize
,
YPerTile
,
XPerTile
,
VecSize
,
tile_distribution_pattern
::
block_raked
>
:
public
TileDistributionEcodingPattern
{
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make2DStaticTileDistribution
()
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert
(
XPerTile
%
VecSize
==
0
,
"XPerTile must be a multiple of VecSize!"
);
constexpr
index_t
warp_size
=
get_warp_size
();
constexpr
index_t
num_warps
=
BlockSize
/
get_warp_size
();
constexpr
index_t
X1
=
VecSize
;
constexpr
index_t
X0
=
XPerTile
/
X1
;
// # of threads in X dim
constexpr
index_t
Y2
=
warp_size
/
X0
;
// # of rows in Y dim to cover whole wavefront
static_assert
(
X0
*
Y2
==
warp_size
,
"X0 * Y2 must cover whole wavefront!"
);
constexpr
index_t
Y1
=
num_warps
;
static_assert
(
X0
*
Y2
*
Y1
==
BlockSize
,
"X0 * Y2 * Y1 must cover whole workgroup!"
);
constexpr
index_t
Y0
=
YPerTile
/
(
Y2
*
Y1
);
static_assert
(
Y0
*
Y1
*
Y2
==
YPerTile
,
"Y0, Y1, Y2 must cover whole YPerTile"
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
Y0
,
Y1
,
Y2
>
,
sequence
<
X0
,
X1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
View file @
c400e5b3
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
View file @
c400e5b3
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment