Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
20ddaeba
Commit
20ddaeba
authored
Apr 22, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
c5f1cdf7
43879b89
Changes
236
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4512 additions
and
0 deletions
+4512
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
...ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
+54
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
+596
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+694
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
...ine/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
+19
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp
.../pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp
+19
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
...le/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
+506
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
+587
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp
.../pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp
+19
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+956
-0
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
+46
-0
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
+30
-0
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+31
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp
...ile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp
+25
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp
.../ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp
+135
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp
...mm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp
+110
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp
...ile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp
+26
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp
.../ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp
+340
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp
...emm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp
+36
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp
...mm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp
+56
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp
.../ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp
+227
-0
No files found.
Too many changes to show.
To preserve performance only
236 of 236+
files are displayed.
Plain diff
Email patch
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
QDataType_
,
typename
KDataType_
,
typename
VDataType_
,
typename
SaccDataType_
,
typename
SMPLComputeDataType_
,
typename
BiasDataType_
,
typename
LSEDataType_
,
typename
PDataType_
,
typename
OaccDataType_
,
typename
ODataType_
,
typename
BlockFmhaShape_
,
bool
kIsGroupMode_
,
typename
FmhaMask_
,
typename
Traits_
>
struct
BlockFmhaPipelineProblem
{
using
QDataType
=
remove_cvref_t
<
QDataType_
>
;
using
KDataType
=
remove_cvref_t
<
KDataType_
>
;
using
VDataType
=
remove_cvref_t
<
VDataType_
>
;
using
SaccDataType
=
remove_cvref_t
<
SaccDataType_
>
;
using
SMPLComputeDataType
=
remove_cvref_t
<
SMPLComputeDataType_
>
;
using
BiasDataType
=
remove_cvref_t
<
BiasDataType_
>
;
using
LSEDataType
=
remove_cvref_t
<
LSEDataType_
>
;
using
PDataType
=
remove_cvref_t
<
PDataType_
>
;
using
OaccDataType
=
remove_cvref_t
<
OaccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
BlockFmhaShape_
>
;
using
FmhaMask
=
remove_cvref_t
<
FmhaMask_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
index_t
kBlockSize
=
BlockFmhaShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
// attributes from traits
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Traits
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Traits
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Traits
::
kPadHeadDimV
;
static
constexpr
bool
kHasBias
=
Traits
::
kHasBias
;
static
constexpr
bool
kStoreLSE
=
Traits
::
kStoreLSE
;
static
constexpr
bool
kDoFp8StaticQuant
=
Traits
::
kDoFp8StaticQuant
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaPipelineQRKSVSDefaultPolicy
>
struct
BlockFmhaPipelineQRKSVS
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
SaccDataType
=
remove_cvref_t
<
typename
Problem
::
SaccDataType
>
;
using
SMPLComputeDataType
=
remove_cvref_t
<
typename
Problem
::
SMPLComputeDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
static
constexpr
bool
kQLoadOnce
=
true
;
// if q_tile load whole block length (hdim) at once
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK0BlockLength
=
BlockFmhaShape
::
kK0BlockLength
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
bool
kHasBias
=
Problem
::
kHasBias
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static
constexpr
index_t
kAlignmentQ
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
return
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
else
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
return
Problem
::
kBlockPerCu
;
else
{
if
constexpr
(
kK0BlockLength
<=
32
)
{
return
2
;
}
else
if
constexpr
(
kK0BlockLength
<=
64
)
{
return
3
;
}
else
if
constexpr
(
kK0BlockLength
<=
128
)
{
if
constexpr
(
kHasBias
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
kK0BlockLength
<=
256
)
{
return
1
;
}
}
}();
static
constexpr
const
char
*
name
=
"qr"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
QElementFunction
,
typename
KElementFunction
,
typename
VElementFunction
,
typename
BiasElementFunction
,
typename
LSEElementFunction
,
typename
SAccElementFunction
,
typename
PComputeElementFunction
,
typename
OAccElementFunction
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
KElementFunction
&
k_element_func
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
VElementFunction
&
v_element_func
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasElementFunction
&
bias_element_func
,
LSEDramBlockWindowTmp
&
lse_dram_window_tmp
,
// M0*1 tile
const
LSEElementFunction
&
lse_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
const
PComputeElementFunction
&
p_compute_element_func
,
const
OAccElementFunction
&
o_acc_element_func
,
FmhaMask
mask
,
float
scale_s
,
void
*
smem_ptr
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
VDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kN1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
// K tile in LDS
KDataType
*
k_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()));
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>());
auto
k_lds_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
// V tile in LDS
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
VDataType
*>
(
smem_ptr
),
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>());
auto
v_lds_window
=
make_tile_window
(
v_lds
,
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
q
=
load_tile
(
q_dram_window
);
using
SaccBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
auto
s_acc
=
SaccBlockTileType
{};
// reduction function for softmax
const
auto
f_max
=
[](
auto
e0
,
auto
e1
)
{
return
max
(
e0
,
e1
);
};
const
auto
f_sum
=
[](
auto
e0
,
auto
e1
)
{
return
e0
+
e1
;
};
// infer Sacc, S, P, M, L, Oacc type
using
SBlockTileType
=
decltype
(
cast_tile
<
SMPLComputeDataType
>
(
s_acc
));
using
MLBlockTileType
=
decltype
(
block_tile_reduce
<
SMPLComputeDataType
>
(
SBlockTileType
{},
sequence
<
1
>
{},
f_max
,
SMPLComputeDataType
{
0
}));
using
OaccBlockTileType
=
decltype
(
gemm_1
.
MakeCBlockTile
());
// init Oacc, M, L
auto
o_acc
=
OaccBlockTileType
{};
auto
m
=
MLBlockTileType
{};
auto
l
=
MLBlockTileType
{};
clear_tile
(
o_acc
);
set_tile
(
m
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
clear_tile
(
l
);
const
auto
q_origin
=
q_dram_window
.
get_window_origin
();
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
// check early exit if masked and no work to do.
if
constexpr
(
FmhaMask
::
IsMasking
)
{
if
(
num_total_loop
<=
0
)
{
if
constexpr
(
kStoreLSE
)
{
auto
lse
=
make_static_distributed_tensor
<
LSEDataType
>
(
m
.
get_tile_distribution
());
set_tile
(
lse
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
store_tile
(
lse_dram_window_tmp
,
tile_elementwise_in
(
lse_element_func
,
lse
));
}
// Note: here occ are all cleard, return it
// Note: q loaded but no fence, ignore it.
return
o_acc
;
}
}
auto
k_dram_block_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_k_start
,
0
});
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_k_start
},
// TODO: hdim split?
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
auto
q_tile
=
tile_elementwise_in
(
q_element_func
,
q
);
// prefetch K tile
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kK0BlockLength
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
2
<=
k0_loops
);
static_assert
(
1
<=
k1_loops
);
do
{
// STAGE 1, QK gemm
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window
.
get_bottom_tensor_view
(),
k_dram_block_window
.
get_window_lengths
(),
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
auto
k_block_tile
=
load_tile
(
k_dram_window
);
{
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
clear_tile
(
s_acc
);
// initialize C
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
k_block_tile
=
load_tile
(
k_dram_window
);
}
if
constexpr
(
kHasBias
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
if
constexpr
(
kHasBias
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
if
constexpr
(
k0_loops
>
2
)
{
static_for
<
0
,
k0_loops
-
2
,
1
>
{}([
&
](
auto
i_k0
)
{
block_sync_lds
();
gemm_0
(
s_acc
,
get_slice_tile
(
q_tile
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
k_lds_window
);
block_sync_lds
();
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
// LDS write i + 1
k_block_tile
=
load_tile
(
k_dram_window
);
// global read i + 2
});
}
const
auto
v_prefetch
=
load_tile
(
v_dram_window
);
// prefetch load v tile
{
// tail
block_sync_lds
();
gemm_0
(
s_acc
,
get_slice_tile
(
q_tile
,
sequence
<
0
,
(
k0_loops
-
2
)
*
kK0
>
{},
sequence
<
kM0
,
(
k0_loops
-
1
)
*
kK0
>
{}),
k_lds_window
);
block_sync_lds
();
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
block_sync_lds
();
gemm_0
(
s_acc
,
get_slice_tile
(
q_tile
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kM0
,
k0_loops
*
kK0
>
{}),
k_lds_window
);
}
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
kHasBias
)
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x
+=
type_convert
<
SaccDataType
>
(
bias_element_func
(
y
));
#else
x
+=
log2e_v
<
SaccDataType
>
*
type_convert
<
SaccDataType
>
(
bias_element_func
(
y
));
#endif
},
s_acc
,
bias_tile
);
}
else
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
#endif
}
move_tile_window
(
bias_dram_window
,
{
0
,
kN0
});
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
set_tile_if
(
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
}
const
auto
s
=
cast_tile
<
SMPLComputeDataType
>
(
s_acc
);
// S{j}
auto
m_local
=
block_tile_reduce
<
SMPLComputeDataType
>
(
s
,
sequence
<
1
>
{},
f_max
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
// m_local = rowmax(S{j})
block_tile_reduce_sync
(
m_local
,
f_max
,
bool_constant
<
false
>
{});
const
auto
m_old
=
m
;
// m{j-1}
tile_elementwise_inout
(
[](
auto
&
e0
,
auto
e1
,
auto
e2
)
{
e0
=
max
(
e1
,
e2
);
},
m
,
m_old
,
m_local
);
// m{j}
auto
p_compute
=
make_static_distributed_tensor
<
SMPLComputeDataType
>
(
s
.
get_tile_distribution
());
// Pcompute{j}
static
const
auto
get_validated_m
=
[](
SMPLComputeDataType
raw_m
)
{
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
if
constexpr
(
kHasBias
||
FmhaMask
::
IsMasking
)
{
return
raw_m
==
-
numeric
<
SMPLComputeDataType
>::
infinity
()
?
type_convert
<
SMPLComputeDataType
>
(
0.
f
)
:
raw_m
;
}
else
{
return
raw_m
;
}
};
constexpr
auto
p_spans
=
decltype
(
p_compute
)
::
get_distributed_spans
();
sweep_tile_span
(
p_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto
row_max
=
scale_s
*
get_validated_m
(
m
[
i_idx
]);
#endif
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
kHasBias
)
{
p_compute
(
i_j_idx
)
=
exp2
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
else
{
p_compute
(
i_j_idx
)
=
exp2
(
scale_s
*
s
[
i_j_idx
]
-
row_max
);
}
#else
p_compute
(
i_j_idx
)
=
exp
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
#endif
});
});
auto
rowsum_p
=
block_tile_reduce
<
SMPLComputeDataType
>
(
p_compute
,
sequence
<
1
>
{},
f_sum
,
SMPLComputeDataType
{
0
});
// rowsum(Pcompute{j})
block_tile_reduce_sync
(
rowsum_p
,
f_sum
,
bool_constant
<
false
>
{});
// l{j}, Oacc{j}
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
kHasBias
)
{
return
exp2
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
else
{
auto
row_max
=
scale_s
*
get_validated_m
(
m
[
i_idx
]);
return
exp2
(
scale_s
*
m_old
[
i_idx
]
-
row_max
);
}
}();
#else
const
auto
tmp
=
exp
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
#endif
l
(
i_idx
)
=
tmp
*
l
[
i_idx
]
+
rowsum_p
[
i_idx
];
sweep_tile_span
(
o_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
block_sync_lds
();
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v_prefetch
);
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
// store the prefetch
}
else
{
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v_prefetch
));
// store the prefetch
}
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
const
auto
p
=
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
// STAGE 3, KV gemm
if
constexpr
(
k1_loops
>
1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
const
auto
v
=
load_tile
(
v_dram_window
);
// load next v
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
i_k1
*
kK1
>
{},
sequence
<
kM0
,
(
i_k1
+
1
)
*
kK1
>
{}),
v_lds_window
);
block_sync_lds
();
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v
);
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
// store the prefetch
}
else
{
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v
));
// store next v
}
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
});
}
// move K tile windows
move_tile_window
(
k_dram_block_window
,
{
kN0
,
0
});
// tail
{
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
(
k1_loops
-
1
)
*
kK1
>
{},
sequence
<
kM0
,
kN0
>
{}),
v_lds_window
);
block_sync_lds
();
}
}
while
(
++
i_total_loops
<
num_total_loop
);
// store lse
if
constexpr
(
kStoreLSE
)
{
auto
lse
=
make_static_distributed_tensor
<
LSEDataType
>
(
m
.
get_tile_distribution
());
constexpr
auto
lse_spans
=
decltype
(
lse
)
::
get_distributed_spans
();
sweep_tile_span
(
lse_spans
[
number
<
0
>
{}],
[
&
,
m_
=
m
,
l_
=
l
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
kHasBias
)
{
lse
(
i_idx
)
=
m_
[
i_idx
]
/
C_LOG2E
+
log
(
l_
[
i_idx
]);
}
else
{
lse
(
i_idx
)
=
m_
[
i_idx
]
*
scale_s
/
C_LOG2E
+
log
(
l_
[
i_idx
]);
}
#else
lse
(
i_idx
)
=
m_
[
i_idx
]
+
log
(
l_
[
i_idx
]);
#endif
});
store_tile
(
lse_dram_window_tmp
,
tile_elementwise_in
(
lse_element_func
,
lse
));
}
// finally, O
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
FmhaMask
::
IsMasking
)
{
return
l
[
i_idx
]
==
0.
f
?
0.
f
:
1
/
l
[
i_idx
];
}
else
return
1
/
l
[
i_idx
];
}();
sweep_tile_span
(
o_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
o_acc
=
tile_elementwise_in
(
o_acc_element_func
,
o_acc
);
return
o_acc
;
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
// M0*1 tile
FmhaMask
mask
,
float
scale_s
,
void
*
smem_ptr
)
const
{
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
k_dram_block_window_tmp
,
identity
{},
v_dram_block_window_tmp
,
identity
{},
bias_dram_block_window_tmp
,
identity
{},
lse_dram_block_window_tmp
,
identity
{},
identity
{},
identity
{},
identity
{},
mask
,
scale_s
,
smem_ptr
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
>
struct
BlockFmhaPipelineQRKSVSAsync
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
SaccDataType
=
remove_cvref_t
<
typename
Problem
::
SaccDataType
>
;
using
SMPLComputeDataType
=
remove_cvref_t
<
typename
Problem
::
SMPLComputeDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
static
constexpr
bool
kQLoadOnce
=
true
;
// if q_tile load whole block length (hdim) at once
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK0BlockLength
=
BlockFmhaShape
::
kK0BlockLength
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
static_assert
(
Problem
::
kPadSeqLenQ
==
true
&&
Problem
::
kPadHeadDimQ
==
true
&&
Problem
::
kPadHeadDimV
==
true
);
static
constexpr
bool
kPadSeqLenQ
=
true
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
true
;
// support multiple of vector(like 8x)
static
constexpr
bool
kPadHeadDimV
=
true
;
// support multiple of vector(like 8x)
static
constexpr
bool
kHasBias
=
Problem
::
kHasBias
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static
constexpr
index_t
kAlignmentQ
=
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
return
Policy
::
template
GetAlignmentV
<
Problem
>();
else
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
static
constexpr
index_t
kAlignmentO
=
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
#if CK_TILE_FMHA_FWD_FAST_EXP2
static
constexpr
auto
R_LOG2E
=
1.0
/
log2e_v
<
SaccDataType
>
;
#endif
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
return
Problem
::
kBlockPerCu
;
else
{
if
constexpr
(
kK0BlockLength
<=
32
)
{
if
constexpr
(
kPadSeqLenK
&&
kHasBias
&&
FmhaMask
::
IsMasking
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
kK0BlockLength
<=
64
)
{
if
constexpr
(
kPadSeqLenK
&&
kHasBias
)
return
2
;
else
return
3
;
}
else
if
constexpr
(
kK0BlockLength
<=
128
)
{
if
constexpr
(
kPadSeqLenK
&&
kHasBias
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
kK0BlockLength
<=
256
)
{
return
1
;
}
}
}();
static
constexpr
const
char
*
name
=
"qr_async"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
QElementFunction
,
typename
KElementFunction
,
typename
VElementFunction
,
typename
BiasElementFunction
,
typename
LSEElementFunction
,
typename
SAccElementFunction
,
typename
PComputeElementFunction
,
typename
OAccElementFunction
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
KElementFunction
&
/*k_element_func*/
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
VElementFunction
&
v_element_func
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasElementFunction
&
bias_element_func
,
LSEDramBlockWindowTmp
&
lse_dram_window_tmp
,
// M0*1 tile
const
LSEElementFunction
&
lse_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
const
PComputeElementFunction
&
p_compute_element_func
,
const
OAccElementFunction
&
o_acc_element_func
,
FmhaMask
mask
,
float
scale_s
,
void
*
smem_ptr
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
VDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kN1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
constexpr
auto
LdsSeq
=
Policy
::
template
GetLdsBufferSequence
<
Problem
>();
// K tile in LDS
auto
k_lds_ptr
=
reinterpret_cast
<
KDataType
*>
(
smem_ptr
);
auto
k_lds_store
=
generate_tuple
(
[
&
](
auto
i_buf
)
{
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsStoreBlockDescriptor
<
Problem
>(
i_buf
)),
Policy
::
template
MakeKLdsStoreBlockDescriptor
<
Problem
>(
i_buf
).
get_lengths
(),
{
0
,
0
,
0
});
},
number
<
Policy
::
NumPrefetchK
>
{});
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
auto
k_lds_load
=
generate_tuple
(
[
&
](
auto
i_buf
)
{
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>(
i_buf
)),
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>(
i_buf
).
get_lengths
(),
{
0
,
0
});
},
number
<
Policy
::
NumPrefetchK
>
{});
#else
auto
k_lds_Load_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>());
auto
k_lds_load
=
make_tile_window
(
k_lds_Load_view
,
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
#endif
// V tile in LDS
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
VDataType
*>
(
smem_ptr
),
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>());
auto
v_lds_window
=
make_tile_window
(
v_lds
,
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
auto
q
=
decltype
(
load_tile
(
q_dram_window
)){};
set_tile
(
q
,
number
<
0
>
{});
// use per-dword clear to avoid scratch
load_tile_raw
(
q
,
q_dram_window
);
__builtin_amdgcn_sched_barrier
(
0
);
using
SaccBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
auto
s_acc
=
SaccBlockTileType
{};
// reduction function for softmax
const
auto
f_max
=
[](
auto
e0
,
auto
e1
)
{
return
max
(
e0
,
e1
);
};
const
auto
f_sum
=
[](
auto
e0
,
auto
e1
)
{
return
e0
+
e1
;
};
// infer Sacc, S, P, M, L, Oacc type
using
SBlockTileType
=
decltype
(
cast_tile
<
SMPLComputeDataType
>
(
s_acc
));
using
MLBlockTileType
=
decltype
(
block_tile_reduce
<
SMPLComputeDataType
>
(
SBlockTileType
{},
sequence
<
1
>
{},
f_max
,
SMPLComputeDataType
{
0
}));
using
OaccBlockTileType
=
decltype
(
gemm_1
.
MakeCBlockTile
());
// init Oacc, M, L
auto
o_acc
=
OaccBlockTileType
{};
auto
m
=
MLBlockTileType
{};
auto
l
=
MLBlockTileType
{};
clear_tile
(
o_acc
);
set_tile
(
m
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
clear_tile
(
l
);
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
q_origin
=
q_dram_window
.
get_window_origin
();
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
// check early exit if masked and no work to do.
if
constexpr
(
FmhaMask
::
IsMasking
)
{
if
(
num_total_loop
<=
0
)
{
if
constexpr
(
kStoreLSE
)
{
auto
lse
=
make_static_distributed_tensor
<
LSEDataType
>
(
m
.
get_tile_distribution
());
set_tile
(
lse
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
store_tile
(
lse_dram_window_tmp
,
tile_elementwise_in
(
lse_element_func
,
lse
));
}
buffer_load_fence
(
0
);
// rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it
return
o_acc
;
}
__builtin_amdgcn_sched_barrier
(
0
);
// make sure sched_barrier(0) for this check
}
auto
k_dram_block_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_k_start
,
0
});
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window
.
get_bottom_tensor_view
(),
k_dram_block_window
.
get_window_lengths
(),
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_k_start
},
// TODO: hdim split?
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
// prefetch K tile
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_fence
(
k_dram_window
.
get_num_access
(),
q
.
get_thread_buffer
());
(
void
)
q_element_func
;
// ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kK0BlockLength
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
1
<=
k0_loops
);
static_assert
(
1
<=
k1_loops
);
// main loop
do
{
// STAGE 1, QK gemm
clear_tile
(
s_acc
);
// initialize C
if
constexpr
(
k0_loops
>
1
)
{
static_for
<
0
,
k0_loops
-
1
,
1
>
{}([
&
](
auto
i_k0
)
{
async_load_tile_raw
(
k_lds_store
(
number
<
LdsSeq
.
at
(
number
<
i_k0
+
1
>
{})
>
{}),
k_dram_window
);
if
constexpr
(
i_k0
<
k0_loops
-
1
)
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
async_load_fence
(
k_dram_window
.
get_num_access
());
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
gemm_0
(
s_acc
,
get_slice_tile
(
q
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
k_lds_load
[
number
<
LdsSeq
.
at
(
number
<
i_k0
>
{})
>
{}]);
#else
get_slice_tile
(
k_lds_load
,
sequence
<
(
LdsSeq
.
at
(
number
<
i_k0
>
{}))
*
kN0
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
i_k0
>
{})
+
1
)
*
kN0
,
kK0
>
{}));
#endif
});
}
// TODO: this to fix a bug when loop smaller than 2,
// the following fence/barrier will be scheduled inside 1st loop
if
constexpr
(
k0_loops
<=
2
)
__builtin_amdgcn_sched_barrier
(
0
);
async_load_fence
();
__builtin_amdgcn_s_barrier
();
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
auto
v_buf
=
load_tile
(
v_dram_window
,
bool_constant
<
false
>
{});
__builtin_amdgcn_sched_barrier
(
0
);
{
// tail
gemm_0
(
s_acc
,
get_slice_tile
(
q
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kM0
,
k0_loops
*
kK0
>
{}),
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
k_lds_load
[
number
<
LdsSeq
.
at
(
number
<
k0_loops
-
1
>
{})
>
{}]);
#else
get_slice_tile
(
k_lds_load
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
-
1
>
{}))
*
kN0
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
-
1
>
{})
+
1
)
*
kN0
,
kK0
>
{}));
#endif
}
__builtin_amdgcn_sched_barrier
(
1
);
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
kHasBias
)
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x
+=
type_convert
<
SaccDataType
>
(
bias_element_func
(
y
));
#else
x
+=
log2e_v
<
SaccDataType
>
*
type_convert
<
SaccDataType
>
(
bias_element_func
(
y
));
#endif
},
s_acc
,
bias_tile
);
}
else
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
#endif
}
move_tile_window
(
bias_dram_window
,
{
0
,
kN0
});
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
set_tile_if
(
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
}
const
auto
s
=
cast_tile
<
SMPLComputeDataType
>
(
s_acc
);
// S{j}
auto
m_local
=
block_tile_reduce
<
SMPLComputeDataType
>
(
s
,
sequence
<
1
>
{},
f_max
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
// m_local = rowmax(S{j})
block_tile_reduce_sync
(
m_local
,
f_max
,
bool_constant
<
false
>
{});
const
auto
m_old
=
m
;
// m{j-1}
tile_elementwise_inout
(
[](
auto
&
e0
,
auto
e1
,
auto
e2
)
{
e0
=
max
(
e1
,
e2
);
},
m
,
m_old
,
m_local
);
// m{j}
auto
p_compute
=
make_static_distributed_tensor
<
SMPLComputeDataType
>
(
s
.
get_tile_distribution
());
// Pcompute{j}
__builtin_amdgcn_sched_barrier
(
0x7F
);
// store & prefetch next v, after the max reduction
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v_buf
);
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
>
{})
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
// store the prefetch
}
else
{
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
>
{})
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_buf
));
// store the prefetch
}
if
constexpr
(
k1_loops
>
1
)
{
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
// will have scratch if move this right after load_tile(v_dram)...
v_buf
=
load_tile
(
v_dram_window
,
bool_constant
<
false
>
{});
// load next v_buf
}
__builtin_amdgcn_sched_barrier
(
0
);
static
const
auto
get_validated_m
=
[](
SMPLComputeDataType
raw_m
)
{
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
if
constexpr
(
kHasBias
||
FmhaMask
::
IsMasking
)
{
return
raw_m
==
-
numeric
<
SMPLComputeDataType
>::
infinity
()
?
type_convert
<
SMPLComputeDataType
>
(
0.
f
)
:
raw_m
;
}
else
{
return
raw_m
;
}
};
constexpr
auto
p_spans
=
decltype
(
p_compute
)
::
get_distributed_spans
();
sweep_tile_span
(
p_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto
row_max
=
scale_s
*
get_validated_m
(
m
[
i_idx
]);
#endif
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
kHasBias
)
{
p_compute
(
i_j_idx
)
=
exp2
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
else
{
p_compute
(
i_j_idx
)
=
exp2
(
scale_s
*
s
[
i_j_idx
]
-
row_max
);
}
#else
p_compute
(
i_j_idx
)
=
exp
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
#endif
});
});
auto
rowsum_p
=
block_tile_reduce
<
SMPLComputeDataType
>
(
p_compute
,
sequence
<
1
>
{},
f_sum
,
SMPLComputeDataType
{
0
});
// rowsum(Pcompute{j})
block_tile_reduce_sync
(
rowsum_p
,
f_sum
,
bool_constant
<
false
>
{});
// l{j}, Oacc{j}
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
kHasBias
)
{
return
exp2
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
else
{
auto
row_max
=
scale_s
*
get_validated_m
(
m
[
i_idx
]);
return
exp2
(
scale_s
*
m_old
[
i_idx
]
-
row_max
);
}
}();
#else
const
auto
tmp
=
exp
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
#endif
l
(
i_idx
)
=
tmp
*
l
[
i_idx
]
+
rowsum_p
[
i_idx
];
sweep_tile_span
(
o_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
const
auto
p
=
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
// STAGE 3, KV gemm
if
constexpr
(
k1_loops
>
1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
if
constexpr
(
i_k1
!=
0
&&
i_k1
<
k1_loops
-
1
)
{
v_buf
=
load_tile
(
v_dram_window
,
bool_constant
<
false
>
{});
// load next v_buf
}
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
i_k1
*
kK1
>
{},
sequence
<
kM0
,
(
i_k1
+
1
)
*
kK1
>
{}),
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
>
{})
+
1
)
*
kN1
,
kK1
>
{}));
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v_buf
);
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
+
1
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
+
1
>
{})
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
// store the prefetch
}
else
{
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
+
1
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
+
1
>
{})
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_buf
));
// store next v_buf
}
if
constexpr
(
i_k1
<
k1_loops
-
1
)
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
});
}
i_total_loops
++
;
if
(
i_total_loops
<
num_total_loop
)
{
// move K tile windows
move_tile_window
(
k_dram_block_window
,
{
kN0
,
0
});
k_dram_window
=
make_tile_window
(
k_dram_block_window
.
get_bottom_tensor_view
(),
k_dram_block_window
.
get_window_lengths
(),
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
if
constexpr
(
k1_loops
>=
2
&&
LdsSeq
.
at
(
number
<
0
>
{})
==
LdsSeq
.
at
(
number
<
k0_loops
+
k1_loops
-
2
>
{}))
__builtin_amdgcn_s_barrier
();
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
}
// tail
{
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
(
k1_loops
-
1
)
*
kK1
>
{},
sequence
<
kM0
,
kN0
>
{}),
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
k1_loops
-
1
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
k1_loops
-
1
>
{})
+
1
)
*
kN1
,
kK1
>
{}));
}
}
while
(
i_total_loops
<
num_total_loop
);
// store lse
if
constexpr
(
kStoreLSE
)
{
auto
lse
=
make_static_distributed_tensor
<
LSEDataType
>
(
m
.
get_tile_distribution
());
constexpr
auto
lse_spans
=
decltype
(
lse
)
::
get_distributed_spans
();
sweep_tile_span
(
lse_spans
[
number
<
0
>
{}],
[
&
,
m_
=
m
,
l_
=
l
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
kHasBias
)
{
lse
(
i_idx
)
=
m_
[
i_idx
]
*
R_LOG2E
+
log
(
l_
[
i_idx
]);
}
else
{
lse
(
i_idx
)
=
m_
[
i_idx
]
*
scale_s
*
R_LOG2E
+
log
(
l_
[
i_idx
]);
}
#else
lse
(
i_idx
)
=
m_
[
i_idx
]
+
log
(
l_
[
i_idx
]);
#endif
});
store_tile
(
lse_dram_window_tmp
,
tile_elementwise_in
(
lse_element_func
,
lse
));
}
// finally, O
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
FmhaMask
::
IsMasking
)
{
return
l
[
i_idx
]
==
0.
f
?
0.
f
:
1
/
l
[
i_idx
];
}
else
return
1
/
l
[
i_idx
];
}();
sweep_tile_span
(
o_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
o_acc
=
tile_elementwise_in
(
o_acc_element_func
,
o_acc
);
return
o_acc
;
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
// M0*1 tile
FmhaMask
mask
,
float
scale_s
,
void
*
smem_ptr
)
const
{
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
k_dram_block_window_tmp
,
identity
{},
v_dram_block_window_tmp
,
identity
{},
bias_dram_block_window_tmp
,
identity
{},
lse_dram_block_window_tmp
,
identity
{},
identity
{},
identity
{},
identity
{},
mask
,
scale_s
,
smem_ptr
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
using
BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
=
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopyK = */
true
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
3
,
/* NumPrefetchV = */
3
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
using
BlockFmhaPipelineQRKSVSDefaultPolicy
=
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopyK = */
false
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaPipelineQRKSVSDefaultPolicy
>
struct
[[
deprecated
]]
BlockFmhaPipelineQRKSVSFp8
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
SaccDataType
=
remove_cvref_t
<
typename
Problem
::
SaccDataType
>
;
using
SMPLComputeDataType
=
remove_cvref_t
<
typename
Problem
::
SMPLComputeDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
static
constexpr
bool
kQLoadOnce
=
true
;
// if q_tile load whole block length (hdim) at once
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK0BlockLength
=
BlockFmhaShape
::
kK0BlockLength
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
bool
kHasBias
=
Problem
::
kHasBias
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static
constexpr
index_t
kAlignmentQ
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
return
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
else
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
return
Problem
::
kBlockPerCu
;
else
{
if
constexpr
(
kK0BlockLength
<=
32
)
{
return
2
;
}
else
if
constexpr
(
kK0BlockLength
<=
64
)
{
return
3
;
}
else
if
constexpr
(
kK0BlockLength
<=
128
)
{
if
constexpr
(
kHasBias
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
kK0BlockLength
<=
256
)
{
return
1
;
}
}
}();
static
constexpr
const
char
*
name
=
"qr_fp8"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
LSEDramBlockWindowTmp
&
/*lse_dram_window_tmp*/
,
// not supported
FmhaMask
mask
,
float
scale_s
,
float
descale_qk
,
float
descale_sv
,
void
*
smem_ptr
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
VDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kN1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
// K tile in LDS
KDataType
*
k_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()));
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>());
auto
k_lds_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
// V tile in LDS
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
VDataType
*>
(
smem_ptr
),
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>());
auto
v_lds_window
=
make_tile_window
(
v_lds
,
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
q
=
load_tile
(
q_dram_window
);
using
SaccBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
auto
s_acc
=
SaccBlockTileType
{};
// reduction function for softmax
const
auto
f_max
=
[](
auto
e0
,
auto
e1
)
{
return
max
(
e0
,
e1
);
};
const
auto
f_sum
=
[](
auto
e0
,
auto
e1
)
{
return
e0
+
e1
;
};
// infer Sacc, S, P, M, L, Oacc type
using
SBlockTileType
=
decltype
(
cast_tile
<
SMPLComputeDataType
>
(
s_acc
));
using
MLBlockTileType
=
decltype
(
block_tile_reduce
<
SMPLComputeDataType
>
(
SBlockTileType
{},
sequence
<
1
>
{},
f_max
,
SMPLComputeDataType
{
0
}));
using
OaccBlockTileType
=
decltype
(
gemm_1
.
MakeCBlockTile
());
// init Oacc, M, L
auto
o_acc
=
OaccBlockTileType
{};
auto
m
=
MLBlockTileType
{};
auto
l
=
MLBlockTileType
{};
clear_tile
(
o_acc
);
set_tile
(
m
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
clear_tile
(
l
);
const
auto
q_origin
=
q_dram_window
.
get_window_origin
();
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
// check early exit if masked and no work to do.
if
constexpr
(
FmhaMask
::
IsMasking
)
{
if
(
num_total_loop
<=
0
)
{
// Note: here occ are all cleard, return it
// Note: q loaded but no fence, ignore it.
return
o_acc
;
}
}
auto
k_dram_block_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_k_start
,
0
});
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_k_start
},
// TODO: hdim split?
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
// auto q_tile = tile_elementwise_in(q_element_func, q);
auto
q_tile
=
q
;
// prefetch K tile
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kK0BlockLength
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
2
<=
k0_loops
);
static_assert
(
1
<=
k1_loops
);
scale_s
=
scale_s
*
descale_qk
;
do
{
// STAGE 1, QK gemm
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window
.
get_bottom_tensor_view
(),
k_dram_block_window
.
get_window_lengths
(),
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
auto
k_block_tile
=
load_tile
(
k_dram_window
);
{
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
clear_tile
(
s_acc
);
// initialize C
store_tile
(
k_lds_window
,
k_block_tile
);
k_block_tile
=
load_tile
(
k_dram_window
);
}
if
constexpr
(
kHasBias
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
if
constexpr
(
kHasBias
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
if
constexpr
(
k0_loops
>
2
)
{
static_for
<
0
,
k0_loops
-
2
,
1
>
{}([
&
](
auto
i_k0
)
{
block_sync_lds
();
gemm_0
(
s_acc
,
get_slice_tile
(
q_tile
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
k_lds_window
);
block_sync_lds
();
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
store_tile
(
k_lds_window
,
k_block_tile
);
// LDS write i + 1
k_block_tile
=
load_tile
(
k_dram_window
);
// global read i + 2
});
}
const
auto
v_prefetch
=
load_tile
(
v_dram_window
);
// prefetch load v tile
{
// tail
block_sync_lds
();
gemm_0
(
s_acc
,
get_slice_tile
(
q_tile
,
sequence
<
0
,
(
k0_loops
-
2
)
*
kK0
>
{},
sequence
<
kM0
,
(
k0_loops
-
1
)
*
kK0
>
{}),
k_lds_window
);
block_sync_lds
();
store_tile
(
k_lds_window
,
k_block_tile
);
block_sync_lds
();
gemm_0
(
s_acc
,
get_slice_tile
(
q_tile
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kM0
,
k0_loops
*
kK0
>
{}),
k_lds_window
);
}
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
kHasBias
)
{
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x
=
scale_s
*
x
+
type_convert
<
SaccDataType
>
((
y
));
#else
x
=
scale_s
*
x
+
log2e_v
<
SaccDataType
>
*
type_convert
<
SaccDataType
>
((
y
));
#endif
},
s_acc
,
bias_tile
);
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
#endif
}
move_tile_window
(
bias_dram_window
,
{
0
,
kN0
});
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
set_tile_if
(
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
}
const
auto
s
=
cast_tile
<
SMPLComputeDataType
>
(
s_acc
);
// S{j}
auto
m_local
=
block_tile_reduce
<
SMPLComputeDataType
>
(
s
,
sequence
<
1
>
{},
f_max
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
// m_local = rowmax(S{j})
block_tile_reduce_sync
(
m_local
,
f_max
,
bool_constant
<
false
>
{});
const
auto
m_old
=
m
;
// m{j-1}
tile_elementwise_inout
(
[](
auto
&
e0
,
auto
e1
,
auto
e2
)
{
e0
=
max
(
e1
,
e2
);
},
m
,
m_old
,
m_local
);
// m{j}
auto
p_compute
=
make_static_distributed_tensor
<
SMPLComputeDataType
>
(
s
.
get_tile_distribution
());
// Pcompute{j}
static
const
auto
get_validated_m
=
[](
SMPLComputeDataType
raw_m
)
{
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
if
constexpr
(
kHasBias
||
FmhaMask
::
IsMasking
)
{
return
raw_m
==
-
numeric
<
SMPLComputeDataType
>::
infinity
()
?
type_convert
<
SMPLComputeDataType
>
(
0.
f
)
:
raw_m
;
}
else
{
return
raw_m
;
}
};
constexpr
auto
p_spans
=
decltype
(
p_compute
)
::
get_distributed_spans
();
sweep_tile_span
(
p_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto
row_max
=
scale_s
*
get_validated_m
(
m
[
i_idx
]);
#endif
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
kHasBias
)
{
p_compute
(
i_j_idx
)
=
exp2
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
else
{
p_compute
(
i_j_idx
)
=
exp2
(
scale_s
*
s
[
i_j_idx
]
-
row_max
);
}
#else
p_compute
(
i_j_idx
)
=
exp
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
#endif
});
});
auto
rowsum_p
=
block_tile_reduce
<
SMPLComputeDataType
>
(
p_compute
,
sequence
<
1
>
{},
f_sum
,
SMPLComputeDataType
{
0
});
// rowsum(Pcompute{j})
block_tile_reduce_sync
(
rowsum_p
,
f_sum
,
bool_constant
<
false
>
{});
// l{j}, Oacc{j}
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
kHasBias
)
{
return
exp2
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
else
{
auto
row_max
=
scale_s
*
get_validated_m
(
m
[
i_idx
]);
return
exp2
(
scale_s
*
m_old
[
i_idx
]
-
row_max
);
}
}();
#else
const
auto
tmp
=
exp
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
#endif
l
(
i_idx
)
=
tmp
*
l
[
i_idx
]
+
rowsum_p
[
i_idx
];
sweep_tile_span
(
o_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
block_sync_lds
();
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v_prefetch
);
store_tile
(
v_lds_window
,
v_shuffle_tmp
);
// store the prefetch
}
else
{
store_tile
(
v_lds_window
,
v_prefetch
);
// store the prefetch
}
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
const
auto
p
=
cast_tile
<
PDataType
>
(
p_compute
);
// STAGE 3, KV gemm
if
constexpr
(
k1_loops
>
1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
const
auto
v
=
load_tile
(
v_dram_window
);
// load next v
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
i_k1
*
kK1
>
{},
sequence
<
kM0
,
(
i_k1
+
1
)
*
kK1
>
{}),
v_lds_window
);
block_sync_lds
();
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v
);
store_tile
(
v_lds_window
,
v_shuffle_tmp
);
}
else
{
store_tile
(
v_lds_window
,
v
);
}
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
});
}
// move K tile windows
move_tile_window
(
k_dram_block_window
,
{
kN0
,
0
});
// tail
{
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
(
k1_loops
-
1
)
*
kK1
>
{},
sequence
<
kM0
,
kN0
>
{}),
v_lds_window
);
block_sync_lds
();
}
}
while
(
++
i_total_loops
<
num_total_loop
);
// finally, O
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
auto
tmp
=
[
&
]()
{
if
constexpr
(
FmhaMask
::
IsMasking
)
{
return
l
[
i_idx
]
==
0.
f
?
0.
f
:
1
/
l
[
i_idx
];
}
else
return
1
/
l
[
i_idx
];
}();
tmp
=
tmp
*
descale_sv
;
sweep_tile_span
(
o_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
return
o_acc
;
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaPipelineQSKSVSDefaultPolicy
>
struct
BlockFmhaPipelineQSKSVS
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
SaccDataType
=
remove_cvref_t
<
typename
Problem
::
SaccDataType
>
;
using
SMPLComputeDataType
=
remove_cvref_t
<
typename
Problem
::
SMPLComputeDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
static
constexpr
bool
kQLoadOnce
=
false
;
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK0BlockLength
=
BlockFmhaShape
::
kK0BlockLength
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
bool
kHasBias
=
Problem
::
kHasBias
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
return
Problem
::
kBlockPerCu
;
else
{
if
constexpr
(
kK0BlockLength
<=
32
)
{
return
2
;
}
else
if
constexpr
(
kK0BlockLength
<=
64
)
{
return
3
;
}
else
if
constexpr
(
kK0BlockLength
<=
128
)
{
if
constexpr
(
kHasBias
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
kK0BlockLength
<=
256
)
{
return
1
;
}
}
}();
static
constexpr
const
char
*
name
=
"qs"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeQ
()
{
return
Policy
::
template
GetSmemSizeQ
<
Problem
>();
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
QElementFunction
,
typename
KElementFunction
,
typename
VElementFunction
,
typename
BiasElementFunction
,
typename
LSEElementFunction
,
typename
SAccElementFunction
,
typename
PComputeElementFunction
,
typename
OAccElementFunction
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
KElementFunction
&
k_element_func
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
VElementFunction
&
v_element_func
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasElementFunction
&
bias_element_func
,
LSEDramBlockWindowTmp
&
lse_dram_window_tmp
,
// M0*1 tile
const
LSEElementFunction
&
lse_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
const
PComputeElementFunction
&
p_compute_element_func
,
const
OAccElementFunction
&
o_acc_element_func
,
FmhaMask
mask
,
float
scale_s
,
void
*
smem_ptr
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
VDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kN1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
// Q tile in LDS
auto
q_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
QDataType
*>
(
smem_ptr
),
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
// K tile in LDS
KDataType
*
k_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()));
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>());
auto
k_lds_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
// V tile in LDS
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
VDataType
*>
(
smem_ptr
),
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>());
auto
v_lds_window
=
make_tile_window
(
v_lds
,
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
using
SaccBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
auto
s_acc
=
SaccBlockTileType
{};
// reduction function for softmax
const
auto
f_max
=
[](
auto
e0
,
auto
e1
)
{
return
max
(
e0
,
e1
);
};
const
auto
f_sum
=
[](
auto
e0
,
auto
e1
)
{
return
e0
+
e1
;
};
// infer Sacc, S, P, M, L, Oacc type
using
SBlockTileType
=
decltype
(
cast_tile
<
SMPLComputeDataType
>
(
s_acc
));
using
MLBlockTileType
=
decltype
(
block_tile_reduce
<
SMPLComputeDataType
>
(
SBlockTileType
{},
sequence
<
1
>
{},
f_max
,
SMPLComputeDataType
{
0
}));
using
OaccBlockTileType
=
decltype
(
gemm_1
.
MakeCBlockTile
());
// init Oacc, M, L
auto
o_acc
=
OaccBlockTileType
{};
auto
m
=
MLBlockTileType
{};
auto
l
=
MLBlockTileType
{};
clear_tile
(
o_acc
);
set_tile
(
m
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
clear_tile
(
l
);
const
auto
q_origin
=
q_dram_block_window_tmp
.
get_window_origin
();
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
// check early exit if masked and no work to do.
if
constexpr
(
FmhaMask
::
IsMasking
)
{
if
(
num_total_loop
<=
0
)
{
if
constexpr
(
kStoreLSE
)
{
auto
lse
=
make_static_distributed_tensor
<
LSEDataType
>
(
m
.
get_tile_distribution
());
set_tile
(
lse
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
store_tile
(
lse_dram_window_tmp
,
tile_elementwise_in
(
lse_element_func
,
lse
));
}
// Note: here occ are all cleard, return it
// Note: q loaded but no fence, ignore it.
return
o_acc
;
}
}
auto
k_dram_block_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_k_start
,
0
});
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_k_start
},
// TODO: hdim split?
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
// prefetch K tile
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kK0BlockLength
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
2
<=
k0_loops
);
static_assert
(
1
<=
k1_loops
);
do
{
// STAGE 1, QK gemm
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window
.
get_bottom_tensor_view
(),
k_dram_block_window
.
get_window_lengths
(),
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
auto
q_block_tile
=
load_tile
(
q_dram_window
);
auto
k_block_tile
=
load_tile
(
k_dram_window
);
{
move_tile_window
(
q_dram_window
,
{
0
,
kK0
});
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
clear_tile
(
s_acc
);
// initialize C
store_tile
(
q_lds_window
,
tile_elementwise_in
(
q_element_func
,
q_block_tile
));
q_block_tile
=
load_tile
(
q_dram_window
);
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
k_block_tile
=
load_tile
(
k_dram_window
);
}
if
constexpr
(
kHasBias
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
if
constexpr
(
kHasBias
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
if
constexpr
(
k0_loops
>
2
)
{
static_for
<
0
,
k0_loops
-
2
,
1
>
{}([
&
](
auto
)
{
block_sync_lds
();
gemm_0
(
s_acc
,
q_lds_window
,
k_lds_window
);
block_sync_lds
();
move_tile_window
(
q_dram_window
,
{
0
,
kK0
});
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
store_tile
(
q_lds_window
,
tile_elementwise_in
(
q_element_func
,
q_block_tile
));
// LDS write i + 1
q_block_tile
=
load_tile
(
q_dram_window
);
// global read i + 2
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
// LDS write i + 1
k_block_tile
=
load_tile
(
k_dram_window
);
// global read i + 2
});
}
const
auto
v_prefetch
=
load_tile
(
v_dram_window
);
// prefetch load v tile
{
// tail
block_sync_lds
();
gemm_0
(
s_acc
,
q_lds_window
,
k_lds_window
);
block_sync_lds
();
store_tile
(
q_lds_window
,
tile_elementwise_in
(
q_element_func
,
q_block_tile
));
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
block_sync_lds
();
gemm_0
(
s_acc
,
q_lds_window
,
k_lds_window
);
}
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
kHasBias
)
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x
+=
type_convert
<
SaccDataType
>
(
bias_element_func
(
y
));
#else
x
+=
log2e_v
<
SaccDataType
>
*
type_convert
<
SaccDataType
>
(
bias_element_func
(
y
));
#endif
},
s_acc
,
bias_tile
);
}
else
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
#endif
}
move_tile_window
(
bias_dram_window
,
{
0
,
kN0
});
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
set_tile_if
(
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
}
const
auto
s
=
cast_tile
<
SMPLComputeDataType
>
(
s_acc
);
// S{j}
auto
m_local
=
block_tile_reduce
<
SMPLComputeDataType
>
(
s
,
sequence
<
1
>
{},
f_max
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
// m_local = rowmax(S{j})
block_tile_reduce_sync
(
m_local
,
f_max
,
bool_constant
<
false
>
{});
const
auto
m_old
=
m
;
// m{j-1}
tile_elementwise_inout
(
[](
auto
&
e0
,
auto
e1
,
auto
e2
)
{
e0
=
max
(
e1
,
e2
);
},
m
,
m_old
,
m_local
);
// m{j}
auto
p_compute
=
make_static_distributed_tensor
<
SMPLComputeDataType
>
(
s
.
get_tile_distribution
());
// Pcompute{j}
static
const
auto
get_validated_m
=
[](
SMPLComputeDataType
raw_m
)
{
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
if
constexpr
(
kHasBias
||
FmhaMask
::
IsMasking
)
{
return
raw_m
==
-
numeric
<
SMPLComputeDataType
>::
infinity
()
?
type_convert
<
SMPLComputeDataType
>
(
0.
f
)
:
raw_m
;
}
else
{
return
raw_m
;
}
};
constexpr
auto
p_spans
=
decltype
(
p_compute
)
::
get_distributed_spans
();
sweep_tile_span
(
p_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto
row_max
=
scale_s
*
get_validated_m
(
m
[
i_idx
]);
#endif
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
kHasBias
)
{
p_compute
(
i_j_idx
)
=
exp2
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
else
{
p_compute
(
i_j_idx
)
=
exp2
(
scale_s
*
s
[
i_j_idx
]
-
row_max
);
}
#else
p_compute
(
i_j_idx
)
=
exp
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
#endif
});
});
auto
rowsum_p
=
block_tile_reduce
<
SMPLComputeDataType
>
(
p_compute
,
sequence
<
1
>
{},
f_sum
,
SMPLComputeDataType
{
0
});
// rowsum(Pcompute{j})
block_tile_reduce_sync
(
rowsum_p
,
f_sum
,
bool_constant
<
false
>
{});
// l{j}, Oacc{j}
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
kHasBias
)
{
return
exp2
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
else
{
auto
row_max
=
scale_s
*
get_validated_m
(
m
[
i_idx
]);
return
exp2
(
scale_s
*
m_old
[
i_idx
]
-
row_max
);
}
}();
#else
const
auto
tmp
=
exp
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
#endif
l
(
i_idx
)
=
tmp
*
l
[
i_idx
]
+
rowsum_p
[
i_idx
];
sweep_tile_span
(
o_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
block_sync_lds
();
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v_prefetch
);
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
// store the prefetch
}
else
{
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v_prefetch
));
// store the prefetch
}
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
const
auto
p
=
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
// STAGE 3, KV gemm
if
constexpr
(
k1_loops
>
1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
const
auto
v
=
load_tile
(
v_dram_window
);
// load next v
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
i_k1
*
kK1
>
{},
sequence
<
kM0
,
(
i_k1
+
1
)
*
kK1
>
{}),
v_lds_window
);
block_sync_lds
();
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v
);
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
// store the prefetch
}
else
{
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v
));
// store next v
}
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
});
}
// move K tile windows
move_tile_window
(
k_dram_block_window
,
{
kN0
,
0
});
// tail
{
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
(
k1_loops
-
1
)
*
kK1
>
{},
sequence
<
kM0
,
kN0
>
{}),
v_lds_window
);
block_sync_lds
();
}
}
while
(
++
i_total_loops
<
num_total_loop
);
// store lse
if
constexpr
(
kStoreLSE
)
{
auto
lse
=
make_static_distributed_tensor
<
LSEDataType
>
(
m
.
get_tile_distribution
());
constexpr
auto
lse_spans
=
decltype
(
lse
)
::
get_distributed_spans
();
sweep_tile_span
(
lse_spans
[
number
<
0
>
{}],
[
&
,
m_
=
m
,
l_
=
l
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
kHasBias
)
{
lse
(
i_idx
)
=
m_
[
i_idx
]
/
C_LOG2E
+
log
(
l_
[
i_idx
]);
}
else
{
lse
(
i_idx
)
=
m_
[
i_idx
]
*
scale_s
/
C_LOG2E
+
log
(
l_
[
i_idx
]);
}
#else
lse
(
i_idx
)
=
m_
[
i_idx
]
+
log
(
l_
[
i_idx
]);
#endif
});
store_tile
(
lse_dram_window_tmp
,
tile_elementwise_in
(
lse_element_func
,
lse
));
}
// finally, O
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
FmhaMask
::
IsMasking
)
{
return
l
[
i_idx
]
==
0.
f
?
0.
f
:
1
/
l
[
i_idx
];
}
else
return
1
/
l
[
i_idx
];
}();
sweep_tile_span
(
o_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
o_acc
=
tile_elementwise_in
(
o_acc_element_func
,
o_acc
);
return
o_acc
;
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
// M0*1 tile
FmhaMask
mask
,
float
scale_s
,
void
*
smem_ptr
)
const
{
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
k_dram_block_window_tmp
,
identity
{},
v_dram_block_window_tmp
,
identity
{},
bias_dram_block_window_tmp
,
identity
{},
lse_dram_block_window_tmp
,
identity
{},
identity
{},
identity
{},
identity
{},
mask
,
scale_s
,
smem_ptr
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
using
BlockFmhaPipelineQSKSVSDefaultPolicy
=
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
false
,
/* AsyncCopyK = */
false
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
// TODO: remove this
#define K_LDS_LOAD_USE_OFFSET_TRANSFORM 0
namespace
ck_tile
{
template
<
bool
QLoadOnce_
>
struct
BlockFmhaPipelineQXCustomPolicy
;
template
<
>
struct
BlockFmhaPipelineQXCustomPolicy
<
/* QLoadOnce = */
true
>
{
static
constexpr
bool
QLoadOnce
=
true
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeQ
()
{
return
0
;
}
// TODO: GetAlignment*() currently didn't consider if need padding or not
// so in pipeline still need check padding requirement
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentQ
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetQKBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
return
WG
::
kK
/
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
}
template
<
typename
Problem
,
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQDramTileDistribution
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0BlockLength
;
constexpr
index_t
K2
=
WG
::
kK
/
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
K1
=
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
K0
=
kKPerBlock
/
(
K1
*
K2
);
constexpr
index_t
M2
=
WG
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
M1
=
MWarp
;
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
,
K2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
0
,
0
,
2
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>>
;
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
return
WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
return
WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
// TODO: hard coded here. Otherwise, it may incorrect result
constexpr
index_t
swizzle_factor
=
4
;
return
WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution
<
swizzle_factor
>
{};
}
// TODO - bf8_t
}();
using
BlockGemmPolicy
=
BlockGemmARegBSmemCRegV2CustomPolicy
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
decltype
(
warp_gemm
)
>
;
return
BlockGemmARegBSmemCRegV2
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
};
template
<
>
struct
BlockFmhaPipelineQXCustomPolicy
<
/* QLoadOnce = */
false
>
{
static
constexpr
bool
QLoadOnce
=
false
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeQ
()
{
constexpr
index_t
lds_alignment
=
16
;
// optional
constexpr
index_t
q_smem_size
=
ck_tile
::
integer_divide_ceil
(
sizeof
(
typename
Problem
::
QDataType
)
*
MakeQLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
(),
lds_alignment
)
*
lds_alignment
;
return
q_smem_size
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentQ
()
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
return
16
/
sizeof
(
QDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQDramTileDistribution
()
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
K1
=
16
/
sizeof
(
QDataType
);
// use dwordx4. TODO: change this
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
// 3d + padding
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQLdsBlockDescriptor
()
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
kKPack
=
16
/
sizeof
(
QDataType
);
constexpr
auto
q_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kMPerBlock
>
{},
number
<
kKPack
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
kKPack
>
{},
number
<
kKPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
auto
q_lds_block_desc
=
transform_tensor_descriptor
(
q_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
kMPerBlock
),
make_merge_transform
(
make_tuple
(
kKPerBlock
/
kKPack
,
kKPack
))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
q_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>>
;
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
return
WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
return
WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
// TODO: hard coded here. Otherwise, it may incorrect result
constexpr
index_t
swizzle_factor
=
4
;
return
WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution
<
swizzle_factor
>
{};
}
// TODO - bf8_t
}();
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
decltype
(
warp_gemm
)
>
;
return
BlockGemmASmemBSmemCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
};
// This pipeline is qkv all located in LDS
template
<
bool
QLoadOnce_
,
bool
AsyncCopyK_
,
bool
AsyncCopyV_
,
index_t
NumPrefetchK_
,
index_t
NumPrefetchV_
>
struct
BlockFmhaPipelineQXKSVSCustomPolicy
:
BlockFmhaPipelineQXCustomPolicy
<
QLoadOnce_
>
{
static
constexpr
bool
AsyncCopyK
=
AsyncCopyK_
;
static
constexpr
bool
AsyncCopyV
=
AsyncCopyV_
;
// TODO: this not supported yet
static
constexpr
index_t
NumPrefetchK
=
NumPrefetchK_
;
static
constexpr
index_t
NumPrefetchV
=
NumPrefetchK_
;
using
QXPolicy
=
BlockFmhaPipelineQXCustomPolicy
<
QLoadOnce_
>
;
template
<
index_t
k_prefetches_
,
index_t
v_prefetches_
,
index_t
k_loops_
,
index_t
v_loops_
>
struct
LdsBufferSequence
{
static
constexpr
auto
Make
()
{
return
transform_sequences
(
[
&
](
auto
i
)
{
if
(
i
<
k_loops_
)
return
i
%
k_prefetches_
;
return
(
i
-
k_loops_
)
%
v_prefetches_
;
},
typename
arithmetic_sequence_gen
<
0
,
k_loops_
+
v_loops_
,
1
>::
type
{});
};
using
type
=
remove_cvref_t
<
decltype
(
Make
())
>
;
};
// clang-format off
template
<
>
struct
LdsBufferSequence
<
3
,
3
,
4
,
4
>
{
using
type
=
sequence
<
1
,
2
,
0
,
1
,
0
,
1
,
2
,
0
>
;
};
template
<
>
struct
LdsBufferSequence
<
3
,
3
,
4
,
2
>
{
using
type
=
sequence
<
1
,
2
,
0
,
1
,
2
,
0
>
;
};
template
<
>
struct
LdsBufferSequence
<
3
,
3
,
2
,
4
>
{
using
type
=
sequence
<
1
,
2
,
0
,
1
,
2
,
0
>
;
};
template
<
>
struct
LdsBufferSequence
<
3
,
3
,
3
,
3
>
{
using
type
=
sequence
<
1
,
2
,
0
,
1
,
2
,
0
>
;
};
template
<
>
struct
LdsBufferSequence
<
3
,
3
,
2
,
2
>
{
using
type
=
sequence
<
1
,
2
,
1
,
0
>
;};
// clang-format on
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetLdsBufferSequence
()
{
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
constexpr
index_t
kK0BlockLength
=
BlockFmhaShape
::
kK0BlockLength
;
constexpr
index_t
k0_loops
=
kK0BlockLength
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
return
typename
LdsBufferSequence
<
NumPrefetchK
,
NumPrefetchV
,
k0_loops
,
k1_loops
>::
type
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackK
()
{
// TODO: this is for 3d layout
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
return
16
/
sizeof
(
KDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentK
()
{
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
if
constexpr
(
AsyncCopyK
)
{
return
4
/
sizeof
(
KDataType
);
}
else
{
return
16
/
sizeof
(
KDataType
);
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackV
()
{
// TODO: this is for 3d layout
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
return
16
/
sizeof
(
VDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentV
()
{
using
VLayout
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
::
VLayout
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
// TODO: not correct!
if
constexpr
(
total_pixels
>
4
)
return
4
;
else
return
2
;
}
else
{
return
16
/
sizeof
(
VDataType
);
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentBias
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
QXPolicy
::
template
GetQKBlockGemm
<
Problem
>())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
constexpr
auto
vec
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
CWarpDstr
::
NDimY
-
1
>
{});
return
vec
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentO
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetKVBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
constexpr
auto
vec
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
CWarpDstr
::
NDimY
-
1
>
{});
return
vec
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSingleSmemElementSpaceSize
()
{
// this function assume K/V can share smem
constexpr
index_t
SingleKSize
=
[
&
]()
{
if
constexpr
(
!
AsyncCopyK
)
{
return
MakeKLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}
else
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
NumWarps
=
Problem
::
BlockFmhaShape
::
NumWarps
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
KPack
=
GetSmemKPackK
<
Problem
>
();
// this is for lds
constexpr
index_t
KVector
=
GetAlignmentK
<
Problem
>
();
// this is for global load
constexpr
index_t
kPad
=
KPack
;
static_assert
(
warpSize
*
KVector
>=
kKPerBlock
&&
warpSize
*
KVector
%
kKPerBlock
==
0
);
constexpr
index_t
LanesPerK
=
kKPerBlock
/
KVector
;
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
constexpr
index_t
NumIssues
=
kNPerBlock
/
(
LaneGroups
*
NumWarps
);
return
NumIssues
*
NumWarps
*
(
warpSize
*
KVector
+
kPad
);
}
}();
constexpr
index_t
SingleVSize
=
[
&
]()
{
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
VDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
static_assert
(
PixelsPerRow
%
kKPack
==
0
);
constexpr
index_t
NPerRow
=
PixelsPerRow
/
kKPack
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
static_assert
(
kNPerBlock
%
NPerRow
==
0
);
static_assert
(
kKPerBlock
%
kKPack
==
0
);
return
(
kKPerBlock
/
kKPack
)
*
(
kNPerBlock
/
NPerRow
)
*
(
PixelsPerRow
+
kKPack
);
}();
return
max
(
SingleKSize
,
SingleVSize
);
}
template
<
typename
Problem
,
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQRegBlockDescriptor
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0BlockLength
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WG
::
kK
;
constexpr
auto
q_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
q_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
q_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
constexpr
auto
q_block_dstr
=
make_static_tile_distribution
(
q_block_dstr_encode
);
return
q_block_dstr
;
}
// TODO: this is used for non async copy desc. unify in the future
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsBlockDescriptor
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
constexpr
auto
k_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kNPerBlock
>
{},
number
<
kKPack
>
{}),
make_tuple
(
number
<
(
kNPerBlock
+
1
)
*
kKPack
>
{},
number
<
kKPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
auto
k_lds_block_desc
=
transform_tensor_descriptor
(
k_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
kNPerBlock
>
{}),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kKPack
>
{}))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
k_lds_block_desc
;
}
template
<
typename
Problem
,
index_t
IBuf
=
0
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsStoreBlockDescriptor
(
number
<
IBuf
>
=
number
<
0
>
{})
{
// K is always k-major, we use async-copy to load into LDS
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NumWarps
=
Problem
::
BlockFmhaShape
::
NumWarps
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
KPack
=
GetSmemKPackK
<
Problem
>
();
// this is for lds
constexpr
index_t
KVector
=
GetAlignmentK
<
Problem
>
();
// this is for global load
constexpr
index_t
kPad
=
KPack
;
// for async-copy, this pad is between warps. Optimize this for lds_read speed
static_assert
(
warpSize
*
KVector
>=
kKPerBlock
&&
warpSize
*
KVector
%
kKPerBlock
==
0
);
constexpr
index_t
LanesPerK
=
kKPerBlock
/
KVector
;
// how many lane (within a wave) to load K
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// how many groups (within a wave), they may load different N, but same K
constexpr
index_t
NumIssues
=
kNPerBlock
/
(
LaneGroups
*
NumWarps
);
static_assert
(
NumIssues
==
kNPerBlock
*
kKPerBlock
/
(
kBlockSize
*
KVector
));
constexpr
auto
k_lds_block_desc_0
=
make_naive_tensor_descriptor_with_offset
(
make_tuple
(
number
<
NumIssues
>
{},
// n0
number
<
LaneGroups
>
{},
// n1
number
<
NumWarps
>
{},
// n2
number
<
LanesPerK
>
{},
// k0
number
<
KVector
>
{}),
// k1
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
kPad
)
>
{},
number
<
kKPerBlock
>
{},
number
<
warpSize
*
KVector
+
kPad
>
{},
number
<
KVector
>
{},
number
<
1
>
{}),
number
<
IBuf
*
GetSingleSmemElementSpaceSize
<
Problem
>
()
>
{},
number
<
KVector
>
{},
number
<
1
>
{});
// TODO this layout is hard coded, and will be used in async copy buffer view load
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
constexpr
auto
k_lds_block_desc_issues_warps_lanes
=
transform_tensor_descriptor
(
k_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
NumIssues
>
{}),
make_pass_through_transform
(
number
<
NumWarps
>
{}),
make_merge_transform
(
make_tuple
(
number
<
LaneGroups
>
{},
number
<
LanesPerK
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
>
{},
sequence
<
2
>
{},
sequence
<
1
,
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}));
return
k_lds_block_desc_issues_warps_lanes
;
}
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
template
<
typename
Problem
,
index_t
IBuf
=
0
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsLoadBlockDescriptor
(
number
<
IBuf
>
=
number
<
0
>
{})
{
// K is always k-major, we use async-copy to load into LDS
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NumWarps
=
Problem
::
BlockFmhaShape
::
NumWarps
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
KPack
=
GetSmemKPackK
<
Problem
>
();
// this is for lds
constexpr
index_t
KVector
=
GetAlignmentK
<
Problem
>
();
// this is for global load
constexpr
index_t
kPad
=
KPack
;
// for async-copy, this pad is between warps
static_assert
(
warpSize
*
KVector
>=
kKPerBlock
&&
warpSize
*
KVector
%
kKPerBlock
==
0
);
constexpr
index_t
LanesPerK
=
kKPerBlock
/
KVector
;
// within a wave
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// within a wave
constexpr
index_t
NumIssues
=
kNPerBlock
/
(
LaneGroups
*
NumWarps
);
static_assert
(
NumIssues
==
kNPerBlock
*
kKPerBlock
/
(
kBlockSize
*
KVector
));
constexpr
auto
k_lds_block_desc_0
=
make_naive_tensor_descriptor_with_offset
(
make_tuple
(
number
<
NumIssues
>
{},
// n0
number
<
NumWarps
>
{},
// n2
number
<
LaneGroups
>
{},
// n1
number
<
kKPerBlock
/
KPack
>
{},
// k0
number
<
KPack
>
{}),
// k1
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
kPad
)
>
{},
number
<
warpSize
*
KVector
+
kPad
>
{},
number
<
kKPerBlock
>
{},
number
<
KPack
>
{},
number
<
1
>
{}),
number
<
IBuf
*
GetSingleSmemElementSpaceSize
<
Problem
>
()
>
{},
number
<
KPack
>
{},
number
<
1
>
{});
constexpr
auto
k_lds_block_desc
=
transform_tensor_descriptor
(
k_lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
NumIssues
>
{},
number
<
LaneGroups
>
{},
number
<
NumWarps
>
{})),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
0
,
2
,
1
>
{},
sequence
<
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
k_lds_block_desc
;
}
#else
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsLoadBlockDescriptor
()
{
// K is always k-major, we use async-copy to load into LDS
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NumWarps
=
Problem
::
BlockFmhaShape
::
NumWarps
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
KPack
=
GetSmemKPackK
<
Problem
>
();
// this is for lds
constexpr
index_t
KVector
=
GetAlignmentK
<
Problem
>
();
// this is for global load
constexpr
index_t
kPad
=
KPack
;
// for async-copy, this pad is between warps
static_assert
(
warpSize
*
KVector
>=
kKPerBlock
&&
warpSize
*
KVector
%
kKPerBlock
==
0
);
constexpr
index_t
LanesPerK
=
kKPerBlock
/
KVector
;
// within a wave
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// within a wave
constexpr
index_t
NumIssues
=
kNPerBlock
/
(
LaneGroups
*
NumWarps
);
static_assert
(
NumIssues
==
kNPerBlock
*
kKPerBlock
/
(
kBlockSize
*
KVector
));
// constexpr index_t SingleKSize = NumIssues * NumWarps * (warpSize * KVector + kPad);
// constexpr index_t SingleVSize =
// MakeVLdsBlockDescriptor<Problem>().get_element_space_size();
constexpr
index_t
BufferSize
=
GetSingleSmemElementSpaceSize
<
Problem
>
();
// max(SingleKSize, SingleVSize);
constexpr
auto
k_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumPrefetchK
>
{},
// num_buffers
number
<
NumIssues
>
{},
// n0
number
<
NumWarps
>
{},
// n2
number
<
LaneGroups
>
{},
// n1
number
<
kKPerBlock
/
KPack
>
{},
// k0
number
<
KPack
>
{}),
// k1
make_tuple
(
number
<
BufferSize
>
{},
number
<
NumWarps
*
(
warpSize
*
KVector
+
kPad
)
>
{},
number
<
warpSize
*
KVector
+
kPad
>
{},
number
<
kKPerBlock
>
{},
number
<
KPack
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
number
<
1
>
{});
constexpr
auto
k_lds_block_desc
=
transform_tensor_descriptor
(
k_lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
NumPrefetchK
>
{},
number
<
NumIssues
>
{},
number
<
LaneGroups
>
{},
number
<
NumWarps
>
{})),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
3
,
2
>
{},
sequence
<
4
,
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
k_lds_block_desc
;
}
#endif
// 3d + padding
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVLdsBlockDescriptor
()
{
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
VDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackV
<
Problem
>
();
static_assert
(
PixelsPerRow
%
kKPack
==
0
);
constexpr
index_t
NPerRow
=
PixelsPerRow
/
kKPack
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
static_assert
(
kNPerBlock
%
NPerRow
==
0
);
static_assert
(
kKPerBlock
%
kKPack
==
0
);
constexpr
auto
v_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumPrefetchV
>
{},
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{},
number
<
kKPack
>
{}),
make_tuple
(
number
<
GetSingleSmemElementSpaceSize
<
Problem
>
()
>
{},
number
<
(
kNPerBlock
/
NPerRow
)
*
(
PixelsPerRow
+
kKPack
)
>
{},
number
<
PixelsPerRow
+
kKPack
>
{},
number
<
kKPack
>
{},
number
<
1
>
{}),
number
<
kKPack
>
{},
number
<
1
>
{});
constexpr
auto
v_lds_block_desc
=
transform_tensor_descriptor
(
v_lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
NumPrefetchV
>
{},
number
<
kNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{})),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kKPack
>
{}))),
make_tuple
(
sequence
<
0
,
2
,
3
>
{},
sequence
<
1
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
v_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
// TODO: assume Q is in register
// TODO: assume K/V has same data type
constexpr
index_t
single_smem_size
=
GetSingleSmemElementSpaceSize
<
Problem
>
()
*
sizeof
(
typename
Problem
::
KDataType
);
return
QXPolicy
::
template
GetSmemSizeQ
<
Problem
>()
+
single_smem_size
*
max
(
NumPrefetchK
,
NumPrefetchV
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKDramTileDistribution
()
{
if
constexpr
(
!
AsyncCopyK
)
{
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
K1
=
16
/
sizeof
(
KDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
else
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NumWarps
=
Problem
::
BlockFmhaShape
::
NumWarps
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
KVector
=
GetAlignmentK
<
Problem
>
();
// this is for global load
static_assert
(
warpSize
*
KVector
>=
kKPerBlock
&&
warpSize
*
KVector
%
kKPerBlock
==
0
);
constexpr
index_t
LanesPerK
=
kKPerBlock
/
KVector
;
// within a wave
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// within a wave
constexpr
index_t
NumIssues
=
kNPerBlock
/
(
LaneGroups
*
NumWarps
);
static_assert
(
NumIssues
==
kNPerBlock
*
kKPerBlock
/
(
kBlockSize
*
KVector
));
constexpr
index_t
N0
=
NumIssues
;
constexpr
index_t
N1
=
LaneGroups
;
constexpr
index_t
N2
=
NumWarps
;
constexpr
index_t
K0
=
LanesPerK
;
constexpr
index_t
K1
=
KVector
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeVDramTileDistribution
()
{
using
VLayout
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
::
VLayout
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
constexpr
index_t
N1
=
GetAlignmentV
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
// P
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackV
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
if
constexpr
(
get_warp_size
()
%
(
K2
*
N0
)
==
0
)
{
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
else
{
constexpr
index_t
K1
=
GetAlignmentV
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
static_assert
(
N0
!=
0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
}
template
<
typename
Problem
,
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBiasDramTileDistribution
()
{
constexpr
index_t
MPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
// Construct C-Block-HostTensor
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
return
c_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledVRegBlockDescriptor
()
{
// This descriptor only used when V layout is seqlen * hdim
using
VLayout
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
::
VLayout
>
;
static_assert
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
N1
=
GetAlignmentV
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackV
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
if
constexpr
(
get_warp_size
()
%
(
K2
*
N0
)
==
0
)
{
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetKVBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN1
,
Problem
::
BlockFmhaShape
::
kK1
>>
;
auto
warp_gemm
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
VDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
OaccDataType
,
float
>
)
{
return
WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution
<>
{};
// return
// WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
// WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename
// Problem::PDataType, typename Problem::VDataType>>>{};
}
else
{
return
WarpGemmMfmaDispatcher
<
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
2
>
{}),
true
>
{};
}
}();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
warp_gemm
)
>
;
using
BlockGemmPolicy
=
BlockGemmARegBSmemCRegV2CustomPolicy
<
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBSmemCRegV2
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
BlockTile_
,
// sequence<...
typename
Gemm0BlockWarps_
,
typename
Gemm0WarpTile_
,
typename
Gemm1BlockWarps_
,
typename
Gemm1WarpTile_
,
bool
IsVLayoutRowMajor_
>
struct
TileFmhaShape
{
using
BlockTile
=
remove_cvref_t
<
BlockTile_
>
;
using
Gemm0BlockWarps
=
remove_cvref_t
<
Gemm0BlockWarps_
>
;
using
Gemm0WarpTile
=
remove_cvref_t
<
Gemm0WarpTile_
>
;
using
Gemm1BlockWarps
=
remove_cvref_t
<
Gemm1BlockWarps_
>
;
using
Gemm1WarpTile
=
remove_cvref_t
<
Gemm1WarpTile_
>
;
static
constexpr
index_t
NumWarps
=
reduce_on_sequence
(
Gemm0BlockWarps
{},
multiplies
{},
number
<
1
>
{});
static_assert
(
NumWarps
==
reduce_on_sequence
(
Gemm1BlockWarps
{},
multiplies
{},
number
<
1
>
{}));
static
constexpr
index_t
kM0
=
BlockTile
::
at
(
number
<
0
>
{});
// tile size along q seqlen
static
constexpr
index_t
kN0
=
BlockTile
::
at
(
number
<
1
>
{});
// tile size along k seqlen
static
constexpr
index_t
kK0
=
BlockTile
::
at
(
number
<
2
>
{});
// tile size along qk gemm unroll
static
constexpr
index_t
kN1
=
BlockTile
::
at
(
number
<
3
>
{});
// tile size along v head_dim
static
constexpr
index_t
kK1
=
BlockTile
::
at
(
number
<
4
>
{});
// tile size along kv gemm unroll
static
constexpr
index_t
kK0BlockLength
=
BlockTile
::
at
(
number
<
5
>
{});
// total length of K0, used for pipeline that need load Q at
// once (or repeately load Q as a whole tile)
static_assert
(
kK0BlockLength
%
kK0
==
0
,
"kK0BlockLength should be divisible by kK0"
);
// v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
static
constexpr
bool
IsVLayoutRowMajor
=
IsVLayoutRowMajor_
;
using
VLayout
=
std
::
conditional_t
<
IsVLayoutRowMajor
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
bool
kPadSeqLenQ_
/* padding for seqlen_q */
,
bool
kPadSeqLenK_
/* padding for seqlen_k */
,
bool
kPadHeadDimQ_
/* paddding for hdim_q */
,
bool
kPadHeadDimV_
/* paddding for hdim_v */
,
bool
kHasBias_
,
bool
kStoreLSE_
,
bool
kDoFp8StaticQuant_
,
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
struct
TileFmhaTraits
{
static
constexpr
bool
kPadSeqLenQ
=
kPadSeqLenQ_
;
static
constexpr
bool
kPadSeqLenK
=
kPadSeqLenK_
;
static
constexpr
bool
kPadHeadDimQ
=
kPadHeadDimQ_
;
static
constexpr
bool
kPadHeadDimV
=
kPadHeadDimV_
;
static
constexpr
bool
kHasBias
=
kHasBias_
;
static
constexpr
bool
kStoreLSE
=
kStoreLSE_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
}
// namespace ck_tile
include/ck_tile/ops/gemm.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
// Problem Description for BlockGemmARegBGmemCReg
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
CDataType_
,
index_t
kBlockSize_
,
typename
BlockGemmShape_
>
struct
BlockGemmARegBGmemCRegProblem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
index_t
kBlockSize
=
kBlockSize_
;
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
namespace
ck_tile
{
// A is block distributed tensor
// B is block window on global memory
// C is block distributed tensor
// This will:
// 1. load B from global memory into shared memory and then
// 2. Call BlockGemmARegSGmemCRegV1
template
<
typename
Problem_
,
typename
Policy_
=
BlockGemmARegBGmemCRegV1DefaultPolicy
>
struct
BlockGemmARegBGmemCRegV1
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
// use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation
using
BlockGemmARegBSmemCRegImpl
=
BlockGemmARegBSmemCRegV1
<
BlockGemmARegBSmemCRegProblem
<
ADataType
,
BDataType
,
CDataType
,
kBlockSize
,
BlockGemmShape
>
,
BlockGemmARegBSmemCRegV1DefaultPolicy
>
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
{
return
sizeof
(
BDataType
)
*
Policy
::
template
MakeBSmemBlockDescriptor
<
Problem
>().
get_element_space_size
();
}
// C += A * B
template
<
typename
CBlockTensor
,
typename
ABlockTensor
,
typename
BBlockGmemWindowTmp
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ABlockTensor
&
a_block_tensor
,
const
BBlockGmemWindowTmp
&
b_block_gmem_window_tmp
,
void
*
smem_ptr
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cv_t
<
typename
ABlockTensor
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cv_t
<
typename
BBlockGmemWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
"wrong!"
);
constexpr
index_t
MPerBlock
=
ABlockTensor
{}.
get_lengths
()[
number
<
0
>
{}];
constexpr
index_t
NPerBlock
=
BBlockGmemWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
KPerBlock
=
ABlockTensor
{}.
get_lengths
()[
number
<
1
>
{}];
static_assert
(
MPerBlock
==
BlockGemmShape
::
kM
&&
NPerBlock
==
BlockGemmShape
::
kN
&&
KPerBlock
==
BlockGemmShape
::
kK
,
"wrong!"
);
const
auto
b_block_gmem_window
=
make_tile_window
(
b_block_gmem_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
b_block_gmem_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeBGmemTileDistribution
<
Problem
>());
// B LDS and LDS window
auto
b_block_smem
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
BDataType
*>
(
smem_ptr
),
Policy
::
template
MakeBSmemBlockDescriptor
<
Problem
>());
auto
b_block_smem_window
=
make_tile_window
(
b_block_smem
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
// load B tile from global mem
const
auto
b_block_tile
=
load_tile
(
b_block_gmem_window
);
// store B tile into shared mem
store_tile
(
b_block_smem_window
,
b_block_tile
);
// wait for store_tile to finish
block_sync_lds
();
// block GEMM
BlockGemmARegBSmemCRegImpl
{}(
c_block_tensor
,
a_block_tensor
,
b_block_smem_window
);
}
// C = A * B
template
<
typename
ABlockTensor
,
typename
BBlockGmemWindowTmp
>
CK_TILE_DEVICE
auto
operator
()(
const
ABlockTensor
&
a_block_tensor
,
const
BBlockGmemWindowTmp
&
b_block_gmem_window_tmp
,
void
*
smem_ptr
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cv_t
<
typename
ABlockTensor
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cv_t
<
typename
BBlockGmemWindowTmp
::
DataType
>>
,
"wrong!"
);
constexpr
index_t
MPerBlock
=
ABlockTensor
{}.
get_lengths
()[
number
<
0
>
{}];
constexpr
index_t
NPerBlock
=
BBlockGmemWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
KPerBlock
=
ABlockTensor
{}.
get_lengths
()[
number
<
1
>
{}];
static_assert
(
MPerBlock
==
BlockGemmShape
::
kM
&&
NPerBlock
==
BlockGemmShape
::
kN
&&
KPerBlock
==
BlockGemmShape
::
kK
,
"wrong!"
);
const
auto
b_block_gmem_window
=
make_tile_window
(
b_block_gmem_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
b_block_gmem_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeBGmemTileDistribution
<
Problem
>());
// B LDS and LDS window
auto
b_block_smem
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
BDataType
*>
(
smem_ptr
),
Policy
::
template
MakeBSmemBlockDescriptor
<
Problem
>());
auto
b_block_smem_window
=
make_tile_window
(
b_block_smem
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
// load B tile from global mem
const
auto
b_block_tile
=
load_tile
(
b_block_gmem_window
);
// store B tile into shared mem
store_tile
(
b_block_smem_window
,
b_block_tile
);
// wait for store_tile to finish
block_sync_lds
();
// block GEMM
return
BlockGemmARegBSmemCRegImpl
{}(
a_block_tensor
,
b_block_smem_window
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
// Default policy for BlockGemmARegBGmemCRegV1
// Default policy class should not be templated, put template on member functions instead
struct
BlockGemmARegBGmemCRegV1DefaultPolicy
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBGmemTileDistribution
()
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
16
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
#if 0
// 2d
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto b_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), number<32>{});
return b_lds_block_desc;
}
#elif
0
// 3d + padding
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBSmemBlockDescriptor
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
auto
b_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
8
>
{},
number
<
kNPerBlock
>
{},
number
<
8
>
{}),
make_tuple
(
number
<
(
kNPerBlock
+
1
)
*
8
>
{},
number
<
8
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
auto
b_lds_block_desc
=
transform_tensor_descriptor
(
b_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
kNPerBlock
),
make_merge_transform
(
make_tuple
(
kKPerBlock
/
8
,
8
))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
b_lds_block_desc
;
}
#elif 1
// fake XOR
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBSmemBlockDescriptor
()
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
auto
b_lds_block_desc_d1_d2_d3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
number
<
kNPerBlock
/
2
>
{},
number
<
2
>
{},
number
<
kKPerBlock
>
{}),
number
<
kKPerBlock
>
{});
constexpr
index_t
kK1
=
16
/
sizeof
(
BDataType
);
constexpr
auto
b_lds_block_desc_d4_d5_d6
=
transform_tensor_descriptor
(
b_lds_block_desc_d1_d2_d3
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
kNPerBlock
/
2
>
{},
number
<
kKPerBlock
>
{}),
kK1
),
make_pass_through_transform
(
2
)),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}));
constexpr
auto
b_lds_block_desc_n_k
=
transform_tensor_descriptor
(
b_lds_block_desc_d4_d5_d6
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
kNPerBlock
/
2
>
{},
number
<
2
>
{})),
make_pass_through_transform
(
kKPerBlock
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
b_lds_block_desc_n_k
;
}
#endif
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
// Problem Description for BlockGemmARegBSmemCReg
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
CDataType_
,
index_t
kBlockSize_
,
typename
BlockGemmShape_
>
struct
BlockGemmARegBSmemCRegProblem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
index_t
kBlockSize
=
kBlockSize_
;
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
namespace
ck_tile
{
// A is block distributed tensor
// B is block window on shared memory
// C is block distributed tensor
template
<
typename
Problem_
,
typename
Policy_
=
BlockGemmARegBSmemCRegV1DefaultPolicy
>
struct
BlockGemmARegBSmemCRegV1
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
// C += A * B
template
<
typename
CBlockTensor
,
typename
ABlockTensorTmp
,
typename
BBlockWindowTmp
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ABlockTensorTmp
&
a_block_tensor_tmp
,
const
BBlockWindowTmp
&
b_block_window_tmp
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cv_t
<
typename
ABlockTensorTmp
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cv_t
<
typename
BBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
"wrong!"
);
constexpr
index_t
MPerBlock
=
ABlockTensorTmp
{}.
get_lengths
()[
number
<
0
>
{}];
constexpr
index_t
NPerBlock
=
BBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
KPerBlock
=
ABlockTensorTmp
{}.
get_lengths
()[
number
<
1
>
{}];
static_assert
(
MPerBlock
==
BlockGemmShape
::
kM
&&
NPerBlock
==
BlockGemmShape
::
kN
&&
KPerBlock
==
BlockGemmShape
::
kK
,
"wrong!"
);
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
constexpr
index_t
NPerBlockPerIter
=
NPerBlock
/
NIterPerWarp
;
constexpr
index_t
KPerBlockPerIter
=
KPerBlock
/
KIterPerWarp
;
const
index_t
iNWarp
=
get_warp_id
()
%
NWarp
;
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
a_block_dstr
=
make_static_tile_distribution
(
a_block_dstr_encode
);
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto
a_block_tensor
=
make_static_distributed_tensor
<
typename
ABlockTensorTmp
::
DataType
>
(
a_block_dstr
);
a_block_tensor
.
get_thread_buffer
()
=
a_block_tensor_tmp
.
get_thread_buffer
();
// construct B-warp-window
auto
b_warp_window_tmp
=
make_tile_window
(
b_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WG
::
kN
>
{},
number
<
WG
::
kK
>
{}),
b_block_window_tmp
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
WG
::
kN
,
0
},
make_static_tile_distribution
(
typename
WG
::
BWarpDstrEncoding
{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array
<
statically_indexed_array
<
decltype
(
b_warp_window_tmp
),
KIterPerWarp
>
,
NIterPerWarp
>
b_warp_windows
;
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
b_warp_windows
(
nIter
)(
kIter
)
=
b_warp_window_tmp
;
move_tile_window
(
b_warp_windows
(
nIter
)(
kIter
),
{
nIter
*
NPerBlockPerIter
,
kIter
*
KPerBlockPerIter
});
});
});
#endif
// check C-block-distribution
static_assert
(
std
::
is_same_v
<
remove_cvref_t
<
decltype
(
c_block_dstr_encode
)
>
,
remove_cvref_t
<
decltype
(
CBlockTensor
::
get_tile_distribution
()
.
get_static_tile_distribution_encoding
())
>>
,
"wrong!"
);
using
AWarpDstr
=
typename
WG
::
AWarpDstr
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
using
AWarpTensor
=
typename
WG
::
AWarpTensor
;
using
CWarpTensor
=
typename
WG
::
CWarpTensor
;
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
// read A warp tensor from A block tensor
AWarpTensor
a_warp_tensor
;
a_warp_tensor
.
get_thread_buffer
()
=
a_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
kIter
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
));
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read B warp tensor from B Block window
const
auto
b_warp_tensor
=
load_tile
(
b_warp_windows
(
nIter
)(
kIter
));
// read C warp tensor from C block tensor
CWarpTensor
c_warp_tensor
;
c_warp_tensor
.
get_thread_buffer
()
=
c_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// warp GEMM
WG
{}(
c_warp_tensor
,
a_warp_tensor
,
b_warp_tensor
);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
c_warp_tensor
.
get_thread_buffer
());
});
});
});
}
// C = A * B
template
<
typename
ABlockTensorTmp
,
typename
BBlockWindowTmp
>
CK_TILE_DEVICE
auto
operator
()(
const
ABlockTensorTmp
&
a_block_tensor_tmp
,
const
BBlockWindowTmp
&
b_block_window_tmp
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cv_t
<
typename
ABlockTensorTmp
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cv_t
<
typename
BBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
constexpr
index_t
MPerBlock
=
ABlockTensorTmp
{}.
get_lengths
()[
number
<
0
>
{}];
constexpr
index_t
NPerBlock
=
BBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
KPerBlock
=
ABlockTensorTmp
{}.
get_lengths
()[
number
<
1
>
{}];
static_assert
(
MPerBlock
==
BlockGemmShape
::
kM
&&
NPerBlock
==
BlockGemmShape
::
kN
&&
KPerBlock
==
BlockGemmShape
::
kK
,
"wrong!"
);
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
constexpr
index_t
NPerBlockPerIter
=
NPerBlock
/
NIterPerWarp
;
constexpr
index_t
KPerBlockPerIter
=
KPerBlock
/
KIterPerWarp
;
const
index_t
iNWarp
=
get_warp_id
()
%
NWarp
;
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
a_block_dstr
=
make_static_tile_distribution
(
a_block_dstr_encode
);
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto
a_block_tensor
=
make_static_distributed_tensor
<
typename
ABlockTensorTmp
::
DataType
>
(
a_block_dstr
);
a_block_tensor
.
get_thread_buffer
()
=
a_block_tensor_tmp
.
get_thread_buffer
();
// construct B-warp-window
auto
b_warp_window_tmp
=
make_tile_window
(
b_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WG
::
kN
>
{},
number
<
WG
::
kK
>
{}),
b_block_window_tmp
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
WG
::
kN
,
0
},
make_static_tile_distribution
(
typename
WG
::
BWarpDstrEncoding
{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array
<
statically_indexed_array
<
decltype
(
b_warp_window_tmp
),
KIterPerWarp
>
,
NIterPerWarp
>
b_warp_windows
;
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
b_warp_windows
(
nIter
)(
kIter
)
=
b_warp_window_tmp
;
move_tile_window
(
b_warp_windows
(
nIter
)(
kIter
),
{
nIter
*
NPerBlockPerIter
,
kIter
*
KPerBlockPerIter
});
});
});
#endif
// Construct C-Block-HostTensor
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
using
AWarpDstr
=
typename
WG
::
AWarpDstr
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
using
AWarpTensor
=
typename
WG
::
AWarpTensor
;
using
CWarpTensor
=
typename
WG
::
CWarpTensor
;
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
// read A warp tensor from A block tensor
AWarpTensor
a_warp_tensor
;
a_warp_tensor
.
get_thread_buffer
()
=
a_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
kIter
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
));
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read B warp tensor from B Block window
const
auto
b_warp_tensor
=
load_tile
(
b_warp_windows
(
nIter
)(
kIter
));
// read C warp tensor from C block tensor
CWarpTensor
c_warp_tensor
;
c_warp_tensor
.
get_thread_buffer
()
=
c_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// warp GEMM
WG
{}(
c_warp_tensor
,
a_warp_tensor
,
b_warp_tensor
);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
c_warp_tensor
.
get_thread_buffer
());
});
});
});
return
c_block_tensor
;
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
AType_
,
typename
BType_
,
typename
CType_
,
typename
BlockWarps_
,
typename
WarpGemm_
>
struct
BlockGemmARegBSmemCRegV1CustomPolicy
{
using
AType
=
remove_cvref_t
<
AType_
>
;
using
BType
=
remove_cvref_t
<
BType_
>
;
using
CType
=
remove_cvref_t
<
CType_
>
;
using
BlockWarps
=
remove_cvref_t
<
BlockWarps_
>
;
static
constexpr
index_t
kMWarps
=
BlockWarps
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kNWarps
=
BlockWarps
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kKWarps
=
BlockWarps
::
at
(
number
<
2
>
{});
using
WarpGemm
=
remove_cvref_t
<
WarpGemm_
>
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemmMWarpNWarp
()
{
return
make_tuple
(
WarpGemm
{},
kMWarps
,
kNWarps
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace
ck_tile
{
// Default policy for BlockGemmARegBSmemCRegV1
// Default policy class should not be templated, put template on member functions instead
struct
BlockGemmARegBSmemCRegV1DefaultPolicy
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemmMWarpNWarp
()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
BDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
CDataType
,
float
>
)
{
#if 0
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
static_assert(kBlockSize % get_warp_size() == 0, "wrong!");
constexpr index_t NumWarp = kBlockSize / get_warp_size();
// FIXME
if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 &&
kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
}
else
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
}
#else
return
make_tuple
(
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
{},
4
,
1
);
#endif
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
BDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
CDataType
,
float
>
)
{
return
make_tuple
(
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
{},
4
,
1
);
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
namespace
ck_tile
{
// A is block distributed tensor
// B is block window on shared memory
// C is block distributed tensor
template
<
typename
Problem_
,
typename
Policy_
=
BlockGemmARegBSmemCRegV2DefaultPolicy
>
struct
BlockGemmARegBSmemCRegV2
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
// C += A * B
template
<
typename
CBlockTensor
,
typename
ABlockTensorTmp
,
typename
BBlockWindowTmp
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ABlockTensorTmp
&
a_block_tensor_tmp
,
const
BBlockWindowTmp
&
b_block_window_tmp
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cv_t
<
typename
ABlockTensorTmp
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cv_t
<
typename
BBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
"wrong!"
);
constexpr
index_t
MPerBlock
=
ABlockTensorTmp
{}.
get_lengths
()[
number
<
0
>
{}];
constexpr
index_t
NPerBlock
=
BBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
KPerBlock
=
ABlockTensorTmp
{}.
get_lengths
()[
number
<
1
>
{}];
static_assert
(
MPerBlock
==
BlockGemmShape
::
kM
&&
NPerBlock
==
BlockGemmShape
::
kN
&&
KPerBlock
==
BlockGemmShape
::
kK
,
"wrong!"
);
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
constexpr
index_t
NPerBlockPerIter
=
NPerBlock
/
NIterPerWarp
;
constexpr
index_t
KPerBlockPerIter
=
KPerBlock
/
KIterPerWarp
;
const
index_t
iNWarp
=
get_warp_id
()
%
NWarp
;
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
a_block_dstr
=
make_static_tile_distribution
(
a_block_dstr_encode
);
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto
a_block_tensor
=
make_static_distributed_tensor
<
typename
ABlockTensorTmp
::
DataType
>
(
a_block_dstr
);
a_block_tensor
.
get_thread_buffer
()
=
a_block_tensor_tmp
.
get_thread_buffer
();
// construct B-warp-window
auto
b_warp_window_tmp
=
make_tile_window
(
b_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WG
::
kN
>
{},
number
<
WG
::
kK
>
{}),
b_block_window_tmp
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
WG
::
kN
,
0
},
make_static_tile_distribution
(
typename
WG
::
BWarpDstrEncoding
{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array
<
statically_indexed_array
<
decltype
(
b_warp_window_tmp
),
KIterPerWarp
>
,
NIterPerWarp
>
b_warp_windows
;
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
b_warp_windows
(
nIter
)(
kIter
)
=
b_warp_window_tmp
;
move_tile_window
(
b_warp_windows
(
nIter
)(
kIter
),
{
nIter
*
NPerBlockPerIter
,
kIter
*
KPerBlockPerIter
});
});
});
#endif
// check C-block-distribution
static_assert
(
std
::
is_same_v
<
remove_cvref_t
<
decltype
(
c_block_dstr_encode
)
>
,
remove_cvref_t
<
decltype
(
CBlockTensor
::
get_tile_distribution
()
.
get_static_tile_distribution_encoding
())
>>
,
"wrong!"
);
using
AWarpDstr
=
typename
WG
::
AWarpDstr
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
using
AWarpTensor
=
typename
WG
::
AWarpTensor
;
using
CWarpTensor
=
typename
WG
::
CWarpTensor
;
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read B warp tensor from B Block window
const
auto
b_warp_tensor
=
load_tile
(
b_warp_windows
(
nIter
)(
kIter
));
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
// read A warp tensor from A block tensor
AWarpTensor
a_warp_tensor
;
a_warp_tensor
.
get_thread_buffer
()
=
a_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
kIter
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
));
// read C warp tensor from C block tensor
CWarpTensor
c_warp_tensor
;
c_warp_tensor
.
get_thread_buffer
()
=
c_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// warp GEMM
WG
{}(
c_warp_tensor
,
a_warp_tensor
,
b_warp_tensor
);
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
c_warp_tensor
.
get_thread_buffer
());
});
});
});
}
CK_TILE_DEVICE
constexpr
auto
MakeCBlockTile
()
const
{
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
// C = A * B
template
<
typename
ABlockTensorTmp
,
typename
BBlockWindowTmp
>
CK_TILE_DEVICE
auto
operator
()(
const
ABlockTensorTmp
&
a_block_tensor_tmp
,
const
BBlockWindowTmp
&
b_block_window_tmp
)
const
{
auto
c_block_tensor
=
MakeCBlockTile
();
operator
()(
c_block_tensor
,
a_block_tensor_tmp
,
b_block_window_tmp
);
return
c_block_tensor
;
}
};
}
// namespace ck_tile
Prev
1
…
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment