Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c400e5b3
Commit
c400e5b3
authored
Jan 09, 2025
by
Adam Osewski
Browse files
Introduce static encoding pattern
parent
6fe9e964
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
265 additions
and
180 deletions
+265
-180
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+2
-1
include/ck_tile/core/algorithm/static_encoding_pattern.hpp
include/ck_tile/core/algorithm/static_encoding_pattern.hpp
+178
-0
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
...gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+85
-178
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
+0
-1
No files found.
include/ck_tile/core.hpp
View file @
c400e5b3
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/algorithm/static_encoding_pattern.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
...
...
include/ck_tile/core/algorithm/static_encoding_pattern.hpp
0 → 100644
View file @
c400e5b3
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
namespace
ck_tile
{
/**
* @brief Enumeration describing tile distribution patterns.
*
*/
enum
struct
tile_distribution_pattern
{
/**
* @brief Thread raked pattern.
*
*/
thread_raked
,
/**
* @brief Warp raked pattern.
*
*/
warp_raked
,
/**
* @brief Block raked pattern - aka linear.
*
*/
block_raked
,
// TODO pattern taking into account MFMA attributes:
// block_fmha_pipeline_qx_ks_vs_custom_policy.hpp::51 MakeQDramTileDistribution()
};
struct
TileDistributionEcodingPattern
{
};
/**
* @brief Class creating 2D static tile distribution with different load/store patterns.
*
* @note We always assume that Tile is YPerTile x XPerTile where X dim (rightmost)
* is contiguous and we can do vector load on this dimension.
*
* @tparam BlockSize Number of threads in a workgroup.
* @tparam YPerTile The tile size of outer/leftmost dimension.
* @tparam XPerTile The tile size of inner/rightmost dimension (contiguous).
* @tparam VecSize The vector access size.
* @tparam DistributionPattern The enumeration describing used access pattern.
*/
template
<
index_t
BlockSize
,
index_t
YPerTile
,
index_t
XPerTile
,
index_t
VecSize
,
tile_distribution_pattern
DistributionPattern
>
struct
TileDistributionEncodingPattern2D
:
public
TileDistributionEcodingPattern
{
};
// Thread raked
template
<
index_t
BlockSize
,
index_t
YPerTile
,
index_t
XPerTile
,
index_t
VecSize
>
struct
TileDistributionEncodingPattern2D
<
BlockSize
,
YPerTile
,
XPerTile
,
VecSize
,
tile_distribution_pattern
::
thread_raked
>
:
public
TileDistributionEcodingPattern
{
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make2DStaticTileDistribution
()
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert
(
XPerTile
%
VecSize
==
0
,
"XPerTile must be a multiple of VecSize!"
);
constexpr
index_t
warp_size
=
get_warp_size
();
constexpr
index_t
num_warps
=
BlockSize
/
get_warp_size
();
constexpr
index_t
X1
=
VecSize
;
constexpr
index_t
X0
=
XPerTile
/
X1
;
// # of threads in X dim
// # of rows in Y dim accessed by single wavefront in one iteration
constexpr
index_t
Y1
=
warp_size
/
X0
;
static_assert
(
X0
*
Y1
==
warp_size
,
"X0 * Y1 must cover whole wavefront!"
);
constexpr
index_t
Y0
=
num_warps
;
// YPerWarp = YPerTile / Y0;
// Y2 = YPerWarp / Y1;
constexpr
index_t
Y2
=
YPerTile
/
(
Y1
*
Y0
);
// # of iters within wavefront
static_assert
(
X0
*
Y1
*
Y0
==
BlockSize
,
"X0 * warp_ys * Y0 must cover whole workgroup!"
);
static_assert
(
Y0
*
Y1
*
Y2
==
YPerTile
,
"Y0, Y1, Y2 must cover whole YPerTile"
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
Y0
,
Y1
,
Y2
>
,
sequence
<
X0
,
X1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
1
>>
{});
}
};
// Warp raked
template
<
index_t
BlockSize
,
index_t
YPerTile
,
index_t
XPerTile
,
index_t
VecSize
>
struct
TileDistributionEncodingPattern2D
<
BlockSize
,
YPerTile
,
XPerTile
,
VecSize
,
tile_distribution_pattern
::
warp_raked
>
:
public
TileDistributionEcodingPattern
{
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make2DStaticTileDistribution
()
{
static_assert
(
XPerTile
%
VecSize
==
0
,
"XPerTile must be a multiple of VecSize!"
);
constexpr
index_t
warp_size
=
get_warp_size
();
constexpr
index_t
num_warps
=
BlockSize
/
get_warp_size
();
constexpr
index_t
X1
=
VecSize
;
constexpr
index_t
X0
=
XPerTile
/
X1
;
// # of threads in X dim
constexpr
index_t
Y2
=
warp_size
/
X0
;
// # of rows in Y dim to cover whole wavefront
static_assert
(
X0
*
Y2
==
warp_size
,
"X0 * Y2 must cover whole wavefront!"
);
constexpr
index_t
Y0
=
num_warps
;
static_assert
(
X0
*
Y2
*
Y0
==
BlockSize
,
"X0 * Y2 * Y1 must cover whole workgroup!"
);
constexpr
index_t
Y1
=
YPerTile
/
(
Y2
*
Y0
);
// # of iters within wavefront
static_assert
(
Y0
*
Y1
*
Y2
==
YPerTile
,
"Y0, Y1, Y2 must cover whole YPerTile"
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
Y0
,
Y1
,
Y2
>
,
sequence
<
X0
,
X1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
};
// Block raked
template
<
index_t
BlockSize
,
index_t
YPerTile
,
index_t
XPerTile
,
index_t
VecSize
>
struct
TileDistributionEncodingPattern2D
<
BlockSize
,
YPerTile
,
XPerTile
,
VecSize
,
tile_distribution_pattern
::
block_raked
>
:
public
TileDistributionEcodingPattern
{
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make2DStaticTileDistribution
()
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert
(
XPerTile
%
VecSize
==
0
,
"XPerTile must be a multiple of VecSize!"
);
constexpr
index_t
warp_size
=
get_warp_size
();
constexpr
index_t
num_warps
=
BlockSize
/
get_warp_size
();
constexpr
index_t
X1
=
VecSize
;
constexpr
index_t
X0
=
XPerTile
/
X1
;
// # of threads in X dim
constexpr
index_t
Y2
=
warp_size
/
X0
;
// # of rows in Y dim to cover whole wavefront
static_assert
(
X0
*
Y2
==
warp_size
,
"X0 * Y2 must cover whole wavefront!"
);
constexpr
index_t
Y1
=
num_warps
;
static_assert
(
X0
*
Y2
*
Y1
==
BlockSize
,
"X0 * Y2 * Y1 must cover whole workgroup!"
);
constexpr
index_t
Y0
=
YPerTile
/
(
Y2
*
Y1
);
static_assert
(
Y0
*
Y1
*
Y2
==
YPerTile
,
"Y0, Y1, Y2 must cover whole YPerTile"
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
Y0
,
Y1
,
Y2
>
,
sequence
<
X0
,
X1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
View file @
c400e5b3
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -17,15 +17,23 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -17,15 +17,23 @@ struct UniversalGemmPipelineAgBgCrPolicy
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I2
=
number
<
2
>
{};
static
constexpr
auto
I2
=
number
<
2
>
{};
/**
* @brief Get the maximum global memory vector load size.
*
* @tparam Problem The UniversalGemmPipelineProblem object.
* @tparam DataType The tensor data type we're considering.
* @tparam MNPerBlock The MPerBlock or NPerBlock value depending on tensor (A/B).
* @return Maximum DRAM vector load size.
*/
template
<
typename
Problem
,
typename
DataType
,
index_t
MNPerBlock
>
template
<
typename
Problem
,
typename
DataType
,
index_t
MNPerBlock
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorLoadSize
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Get
Global
VectorLoadSize
()
{
{
// TODO this does not take into accout the size of contiguous dim!
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
elements_per_thread
=
MNPerBlock
*
KPerBlock
/
BlockSize
;
constexpr
index_t
elements_per_thread
=
MNPerBlock
*
KPerBlock
/
BlockSize
;
// Assume DataType is even!
// Assume DataType is even!
if
constexpr
(
elements_per_thread
%
(
16
/
sizeof
(
DataType
))
==
0
)
if
constexpr
(
elements_per_thread
%
(
16
/
sizeof
(
DataType
))
==
0
)
{
{
return
(
16
/
sizeof
(
DataType
));
return
(
16
/
sizeof
(
DataType
));
...
@@ -56,7 +64,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -56,7 +64,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
return
GetVectorLoadSize
<
Problem
,
ADataType
,
MPerBlock
>
();
return
Get
Global
VectorLoadSize
<
Problem
,
ADataType
,
MPerBlock
>
();
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -65,7 +73,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -65,7 +73,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
return
GetVectorLoadSize
<
Problem
,
BDataType
,
NPerBlock
>
();
return
Get
Global
VectorLoadSize
<
Problem
,
BDataType
,
NPerBlock
>
();
}
}
/**
/**
...
@@ -90,9 +98,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -90,9 +98,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
using
CLayout
=
typename
Problem
::
CLayout
;
using
CLayout
=
typename
Problem
::
CLayout
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
// constexpr auto c_warp_x_lengths = CWarpDstr::get_lengths();
// using c_warp_hs_lengths = typename CWarpDstrEncoding::HsLengthss;
// N is contiguous dimension
// N is contiguous dimension
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
...
@@ -100,11 +105,11 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -100,11 +105,11 @@ struct UniversalGemmPipelineAgBgCrPolicy
{
{
// In this case each thread has multiple consecutive elements in
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
// N dimension, however consecutive threads' elements have stride.
// static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
// c_warp_y_lengths.get(number<NDimY-1>{}));
constexpr
index_t
NDimY
=
CWarpDstr
::
NDimY
;
constexpr
index_t
NDimY
=
CWarpDstr
::
NDimY
;
constexpr
auto
c_warp_y_lengths
=
constexpr
auto
c_warp_y_lengths
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
();
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
();
static_assert
(
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
==
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{}));
return
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{});
return
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{});
}
}
else
else
...
@@ -125,11 +130,11 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -125,11 +130,11 @@ struct UniversalGemmPipelineAgBgCrPolicy
{
{
// In this case each thread has multiple consecutive elements in
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
// M dimension, however consecutive threads' elements have stride.
// static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
// c_warp_y_lengths.get(number<NDimY-1>{}));
constexpr
index_t
NDimY
=
CWarpDstr
::
NDimY
;
constexpr
index_t
NDimY
=
CWarpDstr
::
NDimY
;
constexpr
auto
c_warp_y_lengths
=
constexpr
auto
c_warp_y_lengths
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
();
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
();
static_assert
(
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
==
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{}));
return
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{});
return
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{});
}
}
}
}
...
@@ -139,6 +144,26 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -139,6 +144,26 @@ struct UniversalGemmPipelineAgBgCrPolicy
}
}
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackA
()
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
// TODO: this not alwyas has to be ture, sometimes we may want different KPack value.
return
GetGlobalVectorLoadSize
<
Problem
,
ADataType
,
MPerBlock
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackB
()
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
// TODO: this not alwyas has to be ture, sometimes we may want different KPack value.
return
GetGlobalVectorLoadSize
<
Problem
,
BDataType
,
NPerBlock
>
();
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
{
{
...
@@ -147,7 +172,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -147,7 +172,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPack
=
Get
VectorLoadSize
<
Problem
,
ADataType
,
MPerBlock
>
();
constexpr
index_t
KPack
=
Get
SmemPackA
<
Problem
>
();
constexpr
auto
DataTypeSize
=
sizeof
(
ADataType
);
constexpr
auto
DataTypeSize
=
sizeof
(
ADataType
);
constexpr
auto
MLdsLayer
=
constexpr
auto
MLdsLayer
=
...
@@ -198,7 +223,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -198,7 +223,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPack
=
Get
VectorLoadSize
<
Problem
,
BDataType
,
NPerBlock
>
();
constexpr
index_t
KPack
=
Get
SmemPackB
<
Problem
>
();
constexpr
auto
DataTypeSize
=
sizeof
(
BDataType
);
constexpr
auto
DataTypeSize
=
sizeof
(
BDataType
);
constexpr
auto
NLdsLayer
=
constexpr
auto
NLdsLayer
=
...
@@ -237,6 +262,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -237,6 +262,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
// make_tuple(sequence<1>{}, sequence<0>{}));
return
b_lds_block_desc
;
return
b_lds_block_desc
;
}
}
...
@@ -276,88 +302,30 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -276,88 +302,30 @@ struct UniversalGemmPipelineAgBgCrPolicy
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
// Tile: MPerBlock X KPerBlock
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
constexpr
index_t
M1
=
GetVectorSizeA
<
Problem
>
();
// We should take layout into account!
constexpr
index_t
M0
=
MPerBlock
/
M1
;
constexpr
index_t
VecLoadSize
=
GetVectorSizeA
<
Problem
>
();
constexpr
index_t
elem_per_thr
=
MPerBlock
*
KPerBlock
/
BlockSize
;
constexpr
auto
AccessPattern
=
tile_distribution_pattern
::
thread_raked
;
constexpr
index_t
K3
=
elem_per_thr
/
M1
;
// # of loads per thr
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
constexpr
index_t
KPack
=
GetVectorSizeA
<
Problem
>
();
MPerBlock
,
static_assert
(
KPack
%
K3
==
0
);
KPerBlock
,
constexpr
index_t
K2
=
KPack
/
K3
;
VecLoadSize
,
AccessPattern
>
;
if
constexpr
(
get_warp_size
()
%
(
K2
*
M0
)
==
0
)
return
TileEncodingPattern
::
Make2DStaticTileDistribution
();
{
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
M0
);
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
M0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
}
// Tile: KPerBlock X MPerBlock
else
else
{
{
// In RowMajor scenario we usually want to read whole KPerBlock tile dim
constexpr
index_t
VecLoadSize
=
GetVectorSizeA
<
Problem
>
();
constexpr
index_t
K1
=
GetVectorSizeA
<
Problem
>
();
constexpr
auto
AccessPattern
=
tile_distribution_pattern
::
thread_raked
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
KPerBlock
,
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
MPerBlock
,
// Coalesce reading for whole workgroup - workgroup raked pattern
VecLoadSize
,
if
constexpr
(
get_warp_size
()
%
(
M2
*
K0
)
==
0
)
AccessPattern
>
;
{
return
TileEncodingPattern
::
Make2DStaticTileDistribution
();
constexpr
index_t
M1
=
BlockSize
/
get_warp_size
();
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
static_assert
(
M0
*
M1
*
M2
==
MPerBlock
,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!"
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
// Coalesce reading for each wavefront - wavefront raked pattern
else
{
constexpr
index_t
M0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
MPerBlock
/
(
M2
*
M0
);
static_assert
(
M0
*
M1
*
M2
==
MPerBlock
,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!"
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
}
}
...
@@ -370,107 +338,47 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -370,107 +338,47 @@ struct UniversalGemmPipelineAgBgCrPolicy
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
// Tile: KPerBlock X NPerBlock
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
constexpr
index_t
N1
=
GetVectorSizeB
<
Problem
>
();
constexpr
index_t
VecLoadSize
=
GetVectorSizeB
<
Problem
>
();
constexpr
index_t
N0
=
NPerBlock
/
N1
;
constexpr
auto
AccessPattern
=
tile_distribution_pattern
::
thread_raked
;
constexpr
index_t
elem_per_thr
=
NPerBlock
*
KPerBlock
/
BlockSize
;
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
static_assert
(
elem_per_thr
%
N1
==
0
);
KPerBlock
,
constexpr
index_t
K3
=
elem_per_thr
/
N1
;
NPerBlock
,
constexpr
index_t
KPack
=
GetVectorSizeB
<
Problem
>
();
VecLoadSize
,
static_assert
(
KPack
%
K3
==
0
);
AccessPattern
>
;
constexpr
index_t
K2
=
KPack
/
K3
;
return
TileEncodingPattern
::
Make2DStaticTileDistribution
();
if
constexpr
(
get_warp_size
()
%
(
K2
*
N0
)
==
0
)
{
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
}
// Tile: NPerBlock X KPerBlock
else
else
{
{
constexpr
index_t
VecLoadSize
=
GetVectorSizeB
<
Problem
>
();
constexpr
index_t
K1
=
GetVectorSizeB
<
Problem
>
();
constexpr
auto
AccessPattern
=
tile_distribution_pattern
::
thread_raked
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
NPerBlock
,
// coalesce reading for each blocks
KPerBlock
,
if
constexpr
(
get_warp_size
()
%
(
N2
*
K0
)
==
0
)
VecLoadSize
,
{
AccessPattern
>
;
constexpr
index_t
N1
=
BlockSize
/
get_warp_size
();
return
TileEncodingPattern
::
Make2DStaticTileDistribution
();
static_assert
(
N2
!=
0
,
"N2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"N1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
static_assert
(
N0
*
N1
*
N2
==
NPerBlock
,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!"
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
// coalesce reading for each warps
else
{
constexpr
index_t
N0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
NPerBlock
/
(
N2
*
N0
);
static_assert
(
N0
*
N1
*
N2
==
NPerBlock
,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!"
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledARegBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledARegBlockDescriptor
()
{
{
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
);
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
);
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
M1
=
GetVectorSizeA
<
Problem
>
(
);
constexpr
index_t
M0
=
MPerBlock
/
M1
;
constexpr
index_t
M0
=
MPerBlock
/
M1
;
constexpr
index_t
total_pixels
=
MPerBlock
*
KPerBlock
/
BlockSize
;
constexpr
index_t
total_pixels
=
MPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
M1
==
0
);
static_assert
(
total_pixels
%
M1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
M1
;
constexpr
index_t
K3
=
total_pixels
/
M1
;
constexpr
index_t
kKPack
=
Get
VectorLoadSize
<
Problem
,
ADataType
,
MPerBlock
>
();
constexpr
index_t
kKPack
=
Get
SmemPackA
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
constexpr
index_t
warp_size
=
get_warp_size
();
...
@@ -506,19 +414,18 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -506,19 +414,18 @@ struct UniversalGemmPipelineAgBgCrPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBRegBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBRegBlockDescriptor
()
{
{
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
static_assert
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
static_assert
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
N1
=
GetVectorSizeB
<
Problem
>
(
);
constexpr
index_t
N0
=
NPerBlock
/
N1
;
constexpr
index_t
N0
=
NPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
NPerBlock
*
KPerBlock
/
BlockSize
;
constexpr
index_t
total_pixels
=
NPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
static_assert
(
total_pixels
%
N1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
Get
VectorLoadSize
<
Problem
,
BDataType
,
NPerBlock
>
();
constexpr
index_t
kKPack
=
Get
SmemPackB
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
constexpr
index_t
warp_size
=
get_warp_size
();
...
...
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
View file @
c400e5b3
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment