Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a4522ae3
"docs/basic_usage/openai_api_completions.ipynb" did not exist on "6b0aeb58fd530f88defd5c7862c74a5c7c1a5dba"
Commit
a4522ae3
authored
Nov 06, 2024
by
illsilin
Browse files
sync from public repo
parents
1f127242
e0594d08
Changes
425
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2184 additions
and
252 deletions
+2184
-252
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp
...ine/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp
+95
-0
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp
.../pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp
+142
-0
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp
...t/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp
+41
-0
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp
...ipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp
+266
-0
include/ck_tile/ops/common.hpp
include/ck_tile/ops/common.hpp
+1
-0
include/ck_tile/ops/common/generic_2d_block_shape.hpp
include/ck_tile/ops/common/generic_2d_block_shape.hpp
+77
-0
include/ck_tile/ops/elementwise.hpp
include/ck_tile/ops/elementwise.hpp
+8
-0
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
.../ck_tile/ops/elementwise/unary_element_wise_operation.hpp
+1163
-0
include/ck_tile/ops/epilogue.hpp
include/ck_tile/ops/epilogue.hpp
+2
-0
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
+17
-11
include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp
include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp
+188
-0
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+1
-0
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+8
-6
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+9
-7
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
...ile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
+5
-4
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
...a/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
+12
-15
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
...eline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
+12
-15
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+73
-164
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
...fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
+18
-9
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
...lock_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
+46
-21
No files found.
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp
0 → 100644
View file @
a4522ae3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
namespace
ck_tile
{
struct
AddRmsnorm2dRdquantFwdPipelineDefaultPolicy
{
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeABXBlockTileDistribution
()
{
using
S
=
typename
Problem
::
BlockShape
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
S
::
Repeat_M
,
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
,
S
::
Vector_M
>
,
sequence
<
S
::
Repeat_N
,
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
,
S
::
Vector_N
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>
,
sequence
<
2
,
2
>>
,
sequence
<
1
,
1
,
2
,
2
>
,
sequence
<
0
,
3
,
0
,
3
>>
{});
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeGammaBlockTileDistribution
()
{
using
S
=
typename
Problem
::
BlockShape
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
>
,
tuple
<
sequence
<
S
::
Repeat_N
,
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
,
S
::
Vector_N
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
3
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2d
()
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockReduce2d
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dSync
()
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockReduce2dSync
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dCrossWarpSync
()
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockReduce2dCrossWarpSync
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
if
constexpr
(
Problem
::
kNeedCrossWarpSync
)
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
using
block_reduce2d
=
BlockReduce2d
<
P_
>
;
using
x_block_tile
=
decltype
(
make_static_distributed_tensor
<
typename
Problem
::
ComputeDataType
>
(
MakeABXBlockTileDistribution
<
Problem
>
()));
using
y_block_tile
=
decltype
(
block_reduce2d
::
template
MakeYBlockTile
<
x_block_tile
>());
return
GetBlockReduce2dCrossWarpSync
<
Problem
>
().
template
GetSmemSize
<
y_block_tile
>();
}
else
{
return
1
;
// zero size arrays are an extension
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp
0 → 100644
View file @
a4522ae3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
AddRmsnorm2dRdquantFwdPipelineDefaultPolicy
>
struct
AddRmsnorm2dRdquantFwdPipelineOnePass
{
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
ADataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
YScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
QYDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
QYDataType
>
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveX
=
Problem
::
kSaveX
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
return
"bpr_op"
;
// block per row
else
return
"wpr_op"
;
// warp per row
}();
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
AWindow
,
typename
BWindow
,
typename
GammaWindow
,
typename
XWindow
,
typename
YScaleWindow
,
typename
QYWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
AWindow
&
a_window_
,
const
BWindow
&
b_window_
,
const
GammaWindow
&
gamma_window_
,
XWindow
&
x_window
,
YScaleWindow
&
yscale_window
,
QYWindow
&
qy_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
{
const
auto
a_window
=
make_tile_window
(
a_window_
,
Policy
::
template
MakeABXBlockTileDistribution
<
Problem
>());
const
auto
b_window
=
make_tile_window
(
b_window_
,
Policy
::
template
MakeABXBlockTileDistribution
<
Problem
>());
const
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
Policy
::
template
MakeGammaBlockTileDistribution
<
Problem
>());
auto
reduce_square_sum_func
=
ReduceOp
::
SquareAdd
{};
auto
reduce_sum_func
=
ReduceOp
::
Add
{};
auto
reduce_absmax_func
=
ReduceOp
::
AbsMax
{};
auto
reduce_max_func
=
ReduceOp
::
Max
{};
auto
block_reduce2d
=
Policy
::
template
GetBlockReduce2d
<
Problem
>();
auto
block_reduce2d_sync
=
Policy
::
template
GetBlockReduce2dSync
<
Problem
>();
auto
block_reduce2d_cross_warp_sync
=
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
const
auto
a
=
load_tile
(
a_window
);
const
auto
b
=
load_tile
(
b_window
);
const
auto
gamma
=
load_tile
(
gamma_window
);
auto
x
=
tile_elementwise_in
(
[
&
](
const
auto
&
a_
,
const
auto
&
b_
)
{
return
type_convert
<
ComputeDataType
>
(
a_
)
+
type_convert
<
ComputeDataType
>
(
b_
);
},
a
,
b
);
if
constexpr
(
kSaveX
)
store_tile
(
x_window
,
cast_tile
<
XDataType
>
(
x
));
// compute mean square, each-thread->cross-lane->cross-warp
auto
square_sum
=
block_reduce2d
(
x
,
reduce_square_sum_func
.
GetIdentityValue
<
ComputeDataType
>
(),
reduce_square_sum_func
);
block_reduce2d_sync
(
square_sum
,
reduce_sum_func
);
block_reduce2d_cross_warp_sync
(
square_sum
,
smem
,
reduce_sum_func
);
auto
inv_rms
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
/
row_size
+
epsilon
));
},
square_sum
);
// rmsnorm computation
auto
y
=
make_static_distributed_tensor
<
ComputeDataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
y
,
[
&
,
inv_rms_
=
inv_rms
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
y_
=
x_
*
inv_rms_
[
i_idx
]
*
gamma_
;
y
(
idx
)
=
type_convert
<
ComputeDataType
>
(
y_
);
});
// compute absmax, each-thread->cross-lane->cross-warp
auto
absmax
=
block_reduce2d
(
y
,
reduce_absmax_func
.
GetIdentityValue
<
ComputeDataType
>
(),
reduce_absmax_func
);
block_reduce2d_sync
(
absmax
,
reduce_max_func
);
block_reduce2d_cross_warp_sync
(
absmax
,
smem
,
reduce_max_func
);
// ex: yscale = absmax / 127 if int8
auto
yscale
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
v_
/
type_convert
<
ComputeDataType
>
(
numeric
<
QYDataType
>::
max
());
},
absmax
);
store_tile
(
yscale_window
,
cast_tile
<
YScaleDataType
>
(
yscale
));
// quantize y to qy
auto
qy
=
make_static_distributed_tensor
<
QYDataType
>
(
y
.
get_tile_distribution
());
sweep_tile
(
qy
,
[
&
,
yscale_
=
yscale
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
auto
qy_
=
y
[
idx
]
/
yscale_
[
i_idx
];
qy
(
idx
)
=
saturates
<
QYDataType
>
{}(
qy_
);
});
store_tile
(
qy_window
,
qy
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp
0 → 100644
View file @
a4522ae3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
// X = A + B, Y = Rmsnorm2d(X), QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale)
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
GammaDataType_
,
typename
ComputeDataType_
,
typename
XDataType_
,
typename
YScaleDataType_
,
typename
QYDataType_
,
typename
BlockShape_
,
bool
kPadN_
,
bool
kSaveX_
,
bool
kThreePass_
>
struct
AddRmsnorm2dRdquantFwdPipelineProblem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
QYDataType
=
remove_cvref_t
<
QYDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPerBlock_N
>
1
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveX
=
kSaveX_
;
static
constexpr
bool
kThreePass
=
kThreePass_
;
};
}
// namespace ck_tile
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp
0 → 100644
View file @
a4522ae3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
AddRmsnorm2dRdquantFwdPipelineDefaultPolicy
>
struct
AddRmsnorm2dRdquantFwdPipelineThreePass
{
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
ADataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
YScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
QYDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
QYDataType
>
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveX
=
Problem
::
kSaveX
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
return
"bpr_tp"
;
// block per row
else
return
"wpr_tp"
;
// warp per row
}();
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
AWindow
,
typename
BWindow
,
typename
GammaWindow
,
typename
XWindow
,
typename
YScaleWindow
,
typename
QYWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
AWindow
&
a_window_
,
const
BWindow
&
b_window_
,
const
GammaWindow
&
gamma_window_
,
XWindow
&
x_window_
,
YScaleWindow
&
yscale_window
,
QYWindow
&
qy_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
{
auto
a_window
=
make_tile_window
(
a_window_
,
Policy
::
template
MakeABXBlockTileDistribution
<
Problem
>());
auto
b_window
=
make_tile_window
(
b_window_
,
Policy
::
template
MakeABXBlockTileDistribution
<
Problem
>());
auto
x_window
=
[
&
]()
{
if
constexpr
(
kSaveX
)
return
make_tile_window
(
x_window_
,
Policy
::
template
MakeABXBlockTileDistribution
<
Problem
>());
else
return
x_window_
;
}();
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
Policy
::
template
MakeGammaBlockTileDistribution
<
Problem
>());
auto
reduce_square_sum_func
=
ReduceOp
::
SquareAdd
{};
auto
reduce_sum_func
=
ReduceOp
::
Add
{};
auto
reduce_absmax_func
=
ReduceOp
::
AbsMax
{};
auto
reduce_max_func
=
ReduceOp
::
Max
{};
auto
block_reduce2d
=
Policy
::
template
GetBlockReduce2d
<
Problem
>();
auto
block_reduce2d_sync
=
Policy
::
template
GetBlockReduce2dSync
<
Problem
>();
auto
block_reduce2d_cross_warp_sync
=
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
index_t
num_n_tile_iteration
=
__builtin_amdgcn_readfirstlane
(
integer_divide_ceil
(
row_size
,
Block_N
));
using
XTensorType
=
decltype
(
cast_tile
<
ComputeDataType
>
(
load_tile
(
a_window
)));
auto
square_sum
=
block_reduce2d
.
template
MakeYBlockTile
<
XTensorType
>();
set_tile
(
square_sum
,
reduce_square_sum_func
.
GetIdentityValue
<
ComputeDataType
>
());
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
const
auto
a
=
load_tile
(
a_window
);
const
auto
b
=
load_tile
(
b_window
);
auto
x
=
tile_elementwise_in
(
[
&
](
const
auto
&
a_
,
const
auto
&
b_
)
{
return
type_convert
<
ComputeDataType
>
(
a_
)
+
type_convert
<
ComputeDataType
>
(
b_
);
},
a
,
b
);
if
constexpr
(
kSaveX
)
store_tile
(
x_window
,
cast_tile
<
XDataType
>
(
x
));
block_reduce2d
(
x
,
square_sum
,
reduce_square_sum_func
);
move_tile_window
(
x_window
,
{
0
,
Block_N
});
move_tile_window
(
a_window
,
{
0
,
Block_N
});
move_tile_window
(
b_window
,
{
0
,
Block_N
});
}
block_reduce2d_sync
(
square_sum
,
reduce_sum_func
);
block_reduce2d_cross_warp_sync
(
square_sum
,
smem
,
reduce_sum_func
);
auto
inv_rms
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
/
row_size
+
epsilon
));
},
square_sum
);
// reverse read x to reuse cache
ck_tile
::
index_t
stride_to_right_most_window
=
row_size
%
Block_N
==
0
?
row_size
-
Block_N
:
row_size
-
row_size
%
Block_N
;
if
constexpr
(
kSaveX
)
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
else
{
move_tile_window
(
a_window
,
{
0
,
-
Block_N
});
move_tile_window
(
b_window
,
{
0
,
-
Block_N
});
}
move_tile_window
(
gamma_window
,
{
stride_to_right_most_window
});
using
YTensorType
=
XTensorType
;
auto
absmax
=
block_reduce2d
.
template
MakeYBlockTile
<
YTensorType
>();
set_tile
(
absmax
,
reduce_absmax_func
.
GetIdentityValue
<
ComputeDataType
>
());
// rmsnorm computation + absmax(threadwise reduce)
if
constexpr
(
kSaveX
)
__syncthreads
();
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
auto
x
=
[
&
]()
{
if
constexpr
(
kSaveX
)
{
return
load_tile
(
x_window
);
}
else
{
const
auto
a
=
load_tile
(
a_window
);
const
auto
b
=
load_tile
(
b_window
);
return
tile_elementwise_in
(
[
&
](
const
auto
&
a_
,
const
auto
&
b_
)
{
return
type_convert
<
ComputeDataType
>
(
a_
)
+
type_convert
<
ComputeDataType
>
(
b_
);
},
a
,
b
);
}
}();
auto
gamma
=
load_tile
(
gamma_window
);
auto
y
=
make_static_distributed_tensor
<
ComputeDataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
y
,
[
&
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
y_
=
x_
*
inv_rms
[
i_idx
]
*
gamma_
;
y
(
idx
)
=
type_convert
<
ComputeDataType
>
(
y_
);
});
block_reduce2d
(
y
,
absmax
,
reduce_absmax_func
);
if
constexpr
(
kSaveX
)
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
else
{
move_tile_window
(
a_window
,
{
0
,
-
Block_N
});
move_tile_window
(
b_window
,
{
0
,
-
Block_N
});
}
move_tile_window
(
gamma_window
,
{
-
Block_N
});
}
// compute absmax, cross-lane->cross-warp
block_reduce2d_sync
(
absmax
,
reduce_max_func
);
block_reduce2d_cross_warp_sync
(
absmax
,
smem
,
reduce_max_func
);
// ex: yscale = absmax / 127 if int8
auto
yscale
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
v_
/
type_convert
<
ComputeDataType
>
(
numeric
<
QYDataType
>::
max
());
},
absmax
);
store_tile
(
yscale_window
,
cast_tile
<
YScaleDataType
>
(
yscale
));
// quantize y to qy
// recompute rmsnorm, try to save y in the future
if
constexpr
(
kSaveX
)
move_tile_window
(
x_window
,
{
0
,
Block_N
});
else
{
move_tile_window
(
a_window
,
{
0
,
Block_N
});
move_tile_window
(
b_window
,
{
0
,
Block_N
});
}
move_tile_window
(
gamma_window
,
{
Block_N
});
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
auto
x
=
[
&
]()
{
if
constexpr
(
kSaveX
)
{
return
load_tile
(
x_window
);
}
else
{
const
auto
a
=
load_tile
(
a_window
);
const
auto
b
=
load_tile
(
b_window
);
return
tile_elementwise_in
(
[
&
](
const
auto
&
a_
,
const
auto
&
b_
)
{
return
type_convert
<
ComputeDataType
>
(
a_
)
+
type_convert
<
ComputeDataType
>
(
b_
);
},
a
,
b
);
}
}();
auto
gamma
=
load_tile
(
gamma_window
);
auto
y
=
make_static_distributed_tensor
<
ComputeDataType
>
(
x
.
get_tile_distribution
());
auto
qy
=
make_static_distributed_tensor
<
QYDataType
>
(
y
.
get_tile_distribution
());
sweep_tile
(
y
,
[
&
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
y_
=
x_
*
inv_rms
[
i_idx
]
*
gamma_
;
auto
qy_
=
y_
/
yscale
[
i_idx
];
qy
(
idx
)
=
saturates
<
QYDataType
>
{}(
qy_
);
});
store_tile
(
qy_window
,
qy
);
if
constexpr
(
kSaveX
)
move_tile_window
(
x_window
,
{
0
,
Block_N
});
else
{
move_tile_window
(
a_window
,
{
0
,
Block_N
});
move_tile_window
(
b_window
,
{
0
,
Block_N
});
}
move_tile_window
(
gamma_window
,
{
Block_N
});
move_tile_window
(
qy_window
,
{
0
,
Block_N
});
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/common.hpp
View file @
a4522ae3
...
...
@@ -3,4 +3,5 @@
#pragma once
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/common/generic_2d_block_shape.hpp
0 → 100644
View file @
a4522ae3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck_tile
{
/*
// clang-format off
4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector
Block_N (Warp_N * WarpPerBlock_N * Repeat_N )
+<----------------------< Repeat_N(2)>--------------------->+
| |
+<-- <WarpPerBlock_N(2)> -->+
Warp_N
+--------------+--------------+--------------+--------------+----+----------------+
Warp_M | wrap_0 | wrap_1 | | ^ ^
+--------------+--------------+ | <WarpPerBlock_M(2)> |
| wrap_2 | wrap_3 | | v
+--------------+--------------+--------------+--------------+----+ Block_M
| | |
+ + |
| | | v
+--------------+--------------+--------------+--------------+ +
each Warp-tile (e.g 16 thrd per row)
Vector_N (contiguous pixels each thrd holds along N, or vector size)
+-----------+-----------+-----------+-----------+-----------+
| thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M
+-----------+-----------+-----------+-----------+-----------+
| thrd_16 | thrd_17 | thrd_18 | thrd_19 | ...
+-----------+-----------+-----------+-----------+-----------+
// clang-format on
*/
template
<
typename
BlockTile_
,
// block size, seq<M, N>
typename
WarpPerBlock_
,
// num warps along seq<M, N>
typename
WarpTile_
,
// warp size, seq<M, N>
typename
Vector_
,
// contiguous pixels(vector size) along seq<M, N>
index_t
BlockSize_
=
warpSize
*
reduce_on_sequence
(
WarpPerBlock_
{}
,
multiplies
{}
,
number
<
1
>{})
>
struct
Generic2dBlockShape
{
// block size
static
constexpr
index_t
Block_M
=
BlockTile_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_N
=
BlockTile_
::
at
(
number
<
1
>
{});
// num warps along seq<M, N>, within each block
static
constexpr
index_t
WarpPerBlock_M
=
WarpPerBlock_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
WarpPerBlock_N
=
WarpPerBlock_
::
at
(
number
<
1
>
{});
// warp size
static
constexpr
index_t
Warp_M
=
WarpTile_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_N
=
WarpTile_
::
at
(
number
<
1
>
{});
static_assert
(
Block_M
%
(
WarpPerBlock_M
*
Warp_M
)
==
0
);
static_assert
(
Block_N
%
(
WarpPerBlock_N
*
Warp_N
)
==
0
);
// repeat of each thread along seq<M, N>
static
constexpr
index_t
Repeat_M
=
Block_M
/
(
WarpPerBlock_M
*
Warp_M
);
static
constexpr
index_t
Repeat_N
=
Block_N
/
(
WarpPerBlock_N
*
Warp_N
);
// vector size along seq<M, N>
static
constexpr
index_t
Vector_M
=
Vector_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Vector_N
=
Vector_
::
at
(
number
<
1
>
{});
static_assert
(
Warp_M
%
Vector_M
==
0
);
static_assert
(
Warp_N
%
Vector_N
==
0
);
// num of threads along seq<M, N>, within each warp
static
constexpr
index_t
ThreadPerWarp_M
=
Warp_M
/
Vector_M
;
static
constexpr
index_t
ThreadPerWarp_N
=
Warp_N
/
Vector_N
;
static
constexpr
index_t
BlockSize
=
BlockSize_
;
};
}
// namespace ck_tile
include/ck_tile/ops/elementwise.hpp
0 → 100644
View file @
a4522ae3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
0 → 100644
View file @
a4522ae3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include <type_traits>
namespace
ck_tile
{
namespace
element_wise
{
#if 0
struct PassThroughPack2
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::half2_t& y, const ck_tile::f8x2_t& x) const
{
auto t = type_convert<float2_t>(x);
y = type_convert<half2_t>(t);
}
constexpr const static bool is_pack2_invocable = true;
};
#endif
struct
PassThrough
{
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST_DEVICE
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
double
,
double
>
(
double
&
y
,
const
double
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
double
>
(
float
&
y
,
const
double
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
double
,
float
>
(
double
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
double
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
>
(
ck_tile
::
fp16_t
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
float
>
(
ck_tile
::
fp16_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
fp16_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
>
(
ck_tile
::
bf16_t
&
y
,
const
ck_tile
::
bf16_t
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
int32_t
,
int32_t
>
(
int32_t
&
y
,
const
int32_t
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
bf16_t
,
float
>
(
ck_tile
::
bf16_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
bf16_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
ck_tile
::
bf16_t
>
(
float
&
y
,
const
ck_tile
::
bf16_t
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
bf16_t
,
ck_tile
::
fp16_t
>
(
ck_tile
::
bf16_t
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
bf16_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
ck_tile
::
fp16_t
>
(
float
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
int8_t
,
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
int8_t
>
(
ck_tile
::
fp16_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
fp16_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
bf16_t
,
int8_t
>
(
ck_tile
::
bf16_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
bf16_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
uint8_t
,
uint8_t
>
(
uint8_t
&
y
,
const
uint8_t
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
int8_t
,
int32_t
>
(
int8_t
&
y
,
const
int32_t
&
x
)
const
{
y
=
type_convert
<
int8_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
int32_t
,
int8_t
>
(
int32_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
type_convert
<
int32_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
int8_t
,
float
>
(
int8_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
int8_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
int8_t
>
(
float
&
y
,
const
int8_t
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
int4_t
,
int4_t
>
(
int4_t
&
y
,
const
int4_t
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
int4_t
,
int
>
(
int4_t
&
y
,
const
int
&
x
)
const
{
y
=
type_convert
<
int4_t
>
(
x
);
}
#endif
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
>
(
ck_tile
::
fp8_t
&
y
,
const
ck_tile
::
fp8_t
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
ck_tile
::
fp8_t
>
(
float
&
y
,
const
ck_tile
::
fp8_t
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp8_t
,
float
>
(
ck_tile
::
fp8_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
fp8_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
>
(
ck_tile
::
fp16_t
&
y
,
const
ck_tile
::
fp8_t
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
fp16_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp8_t
,
ck_tile
::
fp16_t
>
(
ck_tile
::
fp8_t
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
fp8_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
bf8_t
,
ck_tile
::
bf8_t
>
(
ck_tile
::
bf8_t
&
y
,
const
ck_tile
::
bf8_t
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
ck_tile
::
bf8_t
>
(
float
&
y
,
const
ck_tile
::
bf8_t
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
bf8_t
,
float
>
(
ck_tile
::
bf8_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
bf8_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
ck_tile
::
bf8_t
>
(
ck_tile
::
fp16_t
&
y
,
const
ck_tile
::
bf8_t
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
fp16_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
bf8_t
,
ck_tile
::
fp16_t
>
(
ck_tile
::
bf8_t
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
y
=
ck_tile
::
type_convert
<
ck_tile
::
bf8_t
>
(
x
);
}
};
#if 0
struct UnaryConvert
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
y = type_convert<Y>(x);
}
};
struct ConvertBF16RTN
{
// convert to bf16 using round to nearest (rtn)
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(std::is_same_v<Y, ck_tile::bf16_t>, "Data type is not supported by this operation!");
// check X datatype
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = bf16_convert_rtn<Y>(x);
}
};
struct ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(std::is_same_v<Y, ck_tile::fp8_t> || std::is_same_v<Y, ck_tile::bf8_t>,
"Data type is not supported by this operation!");
// check X datatype
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = f8_convert_sr<Y>(x);
}
};
struct ConvertF8RNE
{
// convert to fp8 using rounding to nearest even
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(std::is_same_v<Y, ck_tile::fp8_t> || std::is_same_v<Y, ck_tile::bf8_t>,
"Data type is not supported by this operation!");
// check X datatype
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = f8_convert_rne<Y>(x);
}
};
#endif
struct
Scale
{
CK_TILE_HOST_DEVICE
Scale
(
float
scale
=
1.
f
)
:
scale_
(
scale
)
{}
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST_DEVICE
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
y
=
ck_tile
::
type_convert
<
Y
>
(
ck_tile
::
type_convert
<
float
>
(
x
)
*
scale_
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
>
(
ck_tile
::
fp16_t
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
y
=
ck_tile
::
type_convert
<
ck_tile
::
fp16_t
>
(
scale_
)
*
x
;
};
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
>
(
ck_tile
::
bf16_t
&
y
,
const
ck_tile
::
bf16_t
&
x
)
const
{
const
float
x_tmp
=
ck_tile
::
type_convert
<
float
>
(
x
);
const
float
y_tmp
=
scale_
*
x_tmp
;
y
=
ck_tile
::
type_convert
<
ck_tile
::
bf16_t
>
(
y_tmp
);
};
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
y
=
scale_
*
x
;
};
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
double
,
double
>
(
double
&
y
,
const
double
&
x
)
const
{
y
=
scale_
*
x
;
};
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
int8_t
,
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
ck_tile
::
type_convert
<
int8_t
>
(
scale_
*
ck_tile
::
type_convert
<
float
>
(
x
));
};
float
scale_
;
};
struct
ScaleAndResetNaNToMinusInfinity
{
CK_TILE_HOST_DEVICE
ScaleAndResetNaNToMinusInfinity
(
float
scale
)
:
scale_
(
scale
)
{}
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST_DEVICE
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
y
=
ck_tile
::
isnan
(
x
)
?
-
numeric
<
float
>::
infinity
()
:
scale_
*
x
;
};
float
scale_
;
};
struct
UnaryDivide
{
CK_TILE_HOST_DEVICE
UnaryDivide
(
const
int32_t
divider
=
1
)
:
divider_
(
divider
)
{}
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
x
/
type_convert
<
T
>
(
divider_
);
};
int32_t
divider_
=
1
;
};
struct
UnarySquare
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
std
::
is_same_v
<
T
,
int4_t
>
#endif
,
"Data type is not supported by this operation!"
);
y
=
x
*
x
;
};
};
struct
UnaryAbs
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
abs
(
x
);
};
};
struct
UnarySqrt
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
sqrt
(
x
);
};
};
struct
Relu
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
,
"Data type is not supported by this operation!"
);
y
=
x
>
0
?
x
:
0
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()(
ck_tile
::
bf16_t
&
y
,
const
ck_tile
::
bf16_t
&
x
)
const
{
float
x_f32
=
ck_tile
::
type_convert
<
float
>
(
x
);
float
y_f32
=
x_f32
>
0
?
x_f32
:
0
;
y
=
ck_tile
::
type_convert
<
ck_tile
::
bf16_t
>
(
y_f32
);
}
};
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// host code use higher accuracy "exp" and "div"
// gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function
struct
FastGelu
{
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
typename
Y
,
typename
X
>
CK_TILE_DEVICE
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
CK_TILE_HOST
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
const
float
c1
=
-
2.0
*
0.035677
f
;
const
float
c2
=
-
2.0
*
0.797885
f
;
const
float
u
=
x
*
(
c1
*
x
*
x
+
c2
);
const
float
emu
=
exp
(
u
);
y
=
x
/
(
1.
f
+
emu
);
}
// device code, use lower precision "__ocml_exp_f32" and "rcp"
template
<
>
CK_TILE_DEVICE
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
// const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const
float
c1
=
-
2.0
*
0.035677
f
;
const
float
c2
=
-
2.0
*
0.797885
f
;
const
float
u
=
x
*
(
c1
*
x
*
x
+
c2
);
const
float
emu
=
__ocml_exp_f32
(
u
);
y
=
x
*
ck_tile
::
rcp
(
1.
f
+
emu
);
}
template
<
>
CK_TILE_HOST
void
operator
()
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
>
(
ck_tile
::
fp16_t
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
type_convert
<
float
>
(
x
));
y
=
type_convert
<
ck_tile
::
fp16_t
>
(
y_f
);
}
template
<
>
CK_TILE_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
>
(
ck_tile
::
fp16_t
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
type_convert
<
float
>
(
x
));
y
=
type_convert
<
ck_tile
::
fp16_t
>
(
y_f
);
}
template
<
>
CK_TILE_HOST
void
operator
()
<
ck_tile
::
fp16_t
,
float
>
(
ck_tile
::
fp16_t
&
y
,
const
float
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
x
);
y
=
type_convert
<
ck_tile
::
fp16_t
>
(
y_f
);
}
template
<
>
CK_TILE_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
float
>
(
ck_tile
::
fp16_t
&
y
,
const
float
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
x
);
y
=
type_convert
<
ck_tile
::
fp16_t
>
(
y_f
);
}
template
<
>
CK_TILE_HOST
void
operator
()
<
ck_tile
::
bf16_t
,
float
>
(
ck_tile
::
bf16_t
&
y
,
const
float
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
x
);
y
=
type_convert
<
ck_tile
::
bf16_t
>
(
y_f
);
}
template
<
>
CK_TILE_DEVICE
void
operator
()
<
ck_tile
::
bf16_t
,
float
>
(
ck_tile
::
bf16_t
&
y
,
const
float
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
x
);
y
=
type_convert
<
ck_tile
::
bf16_t
>
(
y_f
);
}
template
<
>
CK_TILE_DEVICE
void
operator
()
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
>
(
ck_tile
::
bf16_t
&
y
,
const
ck_tile
::
bf16_t
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
type_convert
<
float
>
(
x
));
y
=
type_convert
<
ck_tile
::
bf16_t
>
(
y_f
);
}
template
<
>
CK_TILE_HOST
void
operator
()
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
>
(
ck_tile
::
bf16_t
&
y
,
const
ck_tile
::
bf16_t
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
type_convert
<
float
>
(
x
));
y
=
type_convert
<
ck_tile
::
bf16_t
>
(
y_f
);
}
};
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+erf(x/sqrt(2)))
struct
Gelu
{
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST_DEVICE
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
y
=
0.5
f
*
x
*
(
1.
f
+
erf
(
float
(
0.70710678118
f
*
x
)));
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
>
(
ck_tile
::
fp16_t
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
y
=
ck_tile
::
fp16_t
(
0.5
)
*
x
*
(
ck_tile
::
fp16_t
(
1
)
+
ck_tile
::
fp16_t
(
erf
(
float
(
0.70710678118
f
*
x
))));
}
};
struct
Sigmoid
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
one
/
(
one
+
ck_tile
::
exp
(
-
x
));
};
};
struct
Silu
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
x
*
(
one
/
(
one
+
ck_tile
::
exp
(
-
x
)));
};
};
struct
TanH
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
tanh
(
x
);
};
};
struct
ACos
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
acos
(
x
);
};
};
struct
Neg
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
neg
(
x
);
};
};
struct
ATan
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
atan
(
x
);
};
};
struct
Sin
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
sin
(
x
);
};
};
struct
ASinH
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
asinh
(
x
);
};
};
struct
Cos
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
cos
(
x
);
};
};
struct
ACosH
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
acosh
(
x
);
};
};
struct
Tan
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
tan
(
x
);
};
};
struct
ATanH
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
atanh
(
x
);
};
};
struct
SinH
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
sinh
(
x
);
};
};
struct
Ceil
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
ceil
(
x
);
};
};
struct
Exp
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
exp
(
x
);
};
};
struct
CosH
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
cosh
(
x
);
};
};
struct
Floor
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
floor
(
x
);
};
};
struct
Log
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
log
(
x
);
};
};
struct
ASin
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
asin
(
x
);
};
};
struct
Rcp
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
rcp
(
x
);
};
};
struct
Swish
{
Swish
(
float
beta
=
1.0
f
)
:
beta_
(
beta
)
{}
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST_DEVICE
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
X
,
float
>
||
std
::
is_same_v
<
X
,
double
>
||
std
::
is_same_v
<
X
,
ck_tile
::
fp16_t
>
,
"Data type is not supported by this operation!"
);
static_assert
(
std
::
is_same_v
<
Y
,
float
>
||
std
::
is_same_v
<
Y
,
double
>
||
std
::
is_same_v
<
Y
,
ck_tile
::
fp16_t
>
,
"Data type is not supported by this operation!"
);
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
ck_tile
::
exp
(
bx
)));
};
const
float
beta_
;
};
struct
SoftRelu
{
SoftRelu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
ck_tile
::
log
(
one
+
ck_tile
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
const
float
alpha_
;
};
struct
Power
{
Power
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
,
float
gamma
=
2.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
),
gamma_
(
gamma
){};
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
T
casted_gamma
=
type_convert
<
T
>
(
gamma_
);
T
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck_tile
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
const
float
alpha_
;
const
float
beta_
;
const
float
gamma_
;
};
struct
ClippedRelu
{
ClippedRelu
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
){};
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
y
=
ck_tile
::
min
(
casted_beta
,
ck_tile
::
max
(
casted_alpha
,
x
));
}
const
float
alpha_
;
const
float
beta_
;
};
struct
LeakyRelu
{
LeakyRelu
(
float
alpha
=
0.01
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
const
float
alpha_
;
};
struct
Elu
{
Elu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck_tile
::
expm1
(
x
);
}
const
float
alpha_
;
};
struct
Logistic
{
Logistic
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck_tile
::
exp
(
-
x
)
*
casted_alpha
);
}
const
float
alpha_
;
};
struct
ConvInvscale
{
CK_TILE_HOST_DEVICE
ConvInvscale
(
float
scale_in
=
1.
f
,
float
scale_wei
=
1.
f
,
float
scale_out
=
1.
f
)
:
scale_in_
(
scale_in
),
scale_wei_
(
scale_wei
),
scale_out_
(
scale_out
)
{
}
template
<
typename
E
,
typename
C
>
CK_TILE_HOST_DEVICE
void
operator
()(
E
&
e
,
const
C
&
c
)
const
;
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp8_t
,
float
>
(
ck_tile
::
fp8_t
&
e
,
const
float
&
c
)
const
{
e
=
type_convert
<
ck_tile
::
fp8_t
>
(
c
/
scale_in_
/
scale_wei_
/
scale_out_
);
};
float
scale_in_
;
float
scale_wei_
;
float
scale_out_
;
};
struct
ConvScale
{
CK_TILE_HOST_DEVICE
ConvScale
(
float
scale_in
=
1.
f
,
float
scale_wei
=
1.
f
,
float
scale_out
=
1.
f
)
:
scale_in_
(
scale_in
),
scale_wei_
(
scale_wei
),
scale_out_
(
scale_out
)
{
}
template
<
typename
E
,
typename
C
>
CK_TILE_HOST_DEVICE
void
operator
()(
E
&
e
,
const
C
&
c
)
const
;
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp8_t
,
float
>
(
ck_tile
::
fp8_t
&
e
,
const
float
&
c
)
const
{
e
=
type_convert
<
ck_tile
::
fp8_t
>
(
c
*
scale_in_
*
scale_wei_
*
scale_out_
);
};
float
scale_in_
;
float
scale_wei_
;
float
scale_out_
;
};
struct
ConvScaleRelu
{
CK_TILE_HOST_DEVICE
ConvScaleRelu
(
float
scale_in
=
1.
f
,
float
scale_wei
=
1.
f
,
float
scale_out
=
1.
f
)
:
scale_in_
(
scale_in
),
scale_wei_
(
scale_wei
),
scale_out_
(
scale_out
)
{
}
template
<
typename
E
,
typename
C
>
CK_TILE_HOST_DEVICE
void
operator
()(
E
&
e
,
const
C
&
c
)
const
;
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp8_t
,
float
>
(
ck_tile
::
fp8_t
&
e
,
const
float
&
c
)
const
{
float
x
;
Relu
{}.
template
operator
()
<
float
>(
x
,
c
*
scale_in_
*
scale_wei_
);
e
=
type_convert
<
ck_tile
::
fp8_t
>
(
x
*
scale_out_
);
};
float
scale_in_
;
float
scale_wei_
;
float
scale_out_
;
};
template
<
typename
DstType
,
typename
SrcType
>
struct
Cast
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
DstType
&
y
,
const
SrcType
&
x
)
const
{
y
=
ck_tile
::
type_convert
<
DstType
>
(
x
);
};
};
// support fastconvert of int8 to fp16
#if 0
template <typename InputDataType, typename OutputDataType, index_t RegPackNumber>
struct FastNumericArrayConverter
{
};
template <>
struct FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, 4>
{
using InputArray = vector_type<uint8_t, 4>;
using OutputArray = vector_type<ck_tile::fp16_t, 4>;
CK_TILE_DEVICE static OutputArray convert(InputArray const& Input)
{
OutputArray Output;
uint32_t* half_2 = reinterpret_cast<uint32_t*>(&Output);
uint32_t const uint8_4 = reinterpret_cast<uint32_t const&>(Input);
static constexpr uint32_t byte_selector_01 = 0x05010500;
static constexpr uint32_t byte_selector_23 = 0x05030502;
static constexpr uint32_t fp16_adder = 0x64646464;
half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01);
half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
: "=v"(half_2[0])
: "v"(half_2[0]), "s"(I8s_TO_F16s_MAGIC_NUM));
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
: "=v"(half_2[1])
: "v"(half_2[1]), "s"(I8s_TO_F16s_MAGIC_NUM));
return Output;
}
CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); }
};
template <index_t N>
struct FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, N>
{
static constexpr int VEC_WIDTH = 4;
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
using InputArray = vector_type<uint8_t, N>;
using OutputArray = vector_type<ck_tile::fp16_t, N>;
CK_TILE_DEVICE static OutputArray convert(InputArray const& Input)
{
FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, 4> converter;
OutputArray Output;
using Vec_InputArray = vector_type<uint8_t, 4>;
using Vec_OutputArray = vector_type<ck_tile::fp16_t, 4>;
Vec_OutputArray* half_4_ptr = reinterpret_cast<Vec_OutputArray*>(&Output);
Vec_InputArray const* uint8_4_ptr = reinterpret_cast<Vec_InputArray const*>(&Input);
static_for<0, N / VEC_WIDTH, 1>{}(
[&](auto i) { half_4_ptr[i] = converter(uint8_4_ptr[i]); });
return Output;
}
CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); }
};
#endif
}
// namespace element_wise
}
// namespace ck_tile
include/ck_tile/ops/epilogue.hpp
View file @
a4522ae3
...
...
@@ -5,4 +5,6 @@
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
View file @
a4522ae3
...
...
@@ -9,23 +9,29 @@ namespace ck_tile {
// this epilogue just store out a M*N matrix, row major
template
<
typename
AccDataType_
,
typename
ODataType_
,
bool
kPadM_
,
bool
kPadN_
>
template
<
typename
AccDataType_
,
typename
ODataType_
,
bool
kPadM_
,
bool
kPadN_
,
bool
UseRawStore_
=
true
>
struct
Default2DEpilogueProblem
{
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
UseRawStore
=
UseRawStore_
;
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
Default2DEpilogue
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
UseRawStore
=
Problem
::
UseRawStore
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
...
...
@@ -36,7 +42,7 @@ struct Default2DEpilogue
{
// TODO: this is ugly
if
constexpr
(
kPadM
||
kPadN
)
if
constexpr
(
UseRawStore
&&
(
kPadM
||
kPadN
)
)
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
buffer_store_fence
();
...
...
include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp
0 → 100644
View file @
a4522ae3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
namespace
ck_tile
{
template
<
bool
kPadM_
,
bool
kPadN_
,
bool
UseSmoothInputScale_
,
bool
UseRawStore_
=
true
,
bool
UseMax3_
=
false
>
struct
DynamicQuantEpilogueTraits
{
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
UseSmoothInputScale
=
UseSmoothInputScale_
;
static
constexpr
bool
UseRawStore
=
UseRawStore_
;
static
constexpr
bool
UseMax3
=
UseMax3_
;
};
// this epilogue just store out a M*N matrix, row major
template
<
typename
AccDataType_
,
typename
XScaleDataType_
,
typename
YScaleDataType_
,
typename
ODataType_
,
typename
BlockShape_
,
typename
Traits_
>
struct
DynamicQuantEpilogueProblem
{
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
XScaleDataType
=
remove_cvref_t
<
XScaleDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
// can consum generic 2d shape
using
Traits
=
remove_cvref_t
<
Traits_
>
;
};
// TODO: we should put descriptor creation function into policy
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
DynamicQuantEpilogue
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
XScaleDataType
=
remove_cvref_t
<
typename
Problem
::
XScaleDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
BlockShape
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
static
constexpr
bool
kPadM
=
Problem
::
Traits
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
UseRawStore
=
Problem
::
Traits
::
UseRawStore
;
static
constexpr
bool
UseMax3
=
Problem
::
Traits
::
UseMax3
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2d
()
{
using
P_
=
BlockReduce2dProblem
<
AccDataType
,
AccDataType
,
BlockShape
>
;
return
BlockReduce2d
<
P_
>
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dSync
()
{
using
P_
=
BlockReduce2dProblem
<
AccDataType
,
AccDataType
,
BlockShape
>
;
return
BlockReduce2dSync
<
P_
>
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dCrossWarpSync
()
{
using
P_
=
BlockReduce2dProblem
<
AccDataType
,
AccDataType
,
BlockShape
>
;
return
BlockReduce2dCrossWarpSync
<
P_
>
{};
}
CK_TILE_DEVICE
static
constexpr
auto
MakeSmoothInputScaleTileDistribution
()
{
using
S
=
BlockShape
;
#if 0
// don't remove this
// Note that if we set encoding purposely like this, you will result in compile fail
// TODO: x_scale create local-scratch to accept arbitrary acc input (with same length)
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>,
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<1, 1>, sequence<2, 2>>,
sequence<0, 1, 1>,
sequence<0, 0, 3>>{});
#else
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
>
,
tuple
<
sequence
<
S
::
Repeat_N
,
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
,
S
::
Vector_N
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
3
>>
{});
#endif
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
auto
reduce_crosswarp_sync
=
GetBlockReduce2dCrossWarpSync
();
return
reduce_crosswarp_sync
.
GetSmemSize
();
}
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
template
<
typename
ODramWindowTmp
,
typename
XScaleWindow
,
typename
YScaleWindow
,
typename
OAccTile
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
const
XScaleWindow
&
x_scale_window_
,
YScaleWindow
&
y_scale_window
,
const
OAccTile
&
o_acc_tile
,
void
*
smem
)
{
auto
reduce
=
GetBlockReduce2d
();
auto
reduce_sync
=
GetBlockReduce2dSync
();
auto
reduce_crosswarp_sync
=
GetBlockReduce2dCrossWarpSync
();
const
auto
x_scale_window
=
make_tile_window
(
x_scale_window_
,
MakeSmoothInputScaleTileDistribution
());
auto
x_scale
=
load_tile
(
x_scale_window
);
auto
o_acc_tmp
=
o_acc_tile
;
sweep_tile
(
o_acc_tmp
,
[
&
](
auto
idx
)
{
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
xs_
=
type_convert
<
AccDataType
>
(
x_scale
[
j_idx
]);
o_acc_tmp
(
idx
)
=
o_acc_tmp
(
idx
)
*
xs_
;
});
const
auto
f_absmax
=
[](
auto
acc_
,
auto
v_0_
)
{
return
max
(
acc_
,
abs
(
v_0_
));
};
auto
row_absmax
=
[
&
]()
{
constexpr
auto
y_size_per_row
=
OAccTile
{}.
get_tile_distribution
().
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
1
>
{});
if
constexpr
(
UseMax3
&&
std
::
is_same_v
<
AccDataType
,
float
>
&&
y_size_per_row
%
2
==
0
)
{
// fast max3+abs implementation
const
auto
f_max3
=
[](
auto
acc_
,
auto
v_0_
,
auto
v_1_
)
{
float
rtn
;
asm
volatile
(
"v_max3_f32 %0, %1, abs(%2), abs(%3)"
:
"=v"
(
rtn
)
:
"v"
(
acc_
),
"v"
(
v_0_
),
"v"
(
v_1_
));
return
rtn
;
};
return
reduce
(
o_acc_tmp
,
type_convert
<
AccDataType
>
(
0
),
f_max3
,
sequence
<
1
,
2
>
{});
}
else
{
return
reduce
(
o_acc_tmp
,
type_convert
<
AccDataType
>
(
0
),
f_absmax
);
}
}();
reduce_sync
(
row_absmax
,
f_absmax
);
reduce_crosswarp_sync
(
row_absmax
,
smem
,
f_absmax
);
// here y_scale is Acc TYpe, need convert to YScale type later
auto
y_scale
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
v_
/
type_convert
<
AccDataType
>
(
numeric
<
ODataType
>::
max
());
},
row_absmax
);
store_tile
(
y_scale_window
,
cast_tile
<
YScaleDataType
>
(
y_scale
));
sweep_tile
(
o_acc_tmp
,
[
&
](
auto
idx
)
{
constexpr
auto
row_id
=
make_tuple
(
idx
[
number
<
0
>
{}]);
o_acc_tmp
(
idx
)
=
o_acc_tmp
[
idx
]
/
y_scale
(
row_id
);
});
// TODO: this is ugly
if
constexpr
(
UseRawStore
&&
(
kPadM
||
kPadN
))
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tmp
));
buffer_store_fence
();
}
else
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tmp
));
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha.hpp
View file @
a4522ae3
...
...
@@ -43,4 +43,5 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
a4522ae3
...
...
@@ -69,7 +69,8 @@ struct FmhaFwdKernel
// sync with generate.py
// clang-format off
using
bfs
=
typename
FmhaPipeline
::
BlockFmhaShape
;
using
gbr
=
typename
bfs
::
Gemm0BlockWarps
;
using
g0br
=
typename
bfs
::
Gemm0BlockWarps
;
using
g1br
=
typename
bfs
::
Gemm1BlockWarps
;
using
gwt
=
typename
bfs
::
Gemm0WarpTile
;
#define _SS_ std::string
#define _TS_ std::to_string
...
...
@@ -81,11 +82,12 @@ struct FmhaFwdKernel
if
(
kPadHeadDimV
)
n
+=
"dv"
;
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
_SS_
(
"fmha_fwd_d"
)
+
_TS_
(
bfs
::
k
K0BlockLength
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
_SS_
(
"fmha_fwd_d"
)
+
_TS_
(
bfs
::
k
QKHeaddim
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
+
_SS_
(
TilePartitioner
::
name
)
+
"_"
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
_TS_
(
bfs
::
kN1
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kK0BlockLength
)
+
"_"
+
"r"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
_TS_
(
bfs
::
kN1
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
"r"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
...
...
@@ -655,7 +657,7 @@ struct FmhaFwdKernel
{
return
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
k
K0BlockLength
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
k
SubQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
else
...
...
@@ -722,7 +724,7 @@ struct FmhaFwdKernel
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
k
K0BlockLength
>
{});
number
<
FmhaPipeline
::
k
SubQKHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
a4522ae3
...
...
@@ -65,7 +65,8 @@ struct FmhaFwdSplitKVKernel
// sync with generate.py
// clang-format off
using
bfs
=
typename
FmhaPipeline
::
BlockFmhaShape
;
using
gbr
=
typename
bfs
::
Gemm0BlockWarps
;
using
g0br
=
typename
bfs
::
Gemm0BlockWarps
;
using
g1br
=
typename
bfs
::
Gemm1BlockWarps
;
using
gwt
=
typename
bfs
::
Gemm0WarpTile
;
#define _SS_ std::string
#define _TS_ std::to_string
...
...
@@ -77,11 +78,12 @@ struct FmhaFwdSplitKVKernel
if
(
kPadHeadDimV
)
n
+=
"dv"
;
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
_SS_
(
"fmha_fwd_splitkv_d"
)
+
_TS_
(
bfs
::
k
K0BlockLength
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
_SS_
(
"fmha_fwd_splitkv_d"
)
+
_TS_
(
bfs
::
k
QKHeaddim
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
_TS_
(
bfs
::
kN1
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kK0BlockLength
)
+
"_"
+
"r"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
_TS_
(
bfs
::
kN1
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
"r"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
...
...
@@ -584,7 +586,7 @@ struct FmhaFwdSplitKVKernel
{
return
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
k
K0BlockLength
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
k
SubQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
else
...
...
@@ -733,7 +735,7 @@ struct FmhaFwdSplitKVKernel
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
k
K0BlockLength
>
{});
number
<
FmhaPipeline
::
k
SubQKHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
...
...
@@ -894,7 +896,7 @@ struct FmhaFwdSplitKVKernel
o_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
stride_o_acc
,
1
),
number
<
1
>
{},
number
<
FmhaPipeline
::
kAlignmentOacc
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
View file @
a4522ae3
...
...
@@ -26,8 +26,8 @@ struct FmhaFwdSplitKVTilePartitioner
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
),
nhead
*
num_splits
,
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
)
*
num_splits
,
nhead
,
batch_size
);
}
...
...
@@ -42,8 +42,9 @@ struct FmhaFwdSplitKVTilePartitioner
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
blockIdx
.
x
,
num_tile_n1
);
const
auto
[
i_nhead
,
i_split
]
=
f
(
blockIdx
.
y
,
num_splits
);
const
auto
[
mn
,
i_split
]
=
f
(
blockIdx
.
x
,
num_splits
);
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
mn
,
num_tile_n1
);
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_split
,
i_nhead
,
i_batch
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
View file @
a4522ae3
...
...
@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
k_lds_ptr
,
Policy
::
template
MakeKLdsWriteBlockDescriptor
<
Problem
>());
auto
k_lds_write_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
k_lds_read_window
=
make_tile_window
(
k_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
k_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeKReg
Slice
BlockDescriptor
<
Problem
>());
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
auto
k_reg_tensor
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
...
...
@@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
v_lds_ptr
,
Policy
::
template
MakeVLdsWriteBlockDescriptor
<
Problem
>());
auto
v_lds_write_window
=
make_tile_window
(
v_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
K2
>
{}),
{
0
,
0
});
make_tile_window
(
v_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
VHeaddim
>
{}),
{
0
,
0
});
auto
v_lds_read_window
=
make_tile_window
(
v_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
v_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeVRegSliceBlockDescriptor
<
Problem
>());
auto
v_reg_tensor
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeVRegBlockDescriptor
<
Problem
>());
Policy
::
template
MakeVRegBlockDescriptor
<
Problem
>());
//------------------------------------------------------------------
// KT, Reg ->LDS ->Reg
...
...
@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
kt_lds_ptr
,
Policy
::
template
MakeShuffledKLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_k_lds_write_window
=
make_tile_window
(
shuffled_k_lds_write
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
shuffled_k_lds_write
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
kt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeKTLdsReadBlockDescriptor
<
Problem
>());
...
...
@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
block_sync_lds
();
v_reg_tensor
=
load_tile
(
v_lds_read_window
);
auto
v_reg_tensor
=
load_tile
(
v_lds_read_window
);
block_sync_lds
();
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
...
...
@@ -276,7 +273,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
q_lds_read_window
=
make_tile_window
(
q_lds_window
.
get_bottom_tensor_view
(),
...
...
@@ -297,7 +294,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
qt_lds_ptr
,
Policy
::
template
MakeShuffledQLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_q_lds_write_window
=
make_tile_window
(
shuffled_q_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
shuffled_q_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
qt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeQTLdsReadBlockDescriptor
<
Problem
>());
...
...
@@ -322,7 +319,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K2
>
{}),
{
0
,
0
});
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
VHeaddim
>
{}),
{
0
,
0
});
auto
do_lds_read_window
=
make_tile_window
(
do_lds_window
.
get_bottom_tensor_view
(),
...
...
@@ -341,7 +338,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
dot_lds_ptr
,
Policy
::
template
MakeShuffledOGradLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_do_lds_write_window
=
make_tile_window
(
shuffled_do_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K2
>
{}),
{
0
,
0
});
shuffled_do_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
VHeaddim
>
{}),
{
0
,
0
});
auto
dot_read_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsReadBlockDescriptor
<
Problem
>());
...
...
@@ -483,9 +480,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
index_t
i_total_loops
=
0
;
index_t
seqlen_q_step
=
seqlen_q_start
;
static_assert
(
kQKHeaddim
=
=
kK0
,
"kQKHeaddim should equal
t
o kK0"
);
static_assert
(
kQKHeaddim
>
=
kK0
,
"kQKHeaddim should
be
equal o
r greater than
kK0"
);
static_assert
(
kM0
==
kK1
,
"kM0 should equal to kK1"
);
static_assert
(
kVHeaddim
=
=
kK2
,
"kVHeaddim should equal
t
o kK2"
);
static_assert
(
kVHeaddim
>
=
kK2
,
"kVHeaddim should
be
equal o
r greater than
kK2"
);
static_assert
(
kM0
==
kK3
,
"kM0 should equal to kK3"
);
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
View file @
a4522ae3
...
...
@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
k_lds_ptr
,
Policy
::
template
MakeKLdsWriteBlockDescriptor
<
Problem
>());
auto
k_lds_write_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
k_lds_read_window
=
make_tile_window
(
k_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
k_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeKReg
Slice
BlockDescriptor
<
Problem
>());
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
auto
k_reg_tensor
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
...
...
@@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
v_lds_ptr
,
Policy
::
template
MakeVLdsWriteBlockDescriptor
<
Problem
>());
auto
v_lds_write_window
=
make_tile_window
(
v_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
K2
>
{}),
{
0
,
0
});
make_tile_window
(
v_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
VHeaddim
>
{}),
{
0
,
0
});
auto
v_lds_read_window
=
make_tile_window
(
v_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
v_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeVRegSliceBlockDescriptor
<
Problem
>());
auto
v_reg_tensor
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeVRegBlockDescriptor
<
Problem
>());
Policy
::
template
MakeVRegBlockDescriptor
<
Problem
>());
//------------------------------------------------------------------
// KT, Reg ->LDS ->Reg
...
...
@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
kt_lds_ptr
,
Policy
::
template
MakeShuffledKLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_k_lds_write_window
=
make_tile_window
(
shuffled_k_lds_write
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
shuffled_k_lds_write
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
kt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeKTLdsReadBlockDescriptor
<
Problem
>());
...
...
@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
block_sync_lds
();
v_reg_tensor
=
load_tile
(
v_lds_read_window
);
auto
v_reg_tensor
=
load_tile
(
v_lds_read_window
);
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
auto
q_dram_window
=
...
...
@@ -275,7 +272,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
q_lds_read_window
=
make_tile_window
(
q_lds_window
.
get_bottom_tensor_view
(),
...
...
@@ -296,7 +293,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
qt_lds_ptr
,
Policy
::
template
MakeShuffledQLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_q_lds_write_window
=
make_tile_window
(
shuffled_q_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
shuffled_q_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
qt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeQTLdsReadBlockDescriptor
<
Problem
>());
...
...
@@ -321,7 +318,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K2
>
{}),
{
0
,
0
});
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
VHeaddim
>
{}),
{
0
,
0
});
auto
do_lds_read_window
=
make_tile_window
(
do_lds_window
.
get_bottom_tensor_view
(),
...
...
@@ -340,7 +337,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
dot_lds_ptr
,
Policy
::
template
MakeShuffledOGradLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_do_lds_write_window
=
make_tile_window
(
shuffled_do_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K2
>
{}),
{
0
,
0
});
shuffled_do_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
VHeaddim
>
{}),
{
0
,
0
});
auto
dot_read_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsReadBlockDescriptor
<
Problem
>());
...
...
@@ -482,9 +479,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
index_t
i_total_loops
=
0
;
index_t
seqlen_q_step
=
seqlen_q_start
;
static_assert
(
kQKHeaddim
=
=
kK0
,
"kQKHeaddim should equal
t
o kK0"
);
static_assert
(
kQKHeaddim
>
=
kK0
,
"kQKHeaddim should
be
equal o
r greater than
kK0"
);
static_assert
(
kM0
==
kK1
,
"kM0 should equal to kK1"
);
static_assert
(
kVHeaddim
=
=
kK2
,
"kVHeaddim should equal
t
o kK2"
);
static_assert
(
kVHeaddim
>
=
kK2
,
"kVHeaddim should
be
equal o
r greater than
kK2"
);
static_assert
(
kM0
==
kK3
,
"kM0 should equal to kK3"
);
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
a4522ae3
...
...
@@ -5,9 +5,8 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/
pipeline/gemm_pipeline
_problem.hpp"
#include "ck_tile/ops/gemm/
block/block_gemm
_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
...
...
@@ -27,20 +26,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
BlockGemmProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
QDataType
,
...
...
@@ -66,20 +60,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetPTOGradTBlockGemm
()
{
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadHeadDimV
,
Problem
::
kPadHeadDimV
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
BlockGemmProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
...
@@ -104,20 +93,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetOGradVBlockGemm
()
{
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK2
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
BlockGemmProblem
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK2
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
OGradDataType
,
...
...
@@ -143,20 +127,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradTQTBlockGemm
()
{
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK3
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
Problem
::
kPadSeqLenK
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
BlockGemmProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK3
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
...
@@ -181,20 +160,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradKTBlockGemm
()
{
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK4
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadHeadDimQ
,
Problem
::
kPadSeqLenK
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
BlockGemmProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK4
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
...
@@ -222,7 +196,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
QDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
QDataType
);
...
...
@@ -241,7 +215,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
KDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
KDataType
);
...
...
@@ -260,7 +234,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
VDataType
);
constexpr
index_t
total_pixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
...
...
@@ -280,7 +254,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
OGradDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
OGradDataType
);
...
...
@@ -341,7 +315,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
...
...
@@ -353,7 +327,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
return
total_pixels
/
GetAlignmentK
<
Problem
>
();
...
...
@@ -364,7 +338,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
...
...
@@ -402,7 +376,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
K1
=
GetAlignmentK
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -425,7 +399,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
K1
=
GetAlignmentV
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -448,7 +422,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
K1
=
GetAlignmentQ
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -471,7 +445,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
K1
=
GetAlignmentOGrad
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -842,44 +816,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsWriteBlockDescriptor
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetQKBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
k_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
k_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
k_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
k_block_dstr
=
make_static_tile_distribution
(
k_block_dstr_encode
);
return
k_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKRegBlockDescriptor
()
{
...
...
@@ -891,7 +833,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
...
...
@@ -916,45 +858,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVLdsWriteBlockDescriptor
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kVPack
=
GetSmemKPackV
<
Problem
>
();
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kVPack
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetOGradVBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
v_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
v_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
v_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
v_block_dstr
=
make_static_tile_distribution
(
v_block_dstr_encode
);
return
v_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVRegBlockDescriptor
()
{
...
...
@@ -966,7 +876,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
...
...
@@ -992,7 +902,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
K1
=
GetAlignmentK
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -1074,7 +984,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQLdsBlockDescriptor
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
...
...
@@ -1118,7 +1028,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
K1
=
GetAlignmentQ
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -1281,7 +1191,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
// Hold full block data
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
...
...
@@ -1325,7 +1235,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
K1
=
GetAlignmentOGrad
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -1885,6 +1795,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static
constexpr
index_t
kN0
=
Problem
::
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kQKHeaddim
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kVHeaddim
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
static
constexpr
index_t
kK0
=
Problem
::
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kK2
=
Problem
::
BlockFmhaShape
::
kK2
;
static
constexpr
index_t
kK4
=
Problem
::
BlockFmhaShape
::
kK4
;
static
constexpr
index_t
WarpGemmM
=
...
...
@@ -1899,14 +1811,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
// Compute
static
constexpr
index_t
Gemm0MFMA
=
kM0
*
kN0
*
kQKHeaddim
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
kM0
*
kN0
*
kK0
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm1MFMA
=
kM0
*
kN0
*
kVHeaddim
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm2MFMA
=
kN0
*
kVHeaddim
*
kM0
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm2MFMA
=
kM0
*
kN0
*
kK2
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm3MFMA
=
kN0
*
kQKHeaddim
*
kM0
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
...
...
@@ -1929,13 +1839,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
kM0
*
kQKHeaddim
/
get_warp_size
()
/
GetTransposedAlignmentQ
<
Problem
>
();
static
constexpr
index_t
SGradT_LDS_READ_P1
=
kM0
*
kK4
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
GetSmemKPackSGrad
<
Problem
>
();
static
constexpr
index_t
Q_LDS_READ
=
kM0
*
kQKHeaddim
/
kBlockSize
/
GetAlignmentQ
<
Problem
>
();
static
constexpr
index_t
Q_LDS_READ
=
kM0
*
kK0
/
kBlockSize
/
GetAlignmentQ
<
Problem
>
();
static
constexpr
index_t
LSE_LDS_READ
=
WarpGemmM
==
16
?
kM0
/
(
4
*
4
)
:
kM0
/
(
2
*
4
);
static
constexpr
index_t
SGradT_LDS_READ_P2
=
kM0
*
(
kN0
-
kK4
)
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
GetSmemKPackSGrad
<
Problem
>
();
static
constexpr
index_t
OGrad_LDS_READ
=
kM0
*
k
VHeaddim
/
kBlockSize
/
GetAlignmentOGrad
<
Problem
>
();
kM0
*
k
K2
/
kBlockSize
/
GetAlignmentOGrad
<
Problem
>
();
static
constexpr
index_t
D_LDS_READ
=
WarpGemmM
==
16
?
kM0
/
(
4
*
4
)
:
kM0
/
(
2
*
4
);
// LDS Write
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
View file @
a4522ae3
...
...
@@ -12,6 +12,16 @@ namespace detail {
template
<
index_t
N
>
struct
log2
;
template
<
>
struct
log2
<
4
>
:
std
::
integral_constant
<
index_t
,
2
>
{
};
template
<
>
struct
log2
<
8
>
:
std
::
integral_constant
<
index_t
,
3
>
{
};
template
<
>
struct
log2
<
16
>
:
std
::
integral_constant
<
index_t
,
4
>
{
...
...
@@ -72,18 +82,18 @@ struct BlockFmhaFwdSplitKVCombinePipeline
{
if
constexpr
(
kHeadDimV
<=
32
)
{
constexpr
std
::
array
<
int
,
4
>
occupancy
{
3
,
3
,
3
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
4
];
constexpr
std
::
array
occupancy
{
3
,
3
,
3
,
3
,
3
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
2
];
}
else
if
constexpr
(
kHeadDimV
<=
128
)
{
constexpr
std
::
array
<
int
,
4
>
occupancy
{
3
,
3
,
2
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
4
];
constexpr
std
::
array
occupancy
{
3
,
3
,
3
,
3
,
2
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
2
];
}
else
if
constexpr
(
kHeadDimV
<=
256
)
{
constexpr
std
::
array
<
int
,
4
>
occupancy
{
2
,
2
,
2
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
4
];
constexpr
std
::
array
occupancy
{
2
,
2
,
2
,
2
,
2
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
2
];
}
}
}();
...
...
@@ -138,9 +148,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto
lse_accum
=
make_static_distributed_tensor
<
LSEDataType
>
(
Policy
::
template
MakeLSEaccRegTileDistribution
<
Problem
>());
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, max(kMaxSplits, warp_size)])
// this will extend the distributed tensor width so that each thread in wave have data to
// reduce.
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, kMaxSplits])
// and fill up -INF values outside the [kM0, num_splits] region.
{
constexpr
auto
spans
=
decltype
(
lse_accum
)
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
View file @
a4522ae3
...
...
@@ -10,11 +10,26 @@ namespace ck_tile {
struct
BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
template
<
index_t
BlockSize
,
index_t
M
,
index_t
N
,
typename
DataType
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeForTile
()
{
constexpr
index_t
PixelsPerThread
=
(
M
*
N
)
/
BlockSize
;
static_assert
(
0
<
PixelsPerThread
);
constexpr
index_t
MaxNPerThread
=
16
/
sizeof
(
DataType
);
constexpr
index_t
NPerThread
=
min
(
MaxNPerThread
,
PixelsPerThread
);
return
NPerThread
;
}
// alignment for dram lse tile (shape=[kMaxSplits, kM0])
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentLSE
()
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
return
16
/
sizeof
(
LSEDataType
);
return
GetVectorSizeForTile
<
Problem
::
kBlockSize
,
Problem
::
kMaxSplits
,
Problem
::
kM0
,
typename
Problem
::
LSEDataType
>
();
}
template
<
typename
Problem
>
...
...
@@ -47,29 +62,31 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
MakeLSEaccLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}
// shape=[kMaxSplits, kM0]
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccDramTileDistribution
()
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNumWarps
=
Problem
::
kNumWarps
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
NPerThread
=
16
/
sizeof
(
LSEDataType
);
constexpr
index_t
NThreads
=
kNPerBlock
/
NPerThread
;
constexpr
index_t
NPerThread
=
GetVectorSizeForTile
<
kBlockSize
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
constexpr
index_t
NThreads
=
kNPerBlock
/
NPerThread
;
constexpr
index_t
MThreadsPerWarp
=
get_warp_size
()
/
NThreads
;
constexpr
index_t
TotalWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
MPerThread
=
kMPerBlock
/
(
TotalWarps
*
MThreadsPerWarp
);
constexpr
index_t
MPerThread
=
kMPerBlock
/
(
kNumWarps
*
MThreadsPerWarp
);
static_assert
(
NThreads
*
NPerThread
==
kNPerBlock
);
static_assert
(
MPerThread
*
Total
Warps
*
MThreadsPerWarp
==
kMPerBlock
);
static_assert
(
MPerThread
*
kNum
Warps
*
MThreadsPerWarp
==
kMPerBlock
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
MPerThread
,
Total
Warps
,
MThreadsPerWarp
>
,
tuple
<
sequence
<
MPerThread
,
kNum
Warps
,
MThreadsPerWarp
>
,
sequence
<
NThreads
,
NPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
...
...
@@ -77,15 +94,18 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
sequence
<
0
,
1
>>
{});
}
// 3d + padding, [kMaxSplits, kM0]
// 3d + padding,
shape=
[kMaxSplits, kM0]
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccLdsStoreBlockDescriptor
()
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
NPack
=
16
/
sizeof
(
LSEDataType
);
constexpr
index_t
NPack
=
GetVectorSizeForTile
<
kBlockSize
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
constexpr
auto
lse_acc_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kNPerBlock
/
NPack
>
{},
number
<
kMPerBlock
>
{},
number
<
NPack
>
{}),
...
...
@@ -103,15 +123,18 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
return
lse_acc_lds_block_desc
;
}
// 3d + padding, [kM0, kMaxSplits]
// 3d + padding,
shape=
[kM0, kMaxSplits]
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccLdsBlockDescriptor
()
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
NPack
=
16
/
sizeof
(
LSEDataType
);
constexpr
index_t
NPack
=
GetVectorSizeForTile
<
kBlockSize
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
constexpr
auto
lse_acc_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kNPerBlock
/
NPack
>
{},
number
<
kMPerBlock
>
{},
number
<
NPack
>
{}),
...
...
@@ -134,26 +157,28 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
max
(
Problem
::
kMaxSplits
,
get_warp_size
())
;
constexpr
index_t
kNPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
NThreads
=
get_warp_size
()
;
constexpr
index_t
NThreads
=
4
;
constexpr
index_t
NPerThread
=
kNPerBlock
/
NThreads
;
constexpr
index_t
MThreads
=
kBlockSize
/
NThreads
;
constexpr
index_t
MPerThread
=
kMPerBlock
/
MThreads
;
constexpr
index_t
MThreads
=
kBlockSize
/
NThreads
;
constexpr
index_t
MPerThread
=
kMPerBlock
/
MThreads
;
constexpr
index_t
MWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
MThreadPerWarp
=
get_warp_size
()
/
NThreads
;
static_assert
(
NThreads
*
NPerThread
==
kNPerBlock
);
static_assert
(
M
Threads
*
MPerThread
==
kMPerBlock
);
static_assert
(
M
Warps
*
MThreadPerWarp
*
MPerThread
==
kMPerBlock
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M
Threads
,
MPerThread
>
,
sequence
<
NThreads
,
NPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
0
>>
,
tuple
<
sequence
<
M
Warps
,
MThreadPerWarp
,
MPerThread
>
,
sequence
<
NThreads
,
NPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
sequence
<
2
,
1
>>
{});
}
template
<
typename
Problem
>
...
...
Prev
1
…
8
9
10
11
12
13
14
15
16
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment