Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
bd689f40
Commit
bd689f40
authored
Aug 20, 2024
by
illsilin
Browse files
merge from public repo
parents
c160c6cf
a94113a9
Changes
333
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2046 additions
and
1132 deletions
+2046
-1132
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+1531
-954
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp
...k_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp
+2
-3
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp
...ile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp
+37
-3
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+3
-1
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
+10
-0
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+3
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
...e/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
+202
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp
...gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp
+36
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp
...emm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp
+33
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp
.../ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp
+6
-9
include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp
.../ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp
+6
-9
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
+6
-0
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
+6
-6
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
+18
-4
library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp
...erence_tensor_operation/cpu/reference_column_to_image.hpp
+59
-57
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
...eference_tensor_operation/cpu/reference_conv_bwd_data.hpp
+12
-12
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp
...erence_tensor_operation/cpu/reference_conv_bwd_weight.hpp
+12
-12
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
...ary/reference_tensor_operation/cpu/reference_conv_fwd.hpp
+13
-13
library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp
...erence_tensor_operation/cpu/reference_image_to_column.hpp
+50
-49
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+1
-0
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
bd689f40
...
...
@@ -11,6 +11,8 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
...
...
@@ -18,60 +20,215 @@
namespace
ck_tile
{
template
<
bool
QLoadOnce_
,
bool
QTLoadOnce_
,
bool
KLoadOnce_
,
bool
KTLoadOnce_
,
bool
VLoadOnce_
,
bool
OGradLoadOnce_
,
bool
OGradTLoadOnce_
>
struct
BlockFmhaBwdPipelineDefaultPolicy
{
static
constexpr
bool
QLoadOnce
=
QLoadOnce_
;
// if q load whole block length (qkhdim) to LDS at once
static
constexpr
bool
QTLoadOnce
=
QTLoadOnce_
;
// if q^t load whole block length (qkhdim) to LDS at once
static
constexpr
bool
KLoadOnce
=
KLoadOnce_
;
// if k load whole block length (qkhdim) to LDS at once
static
constexpr
bool
KTLoadOnce
=
KTLoadOnce_
;
// if k^t load whole block length (qkhdim) to LDS at once
static
constexpr
bool
VLoadOnce
=
VLoadOnce_
;
// if v load whole block length (vhdim) to Vgprs at once
static
constexpr
bool
OGradLoadOnce
=
OGradLoadOnce_
;
// if do load whole block length (vhdim) to LDS at once
static
constexpr
bool
OGradTLoadOnce
=
OGradTLoadOnce_
;
// if do^t load whole block length (vhdim) to LDS at once
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
2
>
{}),
false
,
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{})
==
16
?
false
:
true
>
;
using
BlockGemmPolicy
=
BlockGemmARegBRegCRegV1CustomPolicy
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetPTOGradTBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kK1
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
2
>
{}),
true
>
;
using
BlockGemmPolicy
=
BlockGemmARegBRegCRegV1CustomPolicy
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetOGradVBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK2
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
::
at
(
number
<
2
>
{}),
false
,
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{})
==
16
?
false
:
true
>
;
using
BlockGemmPolicy
=
BlockGemmARegBRegCRegV1CustomPolicy
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradTQTBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK3
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
2
>
{}),
true
>
;
using
BlockGemmPolicy
=
BlockGemmARegBRegCRegV1CustomPolicy
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradKTBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK4
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
2
>
{}),
false
>
;
using
BlockGemmPolicy
=
BlockGemmARegBRegCRegV1CustomPolicy
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
// these are for global load
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentQ
()
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
return
16
/
sizeof
(
QDataType
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
QDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
QDataType
);
constexpr
index_t
total_pixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
kVecLoad
=
((
total_pixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
?
kMaxVecLoad
:
(
total_pixels
/
kMinVecLoad
);
return
kVecLoad
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentK
()
{
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
return
16
/
sizeof
(
KDataType
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
KDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
KDataType
);
constexpr
index_t
total_pixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
kVecLoad
=
((
total_pixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
?
kMaxVecLoad
:
(
total_pixels
/
kMinVecLoad
);
return
kVecLoad
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentV
()
{
if
constexpr
(
VLoadOnce
)
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetOGradVBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
return
WG
::
kK
/
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
}
else
{
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
return
16
/
sizeof
(
VDataType
);
}
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
VDataType
);
constexpr
index_t
total_pixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
return
total_pixels
>
kMaxVecLoad
?
kMaxVecLoad
:
total_pixels
;
}
template
<
typename
Problem
>
...
...
@@ -85,19 +242,38 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentOGrad
()
{
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
return
16
/
sizeof
(
OGradDataType
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
OGradDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
OGradDataType
);
constexpr
index_t
total_pixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
kVecLoad
=
((
total_pixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
?
kMaxVecLoad
:
(
total_pixels
/
kMinVecLoad
);
return
kVecLoad
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment
QGrad
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment
Bias
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetSGradKTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
constexpr
auto
vec
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
CWarpDstr
::
NDimY
-
1
>
{});
return
vec
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
BiasDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
BiasDataType
);
constexpr
index_t
total_pixels
=
kMPerBlock
*
kNPerBlock
/
kBlockSize
;
constexpr
index_t
kVecLoad
=
((
total_pixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
?
kMaxVecLoad
:
(
total_pixels
/
kMinVecLoad
);
return
kVecLoad
;
}
template
<
typename
Problem
>
...
...
@@ -128,60 +304,35 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetTransposedAlignmentQ
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK3
;
}();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
// TODO: not correct!
if
constexpr
(
total_pixels
>
4
)
return
4
;
else
return
2
;
return
total_pixels
/
GetAlignmentQ
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetTransposedAlignmentK
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kN0
;
else
return
Problem
::
BlockFmhaShape
::
kK4
;
}();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
// TODO: not correct!
if
constexpr
(
total_pixels
>
4
)
return
4
;
else
return
2
;
return
total_pixels
/
GetAlignmentK
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetTransposedAlignmentOGrad
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK1
;
}();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
// TODO: not correct!
if
constexpr
(
total_pixels
>
4
)
return
4
;
else
return
2
;
return
total_pixels
/
GetAlignmentOGrad
<
Problem
>
();
}
template
<
typename
Problem
>
...
...
@@ -193,1151 +344,1577 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
total_pixels
=
kMPerBlock
*
kNPerBlock
/
kBlockSize
;
// TODO: not correct!
if
constexpr
(
total_pixels
>
32
)
return
8
;
else
return
4
;
return
total_pixels
/
GetAlignmentBias
<
Problem
>
();
}
// these are for lds
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Get
SmemKPackQ
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Get
AlignmentPostQGradAcc
()
{
// TODO: this is for 3d layout
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
return
16
/
sizeof
(
QDataType
);
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
return
16
/
sizeof
(
AccDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Get
SmemKPackK
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Get
AlignmentPostQGrad
()
{
// TODO: this is for 3d layout
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
return
16
/
sizeof
(
KDataType
);
return
GetAlignmentPostQGradAcc
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackV
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKDramTileDistribution
()
{
// TODO: this is for 3d layout
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
return
16
/
sizeof
(
VDataType
);
}
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackBias
()
{
// TODO: this is for 3d layout
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
return
16
/
sizeof
(
BiasDataType
);
}
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackOGrad
()
{
// TODO: this is for 3d layout
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
return
16
/
sizeof
(
OGradDataType
);
constexpr
index_t
K1
=
GetAlignmentK
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N1
=
get_warp_size
()
/
K0
;
constexpr
index_t
N0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N2
=
kNPerBlock
/
(
N1
*
N0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackSGrad
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVDramTileDistribution
()
{
// TODO: this is for 3d layout
using
GemmDataType
=
remove_cvref_t
<
typename
Problem
::
GemmDataType
>
;
return
16
/
sizeof
(
GemmDataType
);
}
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
template
<
typename
Problem
,
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVInRegDramTileDistribution
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
constexpr
index_t
K1
=
GetAlignmentV
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQDramTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WG
::
k
N
)
;
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WG
::
kK
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
k
M0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK
0
;
constexpr
auto
v_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
v_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
v_block_outer_dstr_encoding
,
typename
WG
::
BWarpDstrEncoding
{});
constexpr
auto
v_block_dstr
=
make_static_tile_distribution
(
v_block_dstr_encode
);
constexpr
index_t
K1
=
GetAlignmentQ
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M1
=
get_warp_size
()
/
K0
;
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M2
=
kMPerBlock
/
(
M1
*
M0
);
return
v_block_dstr
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
1
>>
{});
}
// 3d + padding
template
<
index_t
MNPerBlock
,
index_t
KPerBlock
,
index_t
KPack
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeXLdsBlockDescriptor
()
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOGradDramTileDistribution
()
{
constexpr
auto
x_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
MNPerBlock
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
(
MNPerBlock
+
1
)
*
KPack
>
{},
number
<
KPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
auto
x_lds_block_desc
=
transform_tensor_descriptor
(
x_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
MNPerBlock
),
make_merge_transform
(
make_tuple
(
KPerBlock
/
KPack
,
KPack
))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
x_lds_block_desc
;
}
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
// 3d + padding
template
<
index_t
MNPerBlock
,
index_t
KPerBlock
,
index_t
KPack
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeXLdsBlockDescriptorAsXT
()
{
constexpr
auto
x_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
MNPerBlock
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
(
MNPerBlock
+
1
)
*
KPack
>
{},
number
<
KPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
auto
xt_lds_block_desc
=
transform_tensor_descriptor
(
x_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
MNPerBlock
),
make_merge_transform
(
make_tuple
(
KPerBlock
/
KPack
,
KPack
))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
constexpr
index_t
K1
=
GetAlignmentOGrad
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M1
=
get_warp_size
()
/
K0
;
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M2
=
kMPerBlock
/
(
M1
*
M0
);
return
xt_lds_block_desc
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
1
>>
{});
}
template
<
index_t
MNPerBlock
,
index_t
KPerBlock
,
index_t
KPack
,
index_t
PixelsPerRow
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
XTLdsBlockDescriptor
()
template
<
typename
Problem
,
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
LSEDDramTileDistribution
()
{
static_assert
(
PixelsPerRow
%
KPack
==
0
);
constexpr
index_t
NPerRow
=
PixelsPerRow
/
KPack
;
static_assert
(
MNPerBlock
%
NPerRow
==
0
);
static_assert
(
KPerBlock
%
KPack
==
0
);
constexpr
auto
xt_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
MNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
(
MNPerBlock
/
NPerRow
)
*
(
PixelsPerRow
+
KPack
)
>
{},
number
<
PixelsPerRow
+
KPack
>
{},
number
<
KPack
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
number
<
1
>
{});
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
auto
xt_lds_block_desc
=
transform_tensor_descriptor
(
xt_lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
MNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{})),
make_merge_transform
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
return
xt_lds_block_desc
;
}
// Duplicate dimension
constexpr
index_t
N0
=
NWarp
;
constexpr
index_t
N1
=
(
get_warp_size
()
/
kMPerBlock
)
>
1
?
(
get_warp_size
()
/
kMPerBlock
)
:
1
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQLdsBlockDescriptor
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK0
;
}();
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
constexpr
index_t
M0
=
MWarp
;
constexpr
index_t
M1
=
(
get_warp_size
()
/
kMPerBlock
)
>
1
?
kMPerBlock
:
get_warp_size
();
constexpr
index_t
M2
=
(
get_warp_size
()
/
kMPerBlock
)
>
1
?
1
:
(
kMPerBlock
/
get_warp_size
());
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
N0
,
N1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
0
>
,
sequence
<
1
,
1
>>
,
sequence
<
1
>
,
sequence
<
2
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
QLdsBlockDescriptorAsQT
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
BiasTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
k
K
PerBlock
=
[
&
]()
{
if
constexpr
(
QLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK
0
;
}
();
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
(
);
constexpr
index_t
k
N
PerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
N1
=
GetAlignmentBias
<
Problem
>
()
;
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
M1
=
get_warp_size
()
/
N
0
;
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M2
=
kMPerBlock
/
(
M1
*
M0
);
return
MakeXLdsBlockDescriptorAsXT
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
N0
,
N1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
KLdsBlockDescriptor
()
template
<
typename
DataType
,
index_t
MPerBlock
,
index_t
KPerBlock
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
PreXDramTileDistribution
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK0
;
}();
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
constexpr
index_t
K1
=
16
/
sizeof
(
DataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
M2
=
1
;
constexpr
index_t
M1
=
get_warp_size
();
constexpr
index_t
M0
=
MPerBlock
/
M1
;
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
>
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
2
,
0
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
KLdsBlockDescriptorAsKT
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
PreODramTileDistribution
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK0
;
}();
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
kVHeaddim
;
return
Make
XLdsBlockDescriptorAsXT
<
kNPer
Block
,
kKPerBlock
,
kKPack
>
();
return
Make
PreXDramTileDistribution
<
ODataType
,
k
Block
Size
,
kKPerBlock
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
VLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
PreOGradDramTileDistribution
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
kKPack
=
GetSmemKPackV
<
Problem
>
();
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
>
();
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
kVHeaddim
;
return
MakePreXDramTileDistribution
<
OGradDataType
,
kBlockSize
,
kKPerBlock
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
OGradLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
PostQGradAccDramTileDistribution
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kVHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK2
;
}();
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
}
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
kQKHeaddim
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOGradLdsBlockDescriptorAsOGradT
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kVHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK2
;
}();
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
constexpr
index_t
K1
=
16
/
sizeof
(
AccDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
kMPerBlock
/
(
M1
*
M2
);
return
MakeXLdsBlockDescriptorAsXT
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
1
>
,
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
3
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
,
3
>
,
sequence
<
0
,
0
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
SGradLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
PostQGradDramTileDistribution
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPack
=
GetSmemKPackSGrad
<
Problem
>
();
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
}
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
kQKHeaddim
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQTLdsBlockDescriptor
()
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
QDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK3
;
}();
constexpr
index_t
K1
=
16
/
sizeof
(
AccDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
kMPerBlock
/
(
M1
*
M2
);
return
MakeXTLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
,
PixelsPerRow
>
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
// these are for lds
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKTLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackQ
()
{
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
KDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kN0
;
else
return
Problem
::
BlockFmhaShape
::
kK4
;
}();
return
MakeXTLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
,
PixelsPerRow
>
();
return
GetAlignmentQ
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOGradTLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackQT
()
{
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
OGradDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK1
;
}();
return
MakeXTLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
,
PixelsPerRow
>
();
return
GetTransposedAlignmentQ
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBiasTLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackK
()
{
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
BiasDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackBias
<
Problem
>
();
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
static_assert
(
PixelsPerRow
%
kKPack
==
0
);
constexpr
index_t
NPerRow
=
PixelsPerRow
/
kKPack
;
static_assert
(
kNPerBlock
%
NPerRow
==
0
);
static_assert
(
kMPerBlock
%
kKPack
==
0
);
constexpr
auto
biast_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kMPerBlock
/
kKPack
>
{},
number
<
kNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{},
number
<
kKPack
>
{}),
make_tuple
(
number
<
(
kNPerBlock
/
NPerRow
)
*
(
PixelsPerRow
+
kKPack
)
>
{},
number
<
PixelsPerRow
+
kKPack
>
{},
number
<
kKPack
>
{},
number
<
1
>
{}),
number
<
kKPack
>
{},
number
<
1
>
{});
constexpr
auto
biast_lds_block_desc
=
transform_tensor_descriptor
(
biast_lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
kNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{})),
make_merge_transform
(
make_tuple
(
number
<
kMPerBlock
/
kKPack
>
{},
number
<
kKPack
>
{}))),
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
biast_lds_block_desc
;
return
GetAlignmentK
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmem
SizeQ
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmem
KPackKT
()
{
constexpr
index_t
smem_size_q
=
sizeof
(
typename
Problem
::
QDataType
)
*
MakeQLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_q
;
return
GetTransposedAlignmentK
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmem
SizeQT
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmem
KPackV
()
{
constexpr
index_t
smem_size_qt
=
[
&
]()
{
if
constexpr
(
QLoadOnce
&&
!
QTLoadOnce
)
return
0
;
else
return
sizeof
(
typename
Problem
::
QDataType
)
*
MakeQTLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}();
return
smem_size_qt
;
return
GetAlignmentV
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeK
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackBias
()
{
constexpr
index_t
smem_size_k
=
sizeof
(
typename
Problem
::
KDataType
)
*
MakeKLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_k
;
return
GetAlignmentBias
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeK
T
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackBias
T
()
{
constexpr
index_t
smem_size_kt
=
[
&
]()
{
if
constexpr
(
KLoadOnce
&&
!
KTLoadOnce
)
return
0
;
else
return
sizeof
(
typename
Problem
::
KDataType
)
*
MakeKTLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}();
return
smem_size_kt
;
return
GetTransposedAlignmentBias
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeV
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackOGrad
()
{
constexpr
index_t
smem_size_v
=
[
&
]()
{
if
constexpr
(
VLoadOnce
)
return
0
;
else
return
sizeof
(
typename
Problem
::
VDataType
)
*
MakeVLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}();
return
smem_size_v
;
return
GetAlignmentOGrad
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmem
Size
OGrad
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmem
KPack
OGrad
T
()
{
constexpr
index_t
smem_size_do
=
sizeof
(
typename
Problem
::
OGradDataType
)
*
MakeOGradLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_do
;
return
GetTransposedAlignmentOGrad
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmem
SizeO
Grad
T
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmem
KPackS
Grad
()
{
constexpr
index_t
smem_size_dot
=
[
&
]()
{
if
constexpr
(
OGradLoadOnce
&&
!
OGradTLoadOnce
)
return
0
;
else
return
sizeof
(
typename
Problem
::
OGradDataType
)
*
MakeOGradTLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}();
return
smem_size_dot
;
// TODO: this is for 3d layout
using
GemmDataType
=
remove_cvref_t
<
typename
Problem
::
GemmDataType
>
;
return
16
/
sizeof
(
GemmDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeSGrad
()
template
<
index_t
MNPerBlock
,
index_t
KPerBlock
,
index_t
KPack
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeXLdsBlockDescriptor
()
{
constexpr
index_t
smem_size_ds
=
sizeof
(
typename
Problem
::
GemmDataType
)
*
MakeSGradLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_ds
;
constexpr
auto
DataTypeSize
=
2
;
// sizeof(F16/BF16)
constexpr
auto
MNLdsLayer
=
(
32
*
4
/
KPerBlock
/
DataTypeSize
)
<
1
?
1
:
(
32
*
4
/
KPerBlock
/
DataTypeSize
);
constexpr
auto
x_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
KPerBlock
/
KPack
*
MNLdsLayer
>
{},
number
<
MNPerBlock
/
MNLdsLayer
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
KPack
>
{},
number
<
KPerBlock
*
MNLdsLayer
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
number
<
1
>
{});
constexpr
auto
x_lds_block_desc_permuted
=
transform_tensor_descriptor
(
x_lds_block_desc_0
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
MNPerBlock
/
MNLdsLayer
>
{},
number
<
KPerBlock
/
KPack
*
MNLdsLayer
>
{})),
make_pass_through_transform
(
number
<
KPack
>
{})),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}));
constexpr
auto
x_lds_block_desc_xk0_mnldslayer_mn_xk1
=
transform_tensor_descriptor
(
x_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
MNLdsLayer
>
{})),
make_pass_through_transform
(
number
<
MNPerBlock
/
MNLdsLayer
>
{}),
make_pass_through_transform
(
number
<
KPack
>
{})),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{},
sequence
<
3
>
{}));
constexpr
auto
x_lds_block_desc
=
transform_tensor_descriptor
(
x_lds_block_desc_xk0_mnldslayer_mn_xk1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
MNPerBlock
/
MNLdsLayer
>
{},
number
<
MNLdsLayer
>
{})),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
x_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeBias
()
template
<
typename
Problem
,
index_t
MNPerBlock
,
index_t
KPerBlock
,
index_t
KPack
,
index_t
KPackT
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeXTLdsBlockDescriptor
()
{
constexpr
index_t
smem_size_bias
=
[
&
]()
{
if
constexpr
(
Problem
::
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
sizeof
(
typename
Problem
::
BiasDataType
)
*
MakeBiasTLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
else
return
0
;
}();
return
smem_size_bias
;
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr
auto
MNPerXDL
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{});
constexpr
auto
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
auto
MN0
=
MNPerBlock
/
KPack
;
constexpr
auto
MN1
=
KPack
;
constexpr
auto
KThreadWrite
=
kBlockSize
/
MN0
;
constexpr
auto
K0Number
=
KPerBlock
/
KPackT
;
constexpr
auto
K0PerThreadWrite
=
K0Number
/
KThreadWrite
;
constexpr
auto
KThreadRead
=
get_warp_size
()
/
MNPerXDL
;
// assume 32x32x8 mfma
constexpr
auto
K0PerThreadRead
=
K0Number
/
KThreadRead
;
constexpr
auto
kfold
=
(
KPackT
*
MN0
*
2
>
128
)
?
1
:
128
/
(
KPackT
*
MN0
*
2
);
constexpr
auto
KThreadReadPerm
=
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
>
1
?
KThreadRead
/
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
:
KThreadRead
;
// 1<=mnpair<=n0
constexpr
auto
mnpair
=
(
KPackT
*
MNPerXDL
*
2
>
128
)
?
1
:
((
128
/
(
KPackT
*
MNPerXDL
*
2
))
>
MN0
?
MN0
:
128
/
(
KPackT
*
MNPerXDL
*
2
));
constexpr
auto
xt_lds_block_desc_raw
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
K0PerThreadWrite
>
{},
number
<
KThreadReadPerm
*
MN1
>
{},
number
<
kfold
*
MN0
/
mnpair
>
{},
number
<
mnpair
>
{},
KPackT
),
make_tuple
(
number
<
KPackT
*
kfold
*
MN0
*
KThreadReadPerm
*
MN1
*
K0PerThreadWrite
>
{},
number
<
KPackT
*
kfold
*
MN0
*
KThreadReadPerm
*
MN1
>
{},
number
<
KPackT
*
kfold
*
MN0
>
{},
number
<
KPackT
*
mnpair
>
{},
number
<
KPackT
>
{},
number
<
1
>
{}),
number
<
KPackT
>
{},
number
<
1
>
{});
constexpr
auto
xt_lds_block_desc_permuted
=
transform_tensor_descriptor
(
xt_lds_block_desc_raw
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_xor_transform
(
make_tuple
(
number
<
KThreadReadPerm
*
MN1
>
{},
number
<
kfold
*
MN0
/
mnpair
>
{})),
make_pass_through_transform
(
number
<
mnpair
>
{}),
make_pass_through_transform
(
KPackT
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}));
constexpr
auto
xt_lds_block_desc_unmerged
=
transform_tensor_descriptor
(
xt_lds_block_desc_permuted
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_unmerge_transform
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
MN1
>
{})),
make_unmerge_transform
(
make_tuple
(
number
<
kfold
>
{},
number
<
MN0
/
mnpair
>
{})),
make_pass_through_transform
(
number
<
mnpair
>
{}),
make_pass_through_transform
(
KPackT
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
0
,
3
>
{},
sequence
<
4
,
5
>
{},
sequence
<
6
>
{},
sequence
<
7
>
{}));
constexpr
auto
xt_lds_block_desc
=
transform_tensor_descriptor
(
xt_lds_block_desc_unmerged
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
kfold
>
{},
number
<
K0PerThreadWrite
>
{},
number
<
KPackT
>
{})),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
MN0
/
mnpair
>
{},
number
<
mnpair
>
{},
number
<
MN1
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
4
,
2
,
7
>
{},
sequence
<
5
,
6
,
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
xt_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsWriteBlockDescriptor
()
{
constexpr
index_t
smem_size_q
=
GetSmemSizeQ
<
Problem
>
();
constexpr
index_t
smem_size_qt
=
GetSmemSizeQT
<
Problem
>
();
constexpr
index_t
smem_size_k
=
GetSmemSizeK
<
Problem
>
();
constexpr
index_t
smem_size_kt
=
GetSmemSizeKT
<
Problem
>
();
constexpr
index_t
smem_size_v
=
GetSmemSizeV
<
Problem
>
();
constexpr
index_t
smem_size_do
=
GetSmemSizeOGrad
<
Problem
>
();
constexpr
index_t
smem_size_dot
=
GetSmemSizeOGradT
<
Problem
>
();
constexpr
index_t
smem_size_ds
=
GetSmemSizeSGrad
<
Problem
>
();
constexpr
index_t
smem_size_bias
=
GetSmemSizeBias
<
Problem
>
();
constexpr
index_t
smem_size_transpose
=
max
(
smem_size_ds
,
smem_size_bias
);
index_t
smem_size
=
0
;
if
constexpr
(
QLoadOnce
&&
OGradLoadOnce
)
smem_size
+=
smem_size_q
+
smem_size_qt
+
smem_size_do
+
smem_size_dot
+
smem_size_transpose
;
// 1~4 & 10
else
if
(
QLoadOnce
&&
!
OGradLoadOnce
&&
!
OGradTLoadOnce
)
smem_size
+=
smem_size_q
+
smem_size_qt
+
max
(
smem_size_do
,
smem_size_dot
,
smem_size_transpose
);
// 5/7/11 TODO: Multiple buffers strategy
else
if
(
!
QLoadOnce
&&
!
QTLoadOnce
&&
OGradLoadOnce
)
smem_size
+=
smem_size_do
+
smem_size_dot
+
max
(
smem_size_q
,
smem_size_qt
,
smem_size_transpose
);
// 6/8/12 TODO: Multiple buffers strategy
else
if
(
!
QLoadOnce
&&
!
QTLoadOnce
&&
!
OGradLoadOnce
&&
!
OGradTLoadOnce
)
smem_size
+=
max
(
smem_size_q
,
smem_size_qt
,
smem_size_do
,
smem_size_dot
,
smem_size_transpose
);
// 9/13 TODO: Multiple buffers strategy
// 14/15 needs to be adjusted
if
constexpr
(
KLoadOnce
)
smem_size
+=
(
smem_size_k
+
smem_size_kt
);
// 1~13
else
smem_size
=
max
(
smem_size_k
,
smem_size_kt
,
smem_size
);
// 14/15 TODO: Multiple buffers strategy
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
return
max
(
smem_size
,
smem_size_v
);
// 15
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
>
();
}
template
<
typename
Problem
,
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
LSEDDramTileDistribution
()
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
KRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetQKBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
N1
=
WG
::
WarpGemmAttribute
::
Impl
::
kCNLane
;
constexpr
index_t
N0
=
NWarp
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
M4
=
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
*
2
;
constexpr
index_t
M3
=
WG
::
WarpGemmAttribute
::
Impl
::
kCMLane
;
constexpr
index_t
M2
=
WG
::
WarpGemmAttribute
::
Impl
::
kCM0PerLane
/
2
;
constexpr
index_t
M1
=
MWarp
;
constexpr
index_t
M0
=
kMPerBlock
/
(
M1
*
WG
::
WarpGemmAttribute
::
Impl
::
kM
);
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
N0
,
N1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
,
M3
,
M4
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
3
,
1
>>
,
sequence
<
1
,
1
,
1
>
,
sequence
<
0
,
2
,
4
>>
{});
constexpr
auto
k_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
k_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
k_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
k_block_dstr
=
make_static_tile_distribution
(
k_block_dstr_encode
);
return
k_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
VDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
KRegBlockDescriptor
()
{
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetQKBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
K1
=
16
/
sizeof
(
VDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
return
make_static_tile
_d
i
str
ibution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N
0
,
N1
,
N2
>
,
sequence
<
K
0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
constexpr
auto
k_block_outer
_dstr
_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
N
IterPerWarp
,
NWarp
>
,
sequence
<
K
IterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
sequence
<
0
,
0
>>
{};
constexpr
auto
k_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
k_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
k_block_dstr
=
make_static_tile_distribution
(
k_block_dstr_encode
);
return
k_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
QDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
VLdsWriteBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK0
;
}();
constexpr
index_t
kVPack
=
GetSmemKPackV
<
Problem
>
();
constexpr
index_t
K1
=
GetAlignmentQ
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kVPack
>
();
}
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetOGradVBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
v_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
sequence
<
0
,
0
>>
{};
constexpr
auto
v_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
v_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
v_block_dstr
=
make_static_tile_distribution
(
v_block_dstr_encode
);
return
v_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
KDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
VRegBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetOGradVBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK0
;
}();
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
K1
=
GetAlignmentK
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
return
make_static_tile
_d
i
str
ibution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N
0
,
N1
,
N2
>
,
sequence
<
K
0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
constexpr
auto
v_block_outer
_dstr
_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
N
IterPerWarp
,
NWarp
>
,
sequence
<
K
IterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
sequence
<
0
,
0
>>
{};
constexpr
auto
v_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
v_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
v_block_dstr
=
make_static_tile_distribution
(
v_block_dstr_encode
);
return
v_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
OGradDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
ShuffledKRegWriteBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kVHeaddim
;
else
return
Problem
::
BlockFmhaShape
::
kK2
;
}();
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
K1
=
GetAlignment
OGrad
<
Problem
>
();
constexpr
index_t
K1
=
GetAlignment
K
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
constexpr
index_t
N2
=
GetTransposedAlignmentK
<
Problem
>
();
constexpr
index_t
N1
=
get_warp_size
()
/
K0
;
constexpr
index_t
N0
=
kBlockSize
/
get_warp_size
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M
0
,
M
1
,
M
2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
N
0
,
N
1
,
N
2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
2
,
1
>
,
sequence
<
1
,
2
>>
{});
}
template
<
typename
DataType
,
index_t
MPerBlock
,
index_t
KPerBlock
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
PreXDramTileDistribution
()
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
ShuffledKLdsWriteBlockDescriptor
()
{
constexpr
index_t
K1
=
16
/
sizeof
(
DataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
M2
=
1
;
constexpr
index_t
M1
=
get_warp_size
();
constexpr
index_t
M0
=
MPerBlock
/
M1
;
// Hold all data
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
2
,
0
,
1
>>
{});
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
constexpr
index_t
kKPackT
=
GetSmemKPackKT
<
Problem
>
();
return
MakeXTLdsBlockDescriptor
<
Problem
,
kNPerBlock
,
kKPerBlock
,
kKPack
,
kKPackT
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
PreODramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
KTLdsReadBlockDescriptor
()
{
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
kVHeaddim
;
auto
shuffled_k_lds_block_desc
=
MakeShuffledKLdsWriteBlockDescriptor
<
Problem
>
();
return
MakePreXDramTileDistribution
<
ODataType
,
kBlockSize
,
kKPerBlock
>
();
return
transform_tensor_descriptor
(
shuffled_k_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
number
<
kNPerBlock
>
{}),
make_pass_through_transform
(
number
<
kKPerBlock
>
{})),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
PreOGradDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
KTRegBlockDescriptor
()
{
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetSGradKTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
kVHeaddim
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
0
>
{})
;
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
1
>
{})
;
return
MakePreXDramTileDistribution
<
OGradDataType
,
kBlockSize
,
kKPerBlock
>
();
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
kt_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
kt_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
kt_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
kt_block_dstr
=
make_static_tile_distribution
(
kt_block_dstr_encode
);
return
kt_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeQ
TDramTileDistribution
()
CK_TILE_
HOST_
DEVICE
static
constexpr
auto
MakeQ
LdsBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK3
;
}();
constexpr
index_t
N1
=
GetTransposedAlignmentQ
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
// P
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
ShuffledQTReg
BlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
QRegSlice
BlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
QTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK3
;
}();
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetQKBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
N1
=
GetTransposedAlignmentQ
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
1
>
{});
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
q_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
sequence
<
0
,
0
>>
{};
constexpr
auto
q_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
q_block_outer_dstr_encoding
,
typename
WarpGemm
::
AWarpDstrEncoding
{});
constexpr
auto
q_block_dstr
=
make_static_tile_distribution
(
q_block_dstr_encode
);
return
q_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
Make
KTDramTileDistribution
()
CK_TILE_
HOST_
DEVICE
static
constexpr
auto
Make
ShuffledQRegWriteBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kN0
;
else
return
Problem
::
BlockFmhaShape
::
kK4
;
}();
constexpr
index_t
N1
=
GetTransposedAlignmentK
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
// P
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2
*
K3
);
constexpr
index_t
K1
=
GetAlignmentQ
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
GetTransposedAlignmentQ
<
Problem
>
();
constexpr
index_t
N1
=
get_warp_size
()
/
K0
;
constexpr
index_t
N0
=
kBlockSize
/
get_warp_size
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
sequence
<
1
,
2
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffled
KTReg
BlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffled
QLdsWrite
BlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
// Hold full block data
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
KTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kN0
;
else
return
Problem
::
BlockFmhaShape
::
kK4
;
}();
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
N1
=
GetTransposedAlignmentK
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
constexpr
index_t
kKPackT
=
GetSmemKPackQT
<
Problem
>
();
return
MakeXTLdsBlockDescriptor
<
Problem
,
kNPerBlock
,
kKPerBlock
,
kKPack
,
kKPackT
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQTLdsReadBlockDescriptor
()
{
// Hold full block data
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
auto
shuffled_q_lds_block_desc
=
MakeShuffledQLdsWriteBlockDescriptor
<
Problem
>
();
return
transform_tensor_descriptor
(
shuffled_q_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
number
<
kNPerBlock
>
{}),
make_pass_through_transform
(
number
<
kKPerBlock
>
{})),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQTRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetSGradTQTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK3
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
qt_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
qt_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
qt_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
qt_block_dstr
=
make_static_tile_distribution
(
qt_block_dstr_encode
);
return
qt_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeSGradTRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetSGradTQTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK3
;
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
dst_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
dst_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
dst_block_outer_dstr_encoding
,
typename
WarpGemm
::
AWarpDstrEncoding
{});
constexpr
auto
dst_block_dstr
=
make_static_tile_distribution
(
dst_block_dstr_encode
);
return
dst_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEDLdsWriteBlockDescriptor
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
using
LSEDType
=
remove_cvref_t
<
typename
Problem
::
DDataType
>
;
constexpr
index_t
kMPack
=
16
/
sizeof
(
LSEDType
);
constexpr
auto
lsed_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kMPerBlock
>
{}),
make_tuple
(
number
<
1
>
{}),
number
<
kMPack
>
{},
number
<
1
>
{});
return
lsed_lds_block_desc
;
}
template
<
typename
Problem
,
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEDLdsReadBlockDescriptor
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
N1
=
WG
::
WarpGemmAttribute
::
Impl
::
kCNLane
;
constexpr
index_t
N0
=
NWarp
;
// M4 *2 and M2 /2 when swizzle mode enabled
constexpr
index_t
SwizzleConfig
=
WG
::
kM
==
16
?
1
:
2
;
// constexpr index_t SwizzleConfig = 1;
constexpr
index_t
M4
=
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
*
SwizzleConfig
;
constexpr
index_t
M3
=
WG
::
WarpGemmAttribute
::
Impl
::
kCMLane
;
constexpr
index_t
M2
=
WG
::
WarpGemmAttribute
::
Impl
::
kCM0PerLane
/
SwizzleConfig
;
constexpr
index_t
M1
=
MWarp
;
constexpr
index_t
M0
=
kMPerBlock
/
(
M1
*
WG
::
WarpGemmAttribute
::
Impl
::
kM
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
tile_distribution_encoding
<
sequence
<
N0
,
N1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
,
M3
,
M4
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
3
,
1
>>
,
sequence
<
1
,
1
,
1
>
,
sequence
<
0
,
2
,
4
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOGradLdsBlockDescriptor
()
{
// Hold full block data
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOGradRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetOGradVBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
do_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
sequence
<
0
,
0
>>
{};
constexpr
auto
do_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
do_block_outer_dstr_encoding
,
typename
WarpGemm
::
AWarpDstrEncoding
{});
constexpr
auto
do_block_dstr
=
make_static_tile_distribution
(
do_block_dstr_encode
);
return
do_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
Make
OGradTDramTileDistribution
()
CK_TILE_
HOST_
DEVICE
static
constexpr
auto
Make
ShuffledOGradRegWriteBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK1
;
}();
constexpr
index_t
N1
=
GetTransposedAlignmentOGrad
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
// P
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2
*
K3
);
constexpr
index_t
K1
=
GetAlignmentOGrad
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
GetTransposedAlignmentOGrad
<
Problem
>
();
constexpr
index_t
N1
=
get_warp_size
()
/
K0
;
constexpr
index_t
N0
=
kBlockSize
/
get_warp_size
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
sequence
<
1
,
2
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledOGrad
TReg
BlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledOGrad
LdsWrite
BlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
// Hold all data
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
kKPerBlock
=
[
&
]()
{
if
constexpr
(
OGradTLoadOnce
)
return
Problem
::
BlockFmhaShape
::
kM0
;
else
return
Problem
::
BlockFmhaShape
::
kK1
;
}();
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
N1
=
GetTransposedAlignmentOGrad
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
kKPackT
=
GetSmemKPackOGradT
<
Problem
>
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
return
MakeXTLdsBlockDescriptor
<
Problem
,
kNPerBlock
,
kKPerBlock
,
kKPack
,
kKPackT
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOGradTLdsReadBlockDescriptor
()
{
// Hold all data
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
auto
shuffled_do_lds_block_desc
=
MakeShuffledOGradLdsWriteBlockDescriptor
<
Problem
>
();
return
transform_tensor_descriptor
(
shuffled_do_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
number
<
kNPerBlock
>
{}),
make_pass_through_transform
(
number
<
kKPerBlock
>
{})),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOGradTRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetPTOGradTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
// constexpr index_t kNPerBlock = 32;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
dot_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
dot_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
dot_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
dot_block_dstr
=
make_static_tile_distribution
(
dot_block_dstr_encode
);
return
dot_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakePTRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetPTOGradTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
pt_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
sequence
<
0
,
0
>>
{};
constexpr
auto
pt_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
pt_block_outer_dstr_encoding
,
typename
WarpGemm
::
AWarpDstrEncoding
{});
constexpr
auto
pt_block_dstr
=
make_static_tile_distribution
(
pt_block_dstr_encode
);
return
pt_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
Make
BiasTileDistribution
()
CK_TILE_
HOST_
DEVICE
static
constexpr
auto
Make
SGradLdsBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPack
=
GetSmemKPackSGrad
<
Problem
>
();
constexpr
index_t
N1
=
GetTransposedAlignmentBias
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
// P
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
}
constexpr
index_t
total_pixels
=
kMPerBlock
*
kNPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
M3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackBias
<
Problem
>
();
static_assert
(
kKPack
%
M3
==
0
);
constexpr
index_t
M2
=
kKPack
/
M3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
M1
=
get_warp_size
()
/
(
M2
*
N0
);
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
static_assert
(
kMPerBlock
==
M0
*
M1
*
M2
*
M3
);
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeSGradRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetSGradKTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
,
M3
>
,
sequence
<
N0
,
N1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK4
;
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
ds_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
3
,
1
>>
{});
sequence
<
0
,
0
>>
{};
constexpr
auto
ds_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
ds_block_outer_dstr_encoding
,
typename
WarpGemm
::
AWarpDstrEncoding
{});
constexpr
auto
ds_block_dstr
=
make_static_tile_distribution
(
ds_block_dstr_encode
);
return
ds_block_dstr
;
}
template
<
typename
Problem
,
typename
PTOutTensor
,
typename
PInTensor
>
CK_TILE_DEVICE
static
constexpr
void
PTFromGemm0CToGemm1A
(
PTOutTensor
&
pt_out
,
const
PInTensor
&
p_in
)
{
if
constexpr
(
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
0
>
{})
==
16
)
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetPTOGradTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
using
AWarpDstr
=
typename
WarpGemm
::
AWarpDstr
;
using
CWarpDstr
=
typename
WarpGemm
::
CWarpDstr
;
auto
pt_warp_tensor
=
make_static_distributed_tensor
<
typename
Problem
::
GemmDataType
>
(
CWarpDstr
{});
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
pt_warp_tensor
.
get_thread_buffer
()
=
p_in
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
kIter
,
mIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
pt_out
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
kIter
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
),
pt_warp_tensor
.
get_thread_buffer
());
});
});
}
else
{
pt_out
.
get_thread_buffer
()
=
p_in
.
get_thread_buffer
();
}
}
template
<
typename
Problem
,
typename
SGradTOutTensor
,
typename
SGradInTensor
>
CK_TILE_DEVICE
static
constexpr
void
SGradTFromGemm2CToGemm3A
(
SGradTOutTensor
&
dst_out
,
const
SGradInTensor
&
ds_in
)
{
if
constexpr
(
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
0
>
{})
==
16
)
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetSGradTQTBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK3
;
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
using
AWarpDstr
=
typename
WarpGemm
::
AWarpDstr
;
using
CWarpDstr
=
typename
WarpGemm
::
CWarpDstr
;
auto
dst_warp_tensor
=
make_static_distributed_tensor
<
typename
Problem
::
GemmDataType
>
(
CWarpDstr
{});
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
dst_warp_tensor
.
get_thread_buffer
()
=
ds_in
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
kIter
,
mIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
dst_out
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
kIter
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
),
dst_warp_tensor
.
get_thread_buffer
());
});
});
}
else
{
dst_out
.
get_thread_buffer
()
=
ds_in
.
get_thread_buffer
();
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBiasTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
N1
=
GetTransposed
AlignmentBias
<
Problem
>
();
constexpr
index_t
N1
=
Get
AlignmentBias
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
kMPerBlock
*
kNPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
M3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackBias
<
Problem
>
();
static_assert
(
kKPack
%
M3
==
0
);
constexpr
index_t
M2
=
kKPack
/
M3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
M1
=
get_warp_size
()
/
(
M2
*
N0
);
constexpr
index_t
M2
=
GetTransposedAlignmentBias
<
Problem
>
();
constexpr
index_t
M1
=
get_warp_size
()
/
N0
;
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
,
M3
>
,
sequence
<
N0
,
N1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
N0
,
N1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
2
,
1
>
,
sequence
<
1
,
3
>>
{});
sequence
<
1
,
2
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBiasLdsBlockDescriptor
()
{
// Hold full block data
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPack
=
GetSmemKPackBias
<
Problem
>
();
constexpr
index_t
kKPackT
=
GetSmemKPackBiasT
<
Problem
>
();
return
MakeXTLdsBlockDescriptor
<
Problem
,
kNPerBlock
,
kMPerBlock
,
kKPack
,
kKPackT
>
();
}
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBias
T
TileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBias
S
TileDistribution
()
{
using
c_block_tensor_type
=
decltype
(
BlockGemm
{}.
MakeCBlockTile
());
return
c_block_tensor_type
::
get_tile_distribution
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeQ
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>>
;
constexpr
index_t
smem_size_q
=
sizeof
(
typename
Problem
::
QDataType
)
*
MakeQLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_q
;
}
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
AccDataType
,
float
>
)
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeQT
()
{
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
{};
constexpr
index_t
smem_size_qt
=
sizeof
(
typename
Problem
::
QDataType
)
*
MakeShuffledQLdsWriteBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_qt
;
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
AccDataType
,
float
>
)
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeK
(
)
{
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
{};
constexpr
index_t
smem_size_k
=
sizeof
(
typename
Problem
::
KDataType
)
*
MakeKLdsWriteBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_k
;
}
}();
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
decltype
(
warp_gemm
)
>
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeKT
()
{
constexpr
index_t
smem_size_kt
=
sizeof
(
typename
Problem
::
KDataType
)
*
MakeKTLdsReadBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_kt
;
}
return
BlockGemmASmemBSmemCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeLSE
()
{
constexpr
index_t
smem_size_lse
=
sizeof
(
typename
Problem
::
LSEDataType
)
*
MakeLSEDLdsWriteBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_lse
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetPTOGradTBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeD
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kK1
>>
;
constexpr
index_t
smem_size_d
=
sizeof
(
typename
Problem
::
DDataType
)
*
MakeLSEDLdsWriteBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_d
;
}
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
2
>
{}),
true
>
;
using
BlockGemmPolicy
=
BlockGemmARegBSmemCRegV1CustomPolicy
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBSmemCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeV
()
{
constexpr
index_t
smem_size_v
=
sizeof
(
typename
Problem
::
VDataType
)
*
MakeVLdsWriteBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_v
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetOGradVBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeOGrad
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK2
>>
;
constexpr
index_t
smem_size_do
=
sizeof
(
typename
Problem
::
OGradDataType
)
*
MakeOGradLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_do
;
}
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
OGradDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
VDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
AccDataType
,
float
>
)
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeOGradT
()
{
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
{};
constexpr
index_t
smem_size_dot
=
sizeof
(
typename
Problem
::
OGradDataType
)
*
MakeShuffledOGradLdsWriteBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_dot
;
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
OGradDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
VDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
AccDataType
,
float
>
)
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeSGrad
(
)
{
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
{};
constexpr
index_t
smem_size_ds
=
sizeof
(
typename
Problem
::
GemmDataType
)
*
MakeSGradLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_ds
;
}
}();
using
BlockGemmPolicy
=
BlockGemmASmemBRegCRegV1CustomPolicy
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
decltype
(
warp_gemm
)
>
;
return
BlockGemmASmemBRegCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
// template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
// {
// using BlockGemmProblem =
// BlockGemmPipelineProblem<typename Problem::OGradDataType,
// typename Problem::VDataType,
// typename Problem::AccDataType,
// Problem::kBlockSize,
// TileGemmShape<Problem::BlockFmhaShape::kM0,
// Problem::BlockFmhaShape::kN0,
// Problem::BlockFmhaShape::kK2>>;
// constexpr auto warp_gemm = []() {
// if constexpr(std::is_same_v<typename Problem::OGradDataType, half_t> &&
// std::is_same_v<typename Problem::VDataType, half_t> &&
// std::is_same_v<typename Problem::AccDataType, float>)
// {
// return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{};
// }
// else if constexpr(std::is_same_v<typename Problem::OGradDataType, bf16_t> &&
// std::is_same_v<typename Problem::VDataType, bf16_t> &&
// std::is_same_v<typename Problem::AccDataType, float>)
// {
// return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{};
// }
// }();
// using BlockGemmPolicy =
// BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::OGradDataType,
// typename Problem::VDataType,
// typename Problem::AccDataType,
// typename
// Problem::BlockFmhaShape::Gemm2BlockWarps,
// decltype(warp_gemm)>;
// return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
// }
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeBias
()
{
constexpr
index_t
smem_size_bias
=
[
&
]()
{
if
constexpr
(
Problem
::
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
sizeof
(
typename
Problem
::
BiasDataType
)
*
MakeBiasLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
else
return
0
;
}();
return
smem_size_bias
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradTQTBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK3
>>
;
constexpr
index_t
smem_size_q
=
GetSmemSizeQ
<
Problem
>
();
constexpr
index_t
smem_size_qt
=
GetSmemSizeQT
<
Problem
>
();
constexpr
index_t
smem_size_lse
=
GetSmemSizeLSE
<
Problem
>
();
constexpr
index_t
smem_size_k
=
GetSmemSizeK
<
Problem
>
();
constexpr
index_t
smem_size_kt
=
GetSmemSizeKT
<
Problem
>
();
constexpr
index_t
smem_size_v
=
GetSmemSizeV
<
Problem
>
();
constexpr
index_t
smem_size_do
=
GetSmemSizeOGrad
<
Problem
>
();
constexpr
index_t
smem_size_dot
=
GetSmemSizeOGradT
<
Problem
>
();
constexpr
index_t
smem_size_d
=
GetSmemSizeD
<
Problem
>
();
constexpr
index_t
smem_size_ds
=
GetSmemSizeSGrad
<
Problem
>
();
constexpr
index_t
smem_size_bias
=
GetSmemSizeBias
<
Problem
>
();
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
2
>
{}),
true
>
;
using
BlockGemmPolicy
=
BlockGemmARegBSmemCRegV1CustomPolicy
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBSmemCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
constexpr
index_t
smem_size_stage0_0
=
smem_size_k
+
smem_size_kt
;
constexpr
index_t
smem_size_stage0_1
=
smem_size_v
;
constexpr
index_t
smem_size_stage1
=
smem_size_qt
+
smem_size_q
+
+
smem_size_dot
+
smem_size_do
+
smem_size_lse
+
smem_size_d
+
max
(
smem_size_bias
,
smem_size_ds
);
return
max
(
smem_size_stage0_0
,
smem_size_stage0_1
,
smem_size_stage1
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradKTBlockGemm
()
template
<
typename
Problem
_
>
struct
HotLoopScheduler
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK4
>>
;
using
Problem
=
Problem_
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
2
>
{}),
true
>
;
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
WarpGemm
>
;
return
BlockGemmASmemBSmemCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
template
<
index_t
GemmStage
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
()
{
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
0
>
()
{
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
// Comp: Q x K
constexpr
index_t
VMEM_READ_INST
=
Q_VMEM_READ
+
OGrad_VMEM_READ
+
LSE_VMEM_READ
+
D_VMEM_READ
;
constexpr
index_t
LDS_READ_INST
=
OGradT_LDS_READ
;
constexpr
index_t
MFMA_INST
=
Gemm0MFMA
;
// Evenly distributed to relieve SQ->TA FIFO pressure
constexpr
index_t
MFMA_PER_VMEM_READ
=
MFMA_INST
/
VMEM_READ_INST
;
constexpr
index_t
MFMA_Remainder
=
MFMA_INST
-
MFMA_PER_VMEM_READ
*
VMEM_READ_INST
;
// To hide instruction issue latency
constexpr
index_t
LDS_READ_PER_MFMA
=
LDS_READ_INST
/
MFMA_INST
;
static_for
<
0
,
VMEM_READ_INST
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
static_for
<
0
,
MFMA_PER_VMEM_READ
,
1
>
{}([
&
](
auto
j
)
{
ignore
=
j
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_PER_MFMA
,
0
);
// DS read
});
});
static_for
<
0
,
MFMA_Remainder
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_PER_MFMA
,
0
);
// DS read
});
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
1
>
()
{
// Mem: Q^T LDS load
// Comp: OGrad x V
constexpr
index_t
LDS_READ_INST
=
QT_LDS_READ
;
constexpr
index_t
MFMA_INST
=
Gemm1MFMA
;
// To hide instruction issue latency
constexpr
index_t
LDS_READ_PER_MFMA
=
LDS_READ_INST
/
MFMA_INST
;
static_for
<
0
,
MFMA_INST
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_PER_MFMA
,
0
);
// DS read
});
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
2
>
()
{
// Mem: Q, QT, LSE, OGrad, OGradT, D, LDS store
// Comp: PT x OGrad
constexpr
index_t
LDS_WRITE_INST
=
Q_LDS_WRITE
+
QT_LDS_WRITE
+
OGrad_LDS_WRITE
+
OGradT_LDS_WRITE
+
LSE_LDS_WRITE
+
D_LDS_WRITE
;
constexpr
index_t
MFMA_INST
=
Gemm2MFMA
;
// To hide instruction issue latency
constexpr
index_t
LDS_WRITE_PER_MFMA
=
LDS_WRITE_INST
/
MFMA_INST
;
static_for
<
0
,
MFMA_INST
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x200
,
LDS_WRITE_PER_MFMA
,
0
);
// DS write
});
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
3
>
()
{
// Mem: SGradT LDS store, SGrad, Q, LSE LDS load.
// Comp: SGradT x QT
constexpr
index_t
LDS_WRITE_INST
=
SGradT_LDS_WRITE
;
constexpr
index_t
LDS_READ_INST
=
SGradT_LDS_READ_P1
+
Q_LDS_READ
+
LSE_LDS_READ
;
constexpr
index_t
MFMA_INST
=
Gemm3MFMA
;
// To hide instruction issue latency
constexpr
index_t
LDS_WRITE_PER_MFMA
=
LDS_WRITE_INST
/
MFMA_INST
>=
1
?
LDS_WRITE_INST
/
MFMA_INST
:
1
;
constexpr
index_t
MFMA_INST_LDS_WRITE
=
LDS_WRITE_INST
/
LDS_WRITE_PER_MFMA
;
constexpr
index_t
LDS_READ_PER_MFMA
=
(
MFMA_INST
-
MFMA_INST_LDS_WRITE
)
>
0
?
LDS_READ_INST
/
(
MFMA_INST
-
MFMA_INST_LDS_WRITE
)
>
0
?
LDS_READ_INST
/
(
MFMA_INST
-
MFMA_INST_LDS_WRITE
)
:
1
:
0
;
static_for
<
0
,
MFMA_INST_LDS_WRITE
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x200
,
LDS_WRITE_PER_MFMA
,
0
);
// DS Write
});
static_for
<
0
,
MFMA_INST
-
MFMA_INST_LDS_WRITE
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_PER_MFMA
,
0
);
// DS Read
});
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
4
>
()
{
// Mem: SGrad, OGrad, D LDS load.
// Comp: SGrad x KT
constexpr
index_t
LDS_READ_INST
=
SGradT_LDS_READ_P2
+
OGrad_LDS_READ
+
D_LDS_READ
;
constexpr
index_t
MFMA_INST
=
Gemm4MFMA
;
// To hide instruction issue latency
constexpr
index_t
LDS_READ_PER_MFMA
=
LDS_READ_INST
/
MFMA_INST
>
0
?
LDS_READ_INST
/
MFMA_INST
:
1
;
static_for
<
0
,
MFMA_INST
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_PER_MFMA
,
0
);
// DS Read
});
}
private:
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
Problem
::
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
Problem
::
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kQKHeaddim
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kVHeaddim
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
static
constexpr
index_t
kK4
=
Problem
::
BlockFmhaShape
::
kK4
;
static
constexpr
index_t
WarpGemmM
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
WarpGemmN
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
WarpGemmK
=
WarpGemmM
==
16
?
16
:
8
;
static
constexpr
index_t
Gemm4MWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Gemm4NWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
1
>
{});
// Compute
static
constexpr
index_t
Gemm0MFMA
=
kM0
*
kN0
*
kQKHeaddim
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm1MFMA
=
kM0
*
kN0
*
kVHeaddim
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm2MFMA
=
kN0
*
kVHeaddim
*
kM0
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm3MFMA
=
kN0
*
kQKHeaddim
*
kM0
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm4MFMA
=
kM0
*
kQKHeaddim
*
kN0
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
// VMEM
static
constexpr
index_t
Q_VMEM_READ
=
kM0
*
kQKHeaddim
/
kBlockSize
/
GetAlignmentQ
<
Problem
>
();
static
constexpr
index_t
OGrad_VMEM_READ
=
kM0
*
kVHeaddim
/
kBlockSize
/
GetAlignmentOGrad
<
Problem
>
();
static
constexpr
index_t
LSE_VMEM_READ
=
1
;
static
constexpr
index_t
D_VMEM_READ
=
1
;
// LDS Read
static
constexpr
index_t
OGradT_LDS_READ
=
kM0
*
kVHeaddim
/
get_warp_size
()
/
GetTransposedAlignmentOGrad
<
Problem
>
();
static
constexpr
index_t
QT_LDS_READ
=
kM0
*
kQKHeaddim
/
get_warp_size
()
/
GetTransposedAlignmentQ
<
Problem
>
();
static
constexpr
index_t
SGradT_LDS_READ_P1
=
kM0
*
kK4
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
GetSmemKPackSGrad
<
Problem
>
();
static
constexpr
index_t
Q_LDS_READ
=
kM0
*
kQKHeaddim
/
kBlockSize
/
GetAlignmentQ
<
Problem
>
();
static
constexpr
index_t
LSE_LDS_READ
=
WarpGemmM
==
16
?
kM0
/
(
4
*
4
)
:
kM0
/
(
2
*
4
);
static
constexpr
index_t
SGradT_LDS_READ_P2
=
kM0
*
(
kN0
-
kK4
)
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
GetSmemKPackSGrad
<
Problem
>
();
static
constexpr
index_t
OGrad_LDS_READ
=
kM0
*
kVHeaddim
/
kBlockSize
/
GetAlignmentOGrad
<
Problem
>
();
static
constexpr
index_t
D_LDS_READ
=
WarpGemmM
==
16
?
kM0
/
(
4
*
4
)
:
kM0
/
(
2
*
4
);
// LDS Write
static
constexpr
index_t
Q_LDS_WRITE
=
kM0
*
kQKHeaddim
/
Problem
::
kBlockSize
/
GetAlignmentQ
<
Problem
>
();
static
constexpr
index_t
QT_LDS_WRITE
=
kM0
*
kQKHeaddim
/
kBlockSize
/
GetTransposedAlignmentQ
<
Problem
>
();
static
constexpr
index_t
OGrad_LDS_WRITE
=
kM0
*
kVHeaddim
/
kBlockSize
/
GetAlignmentOGrad
<
Problem
>
();
static
constexpr
index_t
OGradT_LDS_WRITE
=
kM0
*
kVHeaddim
/
kBlockSize
/
GetTransposedAlignmentOGrad
<
Problem
>
();
static
constexpr
index_t
LSE_LDS_WRITE
=
1
;
static
constexpr
index_t
D_LDS_WRITE
=
1
;
static
constexpr
index_t
SGradT_LDS_WRITE
=
kM0
*
kN0
/
kBlockSize
;
};
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp
View file @
bd689f40
...
...
@@ -8,9 +8,8 @@ namespace ck_tile {
// This class is used for codegen pattern matching
enum
class
BlockFmhaBwdPipelineEnum
{
KSKTSVR
=
0
,
QSKSVROGradS
,
KSVR
,
KRKTRVR_IGLP
=
0
,
KRKTRVR
,
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp
View file @
bd689f40
...
...
@@ -24,7 +24,9 @@ template <typename QDataType_,
typename
BiasGradDataType_
,
typename
BlockFmhaShape_
,
bool
kIsGroupMode_
,
bool
kIsDeterministic_
,
typename
FmhaMask_
,
typename
FmhaDropout_
,
typename
Traits_
>
struct
BlockFmhaBwdPipelineProblem
{
...
...
@@ -45,10 +47,12 @@ struct BlockFmhaBwdPipelineProblem
using
BiasGradDataType
=
remove_cvref_t
<
BiasGradDataType_
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
BlockFmhaShape_
>
;
using
FmhaMask
=
remove_cvref_t
<
FmhaMask_
>
;
using
FmhaDropout
=
remove_cvref_t
<
FmhaDropout_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
index_t
kBlockSize
=
BlockFmhaShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
bool
kIsDeterministic
=
kIsDeterministic_
;
// attributes from traits
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
...
...
@@ -57,7 +61,6 @@ struct BlockFmhaBwdPipelineProblem
static
constexpr
bool
kPadHeadDimV
=
Traits
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Traits
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
Traits
::
kHasBiasGrad
;
static
constexpr
bool
kHasDropout
=
Traits
::
kHasDropout
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
...
...
@@ -88,4 +91,35 @@ struct BlockFmhaBwdOGradDotOPipelineProblem
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
template
<
typename
AccDataType_
,
typename
QGradDataType_
,
index_t
kBlockSize_
,
index_t
kM0_
,
index_t
kN0_
,
index_t
kQKHeaddim_
,
bool
kIsGroupMode_
,
bool
kIsDeterministic_
,
typename
Traits_
>
struct
BlockFmhaBwdConvertQGradPipelineProblem
{
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
QGradDataType
=
remove_cvref_t
<
QGradDataType_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static_assert
(
0
<
kBlockSize_
&&
kBlockSize_
%
get_warp_size
()
==
0
,
"kBlockSize should be divisible by get_warp_size()"
);
static
constexpr
index_t
kBlockSize
=
kBlockSize_
;
static
constexpr
index_t
kM0
=
kM0_
;
static
constexpr
index_t
kN0
=
kN0_
;
static
constexpr
index_t
kQKHeaddim
=
kQKHeaddim_
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
bool
kIsDeterministic
=
kIsDeterministic_
;
// attributes from traits
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
static
constexpr
bool
kPadHeadDimQ
=
Traits
::
kPadHeadDimQ
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
bd689f40
...
...
@@ -231,7 +231,9 @@ struct BlockFmhaPipelineQRKSVSAsync
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
auto
q
=
decltype
(
load_tile
(
q_dram_window
)){};
set_tile
(
q
,
number
<
0
>
{});
// use per-dword clear to avoid scratch
// TODO: start from rocm-6.2, compiler will have problem if manually set clear of q.
// however, q would be cleared in the constructor of static distributed tensor
// set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
load_tile_raw
(
q
,
q_dram_window
);
__builtin_amdgcn_sched_barrier
(
0
);
...
...
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
View file @
bd689f40
...
...
@@ -86,4 +86,14 @@ struct TileFmhaBwdOGradDotOTraits
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
template
<
bool
kPadSeqLenQ_
/* padding for seqlen_q */
,
bool
kPadHeadDimQ_
/* paddding for hdim_q */
,
index_t
kBlockPerCu_
=
2
/* hint to occupancy */
>
struct
TileFmhaBwdConvertQGradTraits
{
static
constexpr
bool
kPadSeqLenQ
=
kPadSeqLenQ_
;
static
constexpr
bool
kPadHeadDimQ
=
kPadHeadDimQ_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
}
// namespace ck_tile
include/ck_tile/ops/gemm.hpp
View file @
bd689f40
...
...
@@ -5,6 +5,9 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
0 → 100644
View file @
bd689f40
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
namespace
ck_tile
{
// A is block distributed tensor
// B is block distributed tensor
// C is block distributed tensor
template
<
typename
Problem_
,
typename
Policy_
=
BlockGemmARegBRegCRegV1DefaultPolicy
>
struct
BlockGemmARegBRegCRegV1
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
// C += A * B
template
<
typename
CBlockTensor
,
typename
ABlockTensor
,
typename
BBlockTensor
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ABlockTensor
&
a_block_tensor
,
const
BBlockTensor
&
b_block_tensor
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cv_t
<
typename
ABlockTensor
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cv_t
<
typename
BBlockTensor
::
DataType
>>
&&
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
"wrong!"
);
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
// M->N Warp
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
b_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
constexpr
auto
b_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
b_block_outer_dstr_encoding
,
typename
WG
::
BWarpDstrEncoding
{});
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
// check ABC-block-distribution
static_assert
(
std
::
is_same_v
<
remove_cvref_t
<
decltype
(
a_block_dstr_encode
)
>
,
remove_cvref_t
<
decltype
(
ABlockTensor
::
get_tile_distribution
()
.
get_static_tile_distribution_encoding
())
>>
,
"A distribution is wrong!"
);
static_assert
(
std
::
is_same_v
<
remove_cvref_t
<
decltype
(
b_block_dstr_encode
)
>
,
remove_cvref_t
<
decltype
(
BBlockTensor
::
get_tile_distribution
()
.
get_static_tile_distribution_encoding
())
>>
,
"B distribution is wrong!"
);
static_assert
(
std
::
is_same_v
<
remove_cvref_t
<
decltype
(
c_block_dstr_encode
)
>
,
remove_cvref_t
<
decltype
(
CBlockTensor
::
get_tile_distribution
()
.
get_static_tile_distribution_encoding
())
>>
,
"C distribution is wrong!"
);
using
AWarpDstr
=
typename
WG
::
AWarpDstr
;
using
BWarpDstr
=
typename
WG
::
BWarpDstr
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
using
AWarpTensor
=
typename
WG
::
AWarpTensor
;
using
BWarpTensor
=
typename
WG
::
BWarpTensor
;
using
CWarpTensor
=
typename
WG
::
CWarpTensor
;
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
b_warp_y_lengths
=
to_sequence
(
BWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
b_warp_y_index_zeros
=
uniform_sequence_gen_t
<
BWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
// read A warp tensor from A Block window
AWarpTensor
a_warp_tensor
;
a_warp_tensor
.
get_thread_buffer
()
=
a_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
kIter
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
));
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read B warp tensor from B block tensor
BWarpTensor
b_warp_tensor
;
b_warp_tensor
.
get_thread_buffer
()
=
b_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
nIter
,
kIter
>
{},
b_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
b_warp_y_lengths
));
// read C warp tensor from C block tensor
CWarpTensor
c_warp_tensor
;
c_warp_tensor
.
get_thread_buffer
()
=
c_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// warp GEMM
WG
{}(
c_warp_tensor
,
a_warp_tensor
,
b_warp_tensor
);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
c_warp_tensor
.
get_thread_buffer
());
});
});
});
}
CK_TILE_DEVICE
constexpr
auto
MakeCBlockTile
()
const
{
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
// C = A * B
template
<
typename
ABlockTensor
,
typename
BBlockTensor
>
CK_TILE_DEVICE
auto
operator
()(
const
ABlockTensor
&
a_block_tensor
,
const
BBlockTensor
&
b_block_tensor
)
const
{
auto
c_block_tensor
=
MakeCBlockTile
();
operator
()(
c_block_tensor
,
a_block_tensor
,
b_block_tensor
);
return
c_block_tensor
;
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp
0 → 100644
View file @
bd689f40
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
AType_
,
typename
BType_
,
typename
CType_
,
typename
BlockWarps_
,
typename
WarpGemm_
>
struct
BlockGemmARegBRegCRegV1CustomPolicy
{
using
AType
=
remove_cvref_t
<
AType_
>
;
using
BType
=
remove_cvref_t
<
BType_
>
;
using
CType
=
remove_cvref_t
<
CType_
>
;
using
BlockWarps
=
remove_cvref_t
<
BlockWarps_
>
;
static
constexpr
index_t
kMWarps
=
BlockWarps
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kNWarps
=
BlockWarps
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kKWarps
=
BlockWarps
::
at
(
number
<
2
>
{});
using
WarpGemm
=
remove_cvref_t
<
WarpGemm_
>
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemmMWarpNWarp
()
{
return
make_tuple
(
WarpGemm
{},
kMWarps
,
kNWarps
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp
0 → 100644
View file @
bd689f40
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace
ck_tile
{
// Default policy for BlockGemmARegBRegCRegV1
// Default policy class should not be templated, put template on member functions instead
struct
BlockGemmARegBRegCRegV1DefaultPolicy
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemmMWarpNWarp
()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
BDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
CDataType
,
float
>
)
{
return
make_tuple
(
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
{},
4
,
1
);
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
BDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
CDataType
,
float
>
)
{
return
make_tuple
(
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
{},
4
,
1
);
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp
View file @
bd689f40
...
...
@@ -35,16 +35,13 @@ struct BlockGemmARegBSmemCRegV1
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
"wrong!"
);
// constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
// constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
// constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
constexpr
index_t
MPerBlock
=
ABlockTensorTmp
{}.
get_lengths
()[
number
<
0
>
{}];
constexpr
index_t
NPerBlock
=
BBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
KPerBlock
=
ABlockTensorTmp
{}.
get_lengths
()[
number
<
1
>
{}];
//
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
//
KPerBlock == BlockGemmShape::kK,
//
"wrong!");
static_assert
(
MPerBlock
==
BlockGemmShape
::
kM
&&
NPerBlock
==
BlockGemmShape
::
kN
&&
KPerBlock
==
BlockGemmShape
::
kK
,
"wrong!"
);
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
...
...
include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp
View file @
bd689f40
...
...
@@ -35,16 +35,13 @@ struct BlockGemmASmemBRegCRegV1
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
"wrong!"
);
// constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
// constexpr index_t NPerBlock = BBlockTensorTmp{}.get_lengths()[number<0>{}];
// constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
constexpr
index_t
MPerBlock
=
ABlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
NPerBlock
=
BBlockTensorTmp
{}.
get_lengths
()[
number
<
0
>
{}];
constexpr
index_t
KPerBlock
=
ABlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}];
//
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
//
KPerBlock == BlockGemmShape::kK,
//
"wrong!");
static_assert
(
MPerBlock
==
BlockGemmShape
::
kM
&&
NPerBlock
==
BlockGemmShape
::
kN
&&
KPerBlock
==
BlockGemmShape
::
kK
,
"wrong!"
);
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
...
...
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
View file @
bd689f40
...
...
@@ -22,6 +22,9 @@ using WarpGemmMfmaF16F16F32M32N32K16 =
using
WarpGemmMfmaF16F16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
1
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
2
>>
;
...
...
@@ -59,6 +62,9 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 =
using
WarpGemmMfmaBf16Bf16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
1
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
2
>>
;
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
View file @
bd689f40
...
...
@@ -119,9 +119,9 @@ struct WarpGemmAtrributeMfmaIterateK
static_for
<
0
,
kKIter
,
1
>
{}([
&
](
auto
iKIter
)
{
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_a
>
(
a_vec
)
reinterpret_cast
<
const
buf_a
&
>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
>
(
b_vec
)
reinterpret_cast
<
const
buf_b
&
>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
]);
});
}
...
...
@@ -135,15 +135,15 @@ struct WarpGemmAtrributeMfmaIterateK
// c = a * b
auto
c_vec
=
Impl
{}(
reinterpret_cast
<
const
buf_a
>
(
a_vec
).
template
get_as
<
typename
Impl
::
AVecType
>()[
I0
],
reinterpret_cast
<
const
buf_b
>
(
b_vec
).
template
get_as
<
typename
Impl
::
BVecType
>()[
I0
]);
reinterpret_cast
<
const
buf_a
&
>
(
a_vec
).
template
get_as
<
typename
Impl
::
AVecType
>()[
I0
],
reinterpret_cast
<
const
buf_b
&
>
(
b_vec
).
template
get_as
<
typename
Impl
::
BVecType
>()[
I0
]);
// c += a * b
static_for
<
1
,
kKIter
,
1
>
{}([
&
](
auto
iKIter
)
{
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_a
>
(
a_vec
)
reinterpret_cast
<
const
buf_a
&
>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
>
(
b_vec
)
reinterpret_cast
<
const
buf_b
&
>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
]);
});
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
View file @
bd689f40
...
...
@@ -15,7 +15,8 @@ template <typename AType,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
KPerWave
,
bool
TransposeC
>
bool
TransposeC
,
bool
SwizzleA
=
false
>
struct
WarpGemmMfmaDispatcher
;
// clang-format off
...
...
@@ -29,6 +30,9 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
;
};
// bf16
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
;
};
...
...
@@ -39,6 +43,9 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
;
};
// fp8
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
;
};
...
...
@@ -58,8 +65,15 @@ template <typename AType,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
KPerWave
,
bool
TransposeC
>
using
WarpGemmMfmaDispatcher
=
typename
impl
::
WarpGemmMfmaDispatcher
<
AType
,
BType
,
CType
,
MPerWave
,
NPerWave
,
KPerWave
,
TransposeC
>::
Type
;
bool
TransposeC
,
bool
SwizzleA
=
false
>
using
WarpGemmMfmaDispatcher
=
typename
impl
::
WarpGemmMfmaDispatcher
<
AType
,
BType
,
CType
,
MPerWave
,
NPerWave
,
KPerWave
,
TransposeC
,
SwizzleA
>::
Type
;
}
// namespace ck_tile
library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp
View file @
bd689f40
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -39,11 +39,11 @@ struct ReferenceColumnToImage : public device::BaseOperator
public:
Argument
(
const
Tensor
<
InDataType
>&
input
,
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
std
::
vector
<
ck
::
long_
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
)
:
input_
{
input
},
output_
{
output
},
conv_strides_
{
conv_filter_strides
},
...
...
@@ -58,24 +58,25 @@ struct ReferenceColumnToImage : public device::BaseOperator
const
Tensor
<
InDataType
>&
input_
;
Tensor
<
OutDataType
>&
output_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_dilations_
;
std
::
vector
<
index_t
>
in_left_pads_
;
std
::
vector
<
index_t
>
in_right_pads_
;
std
::
vector
<
long_
index_t
>
conv_strides_
;
std
::
vector
<
long_
index_t
>
conv_dilations_
;
std
::
vector
<
long_
index_t
>
in_left_pads_
;
std
::
vector
<
long_
index_t
>
in_right_pads_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
index_t
>
output_spatial_lengths_
;
std
::
vector
<
long_
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
long_
index_t
>
output_spatial_lengths_
;
private:
void
initOutputSpatialLengths
()
{
constexpr
auto
input_offset_to_spatial
=
3
;
for
(
ck
::
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
for
(
ck
::
long_
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const
ck
::
index_t
x_eff
=
(
filter_spatial_lengths_
[
i
]
-
1
)
*
conv_dilations_
[
i
]
+
1
;
const
ck
::
long_index_t
x_eff
=
(
filter_spatial_lengths_
[
i
]
-
1
)
*
conv_dilations_
[
i
]
+
1
;
output_spatial_lengths_
.
push_back
(
(
output_
.
GetLengths
()[
i
+
input_offset_to_spatial
]
+
in_left_pads_
[
i
]
+
...
...
@@ -98,26 +99,26 @@ struct ReferenceColumnToImage : public device::BaseOperator
throw
std
::
runtime_error
(
"wrong! inconsistent dimension"
);
}
const
index_t
G
=
arg
.
output_
.
GetLengths
()[
0
];
const
index_t
N
=
arg
.
output_
.
GetLengths
()[
1
];
const
index_t
C
=
arg
.
output_
.
GetLengths
()[
2
];
const
long_
index_t
G
=
arg
.
output_
.
GetLengths
()[
0
];
const
long_
index_t
N
=
arg
.
output_
.
GetLengths
()[
1
];
const
long_
index_t
C
=
arg
.
output_
.
GetLengths
()[
2
];
if
constexpr
(
NDimSpatial
==
1
)
{
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
0
];
const
long_
index_t
Wo
=
arg
.
output_spatial_lengths_
[
0
];
auto
func
=
[
&
](
auto
g
,
auto
n
)
{
for
(
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
for
(
long_
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
{
index_t
row
=
n
*
Wo
+
wo
;
index_t
column
=
0
;
long_
index_t
row
=
n
*
Wo
+
wo
;
long_
index_t
column
=
0
;
for
(
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
0
];
++
x
)
for
(
long_
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
0
];
++
x
)
{
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
index_t
c
=
0
;
c
<
C
;
++
c
)
for
(
long_
index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
output_
.
GetLengths
()[
3
])
...
...
@@ -140,32 +141,32 @@ struct ReferenceColumnToImage : public device::BaseOperator
}
else
if
constexpr
(
NDimSpatial
==
2
)
{
const
index_t
Ho
=
arg
.
output_spatial_lengths_
[
0
];
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
1
];
const
long_
index_t
Ho
=
arg
.
output_spatial_lengths_
[
0
];
const
long_
index_t
Wo
=
arg
.
output_spatial_lengths_
[
1
];
auto
func
=
[
&
](
auto
g
,
auto
n
)
{
for
(
index_t
ho
=
0
;
ho
<
Ho
;
++
ho
)
for
(
long_
index_t
ho
=
0
;
ho
<
Ho
;
++
ho
)
{
for
(
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
for
(
long_
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
{
index_t
row
=
n
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
index_t
column
=
0
;
long_
index_t
row
=
n
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
long_
index_t
column
=
0
;
for
(
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
0
];
++
y
)
for
(
long_
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
0
];
++
y
)
{
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
1
];
++
x
)
for
(
long_
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
1
];
++
x
)
{
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
1
])
+
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
for
(
index_t
c
=
0
;
c
<
C
;
++
c
)
for
(
long_
index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
hi
>=
0
&&
...
...
@@ -196,27 +197,27 @@ struct ReferenceColumnToImage : public device::BaseOperator
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
const
index_t
Do
=
arg
.
output_spatial_lengths_
[
0
];
const
index_t
Ho
=
arg
.
output_spatial_lengths_
[
1
];
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
2
];
const
long_
index_t
Do
=
arg
.
output_spatial_lengths_
[
0
];
const
long_
index_t
Ho
=
arg
.
output_spatial_lengths_
[
1
];
const
long_
index_t
Wo
=
arg
.
output_spatial_lengths_
[
2
];
auto
func
=
[
&
](
auto
g
,
auto
n
)
{
for
(
index_t
d_o
=
0
;
d_o
<
Do
;
++
d_o
)
for
(
long_
index_t
d_o
=
0
;
d_o
<
Do
;
++
d_o
)
{
for
(
index_t
ho
=
0
;
ho
<
Ho
;
++
ho
)
for
(
long_
index_t
ho
=
0
;
ho
<
Ho
;
++
ho
)
{
for
(
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
for
(
long_
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
{
index_t
row
=
n
*
Do
*
Ho
*
Wo
+
d_o
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
index_t
column
=
0
;
long_
index_t
row
=
n
*
Do
*
Ho
*
Wo
+
d_o
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
long_
index_t
column
=
0
;
for
(
index_t
z
=
0
;
z
<
arg
.
filter_spatial_lengths_
[
0
];
++
z
)
for
(
long_
index_t
z
=
0
;
z
<
arg
.
filter_spatial_lengths_
[
0
];
++
z
)
{
auto
di
=
static_cast
<
ck
::
long_index_t
>
(
d_o
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
1
];
++
y
)
for
(
long_
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
1
];
++
y
)
{
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
...
...
@@ -224,7 +225,8 @@ struct ReferenceColumnToImage : public device::BaseOperator
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
for
(
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
2
];
++
x
)
for
(
long_index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
2
];
++
x
)
{
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
...
...
@@ -232,7 +234,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
2
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
2
]);
for
(
index_t
c
=
0
;
c
<
C
;
++
c
)
for
(
long_
index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
di
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
di
)
<
...
...
@@ -294,15 +296,15 @@ struct ReferenceColumnToImage : public device::BaseOperator
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
const
ck
::
index_t
G
=
arg
.
output_
.
GetLengths
()[
0
];
const
ck
::
index_t
N
=
arg
.
output_
.
GetLengths
()[
1
];
const
ck
::
index_t
C
=
arg
.
output_
.
GetLengths
()[
2
];
const
ck
::
long_
index_t
G
=
arg
.
output_
.
GetLengths
()[
0
];
const
ck
::
long_
index_t
N
=
arg
.
output_
.
GetLengths
()[
1
];
const
ck
::
long_
index_t
C
=
arg
.
output_
.
GetLengths
()[
2
];
const
index_t
NDoHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
const
long_
index_t
NDoHoWo
=
N
*
ck
::
accumulate_n
<
long_
index_t
>
(
arg
.
output_spatial_lengths_
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
index_t
CZYX
=
C
*
ck
::
accumulate_n
<
index_t
>
(
const
long_
index_t
CZYX
=
C
*
ck
::
accumulate_n
<
long_
index_t
>
(
arg
.
filter_spatial_lengths_
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
if
(
!
(
arg
.
input_
.
GetLengths
()[
0
]
==
static_cast
<
std
::
size_t
>
(
G
)
&&
...
...
@@ -326,11 +328,11 @@ struct ReferenceColumnToImage : public device::BaseOperator
static
auto
MakeArgument
(
const
Tensor
<
InDataType
>&
input
,
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
std
::
vector
<
ck
::
long_
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
)
{
return
Argument
{
input
,
output
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
View file @
bd689f40
...
...
@@ -38,10 +38,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
Tensor
<
InDataType
>&
input
,
const
Tensor
<
WeiDataType
>&
weight
,
const
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
@@ -72,10 +72,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
const
std
::
array
<
Tensor
<
WeiDataType
>
,
NumBElementwiseTensor
>&
elementwise_b_tensors_
;
const
std
::
array
<
Tensor
<
OutDataType
>
,
NumDElementwiseTensor
>&
elementwise_d_tensors_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_dilations_
;
std
::
vector
<
index_t
>
in_left_pads_
;
std
::
vector
<
index_t
>
in_right_pads_
;
std
::
vector
<
long_
index_t
>
conv_strides_
;
std
::
vector
<
long_
index_t
>
conv_dilations_
;
std
::
vector
<
long_
index_t
>
in_left_pads_
;
std
::
vector
<
long_
index_t
>
in_right_pads_
;
InElementwiseOperation
in_element_op_
;
WeiElementwiseOperation
wei_element_op_
;
...
...
@@ -447,10 +447,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
Tensor
<
InDataType
>&
input
,
const
Tensor
<
WeiDataType
>&
weight
,
const
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp
View file @
bd689f40
...
...
@@ -40,10 +40,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
const
Tensor
<
InDataType
>&
in_n_c_hi_wi
,
Tensor
<
WeiDataType
>&
wei_k_c_y_x
,
const
Tensor
<
OutDataType
>&
out_n_k_ho_wo
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
@@ -74,10 +74,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
const
std
::
array
<
Tensor
<
InDataType
>
,
NumBElementwiseTensor
>&
elementwise_b_tensors_
;
const
std
::
array
<
Tensor
<
WeiDataType
>
,
NumDElementwiseTensor
>&
elementwise_d_tensors_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_dilations_
;
std
::
vector
<
index_t
>
in_left_pads_
;
std
::
vector
<
index_t
>
in_right_pads_
;
std
::
vector
<
long_
index_t
>
conv_strides_
;
std
::
vector
<
long_
index_t
>
conv_dilations_
;
std
::
vector
<
long_
index_t
>
in_left_pads_
;
std
::
vector
<
long_
index_t
>
in_right_pads_
;
InElementwiseOperation
in_element_op_
;
WeiElementwiseOperation
wei_element_op_
;
...
...
@@ -402,10 +402,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
const
Tensor
<
InDataType
>&
in_n_c_hi_wi
,
Tensor
<
WeiDataType
>&
wei_k_c_y_x
,
const
Tensor
<
OutDataType
>&
out_n_k_ho_wo
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
View file @
bd689f40
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -69,10 +69,10 @@ struct ReferenceConvFwd : public device::BaseOperator
const
Tensor
<
InDataType
>&
input
,
const
Tensor
<
WeiDataType
>&
weight
,
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
@@ -103,10 +103,10 @@ struct ReferenceConvFwd : public device::BaseOperator
const
std
::
array
<
Tensor
<
WeiDataType
>
,
NumBElementwiseTensor
>&
elementwise_b_tensors_
;
const
std
::
array
<
Tensor
<
OutDataType
>
,
NumDElementwiseTensor
>&
elementwise_d_tensors_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_dilations_
;
std
::
vector
<
index_t
>
in_left_pads_
;
std
::
vector
<
index_t
>
in_right_pads_
;
std
::
vector
<
ck
::
long_
index_t
>
conv_strides_
;
std
::
vector
<
ck
::
long_
index_t
>
conv_dilations_
;
std
::
vector
<
ck
::
long_
index_t
>
in_left_pads_
;
std
::
vector
<
ck
::
long_
index_t
>
in_right_pads_
;
InElementwiseOperation
in_element_op_
;
WeiElementwiseOperation
wei_element_op_
;
...
...
@@ -416,10 +416,10 @@ struct ReferenceConvFwd : public device::BaseOperator
const
Tensor
<
InDataType
>&
input
,
const
Tensor
<
WeiDataType
>&
weight
,
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp
View file @
bd689f40
...
...
@@ -40,11 +40,11 @@ struct ReferenceImageToColumn : public device::BaseOperator
public:
Argument
(
const
Tensor
<
InDataType
>&
input
,
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
std
::
vector
<
ck
::
long_
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
)
:
input_
{
input
},
output_
{
output
},
conv_strides_
{
conv_filter_strides
},
...
...
@@ -59,13 +59,13 @@ struct ReferenceImageToColumn : public device::BaseOperator
const
Tensor
<
InDataType
>&
input_
;
Tensor
<
OutDataType
>&
output_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_dilations_
;
std
::
vector
<
index_t
>
in_left_pads_
;
std
::
vector
<
index_t
>
in_right_pads_
;
std
::
vector
<
long_
index_t
>
conv_strides_
;
std
::
vector
<
long_
index_t
>
conv_dilations_
;
std
::
vector
<
long_
index_t
>
in_left_pads_
;
std
::
vector
<
long_
index_t
>
in_right_pads_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
index_t
>
output_spatial_lengths_
;
std
::
vector
<
long_
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
long_
index_t
>
output_spatial_lengths_
;
private:
void
initOutputSpatialLengths
()
...
...
@@ -76,7 +76,8 @@ struct ReferenceImageToColumn : public device::BaseOperator
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const
ck
::
index_t
x_eff
=
(
filter_spatial_lengths_
[
i
]
-
1
)
*
conv_dilations_
[
i
]
+
1
;
const
ck
::
long_index_t
x_eff
=
(
filter_spatial_lengths_
[
i
]
-
1
)
*
conv_dilations_
[
i
]
+
1
;
output_spatial_lengths_
.
push_back
(
(
input_
.
GetLengths
()[
i
+
input_offset_to_spatial
]
+
in_left_pads_
[
i
]
+
...
...
@@ -99,24 +100,24 @@ struct ReferenceImageToColumn : public device::BaseOperator
throw
std
::
runtime_error
(
"wrong! inconsistent dimension"
);
}
const
index_t
G
=
arg
.
input_
.
GetLengths
()[
0
];
const
index_t
N
=
arg
.
input_
.
GetLengths
()[
1
];
const
index_t
C
=
arg
.
input_
.
GetLengths
()[
2
];
const
long_
index_t
G
=
arg
.
input_
.
GetLengths
()[
0
];
const
long_
index_t
N
=
arg
.
input_
.
GetLengths
()[
1
];
const
long_
index_t
C
=
arg
.
input_
.
GetLengths
()[
2
];
if
constexpr
(
NDimSpatial
==
1
)
{
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
0
];
const
long_
index_t
Wo
=
arg
.
output_spatial_lengths_
[
0
];
auto
func
=
[
&
](
auto
g
,
auto
n
,
auto
wo
)
{
index_t
row
=
n
*
Wo
+
wo
;
index_t
column
=
0
;
long_
index_t
row
=
n
*
Wo
+
wo
;
long_
index_t
column
=
0
;
for
(
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
0
];
++
x
)
for
(
long_
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
0
];
++
x
)
{
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
index_t
c
=
0
;
c
<
C
;
++
c
)
for
(
long_
index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
GetLengths
()[
3
])
...
...
@@ -135,26 +136,26 @@ struct ReferenceImageToColumn : public device::BaseOperator
}
else
if
constexpr
(
NDimSpatial
==
2
)
{
const
index_t
Ho
=
arg
.
output_spatial_lengths_
[
0
];
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
1
];
const
long_
index_t
Ho
=
arg
.
output_spatial_lengths_
[
0
];
const
long_
index_t
Wo
=
arg
.
output_spatial_lengths_
[
1
];
auto
func
=
[
&
](
auto
g
,
auto
n
,
auto
ho
,
auto
wo
)
{
index_t
row
=
n
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
index_t
column
=
0
;
long_
index_t
row
=
n
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
long_
index_t
column
=
0
;
for
(
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
0
];
++
y
)
for
(
long_
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
0
];
++
y
)
{
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
1
];
++
x
)
for
(
long_
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
1
];
++
x
)
{
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
1
])
+
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
for
(
index_t
c
=
0
;
c
<
C
;
++
c
)
for
(
long_
index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
hi
>=
0
&&
...
...
@@ -178,31 +179,31 @@ struct ReferenceImageToColumn : public device::BaseOperator
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
const
index_t
Do
=
arg
.
output_spatial_lengths_
[
0
];
const
index_t
Ho
=
arg
.
output_spatial_lengths_
[
1
];
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
2
];
const
long_
index_t
Do
=
arg
.
output_spatial_lengths_
[
0
];
const
long_
index_t
Ho
=
arg
.
output_spatial_lengths_
[
1
];
const
long_
index_t
Wo
=
arg
.
output_spatial_lengths_
[
2
];
auto
func
=
[
&
](
auto
g
,
auto
n
,
auto
d_o
,
auto
ho
,
auto
wo
)
{
index_t
row
=
n
*
Do
*
Ho
*
Wo
+
d_o
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
index_t
column
=
0
;
long_
index_t
row
=
n
*
Do
*
Ho
*
Wo
+
d_o
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
long_
index_t
column
=
0
;
for
(
index_t
z
=
0
;
z
<
arg
.
filter_spatial_lengths_
[
0
];
++
z
)
for
(
long_
index_t
z
=
0
;
z
<
arg
.
filter_spatial_lengths_
[
0
];
++
z
)
{
auto
di
=
static_cast
<
ck
::
long_index_t
>
(
d_o
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
1
];
++
y
)
for
(
long_
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
1
];
++
y
)
{
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
1
])
+
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
for
(
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
2
];
++
x
)
for
(
long_
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
2
];
++
x
)
{
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
2
])
+
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
2
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
2
]);
for
(
index_t
c
=
0
;
c
<
C
;
++
c
)
for
(
long_
index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
di
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
di
)
<
...
...
@@ -259,15 +260,15 @@ struct ReferenceImageToColumn : public device::BaseOperator
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
const
ck
::
index_t
G
=
arg
.
input_
.
GetLengths
()[
0
];
const
ck
::
index_t
N
=
arg
.
input_
.
GetLengths
()[
1
];
const
ck
::
index_t
C
=
arg
.
input_
.
GetLengths
()[
2
];
const
ck
::
long_
index_t
G
=
arg
.
input_
.
GetLengths
()[
0
];
const
ck
::
long_
index_t
N
=
arg
.
input_
.
GetLengths
()[
1
];
const
ck
::
long_
index_t
C
=
arg
.
input_
.
GetLengths
()[
2
];
const
index_t
NDoHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
const
long_
index_t
NDoHoWo
=
N
*
ck
::
accumulate_n
<
long_
index_t
>
(
arg
.
output_spatial_lengths_
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
index_t
CZYX
=
C
*
ck
::
accumulate_n
<
index_t
>
(
const
long_
index_t
CZYX
=
C
*
ck
::
accumulate_n
<
long_
index_t
>
(
arg
.
filter_spatial_lengths_
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
if
(
!
(
arg
.
output_
.
GetLengths
()[
0
]
==
static_cast
<
std
::
size_t
>
(
G
)
&&
...
...
@@ -291,11 +292,11 @@ struct ReferenceImageToColumn : public device::BaseOperator
static
auto
MakeArgument
(
const
Tensor
<
InDataType
>&
input
,
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
std
::
vector
<
ck
::
long_
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
)
{
return
Argument
{
input
,
output
,
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
bd689f40
...
...
@@ -108,6 +108,7 @@ using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using
MultiplyFastGelu
=
ck
::
tensor_operation
::
element_wise
::
MultiplyFastGelu
;
using
AddMultiply
=
ck
::
tensor_operation
::
element_wise
::
AddMultiply
;
using
MultiplyAdd
=
ck
::
tensor_operation
::
element_wise
::
MultiplyAdd
;
using
MultiplyMultiply
=
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
;
using
ScaleAdd
=
ck
::
tensor_operation
::
element_wise
::
ScaleAdd
;
using
Gelu
=
ck
::
tensor_operation
::
element_wise
::
Gelu
;
using
Swish
=
ck
::
tensor_operation
::
element_wise
::
Swish
;
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment