Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e6bb1dd7
Unverified
Commit
e6bb1dd7
authored
Jul 19, 2024
by
Po Yen Chen
Committed by
GitHub
Jul 19, 2024
Browse files
Merge branch 'develop' into feature/check-window-lengths
parents
9d6a3704
ab250afd
Changes
317
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5073 additions
and
45 deletions
+5073
-45
include/ck_tile/host/reference/reference_layernorm2d.hpp
include/ck_tile/host/reference/reference_layernorm2d.hpp
+69
-0
include/ck_tile/host/stream_config.hpp
include/ck_tile/host/stream_config.hpp
+17
-0
include/ck_tile/host/timer.hpp
include/ck_tile/host/timer.hpp
+79
-0
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+26
-0
include/ck_tile/ops/fmha/block/block_attention_bias_enum.hpp
include/ck_tile/ops/fmha/block/block_attention_bias_enum.hpp
+37
-0
include/ck_tile/ops/fmha/block/block_dropout.hpp
include/ck_tile/ops/fmha/block/block_dropout.hpp
+364
-0
include/ck_tile/ops/fmha/block/block_masking.hpp
include/ck_tile/ops/fmha/block/block_masking.hpp
+84
-7
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
+189
-0
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
+1421
-0
include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
...ude/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
+54
-0
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+224
-33
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
..._tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
+455
-0
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
...fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
+49
-0
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+913
-0
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
...ile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
+53
-0
include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp
...ude/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp
+56
-5
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp
...ude/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp
+95
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp
+20
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp
...a/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp
+848
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp
...k_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp
+20
-0
No files found.
Too many changes to show.
To preserve performance only
317 of 317+
files are displayed.
Plain diff
Email patch
include/ck_tile/host/reference/reference_layernorm2d.hpp
0 → 100644
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace
ck_tile
{
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
YDataType
,
typename
MeanDataType
,
typename
InvStdDataType
>
void
reference_layernorm2d_fwd
(
const
HostTensor
<
XDataType
>&
x_m_n
,
const
HostTensor
<
GammaDataType
>&
gamma_n
,
const
HostTensor
<
BetaDataType
>&
beta_n
,
HostTensor
<
YDataType
>&
y_m_n
,
HostTensor
<
MeanDataType
>&
mean_m
,
HostTensor
<
InvStdDataType
>&
invStd_m
,
ComputeDataType
epsilon
)
{
auto
layernorm2d_fwd_func
=
[
&
](
auto
m
)
{
const
int
N
=
x_m_n
.
mDesc
.
get_lengths
()[
1
];
int
count
=
0
;
ComputeDataType
mean
=
0
;
ComputeDataType
variance
=
0
;
ComputeDataType
divisor
=
0
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
++
count
;
ComputeDataType
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_m_n
(
m
,
n
));
ComputeDataType
delta
=
x
-
mean
;
mean
+=
delta
/
count
;
ComputeDataType
delta2
=
x
-
mean
;
variance
+=
delta
*
delta2
;
}
// actual variance
variance
=
variance
/
count
;
divisor
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
1
)
/
ck_tile
::
sqrt
(
variance
+
epsilon
);
if
constexpr
(
!
std
::
is_same_v
<
MeanDataType
,
ck_tile
::
null_type
>
)
mean_m
(
m
)
=
ck_tile
::
type_convert
<
MeanDataType
>
(
mean
);
if
constexpr
(
!
std
::
is_same_v
<
InvStdDataType
,
ck_tile
::
null_type
>
)
invStd_m
(
m
)
=
ck_tile
::
type_convert
<
InvStdDataType
>
(
divisor
);
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
ComputeDataType
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_m_n
(
m
,
n
));
ComputeDataType
gamma
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
gamma_n
(
n
));
ComputeDataType
beta
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
beta_n
(
n
));
auto
y
=
(
x
-
mean
)
*
divisor
;
y
=
y
*
gamma
+
beta
;
y_m_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
YDataType
>
(
y
);
}
};
make_ParallelTensorFunctor
(
layernorm2d_fwd_func
,
mean_m
.
mDesc
.
get_lengths
()[
0
])(
std
::
thread
::
hardware_concurrency
());
}
}
// namespace ck_tile
include/ck_tile/host/stream_config.hpp
View file @
e6bb1dd7
...
...
@@ -6,6 +6,22 @@
#include <hip/hip_runtime.h>
namespace
ck_tile
{
/*
* construct this structure with behavior as:
*
* // create stream config with default stream(NULL), and not timing the kernel
* stream_config s = stream_config{};
*
* // create stream config with _some_stream_id_, and not timing the kernel
* stream_config s = stream_config{_some_stream_id_};
*
* // create stream config with _some_stream_id_, and benchmark with warmup/repeat as default
* stream_config s = stream_config{_some_stream_id_, true};
*
* // create stream config with _some_stream_id_, and benchmark using cpu timer
* stream_config s = stream_config{_some_stream_id_, true, 0, 3, 10, false};
**/
struct
stream_config
{
hipStream_t
stream_id_
=
nullptr
;
...
...
@@ -13,5 +29,6 @@ struct stream_config
int
log_level_
=
0
;
int
cold_niters_
=
3
;
int
nrepeat_
=
10
;
bool
is_gpu_timer_
=
true
;
// keep compatible
};
}
// namespace ck_tile
include/ck_tile/host/timer.hpp
0 → 100644
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <hip/hip_runtime.h>
#include <cstddef>
#include <chrono>
namespace
ck_tile
{
struct
gpu_timer
{
CK_TILE_HOST
gpu_timer
()
{
HIP_CHECK_ERROR
(
hipEventCreate
(
&
start_evt
));
HIP_CHECK_ERROR
(
hipEventCreate
(
&
stop_evt
));
}
CK_TILE_HOST
~
gpu_timer
()
noexcept
(
false
)
{
HIP_CHECK_ERROR
(
hipEventDestroy
(
start_evt
));
HIP_CHECK_ERROR
(
hipEventDestroy
(
stop_evt
));
}
CK_TILE_HOST
void
start
(
const
hipStream_t
&
s
)
{
HIP_CHECK_ERROR
(
hipStreamSynchronize
(
s
));
HIP_CHECK_ERROR
(
hipEventRecord
(
start_evt
,
s
));
}
CK_TILE_HOST
void
stop
(
const
hipStream_t
&
s
)
{
HIP_CHECK_ERROR
(
hipEventRecord
(
stop_evt
,
s
));
HIP_CHECK_ERROR
(
hipEventSynchronize
(
stop_evt
));
}
// return in ms
CK_TILE_HOST
float
duration
()
const
{
float
ms
=
0
;
HIP_CHECK_ERROR
(
hipEventElapsedTime
(
&
ms
,
start_evt
,
stop_evt
));
return
ms
;
}
private:
hipEvent_t
start_evt
,
stop_evt
;
};
struct
cpu_timer
{
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
CK_TILE_HOST
void
start
(
const
hipStream_t
&
s
)
{
HIP_CHECK_ERROR
(
hipStreamSynchronize
(
s
));
start_tick
=
std
::
chrono
::
high_resolution_clock
::
now
();
}
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
CK_TILE_HOST
void
stop
(
const
hipStream_t
&
s
)
{
HIP_CHECK_ERROR
(
hipStreamSynchronize
(
s
));
stop_tick
=
std
::
chrono
::
high_resolution_clock
::
now
();
}
// return in ms
CK_TILE_HOST
float
duration
()
const
{
double
sec
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
duration
<
double
>>
(
stop_tick
-
start_tick
)
.
count
();
return
static_cast
<
float
>
(
sec
*
1e3
);
}
private:
std
::
chrono
::
time_point
<
std
::
chrono
::
high_resolution_clock
>
start_tick
;
std
::
chrono
::
time_point
<
std
::
chrono
::
high_resolution_clock
>
stop_tick
;
};
}
// namespace ck_tile
include/ck_tile/ops/fmha.hpp
View file @
e6bb1dd7
...
...
@@ -3,9 +3,35 @@
#pragma once
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
...
...
include/ck_tile/ops/fmha/block/block_attention_bias_enum.hpp
0 → 100644
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
namespace
ck_tile
{
// This class is used for codegen pattern matching
enum
class
BlockAttentionBiasEnum
{
NO_BIAS
=
0
,
ELEMENTWISE_BIAS
=
1
,
// attention bias, each elements add to the result of Q*K(after scale)
ALIBI
=
2
,
// bias computed with position encoding, applied after scale
};
template
<
BlockAttentionBiasEnum
>
struct
BlockAttentionBiasEnumToStr
;
template
<
>
struct
BlockAttentionBiasEnumToStr
<
BlockAttentionBiasEnum
::
NO_BIAS
>
{
static
constexpr
const
char
*
name
=
""
;
};
template
<
>
struct
BlockAttentionBiasEnumToStr
<
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
>
{
static
constexpr
const
char
*
name
=
"bias"
;
};
template
<
>
struct
BlockAttentionBiasEnumToStr
<
BlockAttentionBiasEnum
::
ALIBI
>
{
static
constexpr
const
char
*
name
=
"alibi"
;
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/block/block_dropout.hpp
0 → 100644
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace
ck_tile
{
struct
NullBlockDropout
{
template
<
typename
BlockGemm
,
bool
IsFwd
=
true
,
typename
RandValDramBlockWindowTmp
>
__host__
__device__
static
constexpr
auto
MakeRandvalDramWindow
(
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
index_t
seqlen_qk_start
)
{
(
void
)
randval_dram_block_window_tmp
;
(
void
)
seqlen_qk_start
;
return
make_null_tile_window
(
make_tuple
(
number
<
0
>
{},
number
<
0
>
{}));
}
};
struct
BlockDropout
{
CK_TILE_HOST_DEVICE
BlockDropout
(
index_t
i_batch
,
index_t
i_head
,
index_t
nheads
,
unsigned
long
long
seed
,
unsigned
long
long
offset
,
float
rp_undrop_
,
uint8_t
p_undrop_in_uint8_t_
,
bool
is_store_randval_
)
:
ph
(
seed
,
offset
+
(
i_batch
*
nheads
+
i_head
)
*
get_warp_size
()
+
get_lane_id
()),
rp_undrop
(
rp_undrop_
),
p_undrop_in_uint8_t
(
p_undrop_in_uint8_t_
),
is_store_randval
(
is_store_randval_
)
{
}
template
<
typename
BlockGemm
,
bool
IsFwd
=
true
,
typename
RandValDramBlockWindowTmp
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandvalDramWindow
(
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
index_t
seqlen_qk_start
)
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
const
auto
block_origin
=
randval_dram_block_window_tmp
.
get_window_origin
();
auto
randval_dram_window
=
[
&
]()
{
if
constexpr
(
IsFwd
)
{
return
make_tile_window
(
randval_dram_block_window_tmp
.
get_bottom_tensor_view
(),
ck_tile
::
make_tuple
(
number
<
kMPerStep
>
{},
number
<
kNPerStep
>
{}),
{
block_origin
.
at
(
number
<
0
>
{}),
seqlen_qk_start
});
// M/N
}
else
{
return
make_tile_window
(
randval_dram_block_window_tmp
.
get_bottom_tensor_view
(),
ck_tile
::
make_tuple
(
number
<
kMPerStep
>
{},
number
<
kNPerStep
>
{}),
{
seqlen_qk_start
,
block_origin
.
at
(
number
<
1
>
{})});
// M/N
}
}();
return
randval_dram_window
;
}
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandValLdsBlockDescriptor
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kNPerStep
=
WG
::
kN
;
constexpr
index_t
kN1
=
8
;
constexpr
index_t
kN0
=
kNPerStep
/
kN1
;
constexpr
auto
randval_lds_block_desc_0
=
make_naive_tensor_descriptor
(
ck_tile
::
make_tuple
(
number
<
kN0
>
{},
number
<
kMPerStep
>
{},
number
<
kN1
>
{}),
ck_tile
::
make_tuple
(
number
<
(
kMPerStep
+
1
)
*
kN1
>
{},
number
<
kN1
>
{},
number
<
1
>
{}),
number
<
kN1
>
{},
number
<
1
>
{});
constexpr
auto
randval_lds_block_desc
=
transform_tensor_descriptor
(
randval_lds_block_desc_0
,
ck_tile
::
make_tuple
(
make_pass_through_transform
(
number
<
kMPerStep
>
{}),
make_merge_transform
(
ck_tile
::
make_tuple
(
number
<
kN0
>
{},
number
<
kN1
>
{}))),
ck_tile
::
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
ck_tile
::
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
randval_lds_block_desc
;
}
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandValTileDistribution
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
1
;
constexpr
index_t
NIterPerWarp
=
1
;
constexpr
auto
randval_block_outer_part_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
// Use Bwd WarpGemm to ensure that Fwd's random values are consistent with Bwd.
constexpr
auto
randval_block_inner_part_dstr_encoding
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
BlockGemm
::
ADataType
,
half_t
>
&&
std
::
is_same_v
<
typename
BlockGemm
::
BDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
BlockGemm
::
CDataType
,
float
>
)
{
return
typename
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
::
CWarpDstrEncoding
{};
}
else
{
return
typename
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
::
CWarpDstrEncoding
{};
}
}();
constexpr
auto
randval_block_part_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
randval_block_outer_part_dstr_encoding
,
randval_block_inner_part_dstr_encoding
);
return
make_static_tile_distribution
(
randval_block_part_dstr_encode
);
}
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandValLdsShuffleTileDistribution
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
1
;
constexpr
index_t
NIterPerWarp
=
1
;
constexpr
auto
randval_block_outer_part_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
randval_block_part_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
randval_block_outer_part_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
return
make_static_tile_distribution
(
randval_block_part_dstr_encode
);
}
template
<
typename
BlockGemm
,
typename
PComputeDataType
,
typename
RandValOutputDataType
,
typename
PComputeWindow
,
typename
RandValDramWindow
>
CK_TILE_HOST_DEVICE
void
Run
(
void
*
randval_ptr
,
const
index_t
start_n0_idx
,
PComputeWindow
&
p_compute
,
RandValDramWindow
&
randval_dram_window
)
const
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
using
BlockGemmShape
=
remove_cvref_t
<
typename
BlockGemm
::
BlockGemmShape
>
;
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
// randval tile in LDS
auto
randval_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
uint8_t
*>
(
randval_ptr
),
MakeRandValLdsBlockDescriptor
<
BlockGemm
>
());
auto
randval_lds_window
=
make_tile_window
(
randval_lds
,
MakeRandValLdsBlockDescriptor
<
BlockGemm
>
().
get_lengths
(),
{
0
,
0
});
// register distribute
auto
randval_dist_generated
=
make_static_distributed_tensor
<
uint8_t
>
(
MakeRandValTileDistribution
<
BlockGemm
>
());
static_assert
(
randval_dist_generated
.
kThreadElementSpaceSize
==
16
);
auto
randval_lds_read_window
=
make_tile_window
(
randval_lds_window
.
get_bottom_tensor_view
(),
randval_lds_window
.
get_window_lengths
(),
randval_lds_window
.
get_window_origin
(),
MakeRandValLdsShuffleTileDistribution
<
BlockGemm
>
());
const
int
start_m0_idx
=
randval_dram_window
.
get_window_origin
().
at
(
number
<
0
>
{});
if
(
is_store_randval
)
{
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
int
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
(
i_m0
*
MWarp
)
+
get_warp_id
();
int
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
i_n0
;
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
// generate random number
uint8_t
random_uint8_t
[
16
];
ph
.
get_random_16x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
));
constexpr
auto
randval_dist_generated_spans
=
decltype
(
randval_dist_generated
)
::
get_distributed_spans
();
int
i_random_idx
=
0
;
sweep_tile_span
(
randval_dist_generated_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_dist_generated_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
randval_dist_generated
(
i_j_idx
)
=
random_uint8_t
[
i_random_idx
++
];
});
});
// save to LDS
store_tile
(
randval_lds_window
,
randval_dist_generated
);
block_sync_lds
();
// read from LDS to register
auto
randval
=
load_tile
(
randval_lds_read_window
);
// save to Global
const
auto
randval_store
=
cast_tile
<
RandValOutputDataType
>
(
randval
);
store_tile
(
randval_dram_window
,
randval_store
);
move_tile_window
(
randval_dram_window
,
{
0
,
kNPerStep
});
});
move_tile_window
(
randval_dram_window
,
{
kMPerStep
,
-
kNPerBlock
});
});
move_tile_window
(
randval_dram_window
,
{
-
kMPerBlock
,
kNPerBlock
});
};
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
int
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
(
i_m0
*
MWarp
)
+
get_warp_id
();
int
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
i_n0
;
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
// generate random number
uint8_t
random_uint8_t
[
16
];
ph
.
get_random_16x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
));
constexpr
auto
randval_dist_generated_spans
=
decltype
(
randval_dist_generated
)
::
get_distributed_spans
();
int
i_random_idx
=
0
;
sweep_tile_span
(
randval_dist_generated_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_dist_generated_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
randval_dist_generated
(
i_j_idx
)
=
random_uint8_t
[
i_random_idx
++
];
});
});
// save to LDS
store_tile
(
randval_lds_window
,
randval_dist_generated
);
block_sync_lds
();
// read from LDS to register
auto
randval
=
load_tile
(
randval_lds_read_window
);
constexpr
auto
randval_spans
=
decltype
(
randval
)
::
get_distributed_spans
();
sweep_tile_span
(
randval_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
p_idx0
=
tile_distributed_index
<
i_m0
>
{};
constexpr
auto
p_idx1
=
tile_distributed_index
<
i_n0
,
idx1
.
impl_
.
at
(
1
),
idx1
.
impl_
.
at
(
2
)
>
{};
constexpr
auto
p_idx
=
ck_tile
::
make_tuple
(
p_idx0
,
p_idx1
);
constexpr
auto
r_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
p_compute
(
p_idx
)
=
randval
[
r_idx
]
<=
p_undrop_in_uint8_t
?
p_compute
[
p_idx
]
*
rp_undrop
:
PComputeDataType
(
0
);
});
});
});
});
}
template
<
typename
BlockGemm
,
typename
RandValOutputDataType
,
typename
PComputeWindow
,
typename
RandValDramWindow
>
CK_TILE_HOST_DEVICE
void
Run
(
const
index_t
start_m0_idx
,
PComputeWindow
&
p_compute
,
RandValDramWindow
&
randval_dram_window
)
const
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
using
BlockGemmShape
=
remove_cvref_t
<
typename
BlockGemm
::
BlockGemmShape
>
;
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
// register distribute
auto
randval
=
make_static_distributed_tensor
<
uint8_t
>
(
MakeRandValTileDistribution
<
BlockGemm
>
());
static_assert
(
randval
.
kThreadElementSpaceSize
==
16
);
const
int
start_n0_idx
=
randval_dram_window
.
get_window_origin
().
at
(
number
<
1
>
{});
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
int
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
i_m0
;
int
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
(
i_n0
*
NWarp
)
+
get_warp_id
();
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
// generate random number
uint8_t
random_uint8_t
[
16
];
ph
.
get_random_16x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
));
constexpr
auto
randval_spans
=
decltype
(
randval
)
::
get_distributed_spans
();
int
i_random_idx
=
0
;
sweep_tile_span
(
randval_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
r_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
randval
(
r_idx
)
=
random_uint8_t
[
i_random_idx
++
];
constexpr
auto
p_idx0
=
tile_distributed_index
<
i_m0
,
idx0
.
impl_
.
at
(
1
),
idx0
.
impl_
.
at
(
2
)
>
{};
constexpr
auto
p_idx1
=
tile_distributed_index
<
i_n0
>
{};
constexpr
auto
p_idx
=
ck_tile
::
make_tuple
(
p_idx0
,
p_idx1
);
p_compute
(
p_idx
)
=
randval
[
r_idx
]
<=
p_undrop_in_uint8_t
?
p_compute
[
p_idx
]
:
-
p_compute
[
p_idx
];
});
});
// save to Global
if
(
is_store_randval
)
{
const
auto
randval_store
=
cast_tile
<
RandValOutputDataType
>
(
randval
);
store_tile
(
randval_dram_window
,
randval_store
);
move_tile_window
(
randval_dram_window
,
{
kMPerStep
,
0
});
}
});
if
(
is_store_randval
)
{
move_tile_window
(
randval_dram_window
,
{
-
kMPerBlock
,
kNPerStep
});
}
});
if
(
is_store_randval
)
{
move_tile_window
(
randval_dram_window
,
{
kMPerBlock
,
-
kNPerBlock
});
}
}
ck_tile
::
philox
ph
;
const
float
rp_undrop
;
const
uint8_t
p_undrop_in_uint8_t
;
const
bool
is_store_randval
;
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/block/block_masking.hpp
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -141,6 +141,36 @@ struct GenericAttentionMask
}
}
// to get the loop length along Y axis, return index:[start, end), end-start=length
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
// TODO: y_end still could be negative, so end-start could be negative(need check)
template
<
index_t
YTile
,
index_t
XTile
>
CK_TILE_HOST_DEVICE
constexpr
auto
GetTileRangeAlongY
(
index_t
i_x
,
number
<
YTile
>
,
number
<
XTile
>
)
const
{
if
constexpr
(
!
IsMasking
)
{
return
ck_tile
::
make_tuple
(
0
,
y_total
);
}
else
{
// get the tile start/end range assum we loop over along Y tile by tile
index_t
y_start
=
[
&
]()
{
index_t
tmp
=
max
(
-
x
+
i_x
+
1
,
0
);
return
(
tmp
/
YTile
)
*
YTile
;
// round to tile aligned
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t
y_end
=
[
&
]()
{
index_t
tmp
=
min
(
i_x
+
XTile
-
1
+
y
,
y_total
);
return
((
tmp
+
YTile
-
1
)
/
YTile
)
*
YTile
;
}();
return
ck_tile
::
make_tuple
(
y_start
,
y_end
);
}
}
// per-pixel check if out-of-bound, if true, need mask a value(like -INF)
CK_TILE_HOST_DEVICE
constexpr
auto
IsOutOfBound
(
index_t
i_y
,
index_t
i_x
)
const
{
...
...
@@ -160,14 +190,14 @@ struct GenericAttentionMask
}
else
{
return
i_x
>=
x_end
;
return
i_x
>=
x_end
||
i_y
>=
y_total
;
}
}
}
// if current tile is at the edge, means need per-pixel mask check.
// otherwise no need to check per-pixel
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX()
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX
/Y
()
// can be used as a fast-path to decide if do per-pixel check or not
template
<
index_t
TileHeight
,
index_t
TileWidth
>
CK_TILE_HOST_DEVICE
constexpr
auto
...
...
@@ -269,6 +299,53 @@ struct SimplifiedGenericAttentionMask
}
}
template
<
index_t
TileHeight
,
index_t
TileWidth
>
CK_TILE_HOST_DEVICE
constexpr
auto
GetTileRangeAlongX
(
index_t
i_y
,
number
<
TileHeight
>
height
,
number
<
TileWidth
>
width
,
index_t
num_splits
,
index_t
i_split
)
const
{
auto
[
origin_start
,
origin_end
]
=
GetTileRangeAlongX
(
i_y
,
height
,
width
);
const
index_t
x_per_split
=
ck_tile
::
max
(
1
,
x_total
/
num_splits
);
const
index_t
split_start
=
x_per_split
*
i_split
;
const
index_t
split_end
=
(
i_split
==
num_splits
-
1
?
x_total
:
split_start
+
x_per_split
);
return
ck_tile
::
make_tuple
(
ck_tile
::
max
(
origin_start
,
split_start
),
ck_tile
::
min
(
origin_end
,
split_end
));
}
// to get the loop length along Y axis, return index:[start, end), end-start=length
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
// TODO: y_end still could be negative, so end-start could be negative(need check)
template
<
index_t
YTile
,
index_t
XTile
>
CK_TILE_HOST_DEVICE
constexpr
auto
GetTileRangeAlongY
(
index_t
i_x
,
number
<
YTile
>
,
number
<
XTile
>
)
const
{
if
constexpr
(
!
IsMasking
)
{
return
ck_tile
::
make_tuple
(
0
,
y_total
);
}
else
{
// get the tile start/end range assum we loop over along Y tile by tile
index_t
y_start
=
[
&
]()
{
index_t
tmp
=
max
(
-
x
+
i_x
+
1
,
0
);
return
(
tmp
/
YTile
)
*
YTile
;
// round to tile aligned
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t
y_end
=
[
&
]()
{
index_t
tmp
=
min
(
i_x
+
XTile
-
1
+
y
,
y_total
);
return
((
tmp
+
YTile
-
1
)
/
YTile
)
*
YTile
;
}();
return
ck_tile
::
make_tuple
(
y_start
,
y_end
);
}
}
// per-pixel check if out-of-bound, if true, need mask a value(like -INF)
CK_TILE_HOST_DEVICE
constexpr
auto
IsOutOfBound
(
index_t
i_y
,
index_t
i_x
)
const
{
...
...
@@ -283,13 +360,13 @@ struct SimplifiedGenericAttentionMask
index_t
x_start
=
-
y
+
i_y
+
1
;
// this could be negative, but it's fine
index_t
x_end
=
min
(
i_y
+
x
,
x_total
);
// need min in case x is padded
return
i_x
<
x_start
||
i_x
>=
x_end
;
return
i_x
<
x_start
||
i_x
>=
x_end
||
i_y
>=
y_total
;
}
}
// if current tile is at the edge, means need per-pixel mask check.
// otherwise no need to check per-pixel
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX()
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX
/Y
()
// can be used as a fast-path to decide if do per-pixel check or not
template
<
index_t
TileHeight
,
index_t
TileWidth
>
CK_TILE_HOST_DEVICE
constexpr
auto
...
...
@@ -312,7 +389,7 @@ struct SimplifiedGenericAttentionMask
// index_t x_end = min(i_y + x, x_total);
bool
top_right_edge
=
i_x_end
>
min
(
i_y
+
x
,
x_total
);
// consider right pad
bool
bottom_left_edge
=
i_y_end
>
(
i_x
+
y
);
bool
bottom_left_edge
=
i_y_end
>
min
(
i_x
+
y
,
y_total
);
// consider bottom pad
// bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now
return
top_right_edge
||
bottom_left_edge
;
...
...
@@ -361,6 +438,6 @@ make_generic_attention_mask_from_lr_window(index_t left_size,
{
auto
r
=
make_generic_attention_mask_coordinates_from_lr_window
(
left_size
,
right_size
,
y_total
,
x_total
,
is_top_left
);
return
MaskType
{
r
.
at
(
ck_tile
::
number
<
0
>
{}),
r
.
at
(
ck_tile
::
number
<
1
>
{}),
y_total
,
x_total
};
return
MaskType
{
r
.
at
(
number
<
0
>
{}),
r
.
at
(
number
<
1
>
{}),
y_total
,
x_total
};
}
}
// namespace ck_tile
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
0 → 100644
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include <cmath>
#include <vector>
namespace
ck_tile
{
enum
struct
PositionEncodingEnum
{
NO
=
0
,
ALIBI
=
1
,
};
/*
VERTICAL:
[0] 1 2 3 4 5
[0] 1 2 3 4 5
[0] 1 2 3 4 5
[0] 1 2 3 4 5
TOP_LEFT(but negative):
[0] 1 2 3 4 5
1 [0] 1 2 3 4
2 1 [0] 1 2 3
3 2 1 [0] 1 2
FROM_BOTTOM_RIGHT(but negative):
2 1 [0] 1 2 3
3 2 1 [0] 1 2
4 3 2 1 [0] 1
5 4 3 2 1 [0]
*/
enum
struct
AlibiMode
{
VERTICAL
=
0
,
FROM_TOP_LEFT
=
1
,
// keep sync with mask enum
FROM_BOTTOM_RIGHT
=
2
,
};
template
<
typename
DataType
,
bool
RowMajor
=
true
>
struct
Alibi
{
// RowMajor here means if pixel within the same thread are along the row, or col
// this may impact the performance of update(), while the result are the same.
// e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false
CK_TILE_HOST_DEVICE
Alibi
(
DataType
slope_
,
index_t
y_total_
,
index_t
x_total_
,
AlibiMode
mode_
=
AlibiMode
::
VERTICAL
)
{
slope
=
mode_
==
AlibiMode
::
VERTICAL
?
slope_
:
-
slope_
;
shift_left_up
=
[
&
]()
{
if
(
RowMajor
)
{
return
mode_
==
AlibiMode
::
FROM_BOTTOM_RIGHT
?
max
(
y_total_
-
x_total_
,
0
)
:
0
;
}
else
{
return
mode_
==
AlibiMode
::
FROM_BOTTOM_RIGHT
?
max
(
x_total_
-
y_total_
,
0
)
:
0
;
}
}();
shift_right_down
=
[
&
]()
{
if
(
RowMajor
)
{
return
mode_
==
AlibiMode
::
FROM_BOTTOM_RIGHT
?
max
(
x_total_
-
y_total_
,
0
)
:
0
;
}
else
{
return
mode_
==
AlibiMode
::
FROM_BOTTOM_RIGHT
?
max
(
y_total_
-
x_total_
,
0
)
:
0
;
}
}();
mode
=
mode_
;
}
CK_TILE_HOST_DEVICE
void
update
(
DataType
&
pixel
,
index_t
row_idx
,
index_t
col_idx
)
{
if
constexpr
(
RowMajor
)
{
// at least 3 instructions per row
index_t
current_zero_point
=
mode
==
AlibiMode
::
VERTICAL
?
shift_right_down
:
row_idx
+
shift_right_down
;
// for every threads, most of the pixels are along the row, below operation should be
// the main hot spot.
auto
position
=
type_convert
<
DataType
>
(
sad
(
bit_cast
<
uint32_t
>
(
current_zero_point
),
bit_cast
<
uint32_t
>
(
col_idx
+
shift_left_up
),
0
));
pixel
+=
slope
*
position
;
}
else
{
// at least 3 instructions per col;
index_t
current_zero_point
=
mode
==
AlibiMode
::
VERTICAL
?
row_idx
+
col_idx
+
shift_right_down
:
col_idx
+
shift_right_down
;
// for every threads, most of the pixels are along the col, below operation should be
// the main hot spot.
auto
position
=
type_convert
<
DataType
>
(
sad
(
bit_cast
<
uint32_t
>
(
current_zero_point
),
bit_cast
<
uint32_t
>
(
row_idx
+
shift_left_up
),
0
));
pixel
+=
slope
*
position
;
}
}
DataType
slope
;
// float?
index_t
shift_left_up
;
// always possitive
index_t
shift_right_down
;
// always possitive
AlibiMode
mode
;
};
template
<
typename
DataType
>
struct
EmptyPositionEncoding
{
CK_TILE_HOST_DEVICE
void
update
(
DataType
&
/*pixel*/
,
index_t
/*row_idx*/
,
index_t
/*col_idx*/
)
{
}
};
//
// can convert from the FA style left/right to our generic coordinate
// if left_size < 0 && right_size = 0, it is normal causal mask
// local is left_size >=0 or right_size >=0
template
<
typename
DataType
,
bool
RowMajor
=
true
>
CK_TILE_HOST_DEVICE
auto
make_alibi_from_lr_mask
(
DataType
slope
,
index_t
window_left_size
,
index_t
window_right_size
,
index_t
y_total
,
index_t
x_total
,
GenericAttentionMaskEnum
mask_enum
)
{
// assume mask_enum will never be NO_MASK, since if we do not have mask, it's
// totally OK to use constexpr
bool
is_causal
=
window_left_size
<
0
&&
window_right_size
==
0
;
AlibiMode
alibi_mode
=
is_causal
?
AlibiMode
::
VERTICAL
:
static_cast
<
AlibiMode
>
(
mask_enum
)
/*either top-left or bottom-right*/
;
return
Alibi
<
DataType
,
RowMajor
>
{
slope
,
y_total
,
x_total
,
alibi_mode
};
}
// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
// Do we need a device version?
template
<
typename
DataType
>
CK_TILE_HOST
std
::
vector
<
DataType
>
get_alibi_slopes
(
ck_tile
::
index_t
nheads
)
{
auto
get_slopes_power_of_2
=
[](
ck_tile
::
index_t
n
)
{
float
start
=
std
::
powf
(
static_cast
<
float
>
(
2
),
-
std
::
powf
(
static_cast
<
float
>
(
2
),
-
static_cast
<
float
>
((
integer_log2_floor
(
n
)
-
3
))));
std
::
vector
<
DataType
>
rtn
;
for
(
auto
i
=
0
;
i
<
n
;
i
++
)
{
rtn
.
push_back
(
static_cast
<
DataType
>
(
start
*
std
::
powf
(
start
,
i
)));
}
return
rtn
;
};
if
(
is_power_of_two_integer
(
nheads
))
{
// power of 2 calculation
return
get_slopes_power_of_2
(
nheads
);
}
else
{
ck_tile
::
index_t
closest_power_of_2
=
1
<<
integer_log2_floor
(
nheads
);
auto
v0
=
get_slopes_power_of_2
(
closest_power_of_2
);
auto
v1
=
get_slopes_power_of_2
(
closest_power_of_2
*
2
);
auto
v1_sliced
=
[
&
](
auto
vec
,
ck_tile
::
index_t
rem
)
{
std
::
vector
<
DataType
>
sliced
;
for
(
ck_tile
::
index_t
i
=
0
;
i
<
static_cast
<
ck_tile
::
index_t
>
(
vec
.
size
());
i
++
)
{
if
(
i
%
2
==
0
)
sliced
.
push_back
(
vec
[
i
]);
}
std
::
vector
<
DataType
>
sliced_2
(
sliced
.
begin
(),
sliced
.
begin
()
+
rem
);
return
sliced_2
;
}(
v1
,
nheads
-
closest_power_of_2
);
v0
.
insert
(
v0
.
end
(),
v1_sliced
.
begin
(),
v1_sliced
.
end
());
return
v0
;
}
}
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
0 → 100644
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string>
#include <type_traits>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
// dV[seqlen_k, hdim_v] = P^T[seqlen_k, seqlen_q] @ dO^T[hdim_v, seqlen_q]
// dP[seqlen_q, seqlen_k] = dO[seqlen_q, hdim_v] @ V[seqlen_k, hdim_v]
// D[seqlen_q] = rowsum(dO[seqlen_q, hdim_v] * O[seqlen_q, hdim_v])
// dS''[seqlen_q, seqlen_k] = P[seqlen_q, seqlen_k] * (dP[seqlen_q, seqlen_k] - D[seqlen_q])
// dBias[seqlen_q, seqlen_k] = dS'[seqlen_q, seqlen_k] = dS''[seqlen_q, seqlen_k]
// dK[seqlen_k, hdim_q] = dS'^T[seqlen_k, seqlen_q] @ Q^T[hdim_q, seqlen_q] * Scale[1]
// dQ[seqlen_q, hdim_q] = dS'[seqlen_q, seqlen_k] @ K^T[hdim_q, seqlen_k] * Scale[1]
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
KGradEpiloguePipeline_
,
typename
VGradEpiloguePipeline_
>
struct
FmhaBwdDQDKDVKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
ck_tile
::
remove_cvref_t
<
FmhaPipeline_
>
;
using
KGradEpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
KGradEpiloguePipeline_
>
;
using
VGradEpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
VGradEpiloguePipeline_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaPipeline
::
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kBlockPerCu
=
FmhaPipeline
::
kBlockPerCu
;
using
QDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
QDataType
>
;
using
KDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
KDataType
>
;
using
VDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
VDataType
>
;
using
BiasDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
BiasDataType
>
;
using
GemmDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
GemmDataType
>
;
using
LSEDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
LSEDataType
>
;
using
AccDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
AccDataType
>
;
using
DDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
DDataType
>
;
using
RandValOutputDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
RandValOutputDataType
>
;
using
OGradDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
OGradDataType
>
;
using
QGradDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
QGradDataType
>
;
using
KGradDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
KGradDataType
>
;
using
VGradDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
VGradDataType
>
;
using
BiasGradDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
BiasGradDataType
>
;
static
constexpr
bool
kIsGroupMode
=
FmhaPipeline
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
FmhaPipeline
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
FmhaPipeline
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
FmhaPipeline
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
FmhaPipeline
::
kHasBiasGrad
;
static
constexpr
bool
kHasDropout
=
FmhaPipeline
::
kHasDropout
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
ck_tile
::
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
// clang-format on
CK_TILE_HOST
static
std
::
string
GetName
()
{
// sync with generate.py
// clang-format off
using
bfs
=
typename
FmhaPipeline
::
BlockFmhaShape
;
using
gbr
=
typename
bfs
::
Gemm0BlockWarps
;
using
gwt
=
typename
bfs
::
Gemm0WarpTile
;
#define _SS_ std::string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
std
::
string
n
;
if
(
kPadSeqLenQ
)
n
+=
"s"
;
if
(
kPadSeqLenK
)
n
+=
"sk"
;
if
(
kPadHeadDimQ
)
n
+=
"d"
;
if
(
kPadHeadDimV
)
n
+=
"dv"
;
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
_SS_
(
"fmha_bwd_d"
)
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
+
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"x"
+
_TS_
(
bfs
::
kVHeaddim
)
+
"_"
+
"r"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
)
+
_SS_
(
FmhaPipeline
::
name
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasBiasGrad
?
"_dbias"
:
""
)
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
);
#undef _SS_
#undef _TS_
// clang-format on
}
template
<
ck_tile
::
index_t
I
>
// to avoid duplicated base class prblem, introduce an template
// arg
struct
FmhaBwdEmptyKargs
{
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct
FmhaBwdCommonKargs
{
const
void
*
q_ptr
;
const
void
*
k_ptr
;
const
void
*
v_ptr
;
const
void
*
lse_ptr
;
const
void
*
do_ptr
;
const
void
*
d_ptr
;
void
*
dq_ptr
;
void
*
dk_ptr
;
void
*
dv_ptr
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_k
;
ck_tile
::
index_t
hdim_q
;
ck_tile
::
index_t
hdim_v
;
// for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
// if this param is larger than 1, indicate MQA/GQA case
ck_tile
::
index_t
num_head_q
;
ck_tile
::
index_t
nhead_ratio_qk
;
float
raw_scale
;
#if CK_TILE_FMHA_FWD_FAST_EXP2
float
scale
;
#endif
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_do
;
ck_tile
::
index_t
stride_dk
;
ck_tile
::
index_t
stride_dv
;
ck_tile
::
index_t
nhead_stride_q
;
ck_tile
::
index_t
nhead_stride_k
;
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_do
;
ck_tile
::
index_t
nhead_stride_lsed
;
ck_tile
::
index_t
batch_stride_lsed
;
};
struct
FmhaBwdCommonBiasKargs
{
const
void
*
bias_ptr
=
nullptr
;
ck_tile
::
index_t
stride_bias
=
0
;
ck_tile
::
index_t
nhead_stride_bias
=
0
;
};
struct
FmhaBwdBatchModeBiasKargs
:
FmhaBwdCommonBiasKargs
{
ck_tile
::
index_t
batch_stride_bias
=
0
;
};
struct
FmhaBwdAlibiKargs
{
// alibi is batch*nhead*1, no matter in batch/group mode, they are the same
const
void
*
alibi_slope_ptr
;
ck_tile
::
index_t
alibi_slope_stride
;
// stride in batch, or 0 for all batch share same slope
};
struct
FmhaBwdCommonBiasGradKargs
{
void
*
dbias_ptr
=
nullptr
;
ck_tile
::
index_t
stride_dbias
=
0
;
ck_tile
::
index_t
nhead_stride_dbias
=
0
;
};
struct
FmhaBwdBatchModeBiasGradKargs
:
FmhaBwdCommonBiasGradKargs
{
ck_tile
::
index_t
batch_stride_dbias
=
0
;
};
struct
FmhaBwdMaskKargs
{
ck_tile
::
index_t
window_size_left
,
window_size_right
;
ck_tile
::
GenericAttentionMaskEnum
mask_type
;
};
struct
FmhaBwdCommonDropoutKargs
{
void
init_dropout
(
const
float
p_drop
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
,
const
float
raw_scale
)
{
float
p_undrop
=
1.0
-
p_drop
;
p_undrop_in_uint8_t
=
uint8_t
(
std
::
floor
(
p_undrop
*
std
::
numeric_limits
<
uint8_t
>::
max
()));
rp_undrop
=
1.0
/
p_undrop
;
scale_rp_undrop
=
rp_undrop
*
raw_scale
;
drop_seed
=
std
::
get
<
0
>
(
drop_seed_offset
);
drop_offset
=
std
::
get
<
1
>
(
drop_seed_offset
);
}
float
rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
bool
is_store_randval
=
false
;
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
ck_tile
::
index_t
stride_randval
=
0
;
ck_tile
::
index_t
nhead_stride_randval
=
0
;
};
struct
FmhaBwdBatchModeDropoutKargs
:
FmhaBwdCommonDropoutKargs
{
ck_tile
::
index_t
batch_stride_randval
=
0
;
};
struct
FmhaBwdBatchModeKargs
:
FmhaBwdCommonKargs
,
std
::
conditional_t
<
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
,
FmhaBwdBatchModeBiasKargs
,
std
::
conditional_t
<
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
,
FmhaBwdAlibiKargs
,
FmhaBwdEmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasBiasGrad
,
FmhaBwdBatchModeBiasGradKargs
,
FmhaBwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kHasMask
,
FmhaBwdMaskKargs
,
FmhaBwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kHasDropout
,
FmhaBwdBatchModeDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
{
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_dk
;
ck_tile
::
index_t
batch_stride_dv
;
};
struct
FmhaBwdGroupModeKargs
:
FmhaBwdCommonKargs
,
std
::
conditional_t
<
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
,
FmhaBwdCommonBiasKargs
,
std
::
conditional_t
<
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
,
FmhaBwdAlibiKargs
,
FmhaBwdEmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasBiasGrad
,
FmhaBwdCommonBiasGradKargs
,
FmhaBwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kHasMask
,
FmhaBwdMaskKargs
,
FmhaBwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kHasDropout
,
FmhaBwdCommonDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
{
const
int32_t
*
seqstart_q_ptr
;
const
int32_t
*
seqstart_k_ptr
;
const
int32_t
*
seqlen_k_ptr
;
};
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
FmhaBwdGroupModeKargs
,
FmhaBwdBatchModeKargs
>
;
template
<
bool
Cond
=
!
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
q_ptr
,
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
const
void
*
lse_ptr
,
const
void
*
do_ptr
,
const
void
*
d_ptr
,
void
*
rand_val_ptr
,
void
*
dq_ptr
,
void
*
dk_ptr
,
void
*
dv_ptr
,
void
*
dbias_ptr
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_k
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_head_q
,
ck_tile
::
index_t
nhead_ratio_qk
,
float
scale
,
ck_tile
::
index_t
stride_q
,
ck_tile
::
index_t
stride_k
,
ck_tile
::
index_t
stride_v
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_dk
,
ck_tile
::
index_t
stride_dv
,
ck_tile
::
index_t
stride_dbias
,
ck_tile
::
index_t
nhead_stride_q
,
ck_tile
::
index_t
nhead_stride_k
,
ck_tile
::
index_t
nhead_stride_v
,
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
batch_stride_q
,
ck_tile
::
index_t
batch_stride_k
,
ck_tile
::
index_t
batch_stride_v
,
ck_tile
::
index_t
batch_stride_bias
,
ck_tile
::
index_t
batch_stride_randval
,
ck_tile
::
index_t
batch_stride_do
,
ck_tile
::
index_t
batch_stride_lsed
,
ck_tile
::
index_t
batch_stride_dk
,
ck_tile
::
index_t
batch_stride_dv
,
ck_tile
::
index_t
batch_stride_dbias
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
k_ptr
,
v_ptr
,
lse_ptr
,
do_ptr
,
d_ptr
,
dq_ptr
,
dk_ptr
,
dv_ptr
,
seqlen_q
,
seqlen_k
,
hdim_q
,
hdim_v
,
num_head_q
,
nhead_ratio_qk
,
scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast
<
float
>
(
scale
*
ck_tile
::
log2e_v
<>
),
#endif
stride_q
,
stride_k
,
stride_v
,
stride_do
,
stride_dk
,
stride_dv
,
nhead_stride_q
,
nhead_stride_k
,
nhead_stride_v
,
nhead_stride_do
,
nhead_stride_lsed
,
batch_stride_lsed
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for dbias
{},
// placeholder for mask
{},
// placeholder for dropout
batch_stride_q
,
batch_stride_k
,
batch_stride_v
,
batch_stride_do
,
batch_stride_dk
,
batch_stride_dv
};
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
kargs
.
bias_ptr
=
bias_ptr
;
kargs
.
stride_bias
=
stride_bias
;
kargs
.
nhead_stride_bias
=
nhead_stride_bias
;
kargs
.
batch_stride_bias
=
batch_stride_bias
;
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
kargs
.
alibi_slope_ptr
=
bias_ptr
;
kargs
.
alibi_slope_stride
=
stride_bias
;
}
if
constexpr
(
kHasBiasGrad
)
{
kargs
.
dbias_ptr
=
dbias_ptr
;
kargs
.
stride_dbias
=
stride_dbias
;
kargs
.
nhead_stride_dbias
=
nhead_stride_dbias
;
kargs
.
batch_stride_dbias
=
batch_stride_dbias
;
}
if
constexpr
(
kHasMask
)
{
kargs
.
window_size_left
=
window_size_left
;
kargs
.
window_size_right
=
window_size_right
;
kargs
.
mask_type
=
static_cast
<
ck_tile
::
GenericAttentionMaskEnum
>
(
mask_type
);
}
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
}
return
kargs
;
}
template
<
bool
Cond
=
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
q_ptr
,
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
const
void
*
lse_ptr
,
const
void
*
do_ptr
,
const
void
*
d_ptr
,
void
*
rand_val_ptr
,
void
*
dq_ptr
,
void
*
dk_ptr
,
void
*
dv_ptr
,
void
*
dbias_ptr
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_k_ptr
,
const
void
*
seqlen_k_ptr
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_head_q
,
ck_tile
::
index_t
nhead_ratio_qk
,
float
scale
,
ck_tile
::
index_t
stride_q
,
ck_tile
::
index_t
stride_k
,
ck_tile
::
index_t
stride_v
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_dk
,
ck_tile
::
index_t
stride_dv
,
ck_tile
::
index_t
stride_dbias
,
ck_tile
::
index_t
nhead_stride_q
,
ck_tile
::
index_t
nhead_stride_k
,
ck_tile
::
index_t
nhead_stride_v
,
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
batch_stride_lsed
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
k_ptr
,
v_ptr
,
lse_ptr
,
do_ptr
,
d_ptr
,
dq_ptr
,
dk_ptr
,
dv_ptr
,
-
1
,
// seqlen will be updated by another pointer
-
1
,
//
hdim_q
,
hdim_v
,
num_head_q
,
nhead_ratio_qk
,
scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast
<
float
>
(
scale
*
ck_tile
::
log2e_v
<>
),
#endif
stride_q
,
stride_k
,
stride_v
,
stride_do
,
stride_dk
,
stride_dv
,
nhead_stride_q
,
nhead_stride_k
,
nhead_stride_v
,
nhead_stride_do
,
nhead_stride_lsed
,
batch_stride_lsed
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for dbias
{},
// placeholder for mask
{},
// placeholder for dropout
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
)};
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
kargs
.
bias_ptr
=
bias_ptr
;
kargs
.
stride_bias
=
stride_bias
;
kargs
.
nhead_stride_bias
=
nhead_stride_bias
;
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
kargs
.
alibi_slope_ptr
=
bias_ptr
;
kargs
.
alibi_slope_stride
=
stride_bias
;
}
if
constexpr
(
kHasBiasGrad
)
{
kargs
.
dbias_ptr
=
dbias_ptr
;
kargs
.
stride_dbias
=
stride_dbias
;
kargs
.
nhead_stride_dbias
=
nhead_stride_dbias
;
}
if
constexpr
(
kHasMask
)
{
kargs
.
window_size_left
=
window_size_left
;
kargs
.
window_size_right
=
window_size_right
;
kargs
.
mask_type
=
static_cast
<
ck_tile
::
GenericAttentionMaskEnum
>
(
mask_type
);
}
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
}
return
kargs
;
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_k_
)
{
return
TilePartitioner
::
GridSize
(
batch_size_
,
nhead_
,
seqlen_k_
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
ck_tile
::
max
(
FmhaPipeline
::
GetSmemSize
(),
KGradEpiloguePipeline
::
GetSmemSize
(),
VGradEpiloguePipeline
::
GetSmemSize
());
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
// divide problem
const
auto
[
i_tile_n
,
i_nhead
,
i_batch
]
=
TilePartitioner
{}(
kargs
.
seqlen_k
);
const
index_t
i_n0
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN0
);
long_index_t
batch_offset_q
=
0
;
long_index_t
batch_offset_k
=
0
;
long_index_t
batch_offset_v
=
0
;
long_index_t
batch_offset_bias
=
0
;
long_index_t
batch_offset_randval
=
0
;
long_index_t
batch_offset_do
=
0
;
long_index_t
batch_offset_lsed
=
0
;
long_index_t
batch_offset_dk
=
0
;
long_index_t
batch_offset_dv
=
0
;
long_index_t
batch_offset_dbias
=
0
;
if
constexpr
(
kIsGroupMode
)
{
// get starting offset for each batch
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
const
long_index_t
key_start
=
kargs
.
seqstart_k_ptr
[
i_batch
];
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
batch_offset_do
=
query_start
*
kargs
.
stride_do
;
batch_offset_lsed
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lsed
;
batch_offset_dk
=
key_start
*
kargs
.
stride_dk
;
batch_offset_dv
=
key_start
*
kargs
.
stride_dv
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
;
}
if
constexpr
(
kHasBiasGrad
)
{
batch_offset_dbias
=
query_start
*
kargs
.
stride_dbias
;
}
else
{
batch_offset_dbias
=
key_start
;
}
if
constexpr
(
kHasDropout
)
{
batch_offset_randval
=
query_start
*
kargs
.
stride_randval
;
}
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
kargs
.
seqlen_q
=
adjusted_seqstart_q_ptr
[
1
]
-
adjusted_seqstart_q_ptr
[
0
];
if
(
kargs
.
seqlen_k_ptr
!=
nullptr
)
{
kargs
.
seqlen_k
=
kargs
.
seqlen_k_ptr
[
i_batch
];
}
else
{
const
auto
adjusted_seqstart_k_ptr
=
kargs
.
seqstart_k_ptr
+
i_batch
;
kargs
.
seqlen_k
=
adjusted_seqstart_k_ptr
[
1
]
-
adjusted_seqstart_k_ptr
[
0
];
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if
(
kargs
.
seqlen_k
<=
i_n0
)
{
return
;
}
}
else
{
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_do
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_do
;
batch_offset_lsed
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lsed
;
batch_offset_dk
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dk
;
batch_offset_dv
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dv
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
batch_offset_bias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_bias
;
}
if
constexpr
(
kHasBiasGrad
)
{
batch_offset_dbias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dbias
;
}
if
constexpr
(
kHasDropout
)
{
batch_offset_randval
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_randval
;
}
}
// for simplicity, batch stride we just modify the pointer
const
QDataType
*
q_ptr
=
reinterpret_cast
<
const
QDataType
*>
(
kargs
.
q_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_q
+
batch_offset_q
;
const
KDataType
*
k_ptr
=
reinterpret_cast
<
const
KDataType
*>
(
kargs
.
k_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_k
+
batch_offset_k
;
const
VDataType
*
v_ptr
=
reinterpret_cast
<
const
VDataType
*>
(
kargs
.
v_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_v
+
batch_offset_v
;
const
LSEDataType
*
lse_ptr
=
reinterpret_cast
<
const
LSEDataType
*>
(
kargs
.
lse_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_lsed
+
batch_offset_lsed
;
const
DDataType
*
d_ptr
=
reinterpret_cast
<
const
DDataType
*>
(
kargs
.
d_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_lsed
+
batch_offset_lsed
;
const
OGradDataType
*
do_ptr
=
reinterpret_cast
<
const
OGradDataType
*>
(
kargs
.
do_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_do
+
batch_offset_do
;
QGradDataType
*
dq_ptr
=
reinterpret_cast
<
QGradDataType
*>
(
kargs
.
dq_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_q
+
batch_offset_q
;
KGradDataType
*
dk_ptr
=
reinterpret_cast
<
KGradDataType
*>
(
kargs
.
dk_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_k
+
batch_offset_dk
;
VGradDataType
*
dv_ptr
=
reinterpret_cast
<
VGradDataType
*>
(
kargs
.
dv_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_v
+
batch_offset_dv
;
// Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window
const
auto
q_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
q_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQ
>
{},
number
<
1
>
{});
const
auto
q_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
{
return
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
else
{
return
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
}();
const
auto
qt_dram_naive
=
transform_tensor_view
(
q_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_q
),
make_pass_through_transform
(
kargs
.
seqlen_q
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
qt_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQTLoadOnce
)
{
return
pad_tensor_view
(
qt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenQ
>
{});
}
else
{
return
pad_tensor_view
(
qt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK3
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenQ
>
{});
}
}();
const
auto
k_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
k_ptr
,
make_tuple
(
kargs
.
seqlen_k
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_k
,
1
),
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
1
>
{});
const
auto
k_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKLoadOnce
)
{
return
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}
else
{
return
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}
}();
const
auto
kt_dram_naive
=
transform_tensor_view
(
k_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_q
),
make_pass_through_transform
(
kargs
.
seqlen_k
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
kt_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKTLoadOnce
)
{
return
pad_tensor_view
(
kt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kN0
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenK
>
{});
}
else
{
return
pad_tensor_view
(
kt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK4
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenK
>
{});
}
}();
const
auto
v_dram
=
[
&
]()
{
const
auto
v_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
v_ptr
,
make_tuple
(
kargs
.
seqlen_k
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
stride_v
,
1
),
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
1
>
{});
if
constexpr
(
FmhaPipeline
::
kVLoadOnce
)
{
return
pad_tensor_view
(
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimV
>
{});
}
else
{
return
pad_tensor_view
(
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK2
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimV
>
{});
}
}();
const
auto
lse_dram
=
[
&
]()
{
const
auto
lse_dram_naive
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
lse_ptr
,
make_tuple
(
kargs
.
seqlen_q
),
number
<
1
>
{});
return
pad_tensor_view
(
lse_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{}),
sequence
<
kPadSeqLenQ
>
{});
}();
const
auto
d_dram
=
[
&
]()
{
const
auto
d_dram_naive
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
d_ptr
,
make_tuple
(
kargs
.
seqlen_q
),
number
<
1
>
{});
return
pad_tensor_view
(
d_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{}),
sequence
<
kPadSeqLenQ
>
{});
}();
const
auto
do_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
do_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
stride_do
,
1
),
number
<
FmhaPipeline
::
kAlignmentOGrad
>
{},
number
<
1
>
{});
const
auto
do_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradLoadOnce
)
{
return
pad_tensor_view
(
do_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
}
else
{
return
pad_tensor_view
(
do_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK2
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
}
}();
const
auto
dot_dram_naive
=
transform_tensor_view
(
do_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_v
),
make_pass_through_transform
(
kargs
.
seqlen_q
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
dot_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradTLoadOnce
)
{
return
pad_tensor_view
(
dot_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenQ
>
{});
}
else
{
return
pad_tensor_view
(
dot_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenQ
>
{});
}
}();
auto
dq_dram
=
[
&
]()
{
const
auto
dq_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
,
memory_operation_enum
::
atomic_add
>
(
dq_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
auto
q_dram_window
=
make_tile_window
(
q_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
{
0
,
0
});
auto
qt_dram_window
=
make_tile_window
(
qt_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQTLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK3
>
{});
}(),
{
0
,
0
});
auto
k_dram_window
=
make_tile_window
(
k_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
{
i_n0
,
0
});
auto
kt_dram_window
=
make_tile_window
(
kt_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKTLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK4
>
{});
}(),
{
0
,
i_n0
});
auto
v_dram_window
=
make_tile_window
(
v_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kVLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK2
>
{});
}(),
{
i_n0
,
0
});
auto
do_dram_window
=
make_tile_window
(
do_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK2
>
{});
}(),
{
0
,
0
});
auto
dot_dram_window
=
make_tile_window
(
dot_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradTLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kK1
>
{});
}(),
{
0
,
0
});
auto
dq_dram_window
=
make_tile_window
(
dq_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
0
,
0
});
auto
lse_dram_window
=
make_tile_window
(
lse_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{}),
{
0
});
auto
d_dram_window
=
make_tile_window
(
d_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{}),
{
0
});
/// FIXME: Before C++20, capturing structured binding variables are not supported. Remove
/// following copy capture of the 'i_nhead' if in C++20
constexpr
auto
bias_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
const
auto
bias_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
const
BiasDataType
*
bias_ptr
=
reinterpret_cast
<
const
BiasDataType
*>
(
kargs
.
bias_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_bias
+
batch_offset_bias
;
const
auto
bias_dram
=
[
&
]()
{
const
auto
bias_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
bias_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
seqlen_k
),
make_tuple
(
kargs
.
stride_bias
,
1
),
number
<
FmhaPipeline
::
kAlignmentBias
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
bias_dram_naive
,
bias_dram_window_lengths
,
sequence
<
kPadSeqLenQ
,
kPadSeqLenK
>
{});
}();
return
make_tile_window
(
bias_dram
,
bias_dram_window_lengths
,
{
0
,
i_n0
});
}
else
{
return
make_null_tile_window
(
bias_dram_window_lengths
);
}
}();
auto
dbias_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
kHasBiasGrad
)
{
BiasGradDataType
*
dbias_ptr
=
reinterpret_cast
<
BiasGradDataType
*>
(
kargs
.
dbias_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_dbias
+
batch_offset_dbias
;
auto
dbias_dram
=
[
&
]()
{
const
auto
dbias_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dbias_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
seqlen_k
),
make_tuple
(
kargs
.
stride_dbias
,
1
),
number
<
FmhaPipeline
::
kAlignmentBias
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dbias_dram_naive
,
bias_dram_window_lengths
,
sequence
<
kPadSeqLenQ
,
kPadSeqLenK
>
{});
}();
return
make_tile_window
(
dbias_dram
,
bias_dram_window_lengths
,
{
0
,
i_n0
});
}
else
{
return
make_null_tile_window
(
bias_dram_window_lengths
);
}
}();
// WA i_batch capture structure binding before c++20
auto
position_encoding
=
[
&
,
i_batch_
=
i_batch
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
// data loading, shared by entire wg
// TODO: how to use s_read?
AccDataType
slope
=
*
(
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
alibi_slope_ptr
)
+
i_batch_
*
kargs
.
alibi_slope_stride
+
i_nhead_
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
slope
*=
ck_tile
::
log2e_v
<>
;
#endif
if
constexpr
(
kHasMask
)
{
return
make_alibi_from_lr_mask
<
AccDataType
,
false
>
(
slope
,
kargs
.
window_size_left
,
kargs
.
window_size_right
,
kargs
.
seqlen_q
,
kargs
.
seqlen_k
,
kargs
.
mask_type
);
}
else
{
return
Alibi
<
AccDataType
,
false
>
{
slope
,
kargs
.
seqlen_q
,
kargs
.
seqlen_k
,
AlibiMode
::
FROM_BOTTOM_RIGHT
};
}
}
else
{
return
EmptyPositionEncoding
<
AccDataType
>
{};
}
}();
// dropout
float
rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint64_t
drop_seed
=
0
;
uint64_t
drop_offset
=
0
;
bool
is_store_randval
=
false
;
if
constexpr
(
kHasDropout
)
{
rp_undrop
=
kargs
.
rp_undrop
;
scale_rp_undrop
=
kargs
.
scale_rp_undrop
;
p_undrop_in_uint8_t
=
kargs
.
p_undrop_in_uint8_t
;
drop_seed
=
kargs
.
drop_seed
;
drop_offset
=
kargs
.
drop_offset
;
is_store_randval
=
kargs
.
is_store_randval
;
}
BlockDropout
dropout
(
i_batch
,
i_nhead
,
kargs
.
num_head_q
,
drop_seed
,
drop_offset
,
rp_undrop
,
p_undrop_in_uint8_t
,
is_store_randval
);
auto
randval_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
randval_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
if
constexpr
(
kHasDropout
)
{
RandValOutputDataType
*
rand_val_ptr
=
reinterpret_cast
<
RandValOutputDataType
*>
(
kargs
.
rand_val_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_randval
+
batch_offset_randval
;
const
auto
randval_dram
=
[
&
]()
{
const
auto
randval_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
rand_val_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
seqlen_k
),
make_tuple
(
kargs
.
stride_randval
,
1
),
number
<
1
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
randval_dram_naive
,
randval_dram_window_lengths
,
sequence
<
kPadSeqLenQ
,
kPadSeqLenK
>
{});
}();
return
make_tile_window
(
randval_dram
,
randval_dram_window_lengths
,
{
0
,
i_n0
});
}
else
{
return
make_null_tile_window
(
randval_dram_window_lengths
);
}
}();
FmhaMask
mask
=
[
&
]()
{
if
constexpr
(
kHasMask
)
return
ck_tile
::
make_generic_attention_mask_from_lr_window
<
FmhaMask
>
(
kargs
.
window_size_left
,
kargs
.
window_size_right
,
kargs
.
seqlen_q
,
kargs
.
seqlen_k
,
kargs
.
mask_type
==
GenericAttentionMaskEnum
::
MASK_FROM_TOP_LEFT
);
else
return
FmhaMask
{
kargs
.
seqlen_q
,
kargs
.
seqlen_k
};
}();
auto
[
dk_acc_tile
,
dv_acc_tile
]
=
FmhaPipeline
{}(
q_dram_window
,
qt_dram_window
,
k_dram_window
,
kt_dram_window
,
v_dram_window
,
bias_dram_window
,
randval_dram_window
,
do_dram_window
,
dot_dram_window
,
lse_dram_window
,
d_dram_window
,
dq_dram_window
,
dbias_dram_window
,
mask
,
position_encoding
,
kargs
.
raw_scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
kargs
.
scale
,
#endif
rp_undrop
,
scale_rp_undrop
,
smem_ptr
,
dropout
);
auto
dk_dram
=
[
&
]()
{
const
auto
dk_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dk_ptr
,
make_tuple
(
kargs
.
seqlen_k
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_dk
,
1
),
number
<
FmhaPipeline
::
kAlignmentKGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dk_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}();
auto
dv_dram
=
[
&
]()
{
const
auto
dv_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dv_ptr
,
make_tuple
(
kargs
.
seqlen_k
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
stride_dv
,
1
),
number
<
FmhaPipeline
::
kAlignmentVGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dv_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimV
>
{});
}();
auto
dk_dram_window
=
make_tile_window
(
dk_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
i_n0
,
0
});
auto
dv_dram_window
=
make_tile_window
(
dv_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
{
i_n0
,
0
});
KGradEpiloguePipeline
{}(
dk_dram_window
,
dk_acc_tile
);
VGradEpiloguePipeline
{}(
dv_dram_window
,
dv_acc_tile
);
}
};
template
<
typename
TilePartitioner_
,
typename
FmhaBwdOGradDotO_
>
struct
FmhaBwdOGradDotOKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaBwdOGradDotO
=
ck_tile
::
remove_cvref_t
<
FmhaBwdOGradDotO_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaBwdOGradDotO
::
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kBlockPerCu
=
FmhaBwdOGradDotO
::
kBlockPerCu
;
static
constexpr
ck_tile
::
index_t
kM0
=
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kVHeaddim
=
FmhaBwdOGradDotO
::
kVHeaddim
;
using
DDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaBwdOGradDotO
::
DDataType
>
;
using
ODataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaBwdOGradDotO
::
ODataType
>
;
using
OGradDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaBwdOGradDotO
::
OGradDataType
>
;
static
constexpr
bool
kIsGroupMode
=
FmhaBwdOGradDotO
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
FmhaBwdOGradDotO
::
kPadSeqLenQ
;
static
constexpr
bool
kPadHeadDimV
=
FmhaBwdOGradDotO
::
kPadHeadDimV
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
ck_tile
::
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
// clang-format on
CK_TILE_HOST
static
std
::
string
GetName
()
{
// sync with generate.py
// clang-format off
#define _SS_ std::string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
std
::
string
n
;
if
(
kPadSeqLenQ
)
n
+=
"s"
;
if
(
kPadHeadDimV
)
n
+=
"dv"
;
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
_SS_
(
"fmha_bwd_dot_do_o_d"
)
+
_TS_
(
kVHeaddim
)
+
"_"
+
_SS_
(
t2s
<
ODataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
+
(
"o"
+
_TS_
(
kBlockPerCu
))
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
);
#undef _SS_
#undef _TS_
// clang-format on
}
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct
FmhaBwdOGradDotOCommonKargs
{
const
void
*
o_ptr
;
const
void
*
do_ptr
;
void
*
d_ptr
;
float
p_undrop
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
stride_do
;
ck_tile
::
index_t
stride_o
;
ck_tile
::
index_t
nhead_stride_do
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
nhead_stride_d
;
ck_tile
::
index_t
batch_stride_d
;
};
struct
FmhaBwdOGradDotOBatchModeKargs
:
FmhaBwdOGradDotOCommonKargs
{
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_o
;
};
struct
FmhaBwdOGradDotOGroupModeKargs
:
FmhaBwdOGradDotOCommonKargs
{
const
int32_t
*
seqstart_q_ptr
;
};
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
FmhaBwdOGradDotOGroupModeKargs
,
FmhaBwdOGradDotOBatchModeKargs
>
;
template
<
bool
Cond
=
!
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
o_ptr
,
const
void
*
do_ptr
,
void
*
d_ptr
,
float
p_undrop
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_o
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
nhead_stride_d
,
ck_tile
::
index_t
batch_stride_do
,
ck_tile
::
index_t
batch_stride_o
,
ck_tile
::
index_t
batch_stride_d
)
{
Kargs
kargs
{{
o_ptr
,
do_ptr
,
d_ptr
,
p_undrop
,
seqlen_q
,
hdim_v
,
stride_do
,
stride_o
,
nhead_stride_do
,
nhead_stride_o
,
nhead_stride_d
,
batch_stride_d
},
batch_stride_do
,
batch_stride_o
};
return
kargs
;
}
template
<
bool
Cond
=
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
o_ptr
,
const
void
*
do_ptr
,
void
*
d_ptr
,
float
p_undrop
,
const
void
*
seqstart_q_ptr
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_o
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
nhead_stride_d
,
ck_tile
::
index_t
batch_stride_d
)
{
Kargs
kargs
{{
o_ptr
,
do_ptr
,
d_ptr
,
p_undrop
,
-
1
,
// seqlen will be updated by another pointer
hdim_v
,
stride_do
,
stride_o
,
nhead_stride_do
,
nhead_stride_o
,
nhead_stride_d
,
batch_stride_d
},
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
)};
return
kargs
;
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
)
{
return
TilePartitioner
::
GridSize
(
batch_size_
,
nhead_
,
seqlen_q_
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
0
;
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// divide problem
const
auto
[
i_tile_m
,
i_nhead
,
i_batch
]
=
TilePartitioner
{}(
kargs
.
seqlen_q
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
kM0
);
long_index_t
batch_offset_o
=
0
;
long_index_t
batch_offset_do
=
0
;
long_index_t
batch_offset_d
=
0
;
if
constexpr
(
kIsGroupMode
)
{
// get starting offset for each batch
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
batch_offset_o
=
query_start
*
kargs
.
stride_o
;
batch_offset_do
=
query_start
*
kargs
.
stride_do
;
batch_offset_d
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_d
;
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
kargs
.
seqlen_q
=
adjusted_seqstart_q_ptr
[
1
]
-
adjusted_seqstart_q_ptr
[
0
];
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if
(
kargs
.
seqlen_q
<=
i_m0
)
{
return
;
}
}
else
{
batch_offset_o
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o
;
batch_offset_do
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_do
;
batch_offset_d
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_d
;
}
// for simplicity, batch stride we just modify the pointer
const
ODataType
*
o_ptr
=
reinterpret_cast
<
const
ODataType
*>
(
kargs
.
o_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_o
+
batch_offset_o
;
const
OGradDataType
*
do_ptr
=
reinterpret_cast
<
const
OGradDataType
*>
(
kargs
.
do_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_do
+
batch_offset_do
;
DDataType
*
d_ptr
=
reinterpret_cast
<
DDataType
*>
(
kargs
.
d_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_d
+
batch_offset_d
;
// O/dO/D DRAM and DRAM window
const
auto
o_dram
=
[
&
]()
{
auto
o_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
o_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
stride_o
,
1
),
number
<
FmhaBwdOGradDotO
::
kAlignmentO
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
o_dram_naive
,
make_tuple
(
number
<
kM0
>
{},
number
<
kVHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
}();
const
auto
do_dram
=
[
&
]()
{
auto
do_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
do_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
stride_do
,
1
),
number
<
FmhaBwdOGradDotO
::
kAlignmentOGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
do_dram_naive
,
make_tuple
(
number
<
kM0
>
{},
number
<
kVHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
}();
auto
d_dram
=
[
&
]()
{
const
auto
d_dram_naive
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
d_ptr
,
make_tuple
(
kargs
.
seqlen_q
),
number
<
1
>
{});
return
pad_tensor_view
(
d_dram_naive
,
make_tuple
(
number
<
kM0
>
{}),
sequence
<
kPadSeqLenQ
>
{});
}();
auto
o_dram_window
=
make_tile_window
(
o_dram
,
make_tuple
(
number
<
kM0
>
{},
number
<
kVHeaddim
>
{}),
{
i_m0
,
0
});
auto
do_dram_window
=
make_tile_window
(
do_dram
,
make_tuple
(
number
<
kM0
>
{},
number
<
kVHeaddim
>
{}),
{
i_m0
,
0
});
auto
d_dram_window
=
make_tile_window
(
d_dram
,
make_tuple
(
number
<
kM0
>
{}),
{
i_m0
});
FmhaBwdOGradDotO
{}(
o_dram_window
,
do_dram_window
,
d_dram_window
,
kargs
.
p_undrop
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
0 → 100644
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
BlockFmhaShape_
>
struct
FmhaBwdTilePartitioner
{
using
BlockFmhaShape
=
ck_tile
::
remove_cvref_t
<
BlockFmhaShape_
>
;
static
constexpr
ck_tile
::
index_t
kN0
=
BlockFmhaShape
::
kN0
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_k_
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_k_
,
kN0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_k*/
)
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
};
template
<
ck_tile
::
index_t
kBlockSize
>
struct
FmhaBwdOGradDotOTilePartitioner
{
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kBlockSize
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
)
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string>
#include <type_traits>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q]
*
K[seqlen_k, hdim_q]
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q]
@
K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
// P[seqlen_q, seqlen_k] = Softmax(S[seqlen_q, seqlen_k])
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k]
*
V[hdim_v, seqlen_k]
// P[seqlen_q, seqlen_k] = Softmax(S
''
[seqlen_q, seqlen_k])
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k]
@
V
^T
[hdim_v, seqlen_k]
namespace
ck_tile
{
...
...
@@ -31,8 +32,11 @@ struct FmhaFwdKernel
using
KDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
KDataType
>
;
using
VDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
VDataType
>
;
using
BiasDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
BiasDataType
>
;
using
RandValOutputDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
RandValOutputDataType
>
;
using
LSEDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
LSEDataType
>
;
using
ODataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
ODataType
>
;
using
SaccDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
SaccDataType
>
;
using
VLayout
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
VLayout
>
;
...
...
@@ -41,8 +45,9 @@ struct FmhaFwdKernel
static
constexpr
bool
kPadSeqLenK
=
FmhaPipeline
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
FmhaPipeline
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
bool
kHasBias
=
FmhaPipeline
::
kHas
Bias
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
Bias
Enum
;
static
constexpr
bool
kStoreLSE
=
FmhaPipeline
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
FmhaPipeline
::
kHasDropout
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
...
...
@@ -74,14 +79,15 @@ struct FmhaFwdKernel
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
_SS_
(
"fmha_fwd_d"
)
+
_TS_
(
bfs
::
kK0BlockLength
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
+
_SS_
(
TilePartitioner
::
name
)
+
"_"
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
_TS_
(
bfs
::
kN1
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kK0BlockLength
)
+
"_"
+
"r"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
kHasBias
?
"_bias"
:
""
)
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
#undef _SS_
#undef _TS_
// clang-format on
...
...
@@ -108,6 +114,7 @@ struct FmhaFwdKernel
ck_tile
::
index_t
hdim_q
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
num_head_q
;
// for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
// if this param is larger than 1, indicate MQA/GQA case
ck_tile
::
index_t
nhead_ratio_qk
;
...
...
@@ -136,6 +143,13 @@ struct FmhaFwdKernel
ck_tile
::
index_t
batch_stride_bias
=
0
;
};
struct
FmhaFwdAlibiKargs
{
// alibi is batch*nhead*1, no matter in batch/group mode, they are the same
const
void
*
alibi_slope_ptr
;
ck_tile
::
index_t
alibi_slope_stride
;
// stride in batch, or 0 for all batch share same slope
};
struct
FmhaFwdMaskKargs
{
// ck_tile::index_t window_size_left, window_size_right;
...
...
@@ -153,19 +167,48 @@ struct FmhaFwdKernel
{
void
*
lse_ptr
=
nullptr
;
ck_tile
::
index_t
nhead_stride_lse
=
0
;
ck_tile
::
index_t
batch_stride_lse
=
0
;
};
struct
FmhaFwd
BatchModeLSEKargs
:
FmhaFwdCommonLSE
Kargs
struct
FmhaFwd
CommonDropout
Kargs
{
ck_tile
::
index_t
batch_stride_lse
=
0
;
void
init_dropout
(
const
float
p_drop
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
float
p_undrop
=
1.0
-
p_drop
;
p_undrop_in_uint8_t
=
uint8_t
(
std
::
floor
(
p_undrop
*
std
::
numeric_limits
<
uint8_t
>::
max
()));
rp_undrop
=
1.0
/
p_undrop
;
drop_seed
=
std
::
get
<
0
>
(
drop_seed_offset
);
drop_offset
=
std
::
get
<
1
>
(
drop_seed_offset
);
}
float
rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
bool
is_store_randval
=
false
;
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
ck_tile
::
index_t
stride_randval
=
0
;
ck_tile
::
index_t
nhead_stride_randval
=
0
;
};
struct
FmhaFwdBatchModeDropoutKargs
:
FmhaFwdCommonDropoutKargs
{
ck_tile
::
index_t
batch_stride_randval
=
0
;
};
struct
FmhaFwdBatchModeKargs
:
FmhaFwdCommonKargs
,
std
::
conditional_t
<
kHasBias
,
FmhaFwdBatchModeBiasKargs
,
FmhaFwdEmptyKargs
<
0
>>
,
std
::
conditional_t
<
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
,
FmhaFwdBatchModeBiasKargs
,
std
::
conditional_t
<
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
,
FmhaFwdAlibiKargs
,
FmhaFwdEmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasMask
,
FmhaFwdMaskKargs
,
FmhaFwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kStoreLSE
,
FmhaFwdBatchModeLSEKargs
,
FmhaFwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kDoFp8StaticQuant
,
FmhaFwdFp8StaticQuantKargs
,
FmhaFwdEmptyKargs
<
3
>>
std
::
conditional_t
<
kStoreLSE
,
FmhaFwdCommonLSEKargs
,
FmhaFwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kDoFp8StaticQuant
,
FmhaFwdFp8StaticQuantKargs
,
FmhaFwdEmptyKargs
<
3
>>
,
std
::
conditional_t
<
kHasDropout
,
FmhaFwdBatchModeDropoutKargs
,
FmhaFwdEmptyKargs
<
4
>>
{
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
...
...
@@ -175,10 +218,15 @@ struct FmhaFwdKernel
struct
FmhaFwdGroupModeKargs
:
FmhaFwdCommonKargs
,
std
::
conditional_t
<
kHasBias
,
FmhaFwdCommonBiasKargs
,
FmhaFwdEmptyKargs
<
0
>>
,
std
::
conditional_t
<
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
,
FmhaFwdCommonBiasKargs
,
std
::
conditional_t
<
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
,
FmhaFwdAlibiKargs
,
FmhaFwdEmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasMask
,
FmhaFwdMaskKargs
,
FmhaFwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kStoreLSE
,
FmhaFwdCommonLSEKargs
,
FmhaFwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kDoFp8StaticQuant
,
FmhaFwdFp8StaticQuantKargs
,
FmhaFwdEmptyKargs
<
3
>>
std
::
conditional_t
<
kDoFp8StaticQuant
,
FmhaFwdFp8StaticQuantKargs
,
FmhaFwdEmptyKargs
<
3
>>
,
std
::
conditional_t
<
kHasDropout
,
FmhaFwdCommonDropoutKargs
,
FmhaFwdEmptyKargs
<
4
>>
{
const
int32_t
*
seqstart_q_ptr
;
const
int32_t
*
seqstart_k_ptr
;
...
...
@@ -193,12 +241,14 @@ struct FmhaFwdKernel
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
void
*
rand_val_ptr
,
void
*
lse_ptr
,
void
*
o_ptr
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_k
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_head_q
,
ck_tile
::
index_t
nhead_ratio_qk
,
float
scale_s
,
float
scale_p
,
...
...
@@ -207,22 +257,28 @@ struct FmhaFwdKernel
ck_tile
::
index_t
stride_k
,
ck_tile
::
index_t
stride_v
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_o
,
ck_tile
::
index_t
nhead_stride_q
,
ck_tile
::
index_t
nhead_stride_k
,
ck_tile
::
index_t
nhead_stride_v
,
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
batch_stride_q
,
ck_tile
::
index_t
batch_stride_k
,
ck_tile
::
index_t
batch_stride_v
,
ck_tile
::
index_t
batch_stride_bias
,
ck_tile
::
index_t
batch_stride_randval
,
ck_tile
::
index_t
batch_stride_lse
,
ck_tile
::
index_t
batch_stride_o
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
)
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
k_ptr
,
...
...
@@ -232,6 +288,7 @@ struct FmhaFwdKernel
seqlen_k
,
hdim_q
,
hdim_v
,
num_head_q
,
nhead_ratio_qk
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast
<
float
>
(
scale_s
*
ck_tile
::
log2e_v
<>
),
...
...
@@ -250,18 +307,24 @@ struct FmhaFwdKernel
{},
// placeholder for mask
{},
// placeholder for lse
{},
// placeholder for fp8_static_quant args
{},
// placeholder for dropout
batch_stride_q
,
batch_stride_k
,
batch_stride_v
,
batch_stride_o
};
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
kargs
.
bias_ptr
=
bias_ptr
;
kargs
.
stride_bias
=
stride_bias
;
kargs
.
nhead_stride_bias
=
nhead_stride_bias
;
kargs
.
batch_stride_bias
=
batch_stride_bias
;
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
kargs
.
alibi_slope_ptr
=
bias_ptr
;
kargs
.
alibi_slope_stride
=
stride_bias
;
}
if
constexpr
(
kHasMask
)
{
kargs
.
window_size_left
=
window_size_left
;
...
...
@@ -279,6 +342,15 @@ struct FmhaFwdKernel
kargs
.
scale_p
=
scale_p
;
kargs
.
scale_o
=
scale_o
;
}
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
}
return
kargs
;
}
...
...
@@ -289,6 +361,7 @@ struct FmhaFwdKernel
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
void
*
rand_val_ptr
,
void
*
lse_ptr
,
void
*
o_ptr
,
const
void
*
seqstart_q_ptr
,
...
...
@@ -296,6 +369,7 @@ struct FmhaFwdKernel
const
void
*
seqlen_k_ptr
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_head_q
,
ck_tile
::
index_t
nhead_ratio_qk
,
float
scale_s
,
float
scale_p
,
...
...
@@ -304,16 +378,22 @@ struct FmhaFwdKernel
ck_tile
::
index_t
stride_k
,
ck_tile
::
index_t
stride_v
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_o
,
ck_tile
::
index_t
nhead_stride_q
,
ck_tile
::
index_t
nhead_stride_k
,
ck_tile
::
index_t
nhead_stride_v
,
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
batch_stride_lse
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
)
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
k_ptr
,
...
...
@@ -323,6 +403,7 @@ struct FmhaFwdKernel
-
1
,
//
hdim_q
,
hdim_v
,
num_head_q
,
nhead_ratio_qk
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast
<
float
>
(
scale_s
*
ck_tile
::
log2e_v
<>
),
...
...
@@ -341,16 +422,22 @@ struct FmhaFwdKernel
{},
// placeholder for mask
{},
// placeholder for lse
{},
// placeholder for fp8_static_quant args
{},
// placeholder for dropout
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
)};
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
kargs
.
bias_ptr
=
bias_ptr
;
kargs
.
stride_bias
=
stride_bias
;
kargs
.
nhead_stride_bias
=
nhead_stride_bias
;
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
kargs
.
alibi_slope_ptr
=
bias_ptr
;
kargs
.
alibi_slope_stride
=
stride_bias
;
}
if
constexpr
(
kHasMask
)
{
kargs
.
window_size_left
=
window_size_left
;
...
...
@@ -361,12 +448,21 @@ struct FmhaFwdKernel
{
kargs
.
lse_ptr
=
lse_ptr
;
kargs
.
nhead_stride_lse
=
nhead_stride_lse
;
kargs
.
batch_stride_lse
=
batch_stride_lse
;
}
if
constexpr
(
kDoFp8StaticQuant
)
{
kargs
.
scale_p
=
scale_p
;
kargs
.
scale_o
=
scale_o
;
}
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
}
return
kargs
;
}
...
...
@@ -398,12 +494,13 @@ struct FmhaFwdKernel
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
long_index_t
batch_offset_q
=
0
;
long_index_t
batch_offset_k
=
0
;
long_index_t
batch_offset_v
=
0
;
long_index_t
batch_offset_bias
=
0
;
long_index_t
batch_offset_lse
=
0
;
long_index_t
batch_offset_o
=
0
;
long_index_t
batch_offset_q
=
0
;
long_index_t
batch_offset_k
=
0
;
long_index_t
batch_offset_v
=
0
;
long_index_t
batch_offset_bias
=
0
;
long_index_t
batch_offset_randval
=
0
;
long_index_t
batch_offset_lse
=
0
;
long_index_t
batch_offset_o
=
0
;
if
constexpr
(
kIsGroupMode
)
{
...
...
@@ -421,17 +518,17 @@ struct FmhaFwdKernel
{
batch_offset_v
=
key_start
;
}
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
+
key_start
;
}
else
if
constexpr
(
kStoreLSE
)
{
batch_offset_
bias
=
key_start
;
batch_offset_
lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
}
if
constexpr
(
k
StoreLSE
)
if
constexpr
(
k
HasDropout
)
{
batch_offset_
lse
=
query_start
;
batch_offset_
randval
=
query_start
*
kargs
.
stride_randval
;
}
batch_offset_o
=
query_start
*
kargs
.
stride_o
;
...
...
@@ -461,7 +558,7 @@ struct FmhaFwdKernel
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
batch_offset_bias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_bias
;
}
...
...
@@ -469,6 +566,11 @@ struct FmhaFwdKernel
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
}
if
constexpr
(
kHasDropout
)
{
batch_offset_randval
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_randval
;
}
batch_offset_o
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o
;
}
...
...
@@ -585,7 +687,7 @@ struct FmhaFwdKernel
const
auto
bias_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
bias_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
const
BiasDataType
*
bias_ptr
=
reinterpret_cast
<
const
BiasDataType
*>
(
kargs
.
bias_ptr
)
+
...
...
@@ -642,6 +744,56 @@ struct FmhaFwdKernel
}
}();
auto
dropout
=
[
&
,
i_nhead_
=
i_nhead
,
i_batch_
=
i_batch
]()
{
if
constexpr
(
kHasDropout
)
{
return
BlockDropout
{
i_batch_
,
i_nhead_
,
kargs
.
num_head_q
,
kargs
.
drop_seed
,
kargs
.
drop_offset
,
kargs
.
rp_undrop
,
kargs
.
p_undrop_in_uint8_t
,
kargs
.
is_store_randval
};
}
else
{
return
NullBlockDropout
{};
};
}();
auto
randval_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
randval_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
if
constexpr
(
kHasDropout
)
{
RandValOutputDataType
*
rand_val_ptr
=
reinterpret_cast
<
RandValOutputDataType
*>
(
kargs
.
rand_val_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_randval
+
batch_offset_randval
;
const
auto
randval_dram
=
[
&
]()
{
const
auto
randval_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
rand_val_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
seqlen_k
),
make_tuple
(
kargs
.
stride_randval
,
1
),
number
<
1
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
randval_dram_naive
,
randval_dram_window_lengths
,
sequence
<
kPadSeqLenQ
,
kPadSeqLenK
>
{});
}();
return
make_tile_window
(
randval_dram
,
randval_dram_window_lengths
,
{
i_m0
,
0
});
}
else
{
return
make_null_tile_window
(
randval_dram_window_lengths
);
}
}();
FmhaMask
mask
=
[
&
]()
{
if
constexpr
(
kHasMask
)
return
ck_tile
::
make_generic_attention_mask_from_lr_window
<
FmhaMask
>
(
...
...
@@ -654,6 +806,39 @@ struct FmhaFwdKernel
return
FmhaMask
{
kargs
.
seqlen_q
,
kargs
.
seqlen_k
};
}();
// WA i_batch capture structure binding before c++20
auto
position_encoding
=
[
&
,
i_batch_
=
i_batch
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
// data loading, shared by entire wg
// TODO: how to use s_read?
SaccDataType
slope
=
*
(
reinterpret_cast
<
const
SaccDataType
*>
(
kargs
.
alibi_slope_ptr
)
+
i_batch_
*
kargs
.
alibi_slope_stride
+
i_nhead_
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
slope
*=
ck_tile
::
log2e_v
<>
;
#endif
if
constexpr
(
kHasMask
)
{
return
make_alibi_from_lr_mask
<
SaccDataType
,
true
>
(
slope
,
kargs
.
window_size_left
,
kargs
.
window_size_right
,
kargs
.
seqlen_q
,
kargs
.
seqlen_k
,
kargs
.
mask_type
);
}
else
{
return
Alibi
<
SaccDataType
,
true
>
{
slope
,
kargs
.
seqlen_q
,
kargs
.
seqlen_k
,
AlibiMode
::
FROM_BOTTOM_RIGHT
};
}
}
else
{
return
EmptyPositionEncoding
<
SaccDataType
>
{};
}
}();
auto
o_acc_tile
=
[
&
]()
{
if
constexpr
(
kDoFp8StaticQuant
)
{
...
...
@@ -666,14 +851,17 @@ struct FmhaFwdKernel
identity
{},
// v_element_func
bias_dram_window
,
identity
{},
// bias_element_func
randval_dram_window
,
lse_dram_window
,
identity
{},
// lse_element_func
identity
{},
// s_acc_element_func
scales
{
kargs
.
scale_p
},
// p_compute_element_func
composes
(
saturates
<
fp8_t
>
{},
scales
{
kargs
.
scale_o
}),
// o_acc_element_func
mask
,
position_encoding
,
kargs
.
scale_s
,
smem_ptr
);
smem_ptr
,
dropout
);
}
else
{
...
...
@@ -681,10 +869,13 @@ struct FmhaFwdKernel
k_dram_window
,
v_dram_window
,
bias_dram_window
,
randval_dram_window
,
lse_dram_window
,
mask
,
position_encoding
,
kargs
.
scale_s
,
smem_ptr
);
smem_ptr
,
dropout
);
}
}();
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
0 → 100644
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
struct
FmhaFwdSplitKVCombineKernel
{
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
remove_cvref_t
<
FmhaPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
static
constexpr
index_t
kBlockSize
=
FmhaPipeline
::
kBlockSize
;
static
constexpr
index_t
kBlockPerCu
=
FmhaPipeline
::
kBlockPerCu
;
static_assert
(
kBlockPerCu
>
0
);
static
constexpr
index_t
kBlockPerCuInput
=
FmhaPipeline
::
Problem
::
kBlockPerCu
;
using
LSEDataType
=
remove_cvref_t
<
typename
FmhaPipeline
::
LSEDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
FmhaPipeline
::
OaccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
FmhaPipeline
::
ODataType
>
;
static
constexpr
bool
kIsGroupMode
=
FmhaPipeline
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
FmhaPipeline
::
kPadSeqLenQ
;
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
bool
kStoreLSE
=
FmhaPipeline
::
kStoreLSE
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
// clang-format on
__host__
static
std
::
string
GetName
()
{
// sync with generate.py
// clang-format off
#define _SS_ std::string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
std
::
string
n
;
if
(
kPadSeqLenQ
)
n
+=
"s"
;
if
(
kPadHeadDimV
)
n
+=
"dv"
;
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
_SS_
(
"fmha_fwd_splitkv_combine_d"
)
+
_TS_
(
FmhaPipeline
::
kHeadDimV
)
+
"_"
+
_SS_
(
t2s
<
ODataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
"b"
+
_TS_
(
FmhaPipeline
::
kM0
)
+
"x"
+
_TS_
(
FmhaPipeline
::
kN1
)
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
#undef _SS_
#undef _TS_
// clang-format on
}
template
<
ck_tile
::
index_t
I
>
// to avoid duplicated base class prblem, introduce an template
// arg
struct
EmptyKargs
{
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct
CommonKargs
{
const
void
*
lse_acc_ptr
;
const
void
*
o_acc_ptr
;
void
*
o_ptr
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
max_seqlen_q
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
num_splits
;
ck_tile
::
index_t
row_stride_o_acc
;
ck_tile
::
index_t
row_stride_o
;
ck_tile
::
index_t
nhead_stride_lse_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
};
struct
CommonLSEKargs
{
void
*
lse_ptr
=
nullptr
;
ck_tile
::
index_t
nhead_stride_lse
=
0
;
ck_tile
::
index_t
batch_stride_lse
=
0
;
};
struct
Fp8StaticQuantKargs
{
float
scale_o
;
};
struct
BatchModeKargs
:
CommonKargs
,
std
::
conditional_t
<
kStoreLSE
,
CommonLSEKargs
,
EmptyKargs
<
0
>>
,
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
1
>>
{
ck_tile
::
index_t
batch_stride_o
;
};
struct
GroupModeKargs
:
CommonKargs
,
std
::
conditional_t
<
kStoreLSE
,
CommonLSEKargs
,
EmptyKargs
<
0
>>
,
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
3
>>
{
const
int32_t
*
seqstart_q_ptr
;
};
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
GroupModeKargs
,
BatchModeKargs
>
;
template
<
bool
Cond
=
!
kIsGroupMode
>
__host__
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
lse_acc_ptr
,
const
void
*
o_acc_ptr
,
void
*
lse_ptr
,
void
*
o_ptr
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
,
float
scale_o
,
ck_tile
::
index_t
row_stride_o_acc
,
ck_tile
::
index_t
row_stride_o
,
ck_tile
::
index_t
nhead_stride_lse_acc
,
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
batch_stride_lse_acc
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
batch_stride_lse
,
ck_tile
::
index_t
batch_stride_o
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
)
{
Kargs
kargs
{{
lse_acc_ptr
,
o_acc_ptr
,
o_ptr
,
batch
,
max_seqlen_q
,
seqlen_q
,
hdim_v
,
num_splits
,
row_stride_o_acc
,
row_stride_o
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
{},
// placeholder for lse
{},
// placeholder for fp8_static_quant args
batch_stride_o
};
if
constexpr
(
kStoreLSE
)
{
kargs
.
lse_ptr
=
lse_ptr
;
kargs
.
nhead_stride_lse
=
nhead_stride_lse
;
kargs
.
batch_stride_lse
=
batch_stride_lse
;
}
if
constexpr
(
kDoFp8StaticQuant
)
{
kargs
.
scale_o
=
scale_o
;
}
return
kargs
;
}
template
<
bool
Cond
=
kIsGroupMode
>
__host__
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
lse_acc_ptr
,
const
void
*
o_acc_ptr
,
void
*
lse_ptr
,
void
*
o_ptr
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
max_seqlen_q
,
const
void
*
seqstart_q_ptr
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
,
float
scale_o
,
ck_tile
::
index_t
row_stride_o_acc
,
ck_tile
::
index_t
row_stride_o
,
ck_tile
::
index_t
nhead_stride_lse_acc
,
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
batch_stride_lse_acc
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
batch_stride_lse
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
)
{
Kargs
kargs
{{
lse_acc_ptr
,
o_acc_ptr
,
o_ptr
,
batch
,
max_seqlen_q
,
-
1
,
// seqlen will be updated by another pointer
hdim_v
,
num_splits
,
row_stride_o_acc
,
row_stride_o
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
{},
// placeholder for lse
{},
// placeholder for fp8_static_quant args
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
)};
if
constexpr
(
kStoreLSE
)
{
kargs
.
lse_ptr
=
lse_ptr
;
kargs
.
nhead_stride_lse
=
nhead_stride_lse
;
kargs
.
batch_stride_lse
=
batch_stride_lse
;
}
if
constexpr
(
kDoFp8StaticQuant
)
{
kargs
.
scale_o
=
scale_o
;
}
return
kargs
;
}
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
,
ck_tile
::
index_t
hdim_v_
)
{
return
TilePartitioner
::
GridSize
(
batch_size_
,
nhead_
,
seqlen_q_
,
hdim_v_
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
ck_tile
::
max
(
FmhaPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
// divide problem
const
auto
[
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
]
=
TilePartitioner
{}(
kargs
.
seqlen_q
,
kargs
.
hdim_v
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
const
long_index_t
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
const
long_index_t
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
long_index_t
batch_offset_lse
=
0
;
long_index_t
batch_offset_o
=
0
;
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
}
if
constexpr
(
kIsGroupMode
)
{
// get starting offset for each batch
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
batch_offset_o
=
query_start
*
kargs
.
row_stride_o
;
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
kargs
.
seqlen_q
=
adjusted_seqstart_q_ptr
[
1
]
-
adjusted_seqstart_q_ptr
[
0
];
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if
(
kargs
.
seqlen_q
<=
i_m0
)
{
return
;
}
}
else
{
batch_offset_o
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o
;
}
// for simplicity, batch stride we just modify the pointer
const
LSEDataType
*
lse_acc_ptr
=
reinterpret_cast
<
const
LSEDataType
*>
(
kargs
.
lse_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_lse_acc
+
batch_offset_lse_acc
;
const
OaccDataType
*
o_acc_ptr
=
reinterpret_cast
<
const
OaccDataType
*>
(
kargs
.
o_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_o_acc
+
batch_offset_o_acc
;
ODataType
*
o_ptr
=
reinterpret_cast
<
ODataType
*>
(
kargs
.
o_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_o
+
batch_offset_o
;
// LSEacc/Oacc DRAM and DRAM windows
const
auto
lse_acc_dram
=
[
&
]()
{
const
auto
lse_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
lse_acc_ptr
,
make_tuple
(
kargs
.
num_splits
,
kargs
.
seqlen_q
),
make_tuple
(
kargs
.
split_stride_lse_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentLSEacc
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
lse_acc_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kMaxSplits
>
{},
number
<
FmhaPipeline
::
kM0
>
{}),
sequence
<
true
,
kPadSeqLenQ
>
{});
}();
auto
o_acc_dram
=
[
&
]()
{
const
auto
o_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
o_acc_ptr
,
make_tuple
(
kargs
.
num_splits
,
kargs
.
max_seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
split_stride_o_acc
,
kargs
.
row_stride_o_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentOacc
>
{},
number
<
1
>
{});
auto
o_acc_dram_view
=
pad_tensor_view
(
o_acc_dram_naive
,
make_tuple
(
number
<
1
>
{},
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{}),
sequence
<
false
,
kPadSeqLenQ
,
kPadHeadDimV
>
{});
const
index_t
padded_max_seqlen_q
=
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
1
>
{}];
const
index_t
padded_hdim_v
=
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
2
>
{}];
return
transform_tensor_view
(
o_acc_dram_view
,
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
num_splits
,
padded_max_seqlen_q
)),
make_pass_through_transform
(
padded_hdim_v
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
}();
auto
lse_acc_dram_window
=
make_tile_window
(
lse_acc_dram
,
[
&
]()
{
return
make_tuple
(
number
<
FmhaPipeline
::
kMaxSplits
>
{},
number
<
FmhaPipeline
::
kM0
>
{});
}(),
{
0
,
i_m0
});
auto
o_acc_dram_window
=
make_tile_window
(
o_acc_dram
,
[
&
]()
{
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{});
}(),
{
i_m0
,
i_n1
});
// LSE DRAM window
auto
lse_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
lse_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{});
if
constexpr
(
kStoreLSE
)
{
LSEDataType
*
lse_ptr
=
reinterpret_cast
<
LSEDataType
*>
(
kargs
.
lse_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_lse
+
batch_offset_lse
;
const
auto
lse_dram
=
[
&
]()
{
const
auto
lse_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
lse_ptr
,
make_tuple
(
kargs
.
seqlen_q
),
make_tuple
(
1
),
number
<
FmhaPipeline
::
kAlignmentLSE
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
lse_dram_naive
,
lse_dram_window_lengths
,
sequence
<
kPadSeqLenQ
>
{});
}();
return
make_tile_window
(
lse_dram
,
lse_dram_window_lengths
,
{
i_m0
});
}
else
{
return
make_null_tile_window
(
lse_dram_window_lengths
);
}
}();
auto
o_acc_tile
=
[
&
]()
{
if
constexpr
(
kDoFp8StaticQuant
)
{
return
FmhaPipeline
{}(
lse_acc_dram_window
,
o_acc_dram_window
,
lse_dram_window
,
identity
{},
// lse_element_func
composes
(
saturates
<
fp8_t
>
{},
scales
{
kargs
.
scale_o
}),
// o_acc_element_func
kargs
.
num_splits
,
kargs
.
max_seqlen_q
,
smem_ptr
);
}
else
{
return
FmhaPipeline
{}(
lse_acc_dram_window
,
o_acc_dram_window
,
lse_dram_window
,
kargs
.
num_splits
,
kargs
.
max_seqlen_q
,
smem_ptr
);
}
}();
// O DRAM and DRAM window
auto
o_dram
=
[
&
]()
{
const
auto
o_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
o_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
row_stride_o
,
1
),
number
<
FmhaPipeline
::
kAlignmentO
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
o_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
}();
auto
o_dram_window
=
make_tile_window
(
o_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{}),
{
i_m0
,
i_n1
});
EpiloguePipeline
{}(
o_dram_window
,
o_acc_tile
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
0 → 100644
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
index_t
kM0_
,
index_t
kN1_
>
struct
FmhaFwdSplitKVCombineTilePartitioner
{
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
,
ck_tile
::
index_t
hdim_v_
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v_
,
kN1
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
0 → 100644
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string>
#include <type_traits>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
struct
FmhaFwdSplitKVKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
ck_tile
::
remove_cvref_t
<
FmhaPipeline_
>
;
using
EpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
EpiloguePipeline_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaPipeline
::
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kBlockPerCu
=
FmhaPipeline
::
kBlockPerCu
;
static_assert
(
kBlockPerCu
>
0
);
static
constexpr
ck_tile
::
index_t
kBlockPerCuInput
=
FmhaPipeline
::
Problem
::
kBlockPerCu
;
using
QDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
QDataType
>
;
using
KDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
KDataType
>
;
using
VDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
VDataType
>
;
using
BiasDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
BiasDataType
>
;
using
RandValOutputDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
RandValOutputDataType
>
;
using
LSEDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
LSEDataType
>
;
using
SaccDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
SaccDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
FmhaPipeline
::
OaccDataType
>
;
using
VLayout
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
VLayout
>
;
static
constexpr
bool
kIsGroupMode
=
FmhaPipeline
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
FmhaPipeline
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
FmhaPipeline
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
FmhaPipeline
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
bool
kHasDropout
=
FmhaPipeline
::
kHasDropout
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
// clang-format on
__host__
static
std
::
string
GetName
()
{
// sync with generate.py
// clang-format off
using
bfs
=
typename
FmhaPipeline
::
BlockFmhaShape
;
using
gbr
=
typename
bfs
::
Gemm0BlockWarps
;
using
gwt
=
typename
bfs
::
Gemm0WarpTile
;
#define _SS_ std::string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
std
::
string
n
;
if
(
kPadSeqLenQ
)
n
+=
"s"
;
if
(
kPadSeqLenK
)
n
+=
"sk"
;
if
(
kPadHeadDimQ
)
n
+=
"d"
;
if
(
kPadHeadDimV
)
n
+=
"dv"
;
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
_SS_
(
"fmha_fwd_splitkv_d"
)
+
_TS_
(
bfs
::
kK0BlockLength
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
_TS_
(
bfs
::
kN1
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kK0BlockLength
)
+
"_"
+
"r"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
#undef _SS_
#undef _TS_
// clang-format on
}
template
<
ck_tile
::
index_t
I
>
// to avoid duplicated base class prblem, introduce an template
// arg
struct
EmptyKargs
{
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct
CommonKargs
{
const
void
*
q_ptr
;
const
void
*
k_ptr
;
const
void
*
v_ptr
;
void
*
lse_acc_ptr
;
void
*
o_acc_ptr
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
max_seqlen_q
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_k
;
ck_tile
::
index_t
hdim_q
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
num_head_q
;
// for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
// if this param is larger than 1, indicate MQA/GQA case
ck_tile
::
index_t
nhead_ratio_qk
;
ck_tile
::
index_t
num_splits
;
float
scale_s
;
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_o_acc
;
ck_tile
::
index_t
nhead_stride_q
;
ck_tile
::
index_t
nhead_stride_k
;
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_lse_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
};
struct
CommonBiasKargs
{
const
void
*
bias_ptr
=
nullptr
;
ck_tile
::
index_t
stride_bias
=
0
;
ck_tile
::
index_t
nhead_stride_bias
=
0
;
};
struct
BatchModeBiasKargs
:
CommonBiasKargs
{
ck_tile
::
index_t
batch_stride_bias
=
0
;
};
struct
AlibiKargs
{
// alibi is batch*nhead*1, no matter in batch/group mode, they are the same
const
void
*
alibi_slope_ptr
;
ck_tile
::
index_t
alibi_slope_stride
;
// stride in batch, or 0 for all batch share same slope
};
struct
MaskKargs
{
// ck_tile::index_t window_size_left, window_size_right;
ck_tile
::
index_t
window_size_left
,
window_size_right
;
ck_tile
::
GenericAttentionMaskEnum
mask_type
;
};
struct
Fp8StaticQuantKargs
{
float
scale_p
;
};
struct
CommonDropoutKargs
{
void
init_dropout
(
const
float
p_drop
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
float
p_undrop
=
1.0
-
p_drop
;
p_undrop_in_uint8_t
=
uint8_t
(
std
::
floor
(
p_undrop
*
std
::
numeric_limits
<
uint8_t
>::
max
()));
rp_undrop
=
1.0
/
p_undrop
;
drop_seed
=
std
::
get
<
0
>
(
drop_seed_offset
);
drop_offset
=
std
::
get
<
1
>
(
drop_seed_offset
);
}
float
rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
bool
is_store_randval
=
false
;
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
ck_tile
::
index_t
stride_randval
=
0
;
ck_tile
::
index_t
nhead_stride_randval
=
0
;
};
struct
BatchModeDropoutKargs
:
CommonDropoutKargs
{
ck_tile
::
index_t
batch_stride_randval
=
0
;
};
struct
BatchModeKargs
:
CommonKargs
,
std
::
conditional_t
<
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
,
BatchModeBiasKargs
,
std
::
conditional_t
<
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
,
AlibiKargs
,
EmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasMask
,
MaskKargs
,
EmptyKargs
<
1
>>
,
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
2
>>
,
std
::
conditional_t
<
kHasDropout
,
BatchModeDropoutKargs
,
EmptyKargs
<
3
>>
{
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
};
struct
GroupModeKargs
:
CommonKargs
,
std
::
conditional_t
<
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
,
CommonBiasKargs
,
std
::
conditional_t
<
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
,
AlibiKargs
,
EmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasMask
,
MaskKargs
,
EmptyKargs
<
1
>>
,
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
2
>>
,
std
::
conditional_t
<
kHasDropout
,
CommonDropoutKargs
,
EmptyKargs
<
3
>>
{
const
int32_t
*
seqstart_q_ptr
;
const
int32_t
*
seqstart_k_ptr
;
const
int32_t
*
seqlen_k_ptr
;
};
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
GroupModeKargs
,
BatchModeKargs
>
;
template
<
bool
Cond
=
!
kIsGroupMode
>
__host__
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
q_ptr
,
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
void
*
rand_val_ptr
,
void
*
lse_acc_ptr
,
void
*
o_acc_ptr
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_k
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_head_q
,
ck_tile
::
index_t
nhead_ratio_qk
,
ck_tile
::
index_t
num_splits
,
float
scale_s
,
float
scale_p
,
ck_tile
::
index_t
stride_q
,
ck_tile
::
index_t
stride_k
,
ck_tile
::
index_t
stride_v
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_o_acc
,
ck_tile
::
index_t
nhead_stride_q
,
ck_tile
::
index_t
nhead_stride_k
,
ck_tile
::
index_t
nhead_stride_v
,
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_lse_acc
,
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
batch_stride_q
,
ck_tile
::
index_t
batch_stride_k
,
ck_tile
::
index_t
batch_stride_v
,
ck_tile
::
index_t
batch_stride_bias
,
ck_tile
::
index_t
batch_stride_randval
,
ck_tile
::
index_t
batch_stride_lse_acc
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
k_ptr
,
v_ptr
,
lse_acc_ptr
,
o_acc_ptr
,
batch
,
max_seqlen_q
,
seqlen_q
,
seqlen_k
,
hdim_q
,
hdim_v
,
num_head_q
,
nhead_ratio_qk
,
num_splits
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast
<
float
>
(
scale_s
*
ck_tile
::
log2e_v
<>
),
#else
scale_s
,
#endif
stride_q
,
stride_k
,
stride_v
,
stride_o_acc
,
nhead_stride_q
,
nhead_stride_k
,
nhead_stride_v
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for mask
{},
// placeholder for fp8_static_quant args
{},
// placeholder for dropout
batch_stride_q
,
batch_stride_k
,
batch_stride_v
};
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
kargs
.
bias_ptr
=
bias_ptr
;
kargs
.
stride_bias
=
stride_bias
;
kargs
.
nhead_stride_bias
=
nhead_stride_bias
;
kargs
.
batch_stride_bias
=
batch_stride_bias
;
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
kargs
.
alibi_slope_ptr
=
bias_ptr
;
kargs
.
alibi_slope_stride
=
stride_bias
;
}
if
constexpr
(
kHasMask
)
{
kargs
.
window_size_left
=
window_size_left
;
kargs
.
window_size_right
=
window_size_right
;
kargs
.
mask_type
=
static_cast
<
ck_tile
::
GenericAttentionMaskEnum
>
(
mask_type
);
}
if
constexpr
(
kDoFp8StaticQuant
)
{
kargs
.
scale_p
=
scale_p
;
}
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
}
return
kargs
;
}
template
<
bool
Cond
=
kIsGroupMode
>
__host__
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
q_ptr
,
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
void
*
rand_val_ptr
,
void
*
lse_acc_ptr
,
void
*
o_acc_ptr
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
max_seqlen_q
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_k_ptr
,
const
void
*
seqlen_k_ptr
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_head_q
,
ck_tile
::
index_t
nhead_ratio_qk
,
ck_tile
::
index_t
num_splits
,
float
scale_s
,
float
scale_p
,
ck_tile
::
index_t
stride_q
,
ck_tile
::
index_t
stride_k
,
ck_tile
::
index_t
stride_v
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_o_acc
,
ck_tile
::
index_t
nhead_stride_q
,
ck_tile
::
index_t
nhead_stride_k
,
ck_tile
::
index_t
nhead_stride_v
,
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_lse_acc
,
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
batch_stride_lse_acc
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
k_ptr
,
v_ptr
,
lse_acc_ptr
,
o_acc_ptr
,
batch
,
max_seqlen_q
,
-
1
,
// seqlen will be updated by another pointer
-
1
,
//
hdim_q
,
hdim_v
,
num_head_q
,
nhead_ratio_qk
,
num_splits
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast
<
float
>
(
scale_s
*
ck_tile
::
log2e_v
<>
),
#else
scale_s
,
#endif
stride_q
,
stride_k
,
stride_v
,
stride_o_acc
,
nhead_stride_q
,
nhead_stride_k
,
nhead_stride_v
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for mask
{},
// placeholder for fp8_static_quant args
{},
// placeholder for dropout
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
)};
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
kargs
.
bias_ptr
=
bias_ptr
;
kargs
.
stride_bias
=
stride_bias
;
kargs
.
nhead_stride_bias
=
nhead_stride_bias
;
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
kargs
.
alibi_slope_ptr
=
bias_ptr
;
kargs
.
alibi_slope_stride
=
stride_bias
;
}
if
constexpr
(
kHasMask
)
{
kargs
.
window_size_left
=
window_size_left
;
kargs
.
window_size_right
=
window_size_right
;
kargs
.
mask_type
=
static_cast
<
ck_tile
::
GenericAttentionMaskEnum
>
(
mask_type
);
}
if
constexpr
(
kDoFp8StaticQuant
)
{
kargs
.
scale_p
=
scale_p
;
}
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
}
return
kargs
;
}
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
)
{
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
seqlen_q
,
hdim_v
,
num_splits
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
ck_tile
::
max
(
FmhaPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
// divide problem
const
auto
[
i_tile_m
,
i_tile_n
,
i_split
,
i_nhead
,
i_batch
]
=
TilePartitioner
{}(
kargs
.
seqlen_q
,
kargs
.
hdim_v
,
kargs
.
num_splits
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
long_index_t
batch_offset_q
=
0
;
long_index_t
batch_offset_k
=
0
;
long_index_t
batch_offset_v
=
0
;
long_index_t
batch_offset_bias
=
0
;
long_index_t
batch_offset_randval
=
0
;
const
long_index_t
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
const
long_index_t
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
if
constexpr
(
kIsGroupMode
)
{
// get starting offset for each batch
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
const
long_index_t
key_start
=
kargs
.
seqstart_k_ptr
[
i_batch
];
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
}
else
{
batch_offset_v
=
key_start
;
}
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
+
key_start
;
}
if
constexpr
(
kHasDropout
)
{
batch_offset_randval
=
query_start
*
kargs
.
stride_randval
;
}
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
kargs
.
seqlen_q
=
adjusted_seqstart_q_ptr
[
1
]
-
adjusted_seqstart_q_ptr
[
0
];
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if
(
kargs
.
seqlen_q
<=
i_m0
)
{
return
;
}
if
(
kargs
.
seqlen_k_ptr
!=
nullptr
)
{
kargs
.
seqlen_k
=
kargs
.
seqlen_k_ptr
[
i_batch
];
}
else
{
const
auto
adjusted_seqstart_k_ptr
=
kargs
.
seqstart_k_ptr
+
i_batch
;
kargs
.
seqlen_k
=
adjusted_seqstart_k_ptr
[
1
]
-
adjusted_seqstart_k_ptr
[
0
];
}
}
else
{
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
batch_offset_bias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_bias
;
}
if
constexpr
(
kHasDropout
)
{
batch_offset_randval
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_randval
;
}
}
// for simplicity, batch stride we just modify the pointer
const
QDataType
*
q_ptr
=
reinterpret_cast
<
const
QDataType
*>
(
kargs
.
q_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_q
+
batch_offset_q
;
const
KDataType
*
k_ptr
=
reinterpret_cast
<
const
KDataType
*>
(
kargs
.
k_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_k
+
batch_offset_k
;
const
VDataType
*
v_ptr
=
reinterpret_cast
<
const
VDataType
*>
(
kargs
.
v_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_v
+
batch_offset_v
;
OaccDataType
*
o_acc_ptr
=
reinterpret_cast
<
OaccDataType
*>
(
kargs
.
o_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_o_acc
+
batch_offset_o_acc
+
i_split
*
kargs
.
split_stride_o_acc
;
// Q/K/V DRAM and DRAM window
const
auto
q_dram
=
[
&
]()
{
const
auto
q_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
q_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQ
>
{},
number
<
1
>
{});
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
{
return
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0BlockLength
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
else
{
return
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
}();
const
auto
k_dram
=
[
&
]()
{
const
auto
k_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
k_ptr
,
make_tuple
(
kargs
.
seqlen_k
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_k
,
1
),
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}();
const
auto
v_dram
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
const
auto
v_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
v_ptr
,
make_tuple
(
kargs
.
seqlen_k
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
stride_v
,
1
),
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
1
>
{});
const
auto
v_dram_transposed
=
transform_tensor_view
(
v_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_v
),
make_pass_through_transform
(
kargs
.
seqlen_k
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
pad_tensor_view
(
v_dram_transposed
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenK
>
{});
}
else
{
const
auto
v_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
v_ptr
,
make_tuple
(
kargs
.
hdim_v
,
kargs
.
seqlen_k
),
make_tuple
(
kargs
.
stride_v
,
1
),
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenK
>
{});
}
}();
auto
q_dram_window
=
make_tile_window
(
q_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0BlockLength
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
{
i_m0
,
0
});
auto
k_dram_window
=
make_tile_window
(
k_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
{
0
,
0
});
auto
v_dram_window
=
make_tile_window
(
v_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
{
i_n1
,
0
});
/// FIXME: Before C++20, capturing structured binding variables are not supported. Remove
/// following copy capture of the 'i_nhead' if in C++20
const
auto
bias_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
bias_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
const
BiasDataType
*
bias_ptr
=
reinterpret_cast
<
const
BiasDataType
*>
(
kargs
.
bias_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_bias
+
batch_offset_bias
;
const
auto
bias_dram
=
[
&
]()
{
const
auto
bias_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
bias_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
seqlen_k
),
make_tuple
(
kargs
.
stride_bias
,
1
),
number
<
FmhaPipeline
::
kAlignmentBias
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
bias_dram_naive
,
bias_dram_window_lengths
,
sequence
<
kPadSeqLenQ
,
kPadSeqLenK
>
{});
}();
return
make_tile_window
(
bias_dram
,
bias_dram_window_lengths
,
{
i_m0
,
0
});
}
else
{
return
make_null_tile_window
(
bias_dram_window_lengths
);
}
}();
// lse acc
auto
lse_acc_dram_window
=
[
&
,
i_nhead_
=
i_nhead
,
i_split_
=
i_split
]()
{
constexpr
auto
lse_acc_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{});
LSEDataType
*
lse_acc_ptr
=
reinterpret_cast
<
LSEDataType
*>
(
kargs
.
lse_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_lse_acc
+
batch_offset_lse_acc
+
i_split_
*
kargs
.
split_stride_lse_acc
;
const
auto
lse_acc_dram
=
[
&
]()
{
const
auto
lse_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
lse_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
),
make_tuple
(
1
),
number
<
1
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
lse_acc_dram_naive
,
lse_acc_dram_window_lengths
,
sequence
<
kPadSeqLenQ
>
{});
}();
return
make_tile_window
(
lse_acc_dram
,
lse_acc_dram_window_lengths
,
{
i_m0
});
}();
// dropout
float
rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint64_t
drop_seed
=
0
;
uint64_t
drop_offset
=
0
;
bool
is_store_randval
=
false
;
if
constexpr
(
kHasDropout
)
{
rp_undrop
=
kargs
.
rp_undrop
;
p_undrop_in_uint8_t
=
kargs
.
p_undrop_in_uint8_t
;
drop_seed
=
kargs
.
drop_seed
;
drop_offset
=
kargs
.
drop_offset
;
is_store_randval
=
kargs
.
is_store_randval
;
}
BlockDropout
dropout
(
i_batch
,
i_nhead
,
kargs
.
num_head_q
,
drop_seed
,
drop_offset
,
rp_undrop
,
p_undrop_in_uint8_t
,
is_store_randval
);
auto
randval_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
randval_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
if
constexpr
(
kHasDropout
)
{
RandValOutputDataType
*
rand_val_ptr
=
reinterpret_cast
<
RandValOutputDataType
*>
(
kargs
.
rand_val_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_randval
+
batch_offset_randval
;
const
auto
randval_dram
=
[
&
]()
{
const
auto
randval_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
rand_val_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
seqlen_k
),
make_tuple
(
kargs
.
stride_randval
,
1
),
number
<
1
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
randval_dram_naive
,
randval_dram_window_lengths
,
sequence
<
kPadSeqLenQ
,
kPadSeqLenK
>
{});
}();
return
make_tile_window
(
randval_dram
,
randval_dram_window_lengths
,
{
i_m0
,
0
});
}
else
{
return
make_null_tile_window
(
randval_dram_window_lengths
);
}
}();
FmhaMask
mask
=
[
&
]()
{
if
constexpr
(
kHasMask
)
return
ck_tile
::
make_generic_attention_mask_from_lr_window
<
FmhaMask
>
(
kargs
.
window_size_left
,
kargs
.
window_size_right
,
kargs
.
seqlen_q
,
kargs
.
seqlen_k
,
kargs
.
mask_type
==
GenericAttentionMaskEnum
::
MASK_FROM_TOP_LEFT
);
else
return
FmhaMask
{
kargs
.
seqlen_q
,
kargs
.
seqlen_k
};
}();
// WA i_batch capture structure binding before c++20
auto
position_encoding
=
[
&
,
i_batch_
=
i_batch
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
// data loading, shared by entire wg
// TODO: how to use s_read?
SaccDataType
slope
=
*
(
reinterpret_cast
<
const
SaccDataType
*>
(
kargs
.
alibi_slope_ptr
)
+
i_batch_
*
kargs
.
alibi_slope_stride
+
i_nhead_
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
slope
*=
ck_tile
::
log2e_v
<>
;
#endif
if
constexpr
(
kHasMask
)
{
return
make_alibi_from_lr_mask
<
SaccDataType
,
true
>
(
slope
,
kargs
.
window_size_left
,
kargs
.
window_size_right
,
kargs
.
seqlen_q
,
kargs
.
seqlen_k
,
kargs
.
mask_type
);
}
else
{
return
Alibi
<
SaccDataType
,
true
>
{
slope
,
kargs
.
seqlen_q
,
kargs
.
seqlen_k
,
AlibiMode
::
FROM_BOTTOM_RIGHT
};
}
}
else
{
return
EmptyPositionEncoding
<
SaccDataType
>
{};
}
}();
auto
o_acc_tile
=
[
&
,
i_split_
=
i_split
]()
{
if
constexpr
(
kDoFp8StaticQuant
)
{
return
FmhaPipeline
{}(
q_dram_window
,
identity
{},
// q_element_func
k_dram_window
,
identity
{},
// k_element_func
v_dram_window
,
identity
{},
// v_element_func
bias_dram_window
,
identity
{},
// bias_element_func
randval_dram_window
,
lse_acc_dram_window
,
identity
{},
// lse_element_func
identity
{},
// s_acc_element_func
scales
{
kargs
.
scale_p
},
// p_compute_element_func
identity
{},
// o_acc_element_func
kargs
.
num_splits
,
i_split_
,
mask
,
position_encoding
,
kargs
.
scale_s
,
smem_ptr
,
dropout
);
}
else
{
return
FmhaPipeline
{}(
q_dram_window
,
k_dram_window
,
v_dram_window
,
bias_dram_window
,
randval_dram_window
,
lse_acc_dram_window
,
kargs
.
num_splits
,
i_split_
,
mask
,
position_encoding
,
kargs
.
scale_s
,
smem_ptr
,
dropout
);
}
}();
// Oacc DRAM and Oacc DRAM window
auto
o_acc_dram
=
[
&
]()
{
const
auto
o_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
o_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
hdim_v
,
1
),
number
<
FmhaPipeline
::
kAlignmentO
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
o_acc_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
}();
auto
o_acc_dram_window
=
make_tile_window
(
o_acc_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{}),
{
i_m0
,
i_n1
});
EpiloguePipeline
{}(
o_acc_dram_window
,
o_acc_tile
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
0 → 100644
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
BlockFmhaShape_
>
struct
FmhaFwdSplitKVTilePartitioner
{
using
BlockFmhaShape
=
ck_tile
::
remove_cvref_t
<
BlockFmhaShape_
>
;
static
constexpr
ck_tile
::
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
ck_tile
::
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
ck_tile
::
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
ck_tile
::
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
ck_tile
::
index_t
kK1
=
BlockFmhaShape
::
kK1
;
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
),
nhead
*
num_splits
,
batch_size
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
)
{
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
blockIdx
.
x
,
num_tile_n1
);
const
auto
[
i_nhead
,
i_split
]
=
f
(
blockIdx
.
y
,
num_splits
);
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_split
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -18,10 +18,12 @@ struct FmhaFwdTilePartitioner
static
constexpr
ck_tile
::
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
ck_tile
::
index_t
kK1
=
BlockFmhaShape
::
kK1
;
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
,
ck_tile
::
index_t
hdim_v_
)
static
constexpr
const
char
*
name
=
"shb"
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
,
ck_tile
::
index_t
hdim_v_
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
)
*
...
...
@@ -51,4 +53,53 @@ struct FmhaFwdTilePartitioner
}
};
template
<
typename
BlockFmhaShape_
>
using
FmhaFwdTilePartitioner_SHB
=
FmhaFwdTilePartitioner
<
BlockFmhaShape_
>
;
template
<
typename
BlockFmhaShape_
>
struct
FmhaFwdTilePartitioner_HBS
{
using
BlockFmhaShape
=
ck_tile
::
remove_cvref_t
<
BlockFmhaShape_
>
;
static
constexpr
ck_tile
::
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
ck_tile
::
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
ck_tile
::
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
ck_tile
::
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
ck_tile
::
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
const
char
*
name
=
"hbs"
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
,
ck_tile
::
index_t
hdim_v_
)
{
// TODO: this may need tuning
return
dim3
(
nhead_
,
batch_size_
,
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v_
,
kN1
));
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
index_t
i_block
=
blockIdx
.
z
;
const
index_t
i_nhead
=
blockIdx
.
x
;
const
index_t
i_batch
=
blockIdx
.
y
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp
0 → 100644
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwdOGradDotODefaultPolicy
>
struct
BlockFmhaBwdOGradDotO
{
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
using
DDataType
=
remove_cvref_t
<
typename
Problem
::
DDataType
>
;
static
constexpr
index_t
kBlockPerCu
=
Problem
::
kBlockPerCu
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kVHeaddim
=
Problem
::
kVHeaddim
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentOGrad
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentOGrad
<
Problem
>();
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
0
;
}
template
<
typename
ODramBlockWindowTmp
,
typename
OGradDramBlockWindowTmp
,
typename
DDramBlockWindowTmp
>
CK_TILE_HOST_DEVICE
void
operator
()(
const
ODramBlockWindowTmp
&
o_dram_block_window_tmp
,
const
OGradDramBlockWindowTmp
&
do_dram_block_window_tmp
,
DDramBlockWindowTmp
&
d_dram_block_window_tmp
,
float
p_undrop
)
const
{
static_assert
(
std
::
is_same_v
<
ODataType
,
remove_cvref_t
<
typename
ODramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
OGradDataType
,
remove_cvref_t
<
typename
OGradDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
DDataType
,
remove_cvref_t
<
typename
DDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kBlockSize
==
ODramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kBlockSize
==
OGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kBlockSize
==
DDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}],
"wrong!"
);
auto
o_dram_window
=
make_tile_window
(
o_dram_block_window_tmp
.
get_bottom_tensor_view
(),
o_dram_block_window_tmp
.
get_window_lengths
(),
o_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakePreODramTileDistribution
<
Problem
>());
auto
o
=
load_tile
(
o_dram_window
);
auto
do_dram_window
=
make_tile_window
(
do_dram_block_window_tmp
.
get_bottom_tensor_view
(),
do_dram_block_window_tmp
.
get_window_lengths
(),
do_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakePreOGradDramTileDistribution
<
Problem
>());
auto
do_
=
load_tile
(
do_dram_window
);
// declare d
constexpr
auto
d_dstr
=
make_static_tile_distribution
(
detail
::
make_reduce_tile_distribution_encoding
(
o
.
get_tile_distribution
().
get_static_tile_distribution_encoding
(),
sequence
<
1
>
{}));
auto
d
=
make_static_distributed_tensor
<
DDataType
>
(
d_dstr
);
clear_tile
(
d
);
// Initialize D
constexpr
auto
o_spans
=
decltype
(
o
)
::
get_distributed_spans
();
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
o_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
d
(
i_idx
)
+=
(
type_convert
<
DDataType
>
(
o
[
i_j_idx
])
*
type_convert
<
DDataType
>
(
do_
[
i_j_idx
]));
});
});
tile_elementwise_inout
([
&
p_undrop
](
auto
&
x
)
{
x
=
x
*
p_undrop
;
},
d
);
store_tile
(
d_dram_block_window_tmp
,
d
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp
0 → 100644
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
// These templates are not used here.
using
BlockFmhaBwdOGradDotODefaultPolicy
=
BlockFmhaBwdPipelineDefaultPolicy
<
/* QLoadOnce_ = */
false
,
/* QTLoadOnce_ = */
false
,
/* KLoadOnce_ = */
false
,
/* KTLoadOnce_ = */
false
,
/* VLoadOnce_ = */
false
,
/* OGradLoadOnce_ = */
false
,
/* OGradTLoadOnce_ = */
false
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp
0 → 100644
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy
>
struct
BlockFmhaBwdDQDKDVPipelineKSKTSVR
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
GemmDataType
=
remove_cvref_t
<
typename
Problem
::
GemmDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
DDataType
=
remove_cvref_t
<
typename
Problem
::
DDataType
>
;
using
RandValOutputDataType
=
remove_cvref_t
<
typename
Problem
::
RandValOutputDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
using
QGradDataType
=
remove_cvref_t
<
typename
Problem
::
QGradDataType
>
;
using
KGradDataType
=
remove_cvref_t
<
typename
Problem
::
KGradDataType
>
;
using
VGradDataType
=
remove_cvref_t
<
typename
Problem
::
VGradDataType
>
;
using
BiasGradDataType
=
remove_cvref_t
<
typename
Problem
::
BiasGradDataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
static
constexpr
index_t
kBlockPerCu
=
Problem
::
kBlockPerCu
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK2
=
BlockFmhaShape
::
kK2
;
static
constexpr
index_t
kK3
=
BlockFmhaShape
::
kK3
;
static
constexpr
index_t
kK4
=
BlockFmhaShape
::
kK4
;
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kVHeaddim
=
BlockFmhaShape
::
kVHeaddim
;
static
constexpr
bool
kQLoadOnce
=
false
;
static
constexpr
bool
kQTLoadOnce
=
false
;
static
constexpr
bool
kKLoadOnce
=
true
;
static
constexpr
bool
kKTLoadOnce
=
true
;
static
constexpr
bool
kVLoadOnce
=
true
;
static
constexpr
bool
kOGradLoadOnce
=
false
;
static
constexpr
bool
kOGradTLoadOnce
=
false
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
Problem
::
kHasBiasGrad
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static
constexpr
index_t
kAlignmentQ
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentOGrad
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentOGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentQGrad
=
kPadHeadDimQ
?
2
:
Policy
::
template
GetAlignmentQGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentKGrad
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentKGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentVGrad
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentVGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetTransposedAlignmentBias
<
Problem
>();
static
constexpr
const
char
*
name
=
"ks_kts_vr"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
QDramBlockWindowTmp
,
typename
QTDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
KTDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
OGradDramBlockWindowTmp
,
typename
OGradTDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
DDramBlockWindowTmp
,
typename
QGradDramBlockWindowTmp
,
typename
BiasGradDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
const
QTDramBlockWindowTmp
&
qt_dram_block_window_tmp
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
const
KTDramBlockWindowTmp
&
kt_dram_block_window_tmp
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
const
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
const
OGradDramBlockWindowTmp
&
do_dram_block_window_tmp
,
const
OGradTDramBlockWindowTmp
&
dot_dram_block_window_tmp
,
const
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
const
DDramBlockWindowTmp
&
d_dram_block_window_tmp
,
const
QGradDramBlockWindowTmp
&
dq_dram_block_window_tmp
,
const
BiasGradDramBlockWindowTmp
&
dbias_dram_block_window_tmp
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
raw_scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float
scale
,
#endif
float
rp_undrop
,
float
scale_rp_undrop
,
void
*
smem_ptr
,
BlockDropout
&
dropout
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QTDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KTDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
VDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
OGradDataType
,
remove_cvref_t
<
typename
OGradDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
OGradDataType
,
remove_cvref_t
<
typename
OGradTDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
LSEDataType
,
remove_cvref_t
<
typename
LSEDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
DDataType
,
remove_cvref_t
<
typename
DDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QGradDataType
,
remove_cvref_t
<
typename
QGradDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kQKHeaddim
==
QTDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kQKHeaddim
==
KTDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kM0
==
OGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kVHeaddim
==
OGradTDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
LSEDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
DDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
QGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
BiasGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
// Q tile in LDS
QDataType
*
q_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
q_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
// QT tile in LDS
QDataType
*
qt_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
qt_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeQTLdsBlockDescriptor
<
Problem
>());
auto
qt_lds_window
=
make_tile_window
(
qt_lds
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kK3
>
{}),
{
0
,
0
});
// K tile in LDS
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
KDataType
*>
(
smem_ptr
),
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>());
auto
k_lds_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kQKHeaddim
>
{}),
{
0
,
0
});
// KT tile in LDS
KDataType
*
kt_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
kt_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeKTLdsBlockDescriptor
<
Problem
>());
auto
kt_lds_window
=
make_tile_window
(
kt_lds
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
// OGrad tile in LDS
OGradDataType
*
do_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
do_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
// OGradT tile in LDS
OGradDataType
*
dot_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
dot_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsBlockDescriptor
<
Problem
>());
auto
dot_lds_window
=
make_tile_window
(
dot_lds
,
make_tuple
(
number
<
kVHeaddim
>
{},
number
<
kK1
>
{}),
{
0
,
0
});
// SGrad tile in LDS
GemmDataType
*
ds_lds_ptr
=
static_cast
<
GemmDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
ds_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
ds_lds_ptr
,
Policy
::
template
MakeSGradLdsBlockDescriptor
<
Problem
>());
auto
ds_lds_window
=
make_tile_window
(
ds_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType
*
biast_lds_ptr
=
static_cast
<
BiasDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
biast_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
biast_lds_ptr
,
Policy
::
template
MakeBiasTLdsBlockDescriptor
<
Problem
>());
auto
biast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
dbiast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
static_assert
(
std
::
is_same_v
<
BiasDataType
,
BiasGradDataType
>
,
"BiasDataType and BiasGradDataType should be the same!"
);
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetPTOGradTBlockGemm
<
Problem
>();
constexpr
auto
gemm_2
=
Policy
::
template
GetOGradVBlockGemm
<
Problem
>();
constexpr
auto
gemm_3
=
Policy
::
template
GetSGradTQTBlockGemm
<
Problem
>();
constexpr
auto
gemm_4
=
Policy
::
template
GetSGradKTBlockGemm
<
Problem
>();
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
v_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeVInRegDramTileDistribution
<
Problem
,
decltype
(
gemm_2
)>());
auto
v
=
load_tile
(
v_dram_window
);
// persistent V register tile
using
SPTBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
using
SPGradTBlockTileType
=
decltype
(
gemm_2
.
MakeCBlockTile
());
using
QGradBlockTileType
=
decltype
(
gemm_4
.
MakeCBlockTile
());
// init VGrad & KGrad
auto
dv_acc
=
decltype
(
gemm_1
.
MakeCBlockTile
()){};
auto
dk_acc
=
decltype
(
gemm_3
.
MakeCBlockTile
()){};
clear_tile
(
dv_acc
);
clear_tile
(
dk_acc
);
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
k_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
k_origin
=
k_dram_window
.
get_window_origin
();
const
auto
[
seqlen_q_start
,
seqlen_q_end
]
=
mask
.
GetTileRangeAlongY
(
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_q_end
-
seqlen_q_start
,
kM0
);
// check early exit if masked and no work to do.
if
constexpr
(
FmhaMask
::
IsMasking
)
{
if
(
num_total_loop
<=
0
)
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return
ck_tile
::
make_tuple
(
dk_acc
,
dv_acc
);
}
}
auto
k_block_tile
=
load_tile
(
k_dram_window
);
store_tile
(
k_lds_window
,
k_block_tile
);
// // persistent K in LDS
auto
kt_dram_block_window
=
kt_dram_block_window_tmp
;
auto
kt_dram_window
=
make_tile_window
(
kt_dram_block_window
.
get_bottom_tensor_view
(),
kt_dram_block_window
.
get_window_lengths
(),
kt_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKTDramTileDistribution
<
Problem
>());
// K^T DRAM tile window for
// load
auto
kt_block_tile
=
load_tile
(
kt_dram_window
);
auto
kt_shuffle_tmp
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeShuffledKTRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
kt_shuffle_tmp
,
kt_block_tile
);
store_tile
(
kt_lds_window
,
kt_shuffle_tmp
);
// persistent K^T in LDS
auto
q_dram_block_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
auto
qt_dram_block_window
=
make_tile_window
(
qt_dram_block_window_tmp
.
get_bottom_tensor_view
(),
qt_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_q_start
});
auto
do_dram_block_window
=
make_tile_window
(
do_dram_block_window_tmp
.
get_bottom_tensor_view
(),
do_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
auto
dot_dram_block_window
=
make_tile_window
(
dot_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dot_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_q_start
});
auto
dq_dram_block_window
=
make_tile_window
(
dq_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
auto
lse_dram_block_window
=
make_tile_window
(
lse_dram_block_window_tmp
.
get_bottom_tensor_view
(),
lse_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
});
auto
d_dram_block_window
=
make_tile_window
(
d_dram_block_window_tmp
.
get_bottom_tensor_view
(),
d_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
});
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_block_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
bias_origin
.
at
(
number
<
1
>
{})});
// M/N
const
auto
dbias_origin
=
dbias_dram_block_window_tmp
.
get_window_origin
();
auto
dbias_dram_block_window
=
make_tile_window
(
dbias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dbias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
dbias_origin
.
at
(
number
<
1
>
{})});
// M/N
auto
qt_dram_window
=
make_tile_window
(
qt_dram_block_window
.
get_bottom_tensor_view
(),
qt_dram_block_window
.
get_window_lengths
(),
qt_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeQTDramTileDistribution
<
Problem
>());
auto
dot_dram_window
=
make_tile_window
(
dot_dram_block_window
.
get_bottom_tensor_view
(),
dot_dram_block_window
.
get_window_lengths
(),
dot_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeOGradTDramTileDistribution
<
Problem
>());
auto
lse_dram_window
=
make_tile_window
(
lse_dram_block_window
.
get_bottom_tensor_view
(),
lse_dram_block_window
.
get_window_lengths
(),
lse_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeLSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
d_dram_window
=
make_tile_window
(
d_dram_block_window
.
get_bottom_tensor_view
(),
d_dram_block_window
.
get_window_lengths
(),
d_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeLSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window
.
get_bottom_tensor_view
(),
bias_dram_block_window
.
get_window_lengths
(),
bias_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
auto
biast_lds_window
=
make_tile_window
(
biast_lds_shuffle_window
.
get_bottom_tensor_view
(),
biast_lds_shuffle_window
.
get_window_lengths
(),
biast_lds_shuffle_window
.
get_window_origin
(),
Policy
::
template
MakeBiasTTileDistribution
<
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
MakeRandvalDramWindow
<
decltype
(
gemm_0
),
false
>
(
randval_dram_block_window_tmp
,
seqlen_q_start
);
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kQKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kM0
/
kK1
;
constexpr
index_t
k2_loops
=
kVHeaddim
/
kK2
;
constexpr
index_t
k3_loops
=
kM0
/
kK3
;
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
do
{
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window
.
get_bottom_tensor_view
(),
q_dram_block_window
.
get_window_lengths
(),
q_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
// Q DRAM tile window for
// load
auto
do_dram_window
=
make_tile_window
(
do_dram_block_window
.
get_bottom_tensor_view
(),
do_dram_block_window
.
get_window_lengths
(),
do_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeOGradDramTileDistribution
<
Problem
>());
// OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0
auto
st_acc
=
SPTBlockTileType
{};
auto
q_block_tile
=
load_tile
(
q_dram_window
);
{
move_tile_window
(
q_dram_window
,
{
0
,
kK0
});
clear_tile
(
st_acc
);
// Initialize S^T
store_tile
(
q_lds_window
,
q_block_tile
);
// LDS write 0
q_block_tile
=
load_tile
(
q_dram_window
);
// global read 1
}
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
if
constexpr
(
k0_loops
>
2
)
{
static_for
<
0
,
k0_loops
-
2
,
1
>
{}([
&
](
auto
i_k0
)
{
block_sync_lds
();
gemm_0
(
st_acc
,
q_lds_window
,
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kN0
,
(
i_k0
+
1
)
*
kK0
>
{}));
block_sync_lds
();
move_tile_window
(
q_dram_window
,
{
0
,
kK0
});
store_tile
(
q_lds_window
,
q_block_tile
);
// LDS write i + 1
q_block_tile
=
load_tile
(
q_dram_window
);
// global read i + 2
});
}
const
auto
dot_prefetch
=
load_tile
(
dot_dram_window
);
// prefetch load OGrad^T tile
{
// tail
block_sync_lds
();
gemm_0
(
st_acc
,
q_lds_window
,
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
(
k0_loops
-
2
)
*
kK0
>
{},
sequence
<
kN0
,
(
k0_loops
-
1
)
*
kK0
>
{}));
block_sync_lds
();
store_tile
(
q_lds_window
,
q_block_tile
);
block_sync_lds
();
gemm_0
(
st_acc
,
q_lds_window
,
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kN0
,
k0_loops
*
kK0
>
{}));
}
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
block_sync_lds
();
auto
bias_shuffle_tmp
=
make_static_distributed_tensor
<
BiasDataType
>
(
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
shuffle_tile
(
bias_shuffle_tmp
,
bias_tile
);
store_tile
(
biast_lds_shuffle_window
,
bias_shuffle_tmp
);
block_sync_lds
();
auto
biast_tile
=
load_tile
(
biast_lds_window
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x
=
raw_scale
*
x
+
type_convert
<
AccDataType
>
(
y
);
#else
x
=
scale
*
x
+
log2e_v
<
AccDataType
>
*
type_convert
<
AccDataType
>
(
y
);
#endif
},
st_acc
,
biast_tile
);
move_tile_window
(
bias_dram_window
,
{
kM0
,
0
});
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
q_origin
=
q_dram_block_window
.
get_window_origin
();
constexpr
auto
st_spans
=
decltype
(
st_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
st_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
st_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
st_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
st_acc
(
i_j_idx
)
*=
raw_scale
;
#else
st_acc
(
i_j_idx
)
*=
scale
;
#endif
position_encoding
.
update
(
st_acc
(
i_j_idx
),
row
,
col
);
});
});
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
st_acc
);
#endif
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
const
auto
q_origin
=
q_dram_block_window
.
get_window_origin
();
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
set_tile_if
(
st_acc
,
-
numeric
<
AccDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
}
const
auto
lse
=
load_tile
(
lse_dram_window
);
static
const
auto
get_validated_lse
=
[](
LSEDataType
raw_lse
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
return
raw_lse
==
-
numeric
<
LSEDataType
>::
infinity
()
?
type_convert
<
LSEDataType
>
(
0.
f
)
:
raw_lse
;
}
else
{
return
raw_lse
;
}
};
auto
pt
=
SPTBlockTileType
{};
constexpr
auto
pt_spans
=
decltype
(
pt
)
::
get_distributed_spans
();
sweep_tile_span
(
pt_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto
row_lse
=
log2e_v
<
LSEDataType
>
*
get_validated_lse
(
lse
[
i_idx
]);
#endif
sweep_tile_span
(
pt_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
pt
(
i_j_idx
)
=
exp2
(
st_acc
[
i_j_idx
]
-
row_lse
);
}
else
{
pt
(
i_j_idx
)
=
exp2
(
scale
*
st_acc
[
i_j_idx
]
-
row_lse
);
}
#else
pt
(
i_j_idx
)
=
exp
(
st_acc
[
i_j_idx
]
-
get_validated_lse
(
lse
[
i_idx
]));
#endif
});
});
auto
dot_shuffle_tmp
=
make_static_distributed_tensor
<
OGradDataType
>
(
Policy
::
template
MakeShuffledOGradTRegBlockDescriptor
<
Problem
>());
block_sync_lds
();
{
shuffle_tile
(
dot_shuffle_tmp
,
dot_prefetch
);
store_tile
(
dot_lds_window
,
dot_shuffle_tmp
);
// store the prefetch
}
move_tile_window
(
dot_dram_window
,
{
0
,
kK1
});
if
constexpr
(
kHasDropout
)
{
dropout
.
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>
(
seqlen_q_start
+
i_total_loops
*
kM0
,
pt
,
randval_dram_window
);
}
// STAGE 3, P^T@OGrad^T Gemm1
const
auto
pt_gemm
=
[
&
]()
{
if
constexpr
(
kHasDropout
)
{
return
tile_elementwise_in
(
[](
const
auto
&
x
)
{
return
type_convert
<
GemmDataType
>
(
x
>
0.
f
?
x
:
0.
f
);
},
pt
);
}
else
{
return
cast_tile
<
GemmDataType
>
(
pt
);
}
}();
if
constexpr
(
k1_loops
>
1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
const
auto
dot
=
load_tile
(
dot_dram_window
);
// load next OGrad^T
block_sync_lds
();
gemm_1
(
dv_acc
,
get_slice_tile
(
pt_gemm
,
sequence
<
i_k1
*
kK1
,
0
>
{},
sequence
<
(
i_k1
+
1
)
*
kK1
,
kN0
>
{}),
dot_lds_window
);
block_sync_lds
();
shuffle_tile
(
dot_shuffle_tmp
,
dot
);
store_tile
(
dot_lds_window
,
dot_shuffle_tmp
);
// store the prefetch
move_tile_window
(
dot_dram_window
,
{
0
,
kK1
});
});
}
auto
do_block_tile
=
load_tile
(
do_dram_window
);
// prefetch load OGrad tile
// tail
{
block_sync_lds
();
gemm_1
(
dv_acc
,
get_slice_tile
(
pt_gemm
,
sequence
<
(
k1_loops
-
1
)
*
kK1
,
0
>
{},
sequence
<
kM0
,
kN0
>
{}),
dot_lds_window
);
block_sync_lds
();
}
// STAGE 4, OGrad@V Gemm2
auto
dpt_acc
=
SPGradTBlockTileType
{};
{
move_tile_window
(
do_dram_window
,
{
0
,
kK2
});
clear_tile
(
dpt_acc
);
// Initialize PGrad^T
store_tile
(
do_lds_window
,
do_block_tile
);
// LDS write 0
do_block_tile
=
load_tile
(
do_dram_window
);
// global read 1
}
if
constexpr
(
k2_loops
>
2
)
{
static_for
<
0
,
k2_loops
-
2
,
1
>
{}([
&
](
auto
i_k2
)
{
block_sync_lds
();
gemm_2
(
dpt_acc
,
do_lds_window
,
get_slice_tile
(
v
,
sequence
<
0
,
i_k2
*
kK2
>
{},
sequence
<
kN0
,
(
i_k2
+
1
)
*
kK2
>
{}));
block_sync_lds
();
move_tile_window
(
do_dram_window
,
{
0
,
kK2
});
store_tile
(
do_lds_window
,
do_block_tile
);
// LDS write i + 1
do_block_tile
=
load_tile
(
do_dram_window
);
// global read i + 2
});
}
const
auto
qt_prefetch
=
load_tile
(
qt_dram_window
);
// prefetch load Q^T tile
{
// tail
block_sync_lds
();
gemm_2
(
dpt_acc
,
do_lds_window
,
get_slice_tile
(
v
,
sequence
<
0
,
(
k2_loops
-
2
)
*
kK2
>
{},
sequence
<
kN0
,
(
k2_loops
-
1
)
*
kK2
>
{}));
block_sync_lds
();
store_tile
(
do_lds_window
,
do_block_tile
);
block_sync_lds
();
gemm_2
(
dpt_acc
,
do_lds_window
,
get_slice_tile
(
v
,
sequence
<
0
,
(
k2_loops
-
1
)
*
kK2
>
{},
sequence
<
kN0
,
k2_loops
*
kK2
>
{}));
}
// STAGE 5, P^T(PGrad^T - D)
const
auto
d
=
load_tile
(
d_dram_window
);
auto
dst
=
SPGradTBlockTileType
{};
constexpr
auto
dst_spans
=
decltype
(
dst
)
::
get_distributed_spans
();
sweep_tile_span
(
dst_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
dst_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
bool
undrop_flag
=
pt
[
i_j_idx
]
>=
0
;
dst
(
i_j_idx
)
=
pt
[
i_j_idx
]
*
(
!
kHasDropout
||
undrop_flag
?
(
dpt_acc
[
i_j_idx
]
-
d
[
i_idx
])
:
d
[
i_idx
]);
});
});
if
constexpr
(
kHasBiasGrad
)
{
const
auto
dbiast
=
[
&
]()
{
if
constexpr
(
kHasDropout
)
{
return
tile_elementwise_in
(
[
&
rp_undrop
](
const
auto
&
x
)
{
return
type_convert
<
BiasGradDataType
>
(
x
*
rp_undrop
);
},
dst
);
}
else
{
return
cast_tile
<
BiasGradDataType
>
(
dst
);
}
}();
store_tile
(
biast_lds_shuffle_window
,
dbiast
);
block_sync_lds
();
auto
dbiast_tile
=
load_tile
(
dbiast_lds_shuffle_window
);
auto
dbiast_shuffle_tmp
=
make_static_distributed_tensor
<
BiasGradDataType
>
(
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
shuffle_tile
(
dbiast_shuffle_tmp
,
dbiast_tile
);
store_tile
(
dbias_dram_block_window
,
dbiast_shuffle_tmp
);
move_tile_window
(
dbias_dram_block_window
,
{
kM0
,
0
});
}
// STAGE 6, SGrad^T@Q^T Gemm3
auto
qt_shuffle_tmp
=
make_static_distributed_tensor
<
QDataType
>
(
Policy
::
template
MakeShuffledQTRegBlockDescriptor
<
Problem
>());
block_sync_lds
();
{
shuffle_tile
(
qt_shuffle_tmp
,
qt_prefetch
);
store_tile
(
qt_lds_window
,
qt_shuffle_tmp
);
// store the prefetch
}
move_tile_window
(
qt_dram_window
,
{
0
,
kK3
});
const
auto
dst_gemm
=
cast_tile
<
GemmDataType
>
(
dst
);
if
constexpr
(
k3_loops
>
1
)
{
static_for
<
0
,
k3_loops
-
1
,
1
>
{}([
&
](
auto
i_k3
)
{
const
auto
qt
=
load_tile
(
qt_dram_window
);
// load next Q^T
block_sync_lds
();
gemm_3
(
dk_acc
,
get_slice_tile
(
dst_gemm
,
sequence
<
i_k3
*
kK3
,
0
>
{},
sequence
<
(
i_k3
+
1
)
*
kK3
,
kN0
>
{}),
qt_lds_window
);
block_sync_lds
();
shuffle_tile
(
qt_shuffle_tmp
,
qt
);
store_tile
(
qt_lds_window
,
qt_shuffle_tmp
);
// store the prefetch
move_tile_window
(
qt_dram_window
,
{
0
,
kK3
});
});
}
// tail
{
block_sync_lds
();
gemm_3
(
dk_acc
,
get_slice_tile
(
dst_gemm
,
sequence
<
(
k3_loops
-
1
)
*
kK3
,
0
>
{},
sequence
<
kM0
,
kN0
>
{}),
qt_lds_window
);
block_sync_lds
();
}
// STAGE 7, SGrad@K^T Gemm4
store_tile
(
ds_lds_window
,
dst_gemm
);
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
// Initialize QGrad
block_sync_lds
();
static_for
<
0
,
k4_loops
,
1
>
{}([
&
](
auto
i_k4
)
{
gemm_4
(
dq_acc
,
get_slice_tile
(
ds_lds_window
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kM0
,
(
i_k4
+
1
)
*
kK4
>
{}),
get_slice_tile
(
kt_lds_window
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kQKHeaddim
,
(
i_k4
+
1
)
*
kK4
>
{}));
});
// QGrad Scale
if
constexpr
(
kHasDropout
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dq_acc
);
}
else
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dq_acc
);
}
const
auto
dq
=
cast_tile
<
QGradDataType
>
(
dq_acc
);
update_tile
(
dq_dram_block_window
,
dq
);
// move tile windows
move_tile_window
(
q_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
dq_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
do_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
lse_dram_window
,
{
kM0
});
move_tile_window
(
d_dram_window
,
{
kM0
});
}
while
(
++
i_total_loops
<
num_total_loop
);
// KGrad Scale
if
constexpr
(
kHasDropout
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dk_acc
);
}
else
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dk_acc
);
}
// VGrad Scale
if
constexpr
(
kHasDropout
)
{
tile_elementwise_inout
([
&
rp_undrop
](
auto
&
x
)
{
x
=
x
*
rp_undrop
;
},
dv_acc
);
}
return
ck_tile
::
make_tuple
(
dk_acc
,
dv_acc
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp
0 → 100644
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
// This pipeline is v located in regs, k & k^t located in lds.
using
BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy
=
BlockFmhaBwdPipelineDefaultPolicy
<
/* QLoadOnce_ = */
false
,
/* QTLoadOnce_ = */
false
,
/* KLoadOnce_ = */
true
,
/* KTLoadOnce_ = */
true
,
/* VLoadOnce_ = */
true
,
/* OGradLoadOnce_ = */
false
,
/* OGradTLoadOnce_ = */
false
>
;
}
// namespace ck_tile
Prev
1
…
11
12
13
14
15
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment