Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4885c38a
Commit
4885c38a
authored
Sep 03, 2024
by
aska-0096
Browse files
Merge branch 'transpose_opt' of
https://github.com/ROCm/composable_kernel
into rowwise_opt
parents
cbf14ee1
7c8e92fa
Changes
83
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2228 additions
and
1091 deletions
+2228
-1091
include/ck_tile/core/numeric/math.hpp
include/ck_tile/core/numeric/math.hpp
+10
-3
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+52
-1
include/ck_tile/core/utility/type_traits.hpp
include/ck_tile/core/utility/type_traits.hpp
+17
-0
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+1
-0
include/ck_tile/host/host_tensor.hpp
include/ck_tile/host/host_tensor.hpp
+9
-0
include/ck_tile/host/kernel_launch.hpp
include/ck_tile/host/kernel_launch.hpp
+5
-5
include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp
...reference/reference_batched_rotary_position_embedding.hpp
+73
-0
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+6
-2
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
+19
-3
include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp
include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp
+108
-0
include/ck_tile/ops/fmha/block/page_block_navigator.hpp
include/ck_tile/ops/fmha/block/page_block_navigator.hpp
+279
-0
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp
+679
-0
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp
...le/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp
+42
-0
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+182
-177
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp
...le/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp
+277
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp
...eline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp
+288
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
...mha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
+98
-80
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp
...peline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp
+0
-770
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp
...ha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp
+0
-19
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
...ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
+83
-31
No files found.
include/ck_tile/core/numeric/math.hpp
View file @
4885c38a
...
...
@@ -536,13 +536,20 @@ float log(float x) { return __logf(x); };
CK_TILE_HOST
float
log
(
float
x
)
{
return
std
::
logf
(
x
);
};
CK_TILE_DEVICE
uint
32
_t
sad
(
uint
32
_t
x
,
uint
32
_t
y
,
uint
32
_t
acc
)
CK_TILE_DEVICE
uint
16
_t
sad
_u16
(
uint
16
_t
x
,
uint
16
_t
y
,
uint
16
_t
acc
)
{
// TODO: this is hacky, we use u16
return
__builtin_amdgcn_sad_u16
(
x
,
y
,
acc
);
}
CK_TILE_HOST
uint32_t
sad
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
CK_TILE_DEVICE
uint32_t
sad_u32
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
{
/// TODO: replace inline asm when intrinsic is available
uint32_t
res
;
asm
volatile
(
"v_sad_u32 %0, %1, %2, %3"
:
"=v"
(
res
)
:
"v"
(
x
),
"v"
(
y
),
"v"
(
acc
));
return
res
;
}
CK_TILE_HOST
uint32_t
sad_u32
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
{
return
(
x
>
y
?
(
x
-
y
)
:
(
y
-
x
))
+
acc
;
}
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
4885c38a
...
...
@@ -214,6 +214,12 @@ struct tile_window_with_static_distribution
CK_TILE_DEVICE
constexpr
auto
get_window_origin
()
const
{
return
window_origin_
;
}
CK_TILE_DEVICE
constexpr
void
set_bottom_tensor_view_data_ptr
(
typename
BottomTensorView
::
DataType
*
data
)
{
bottom_tensor_view_
.
buf_
.
p_data_
=
data
;
}
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
CK_TILE_DEVICE
void
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
...
...
@@ -393,7 +399,8 @@ struct tile_window_with_static_distribution
bottom_tensor_thread_coord
,
bool_constant
<
oob_conditional_check
>
{},
pre_nop_
);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
asm
volatile
(
""
);
// this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
...
...
@@ -843,6 +850,17 @@ struct tile_window_with_static_lengths
CK_TILE_DEVICE
constexpr
auto
get_window_origin
()
const
{
return
window_origin_
;
}
CK_TILE_DEVICE
void
set_window_origin
(
const
BottomTensorIndex
&
new_window_origin
)
{
window_origin_
=
new_window_origin
;
}
CK_TILE_DEVICE
constexpr
void
set_bottom_tensor_view_data_ptr
(
typename
BottomTensorView
::
DataType
*
data
)
{
bottom_tensor_view_
.
buf_
.
p_data_
=
data
;
}
// move window-origin
CK_TILE_DEVICE
void
move
(
const
BottomTensorIndex
&
step
)
{
window_origin_
+=
step
;
}
...
...
@@ -871,6 +889,39 @@ make_tile_window(const TensorView_& tensor_view,
tensor_view
,
window_lengths
,
origin
};
}
// duplicate tile window and replace its origin
template
<
typename
TensorView
,
typename
WindowLengths
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window
(
const
tile_window_with_static_lengths
<
TensorView
,
WindowLengths
>&
tile_window
,
const
multi_index
<
TensorView
::
get_num_of_dimension
()
>&
origin
)
{
return
tile_window_with_static_lengths
<
TensorView
,
WindowLengths
>
{
tile_window
.
get_bottom_tensor_view
(),
tile_window
.
get_window_lengths
(),
origin
};
}
template
<
typename
TensorView
,
typename
WindowLengths
,
typename
StaticTileDistribution
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window
(
const
tile_window_with_static_lengths
<
TensorView
,
WindowLengths
>&
tile_window
,
const
multi_index
<
TensorView
::
get_num_of_dimension
()
>&
origin
,
const
StaticTileDistribution
&
tile_distribution
)
{
return
make_tile_window
(
tile_window
.
get_bottom_tensor_view
(),
tile_window
.
get_window_lengths
(),
origin
,
tile_distribution
);
}
template
<
typename
TensorView
,
typename
WindowLengths
,
typename
StaticTileDistribution
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window
(
const
tile_window_with_static_lengths
<
TensorView
,
WindowLengths
>&
tile_window
,
const
StaticTileDistribution
&
tile_distribution
)
{
return
make_tile_window
(
tile_window
.
get_bottom_tensor_view
(),
tile_window
.
get_window_lengths
(),
tile_window
.
get_window_origin
(),
tile_distribution
);
}
template
<
typename
TensorView_
,
typename
WindowLengths_
>
CK_TILE_DEVICE
void
move_tile_window
(
tile_window_with_static_lengths
<
TensorView_
,
WindowLengths_
>&
window
,
...
...
include/ck_tile/core/utility/type_traits.hpp
View file @
4885c38a
...
...
@@ -22,6 +22,23 @@ using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
template
<
typename
T
>
using
remove_pointer_t
=
typename
std
::
remove_pointer
<
T
>::
type
;
template
<
typename
From
,
typename
To
>
struct
copy_const
{
static_assert
(
!
std
::
is_const_v
<
From
>
);
using
type
=
To
;
};
template
<
typename
From
,
typename
To
>
struct
copy_const
<
const
From
,
To
>
{
using
type
=
std
::
add_const_t
<
typename
copy_const
<
From
,
To
>::
type
>
;
};
template
<
typename
From
,
typename
To
>
using
copy_const_t
=
typename
copy_const
<
From
,
To
>::
type
;
namespace
detail
{
template
<
class
Default
,
class
AlwaysVoid
,
template
<
class
...
>
class
Op
,
class
...
Args
>
struct
detector
...
...
include/ck_tile/host.hpp
View file @
4885c38a
...
...
@@ -15,6 +15,7 @@
#include "ck_tile/host/reference/reference_batched_elementwise.hpp"
#include "ck_tile/host/reference/reference_batched_gemm.hpp"
#include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
...
...
include/ck_tile/host/host_tensor.hpp
View file @
4885c38a
...
...
@@ -155,7 +155,12 @@ struct HostTensorDescriptor
return
space
;
}
std
::
size_t
get_length
(
std
::
size_t
dim
)
const
{
return
mLens
[
dim
];
}
const
std
::
vector
<
std
::
size_t
>&
get_lengths
()
const
{
return
mLens
;
}
std
::
size_t
get_stride
(
std
::
size_t
dim
)
const
{
return
mStrides
[
dim
];
}
const
std
::
vector
<
std
::
size_t
>&
get_strides
()
const
{
return
mStrides
;
}
template
<
typename
...
Is
>
...
...
@@ -325,8 +330,12 @@ struct HostTensor
{
}
std
::
size_t
get_length
(
std
::
size_t
dim
)
const
{
return
mDesc
.
get_length
(
dim
);
}
decltype
(
auto
)
get_lengths
()
const
{
return
mDesc
.
get_lengths
();
}
std
::
size_t
get_stride
(
std
::
size_t
dim
)
const
{
return
mDesc
.
get_stride
(
dim
);
}
decltype
(
auto
)
get_strides
()
const
{
return
mDesc
.
get_strides
();
}
std
::
size_t
get_num_of_dimension
()
const
{
return
mDesc
.
get_num_of_dimension
();
}
...
...
include/ck_tile/host/kernel_launch.hpp
View file @
4885c38a
...
...
@@ -73,17 +73,17 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables)
{
// clang-format off
if
(
!
s
.
time_kernel_
)
{
(
callables
(
s
),...);
hip_check_error
(
hipGetLastError
());
(
callables
(
s
),...);
HIP_CHECK_ERROR
(
hipGetLastError
());
return
0
;
}
if
(
s
.
is_gpu_timer_
)
{
gpu_timer
timer
{};
// warmup
for
(
int
i
=
0
;
i
<
s
.
cold_niters_
;
i
++
)
{
(
callables
(
s
),...);
}
hip_check_error
(
hipGetLastError
());
for
(
int
i
=
0
;
i
<
s
.
cold_niters_
;
i
++
)
{
(
callables
(
s
),...);
}
HIP_CHECK_ERROR
(
hipGetLastError
());
timer
.
start
(
s
.
stream_id_
);
for
(
int
i
=
0
;
i
<
s
.
nrepeat_
;
i
++
)
{
(
callables
(
s
),...);
}
hip_check_error
(
hipGetLastError
());
for
(
int
i
=
0
;
i
<
s
.
nrepeat_
;
i
++
)
{
(
callables
(
s
),...);
}
HIP_CHECK_ERROR
(
hipGetLastError
());
timer
.
stop
(
s
.
stream_id_
);
return
timer
.
duration
()
/
s
.
nrepeat_
;
...
...
@@ -92,10 +92,10 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables)
cpu_timer
timer
{};
// warmup
for
(
int
i
=
0
;
i
<
s
.
cold_niters_
;
i
++
)
{
(
callables
(
s
),...);
}
hip_check_error
(
hipGetLastError
());
for
(
int
i
=
0
;
i
<
s
.
cold_niters_
;
i
++
)
{
(
callables
(
s
),...);
}
HIP_CHECK_ERROR
(
hipGetLastError
());
timer
.
start
(
s
.
stream_id_
);
for
(
int
i
=
0
;
i
<
s
.
nrepeat_
;
i
++
)
{
(
callables
(
s
),...);
}
hip_check_error
(
hipGetLastError
());
for
(
int
i
=
0
;
i
<
s
.
nrepeat_
;
i
++
)
{
(
callables
(
s
),...);
}
HIP_CHECK_ERROR
(
hipGetLastError
());
timer
.
stop
(
s
.
stream_id_
);
return
timer
.
duration
()
/
s
.
nrepeat_
;
...
...
include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp
0 → 100644
View file @
4885c38a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <cassert>
#include <thread>
namespace
ck_tile
{
template
<
typename
DataType
,
typename
ComputeDataType
=
float
>
CK_TILE_HOST
void
reference_batched_rotary_position_embedding
(
const
HostTensor
<
DataType
>&
input_bsd
,
const
HostTensor
<
DataType
>&
cos_sd
,
const
HostTensor
<
DataType
>&
sin_sd
,
bool
interleaved
,
HostTensor
<
DataType
>&
output_bsd
,
bool
use_1_row_sin_cos
=
false
)
{
assert
(
cos_sd
.
get_num_of_dimension
()
==
2
&&
sin_sd
.
get_num_of_dimension
()
==
2
);
assert
(
cos_sd
.
get_length
(
0
)
==
sin_sd
.
get_length
(
0
)
&&
cos_sd
.
get_length
(
1
)
==
sin_sd
.
get_length
(
1
));
const
index_t
rotary_dim
=
cos_sd
.
get_length
(
1
)
*
2
;
assert
(
static_cast
<
std
::
size_t
>
(
rotary_dim
)
<=
input_bsd
.
get_length
(
2
));
output_bsd
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
const
index_t
i_d
=
i
[
2
];
if
(
rotary_dim
<=
i_d
)
{
self
(
i
)
=
input_bsd
(
i
);
return
;
}
assert
(
i_d
<
rotary_dim
);
const
index_t
i_s
=
i
[
1
];
const
index_t
i_s_cos_sin
=
(
use_1_row_sin_cos
?
0
:
i_s
);
const
ComputeDataType
cos
=
type_convert
<
ComputeDataType
>
(
interleaved
?
cos_sd
(
i_s_cos_sin
,
i_d
/
2
)
:
cos_sd
(
i_s_cos_sin
,
i_d
%
cos_sd
.
get_length
(
1
)));
const
ComputeDataType
sin
=
type_convert
<
ComputeDataType
>
(
interleaved
?
sin_sd
(
i_s_cos_sin
,
i_d
/
2
)
:
sin_sd
(
i_s_cos_sin
,
i_d
%
sin_sd
.
get_length
(
1
)));
const
ComputeDataType
half_rotated_input
=
[
&
]
{
const
index_t
i_b
=
i
[
0
];
if
(
interleaved
)
{
const
bool
is_even
=
(
i_d
%
2
==
0
);
const
index_t
pos
=
i_d
+
(
is_even
?
1
:
-
1
);
const
ComputeDataType
sign
=
(
is_even
?
-
1
:
1
);
return
sign
*
type_convert
<
ComputeDataType
>
(
input_bsd
(
i_b
,
i_s
,
pos
));
}
else
{
const
index_t
half_rdim
=
(
rotary_dim
/
2
);
const
index_t
pos
=
(
i_d
+
half_rdim
)
%
rotary_dim
;
const
ComputeDataType
sign
=
(
pos
<
half_rdim
?
1
:
-
1
);
return
sign
*
type_convert
<
ComputeDataType
>
(
input_bsd
(
i_b
,
i_s
,
pos
));
}
}();
ComputeDataType
result
=
type_convert
<
ComputeDataType
>
(
input_bsd
(
i
))
*
cos
+
half_rotated_input
*
sin
;
self
(
i
)
=
type_convert
<
DataType
>
(
result
);
});
}
}
// namespace ck_tile
include/ck_tile/ops/fmha.hpp
View file @
4885c38a
...
...
@@ -7,7 +7,11 @@
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
#include "ck_tile/ops/fmha/block/page_block_navigator.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
...
...
@@ -21,11 +25,11 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
...
...
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
View file @
4885c38a
...
...
@@ -43,9 +43,12 @@ enum struct AlibiMode
FROM_BOTTOM_RIGHT
=
2
,
};
template
<
typename
DataType
,
bool
RowMajor
=
true
>
template
<
typename
DataType
,
bool
RowMajor
=
true
,
unsigned
LogMaxSadOprndSize
=
16
>
struct
Alibi
{
static_assert
(
1
<=
LogMaxSadOprndSize
&&
LogMaxSadOprndSize
<=
32
,
"for LogMaxSadOprndSize <= 16, we use SAD uint16_t, otherwise, use SAD uint32_t"
);
// RowMajor here means if pixel within the same thread are along the row, or col
// this may impact the performance of update(), while the result are the same.
// e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false
...
...
@@ -79,6 +82,19 @@ struct Alibi
mode
=
mode_
;
}
CK_TILE_HOST
uint32_t
sad
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
{
return
sad_u32
(
x
,
y
,
acc
);
}
CK_TILE_DEVICE
uint32_t
sad
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
{
if
constexpr
(
LogMaxSadOprndSize
<=
16
)
{
return
sad_u16
(
static_cast
<
uint16_t
>
(
x
),
static_cast
<
uint16_t
>
(
y
),
static_cast
<
uint16_t
>
(
acc
));
}
return
sad_u32
(
x
,
y
,
acc
);
}
CK_TILE_HOST_DEVICE
void
update
(
DataType
&
pixel
,
index_t
row_idx
,
index_t
col_idx
)
{
if
constexpr
(
RowMajor
)
...
...
@@ -128,7 +144,7 @@ struct EmptyPositionEncoding
// can convert from the FA style left/right to our generic coordinate
// if left_size < 0 && right_size = 0, it is normal causal mask
// local is left_size >=0 or right_size >=0
template
<
typename
DataType
,
bool
RowMajor
=
true
>
template
<
typename
DataType
,
bool
RowMajor
=
true
,
unsigned
LogMaxSadOprndSize
=
16
>
CK_TILE_HOST_DEVICE
auto
make_alibi_from_lr_mask
(
DataType
slope
,
index_t
window_left_size
,
index_t
window_right_size
,
...
...
@@ -142,7 +158,7 @@ CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope,
AlibiMode
alibi_mode
=
is_causal
?
AlibiMode
::
VERTICAL
:
static_cast
<
AlibiMode
>
(
mask_enum
)
/*either top-left or bottom-right*/
;
return
Alibi
<
DataType
,
RowMajor
>
{
slope
,
y_total
,
x_total
,
alibi_mode
};
return
Alibi
<
DataType
,
RowMajor
,
LogMaxSadOprndSize
>
{
slope
,
y_total
,
x_total
,
alibi_mode
};
}
// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
...
...
include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp
0 → 100644
View file @
4885c38a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
namespace
ck_tile
{
// This class is used for codegen pattern matching
enum
class
RotaryEmbeddingEnum
{
NONE
=
0
,
INTERLEAVED
=
1
,
// combine dimensions 0 & 1, 2 & 3, etc
HALF_ROTATED
=
2
,
// combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1, etc
};
template
<
RotaryEmbeddingEnum
>
struct
RotaryEmbeddingEnumToStr
;
template
<
>
struct
RotaryEmbeddingEnumToStr
<
RotaryEmbeddingEnum
::
NONE
>
{
static
constexpr
const
char
*
name
=
""
;
};
template
<
>
struct
RotaryEmbeddingEnumToStr
<
RotaryEmbeddingEnum
::
INTERLEAVED
>
{
static
constexpr
const
char
*
name
=
"inter"
;
};
template
<
>
struct
RotaryEmbeddingEnumToStr
<
RotaryEmbeddingEnum
::
HALF_ROTATED
>
{
static
constexpr
const
char
*
name
=
"half"
;
};
template
<
RotaryEmbeddingEnum
RotaryEnum
,
typename
ComputeDataType
=
float
>
struct
BlockRotaryEmbedding
{
template
<
typename
DistributedTensor
,
typename
OtherDramBlockWindow
,
typename
RotaryCosDramBlockWindow
,
typename
RotarySinDramBlockWindow
>
CK_TILE_HOST_DEVICE
static
void
apply
(
DistributedTensor
&
tile
,
OtherDramBlockWindow
other_window
,
RotaryCosDramBlockWindow
rotary_cos_window
,
RotarySinDramBlockWindow
rotary_sin_window
,
index_t
rotary_dim
,
index_t
thread_end
)
{
using
DataType
=
typename
remove_cvref_t
<
DistributedTensor
>::
DataType
;
if
constexpr
(
RotaryEnum
==
RotaryEmbeddingEnum
::
INTERLEAVED
)
{
auto
rotary_cos_tile
=
load_tile
(
rotary_cos_window
);
auto
rotary_sin_tile
=
load_tile
(
rotary_sin_window
);
if
(
thread_end
<=
rotary_dim
)
{
constexpr
index_t
thread_buffer_size
=
decltype
(
tile
.
thread_buf_
)
::
size
();
static_for
<
0
,
thread_buffer_size
,
2
>
{}([
&
](
auto
idx
)
{
const
auto
left
=
type_convert
<
ComputeDataType
>
(
tile
.
thread_buf_
[
idx
]);
const
auto
right
=
type_convert
<
ComputeDataType
>
(
tile
.
thread_buf_
[
idx
+
1
]);
const
auto
cos
=
type_convert
<
ComputeDataType
>
(
rotary_cos_tile
.
thread_buf_
[
idx
/
2
]);
const
auto
sin
=
type_convert
<
ComputeDataType
>
(
rotary_sin_tile
.
thread_buf_
[
idx
/
2
]);
tile
.
thread_buf_
[
idx
]
=
type_convert
<
DataType
>
(
left
*
cos
-
right
*
sin
);
tile
.
thread_buf_
[
idx
+
1
]
=
type_convert
<
DataType
>
(
right
*
cos
+
left
*
sin
);
});
}
}
else
if
constexpr
(
RotaryEnum
==
RotaryEmbeddingEnum
::
HALF_ROTATED
)
{
if
(
thread_end
<=
rotary_dim
)
{
const
bool
is_left
=
(
thread_end
<=
(
rotary_dim
/
2
));
move_tile_window
(
other_window
,
{
0
,
is_left
?
rotary_dim
/
2
:
-
(
rotary_dim
/
2
)});
auto
other_tile
=
load_tile
(
other_window
);
move_tile_window
(
rotary_cos_window
,
{
0
,
is_left
?
0
:
-
(
rotary_dim
/
2
)});
auto
rotary_cos_tile
=
load_tile
(
rotary_cos_window
);
move_tile_window
(
rotary_sin_window
,
{
0
,
is_left
?
0
:
-
(
rotary_dim
/
2
)});
auto
rotary_sin_tile
=
load_tile
(
rotary_sin_window
);
constexpr
index_t
thread_buffer_size
=
decltype
(
tile
.
thread_buf_
)
::
size
();
static_for
<
0
,
thread_buffer_size
,
1
>
{}([
&
](
auto
idx
)
{
const
auto
curr
=
type_convert
<
ComputeDataType
>
(
tile
.
thread_buf_
[
idx
]);
const
auto
other
=
type_convert
<
ComputeDataType
>
(
other_tile
.
thread_buf_
[
idx
]);
const
auto
cos
=
type_convert
<
ComputeDataType
>
(
rotary_cos_tile
.
thread_buf_
[
idx
]);
const
auto
sin
=
type_convert
<
ComputeDataType
>
(
rotary_sin_tile
.
thread_buf_
[
idx
]);
tile
.
thread_buf_
[
idx
]
=
type_convert
<
DataType
>
(
curr
*
cos
+
other
*
(
is_left
?
-
sin
:
sin
));
});
}
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/block/page_block_navigator.hpp
0 → 100644
View file @
4885c38a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
namespace
ck_tile
{
// assume that we have only 1 page-block/tensor view
template
<
typename
TensorView
>
struct
TrivialPageBlockNavigator
{
using
DataType
=
typename
TensorView
::
DataType
;
using
WindowOrigin
=
multi_index
<
2
>
;
CK_TILE_HOST_DEVICE
constexpr
TrivialPageBlockNavigator
(
const
TensorView
&
tensor_view_
)
:
tensor_view
(
tensor_view_
)
{
}
template
<
typename
WindowLengths
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_tile_window
(
const
WindowLengths
&
window_lengths
,
const
WindowOrigin
&
window_origin
)
const
{
return
make_tuple
(
/*block_index=*/
0
,
ck_tile
::
make_tile_window
(
tensor_view
,
window_lengths
,
window_origin
));
}
template
<
typename
WindowLengths
,
typename
TileDistribution
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_tile_window
(
const
WindowLengths
&
window_lengths
,
const
WindowOrigin
&
window_origin
,
const
TileDistribution
&
tile_distribution
)
const
{
return
make_tuple
(
/*block_index=*/
0
,
ck_tile
::
make_tile_window
(
tensor_view
,
window_lengths
,
window_origin
,
tile_distribution
));
}
template
<
typename
TileWindow
>
CK_TILE_HOST_DEVICE
static
index_t
move_tile_window
(
index_t
/*block_index*/
,
TileWindow
&
tile_window
,
const
typename
remove_cvref_t
<
TileWindow
>::
BottomTensorIndex
&
step
)
{
ck_tile
::
move_tile_window
(
tile_window
,
step
);
return
/*block_index=*/
0
;
}
CK_TILE_HOST_DEVICE
static
constexpr
WindowOrigin
to_local_window_origin
(
const
WindowOrigin
&
global_window_origin
)
{
return
global_window_origin
;
}
CK_TILE_HOST_DEVICE
static
constexpr
WindowOrigin
to_global_window_origin
(
index_t
/*block_index*/
,
const
WindowOrigin
&
local_window_origin
)
{
return
local_window_origin
;
}
private:
TensorView
tensor_view
;
};
// default page-block navigator, assume that tensor view size is same as page-block size or smaller
// if tile window on last page-block
template
<
typename
DataType_
,
index_t
VirtualDim
,
typename
TensorView
>
struct
PageBlockNavigator
{
using
DataType
=
DataType_
;
static_assert
(
std
::
is_same_v
<
DataType
,
typename
TensorView
::
DataType
>
);
static_assert
(
VirtualDim
==
0
||
VirtualDim
==
1
,
"only support 2d tile window"
);
using
WindowOrigin
=
multi_index
<
2
>
;
CK_TILE_HOST_DEVICE
constexpr
PageBlockNavigator
(
copy_const_t
<
DataType
,
void
>*
physical_blocks_
,
long_index_t
block_stride_
,
long_index_t
fixed_offset_
,
const
int32_t
*
physical_block_indices_
,
index_t
num_blocks_
,
index_t
page_block_size_
,
const
TensorView
&
complete_view_
,
const
TensorView
&
last_view_
)
:
physical_blocks
(
reinterpret_cast
<
DataType
*>
(
physical_blocks_
)),
block_stride
(
block_stride_
),
fixed_offset
(
fixed_offset_
),
physical_block_indices
(
physical_block_indices_
),
num_blocks
(
num_blocks_
),
page_block_size
(
page_block_size_
),
complete_view
(
complete_view_
),
last_view
(
last_view_
)
{
}
template
<
typename
WindowLengths
>
CK_TILE_HOST_DEVICE
auto
make_tile_window
(
const
WindowLengths
&
window_lengths
,
const
WindowOrigin
&
window_origin
)
const
{
const
index_t
block_index
=
get_block_index
(
window_origin
);
const
WindowOrigin
local_window_origin
=
to_local_window_origin
(
window_origin
);
auto
new_tile_window
=
ck_tile
::
make_tile_window
(
is_last_block
(
block_index
)
?
last_view
:
complete_view
,
window_lengths
,
local_window_origin
);
new_tile_window
.
set_bottom_tensor_view_data_ptr
(
get_block_ptr
(
block_index
));
return
make_tuple
(
block_index
,
new_tile_window
);
}
template
<
typename
WindowLengths
,
typename
TileDistribution
>
CK_TILE_HOST_DEVICE
auto
make_tile_window
(
const
WindowLengths
&
window_lengths
,
const
WindowOrigin
&
window_origin
,
const
TileDistribution
&
tile_distribution
)
const
{
const
index_t
block_index
=
get_block_index
(
window_origin
);
const
WindowOrigin
local_window_origin
=
to_local_window_origin
(
window_origin
);
auto
new_tile_window
=
ck_tile
::
make_tile_window
(
is_last_block
(
block_index
)
?
last_view
:
complete_view
,
window_lengths
,
local_window_origin
,
tile_distribution
);
new_tile_window
.
set_bottom_tensor_view_data_ptr
(
get_block_ptr
(
block_index
));
return
make_tuple
(
block_index
,
new_tile_window
);
}
template
<
typename
TileWindow
>
CK_TILE_HOST_DEVICE
index_t
move_tile_window
(
index_t
block_index
,
TileWindow
&
tile_window
,
const
typename
remove_cvref_t
<
TileWindow
>::
BottomTensorIndex
&
step
)
const
{
ck_tile
::
move_tile_window
(
tile_window
,
step
);
const
WindowOrigin
global_window_origin
=
to_global_window_origin
(
block_index
,
tile_window
.
get_window_origin
());
const
WindowOrigin
local_window_origin
=
to_local_window_origin
(
global_window_origin
);
const
index_t
new_block_index
=
get_block_index
(
global_window_origin
);
/// TODO: only update necessary attributes
tile_window
.
bottom_tensor_view_
.
desc_
=
(
is_last_block
(
new_block_index
)
?
last_view
:
complete_view
).
get_tensor_descriptor
();
tile_window
.
set_window_origin
(
local_window_origin
);
tile_window
.
set_bottom_tensor_view_data_ptr
(
get_block_ptr
(
new_block_index
));
return
new_block_index
;
}
CK_TILE_HOST_DEVICE
bool
is_last_block
(
index_t
block_index
)
const
{
return
block_index
==
num_blocks
-
1
;
}
template
<
typename
TileWindow
>
CK_TILE_HOST_DEVICE
bool
is_cross_block
(
index_t
block_index
,
const
TileWindow
&
tile_window
)
const
{
const
index_t
origin
=
tile_window
.
get_window_origin
().
at
(
number
<
VirtualDim
>
{});
const
index_t
length
=
tile_window
.
get_window_lengths
().
at
(
number
<
VirtualDim
>
{});
return
(
block_index
<
num_blocks
-
1
)
&&
(
page_block_size
<
origin
+
length
);
}
template
<
typename
TileWindow
>
CK_TILE_HOST_DEVICE
void
move_to_block
(
index_t
block_index
,
TileWindow
&
tile_window
,
index_t
new_block_index
)
const
{
const
multi_index
<
2
>
step
=
[
&
]()
{
const
index_t
origin_diff
=
(
block_index
-
new_block_index
)
*
page_block_size
;
if
constexpr
(
VirtualDim
==
0
)
{
return
make_multi_index
(
origin_diff
,
0
);
}
else
{
return
make_multi_index
(
0
,
origin_diff
);
}
}();
/// TODO: only update necessary attributes
tile_window
.
bottom_tensor_view_
.
desc_
=
(
is_last_block
(
new_block_index
)
?
last_view
:
complete_view
).
get_tensor_descriptor
();
tile_window
.
set_window_origin
(
tile_window
.
get_window_origin
()
+
step
);
tile_window
.
set_bottom_tensor_view_data_ptr
(
get_block_ptr
(
new_block_index
));
}
CK_TILE_HOST_DEVICE
WindowOrigin
to_local_window_origin
(
const
WindowOrigin
&
global_window_origin
)
const
{
if
constexpr
(
VirtualDim
==
0
)
{
const
index_t
length
=
global_window_origin
.
at
(
number
<
0
>
{});
const
index_t
num_complete_blocks
=
integer_divide_floor
(
length
,
page_block_size
);
return
make_multi_index
(
length
-
page_block_size
*
num_complete_blocks
,
global_window_origin
.
at
(
number
<
1
>
{}));
}
else
{
const
index_t
length
=
global_window_origin
.
at
(
number
<
1
>
{});
const
index_t
num_complete_blocks
=
integer_divide_floor
(
length
,
page_block_size
);
return
make_multi_index
(
global_window_origin
.
at
(
number
<
0
>
{}),
length
-
page_block_size
*
num_complete_blocks
);
}
}
CK_TILE_HOST_DEVICE
WindowOrigin
to_global_window_origin
(
index_t
block_index
,
const
WindowOrigin
&
local_window_origin
)
const
{
if
constexpr
(
VirtualDim
==
0
)
{
return
make_multi_index
(
block_index
*
page_block_size
+
local_window_origin
.
at
(
number
<
0
>
{}),
local_window_origin
.
at
(
number
<
1
>
{}));
}
else
{
return
make_multi_index
(
local_window_origin
.
at
(
number
<
0
>
{}),
block_index
*
page_block_size
+
local_window_origin
.
at
(
number
<
1
>
{}));
}
}
private:
CK_TILE_HOST_DEVICE
DataType
*
get_block_ptr
(
index_t
block_index
)
const
{
return
physical_blocks
+
physical_block_indices
[
block_index
]
*
block_stride
+
fixed_offset
;
}
CK_TILE_HOST_DEVICE
int32_t
get_block_index
(
const
WindowOrigin
&
global_window_origin
)
const
{
return
integer_divide_floor
(
global_window_origin
.
at
(
number
<
VirtualDim
>
{}),
page_block_size
);
}
DataType
*
physical_blocks
;
long_index_t
block_stride
;
long_index_t
fixed_offset
;
const
int32_t
*
physical_block_indices
;
index_t
num_blocks
;
index_t
page_block_size
;
TensorView
complete_view
;
TensorView
last_view
;
};
template
<
typename
TensorView
>
CK_TILE_HOST_DEVICE
auto
make_page_block_navigator
(
const
TensorView
&
tensor_view
)
{
return
TrivialPageBlockNavigator
<
TensorView
>
(
tensor_view
);
}
template
<
typename
DataType
,
index_t
VirtualDim
,
typename
TensorView
>
CK_TILE_HOST_DEVICE
auto
make_page_block_navigator
(
copy_const_t
<
DataType
,
void
>*
physical_blocks
,
long_index_t
block_stride
,
long_index_t
fixed_offset
,
const
int32_t
*
physical_block_indices
,
index_t
num_blocks
,
index_t
page_block_size
,
const
TensorView
&
complete_view
,
const
TensorView
&
last_view
)
{
return
PageBlockNavigator
<
DataType
,
VirtualDim
,
TensorView
>
(
physical_blocks
,
block_stride
,
fixed_offset
,
physical_block_indices
,
num_blocks
,
page_block_size
,
complete_view
,
last_view
);
}
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp
0 → 100644
View file @
4885c38a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
>
struct
FmhaFwdAppendKVKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
ck_tile
::
remove_cvref_t
<
FmhaPipeline_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaPipeline
::
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kBlockPerCu
=
FmhaPipeline
::
kBlockPerCu
;
static_assert
(
kBlockPerCu
>
0
);
static
constexpr
ck_tile
::
index_t
kBlockPerCuInput
=
FmhaPipeline
::
Problem
::
kBlockPerCu
;
using
QDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
QDataType
>
;
using
KDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
KDataType
>
;
using
VDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
VDataType
>
;
using
VLayout
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
VLayout
>
;
static
constexpr
bool
kApplyRoPE
=
FmhaPipeline
::
RotaryEnum
!=
RotaryEmbeddingEnum
::
NONE
;
static
constexpr
bool
kIsPagedKV
=
FmhaPipeline
::
kIsPagedKV
;
static
constexpr
bool
kPadSeqLenQ
=
FmhaPipeline
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
FmhaPipeline
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
FmhaPipeline
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
// clang-format on
__host__
static
std
::
string
GetName
()
{
// sync with generate.py
// clang-format off
#define _SS_ std::string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
std
::
string
n
;
if
(
kPadSeqLenQ
)
n
+=
"s"
;
if
(
kPadSeqLenK
)
n
+=
"sk"
;
if
(
kPadHeadDimQ
)
n
+=
"d"
;
if
(
kPadHeadDimV
)
n
+=
"dv"
;
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
_SS_
(
"fmha_fwd_appendkv_d"
)
+
_TS_
(
FmhaPipeline
::
kK0
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
"_"
"b"
+
_TS_
(
FmhaPipeline
::
kM0
)
+
"x"
+
_TS_
(
FmhaPipeline
::
kN0
)
+
"x"
+
_TS_
(
FmhaPipeline
::
kK0
)
+
"x"
+
_TS_
(
FmhaPipeline
::
kN1
)
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
!
kApplyRoPE
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
RotaryEmbeddingEnumToStr
<
FmhaPipeline
::
RotaryEnum
>::
name
))
+
(
kIsPagedKV
?
"_pagedkv"
:
""
);
#undef _SS_
#undef _TS_
// clang-format on
}
template
<
ck_tile
::
index_t
I
>
// to avoid duplicated base class prblem, introduce an template
// arg
struct
EmptyKargs
{
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct
BasicKargs
{
void
*
q_ptr
;
void
*
k_ptr
;
const
void
*
knew_ptr
;
void
*
v_ptr
;
const
void
*
vnew_ptr
;
const
int32_t
*
seqlen_k_ptr
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_k
;
ck_tile
::
index_t
seqlen_knew
;
ck_tile
::
index_t
hdim_q
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
num_head_q
;
// for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
// if this param is larger than 1, indicate MQA/GQA case
ck_tile
::
index_t
nhead_ratio_qk
;
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_knew
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_vnew
;
ck_tile
::
index_t
nhead_stride_q
;
ck_tile
::
index_t
nhead_stride_k
;
ck_tile
::
index_t
nhead_stride_knew
;
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_vnew
;
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_knew
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_vnew
;
};
struct
RoPEKargs
{
const
void
*
rotary_cos_ptr
;
const
void
*
rotary_sin_ptr
;
ck_tile
::
index_t
rotary_dim
;
bool
has_mask
;
};
struct
PageBlockTableKargs
{
const
int32_t
*
block_table_ptr
;
ck_tile
::
index_t
batch_stride_block_table
;
ck_tile
::
index_t
page_block_size
;
};
struct
CacheBatchIdxKargs
{
const
int32_t
*
cache_batch_idx
;
};
struct
Kargs
:
BasicKargs
,
std
::
conditional_t
<
kApplyRoPE
,
RoPEKargs
,
EmptyKargs
<
0
>>
,
std
::
conditional_t
<
kIsPagedKV
,
PageBlockTableKargs
,
CacheBatchIdxKargs
>
{
};
__host__
static
constexpr
Kargs
MakeKargs
(
void
*
q_ptr
,
void
*
k_ptr
,
const
void
*
knew_ptr
,
void
*
v_ptr
,
const
void
*
vnew_ptr
,
ck_tile
::
index_t
seqlen_q
,
const
void
*
seqlen_k_ptr
,
ck_tile
::
index_t
seqlen_knew
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_head_q
,
ck_tile
::
index_t
nhead_ratio_qk
,
const
void
*
rotary_cos_ptr
,
const
void
*
rotary_sin_ptr
,
ck_tile
::
index_t
rotary_dim
,
bool
has_mask
,
const
void
*
block_table_ptr
,
ck_tile
::
index_t
batch_stride_block_table
,
ck_tile
::
index_t
page_block_size
,
const
void
*
cache_batch_idx
,
ck_tile
::
index_t
stride_q
,
ck_tile
::
index_t
stride_k
,
ck_tile
::
index_t
stride_knew
,
ck_tile
::
index_t
stride_v
,
ck_tile
::
index_t
stride_vnew
,
ck_tile
::
index_t
nhead_stride_q
,
ck_tile
::
index_t
nhead_stride_k
,
ck_tile
::
index_t
nhead_stride_knew
,
ck_tile
::
index_t
nhead_stride_v
,
ck_tile
::
index_t
nhead_stride_vnew
,
ck_tile
::
index_t
batch_stride_q
,
ck_tile
::
index_t
batch_stride_k
,
ck_tile
::
index_t
batch_stride_knew
,
ck_tile
::
index_t
batch_stride_v
,
ck_tile
::
index_t
batch_stride_vnew
)
{
Kargs
kargs
{
{
q_ptr
,
k_ptr
,
knew_ptr
,
v_ptr
,
vnew_ptr
,
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
),
seqlen_q
,
-
1
,
// seqlen_k will be updated by content of seqlen_k_ptr
seqlen_knew
,
hdim_q
,
hdim_v
,
num_head_q
,
nhead_ratio_qk
,
stride_q
,
stride_k
,
stride_knew
,
stride_v
,
stride_vnew
,
nhead_stride_q
,
nhead_stride_k
,
nhead_stride_knew
,
nhead_stride_v
,
nhead_stride_vnew
,
batch_stride_q
,
batch_stride_k
,
batch_stride_knew
,
batch_stride_v
,
batch_stride_vnew
},
// args for common karg
{},
// placeholder for rope
{}
// placeholder for paged-block table or cache_batch_idx
};
if
constexpr
(
kApplyRoPE
)
{
kargs
.
rotary_cos_ptr
=
rotary_cos_ptr
;
kargs
.
rotary_sin_ptr
=
rotary_sin_ptr
;
kargs
.
rotary_dim
=
rotary_dim
;
kargs
.
has_mask
=
has_mask
;
}
if
constexpr
(
kIsPagedKV
)
{
kargs
.
block_table_ptr
=
reinterpret_cast
<
const
int32_t
*>
(
block_table_ptr
);
kargs
.
batch_stride_block_table
=
batch_stride_block_table
;
kargs
.
page_block_size
=
page_block_size
;
}
else
{
kargs
.
cache_batch_idx
=
reinterpret_cast
<
const
int32_t
*>
(
cache_batch_idx
);
}
return
kargs
;
}
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_knew
)
{
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
seqlen_q
,
seqlen_knew
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// divide problem
const
auto
[
i_tile
,
i_nhead
,
i_batch
]
=
TilePartitioner
{}();
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile
*
FmhaPipeline
::
kM0
);
const
index_t
i_n0
=
__builtin_amdgcn_readfirstlane
(
i_tile
*
FmhaPipeline
::
kN0
);
const
index_t
i_cache_batch
=
[
&
,
i_batch_
=
i_batch
]
{
if
constexpr
(
kIsPagedKV
)
{
return
i_batch_
;
}
else
{
return
(
kargs
.
cache_batch_idx
!=
nullptr
?
kargs
.
cache_batch_idx
[
i_batch_
]
:
i_batch_
);
}
}();
const
long_index_t
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
const
long_index_t
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_cache_batch
)
*
kargs
.
batch_stride_k
;
const
long_index_t
batch_offset_knew
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_knew
;
const
long_index_t
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_cache_batch
)
*
kargs
.
batch_stride_v
;
const
long_index_t
batch_offset_vnew
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_vnew
;
kargs
.
seqlen_k
=
kargs
.
seqlen_k_ptr
[
i_batch
];
// for simplicity, batch stride we just modify the pointer
QDataType
*
q_ptr
=
reinterpret_cast
<
QDataType
*>
(
kargs
.
q_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_q
+
batch_offset_q
;
KDataType
*
k_ptr
=
reinterpret_cast
<
KDataType
*>
(
kargs
.
k_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_k
+
batch_offset_k
;
const
KDataType
*
knew_ptr
=
reinterpret_cast
<
const
KDataType
*>
(
kargs
.
knew_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_knew
+
batch_offset_knew
;
VDataType
*
v_ptr
=
reinterpret_cast
<
VDataType
*>
(
kargs
.
v_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_v
+
batch_offset_v
;
const
VDataType
*
vnew_ptr
=
reinterpret_cast
<
const
VDataType
*>
(
kargs
.
vnew_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_vnew
+
batch_offset_vnew
;
// Q/K/V DRAM and DRAM window
const
auto
q_dram
=
[
&
]()
{
const
auto
q_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
q_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQ
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
const
auto
make_k_dram
=
[
&
](
KDataType
*
data
,
index_t
height
)
{
const
auto
k_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
data
,
// will update this pointer if using paged-kvcache
make_tuple
(
height
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_k
,
1
),
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
};
const
auto
k_dram
=
[
&
]()
{
if
constexpr
(
kIsPagedKV
)
{
return
make_k_dram
(
nullptr
,
kargs
.
page_block_size
);
}
else
{
return
make_k_dram
(
k_ptr
,
kargs
.
seqlen_k
+
kargs
.
seqlen_knew
);
}
}();
const
auto
knew_dram
=
[
&
]()
{
const
auto
knew_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
knew_ptr
,
make_tuple
(
kargs
.
seqlen_knew
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_knew
,
1
),
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
knew_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}();
const
auto
make_v_dram
=
[
&
](
VDataType
*
data
,
index_t
length
)
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
const
auto
v_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
data
,
// will update this pointer if using paged-kvcache
make_tuple
(
length
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
stride_v
,
1
),
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
1
>
{});
const
auto
v_dram_transposed
=
transform_tensor_view
(
v_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_v
),
make_pass_through_transform
(
length
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
pad_tensor_view
(
v_dram_transposed
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kN0
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenK
>
{});
}
else
{
const
auto
v_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
data
,
// will update this pointer if using paged-kvcache
make_tuple
(
kargs
.
hdim_v
,
length
),
make_tuple
(
kargs
.
stride_v
,
1
),
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kN0
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenK
>
{});
}
};
const
auto
v_dram
=
[
&
]()
{
if
constexpr
(
kIsPagedKV
)
{
return
make_v_dram
(
nullptr
,
kargs
.
page_block_size
);
}
else
{
return
make_v_dram
(
v_ptr
,
kargs
.
seqlen_k
+
kargs
.
seqlen_knew
);
}
}();
const
auto
vnew_dram
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
const
auto
vnew_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
vnew_ptr
,
make_tuple
(
kargs
.
seqlen_knew
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
stride_vnew
,
1
),
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
1
>
{});
const
auto
vnew_dram_transposed
=
transform_tensor_view
(
vnew_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_v
),
make_pass_through_transform
(
kargs
.
seqlen_knew
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
pad_tensor_view
(
vnew_dram_transposed
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kN0
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenK
>
{});
}
else
{
const
auto
vnew_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
vnew_ptr
,
make_tuple
(
kargs
.
hdim_v
,
kargs
.
seqlen_knew
),
make_tuple
(
kargs
.
stride_vnew
,
1
),
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
vnew_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kN0
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenK
>
{});
}
}();
constexpr
auto
q_rotary_cos_sin_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
/
2
>
{});
const
auto
q_rotary_cos_dram_window
=
[
&
]()
{
if
constexpr
(
kApplyRoPE
)
{
const
auto
rotary_cos_dram_native
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
reinterpret_cast
<
const
QDataType
*>
(
kargs
.
rotary_cos_ptr
)
+
kargs
.
seqlen_k
*
(
kargs
.
rotary_dim
/
2
),
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
rotary_dim
/
2
),
make_tuple
(
kargs
.
has_mask
*
(
kargs
.
rotary_dim
/
2
),
1
),
number
<
8
>
{},
number
<
1
>
{});
const
auto
rotary_cos_dram
=
[
&
]()
{
return
pad_tensor_view
(
rotary_cos_dram_native
,
q_rotary_cos_sin_dram_window_lengths
,
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
return
make_tile_window
(
rotary_cos_dram
,
q_rotary_cos_sin_dram_window_lengths
,
{
i_m0
,
0
});
}
else
{
return
make_null_tile_window
(
q_rotary_cos_sin_dram_window_lengths
);
}
}();
const
auto
q_rotary_sin_dram_window
=
[
&
]()
{
if
constexpr
(
kApplyRoPE
)
{
const
auto
rotary_sin_dram_native
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
reinterpret_cast
<
const
QDataType
*>
(
kargs
.
rotary_sin_ptr
)
+
kargs
.
seqlen_k
*
(
kargs
.
rotary_dim
/
2
),
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
rotary_dim
/
2
),
make_tuple
(
kargs
.
has_mask
*
(
kargs
.
rotary_dim
/
2
),
1
),
number
<
8
>
{},
number
<
1
>
{});
const
auto
rotary_sin_dram
=
[
&
]()
{
return
pad_tensor_view
(
rotary_sin_dram_native
,
q_rotary_cos_sin_dram_window_lengths
,
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
return
make_tile_window
(
rotary_sin_dram
,
q_rotary_cos_sin_dram_window_lengths
,
{
i_m0
,
0
});
}
else
{
return
make_null_tile_window
(
q_rotary_cos_sin_dram_window_lengths
);
}
}();
constexpr
auto
knew_rotary_cos_sin_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
/
2
>
{});
const
auto
knew_rotary_cos_dram_window
=
[
&
]()
{
if
constexpr
(
kApplyRoPE
)
{
const
auto
rotary_cos_dram_native
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
reinterpret_cast
<
const
KDataType
*>
(
kargs
.
rotary_cos_ptr
)
+
kargs
.
seqlen_k
*
(
kargs
.
rotary_dim
/
2
),
make_tuple
(
kargs
.
seqlen_knew
,
kargs
.
rotary_dim
/
2
),
make_tuple
(
kargs
.
rotary_dim
/
2
,
1
),
number
<
8
>
{},
number
<
1
>
{});
const
auto
rotary_cos_dram
=
[
&
]()
{
return
pad_tensor_view
(
rotary_cos_dram_native
,
knew_rotary_cos_sin_dram_window_lengths
,
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}();
return
make_tile_window
(
rotary_cos_dram
,
knew_rotary_cos_sin_dram_window_lengths
,
{
i_n0
,
0
});
}
else
{
return
make_null_tile_window
(
knew_rotary_cos_sin_dram_window_lengths
);
}
}();
const
auto
knew_rotary_sin_dram_window
=
[
&
]()
{
if
constexpr
(
kApplyRoPE
)
{
const
auto
rotary_sin_dram_native
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
reinterpret_cast
<
const
KDataType
*>
(
kargs
.
rotary_sin_ptr
)
+
kargs
.
seqlen_k
*
(
kargs
.
rotary_dim
/
2
),
make_tuple
(
kargs
.
seqlen_knew
,
kargs
.
rotary_dim
/
2
),
make_tuple
(
kargs
.
rotary_dim
/
2
,
1
),
number
<
8
>
{},
number
<
1
>
{});
const
auto
rotary_sin_dram
=
[
&
]()
{
return
pad_tensor_view
(
rotary_sin_dram_native
,
knew_rotary_cos_sin_dram_window_lengths
,
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}();
return
make_tile_window
(
rotary_sin_dram
,
knew_rotary_cos_sin_dram_window_lengths
,
{
i_n0
,
0
});
}
else
{
return
make_null_tile_window
(
knew_rotary_cos_sin_dram_window_lengths
);
}
}();
auto
k_page_block_navigator
=
[
&
,
i_batch_
=
i_batch
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
kIsPagedKV
)
{
const
auto
*
block_indices
=
reinterpret_cast
<
const
int32_t
*>
(
kargs
.
block_table_ptr
)
+
i_batch_
*
kargs
.
batch_stride_block_table
;
const
index_t
num_blocks
=
integer_divide_ceil
(
kargs
.
seqlen_k
+
kargs
.
seqlen_knew
,
kargs
.
page_block_size
);
const
long_index_t
fixed_offset
=
static_cast
<
long_index_t
>
(
i_nhead_
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_k
;
return
make_page_block_navigator
<
KDataType
,
0
>
(
kargs
.
k_ptr
,
kargs
.
batch_stride_k
,
fixed_offset
,
block_indices
,
num_blocks
,
kargs
.
page_block_size
,
k_dram
,
make_k_dram
(
nullptr
,
(
kargs
.
seqlen_k
+
kargs
.
seqlen_knew
)
-
(
num_blocks
-
1
)
*
kargs
.
page_block_size
));
}
else
{
return
make_page_block_navigator
(
k_dram
);
}
}();
auto
v_page_block_navigator
=
[
&
,
i_batch_
=
i_batch
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
kIsPagedKV
)
{
const
auto
*
block_indices
=
reinterpret_cast
<
const
int32_t
*>
(
kargs
.
block_table_ptr
)
+
i_batch_
*
kargs
.
batch_stride_block_table
;
const
index_t
num_blocks
=
integer_divide_ceil
(
kargs
.
seqlen_k
+
kargs
.
seqlen_knew
,
kargs
.
page_block_size
);
const
long_index_t
fixed_offset
=
static_cast
<
long_index_t
>
(
i_nhead_
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_v
;
return
make_page_block_navigator
<
VDataType
,
1
>
(
kargs
.
v_ptr
,
kargs
.
batch_stride_v
,
fixed_offset
,
block_indices
,
num_blocks
,
kargs
.
page_block_size
,
v_dram
,
make_v_dram
(
nullptr
,
(
kargs
.
seqlen_k
+
kargs
.
seqlen_knew
)
-
(
num_blocks
-
1
)
*
kargs
.
page_block_size
));
}
else
{
return
make_page_block_navigator
(
v_dram
);
}
}();
auto
q_dram_window
=
make_tile_window
(
q_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
{
i_m0
,
0
});
const
bool
skip_append_kv
=
kargs
.
seqlen_knew
<=
i_n0
;
// window origin = (0, 0) if no work to do for current block
auto
[
i_page_block_k
,
k_dram_window
]
=
k_page_block_navigator
.
make_tile_window
(
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
{
!
skip_append_kv
*
(
kargs
.
seqlen_k
+
i_n0
),
0
});
auto
knew_dram_window
=
make_tile_window
(
knew_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
{
i_n0
,
0
});
// window origin = (0, 0) if no work to do for current block
auto
[
i_page_block_v
,
v_dram_window
]
=
v_page_block_navigator
.
make_tile_window
(
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kN0
>
{}),
{
0
,
!
skip_append_kv
*
(
kargs
.
seqlen_k
+
i_n0
)});
auto
vnew_dram_window
=
make_tile_window
(
vnew_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kN0
>
{}),
{
0
,
i_n0
});
if
constexpr
(
kApplyRoPE
)
{
FmhaPipeline
{}(
q_dram_window
,
k_dram_window
,
i_page_block_k
,
k_page_block_navigator
,
knew_dram_window
,
v_dram_window
,
i_page_block_v
,
v_page_block_navigator
,
vnew_dram_window
,
q_rotary_cos_dram_window
,
q_rotary_sin_dram_window
,
knew_rotary_cos_dram_window
,
knew_rotary_sin_dram_window
,
kargs
.
rotary_dim
,
kargs
.
seqlen_q
<=
i_m0
,
skip_append_kv
);
}
else
{
FmhaPipeline
{}(
q_dram_window
,
k_dram_window
,
i_page_block_k
,
k_page_block_navigator
,
knew_dram_window
,
v_dram_window
,
i_page_block_v
,
v_page_block_navigator
,
vnew_dram_window
,
q_rotary_cos_dram_window
,
q_rotary_sin_dram_window
,
knew_rotary_cos_dram_window
,
knew_rotary_sin_dram_window
,
0
,
// rotary_dim not used
kargs
.
seqlen_q
<=
i_m0
,
skip_append_kv
);
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp
0 → 100644
View file @
4885c38a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
index_t
kM0_
,
index_t
kN0_
,
index_t
kK0_
,
index_t
kN1_
>
struct
FmhaFwdAppendKVTilePartitioner
{
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kN0
=
kN0_
;
static
constexpr
ck_tile
::
index_t
kK0
=
kK0_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
static_assert
(
kK0
==
kN1
);
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_knew
)
{
// TODO: this may need tuning
return
dim3
(
std
::
max
(
ck_tile
::
integer_divide_ceil
(
seqlen_q
,
kM0
),
ck_tile
::
integer_divide_ceil
(
seqlen_knew
,
kN0
)),
nhead
,
batch_size
);
}
CK_TILE_DEVICE
auto
operator
()()
{
const
index_t
i_tile
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_tile
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
4885c38a
...
...
@@ -32,8 +32,6 @@ struct FmhaFwdSplitKVKernel
using
KDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
KDataType
>
;
using
VDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
VDataType
>
;
using
BiasDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
BiasDataType
>
;
using
RandValOutputDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
RandValOutputDataType
>
;
using
LSEDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
LSEDataType
>
;
using
SaccDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
SaccDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
FmhaPipeline
::
OaccDataType
>
;
...
...
@@ -46,8 +44,10 @@ struct FmhaFwdSplitKVKernel
static
constexpr
bool
kPadHeadDimQ
=
FmhaPipeline
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
bool
kHasDropout
=
FmhaPipeline
::
kHasDropout
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
static
constexpr
bool
kIsPagedKV
=
FmhaPipeline
::
Problem
::
kIsPagedKV
;
static_assert
(
!
kIsGroupMode
||
(
kIsGroupMode
&&
!
kIsPagedKV
),
"paged-kvcache only supported by batch mode kernels"
);
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
...
...
@@ -85,8 +85,8 @@ struct FmhaFwdSplitKVKernel
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
k
HasDropout
?
"_dropou
t"
:
""
)
+
(
k
DoFp8StaticQuant
?
"_squant
"
:
""
);
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
k
DoFp8StaticQuant
?
"_squan
t"
:
""
)
+
(
k
IsPagedKV
?
"_pagedkv
"
:
""
);
#undef _SS_
#undef _TS_
// clang-format on
...
...
@@ -110,7 +110,6 @@ struct FmhaFwdSplitKVKernel
void
*
o_acc_ptr
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
max_seqlen_q
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_k
;
...
...
@@ -136,6 +135,7 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
nhead_stride_lse_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
...
...
@@ -173,32 +173,16 @@ struct FmhaFwdSplitKVKernel
float
scale_p
;
};
struct
CommonDropout
Kargs
struct
PageBlockTable
Kargs
{
void
init_dropout
(
const
float
p_drop
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
float
p_undrop
=
1.0
-
p_drop
;
p_undrop_in_uint8_t
=
uint8_t
(
std
::
floor
(
p_undrop
*
std
::
numeric_limits
<
uint8_t
>::
max
()));
rp_undrop
=
1.0
/
p_undrop
;
drop_seed
=
std
::
get
<
0
>
(
drop_seed_offset
);
drop_offset
=
std
::
get
<
1
>
(
drop_seed_offset
);
}
float
rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
bool
is_store_randval
=
false
;
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
ck_tile
::
index_t
stride_randval
=
0
;
ck_tile
::
index_t
nhead_stride_randval
=
0
;
const
int32_t
*
block_table_ptr
;
ck_tile
::
index_t
batch_stride_block_table
;
ck_tile
::
index_t
page_block_size
;
};
struct
BatchModeDropoutKargs
:
CommonDropoutKargs
struct
CacheBatchIdxKargs
{
c
k_tile
::
index_t
batch_stride_randval
=
0
;
c
onst
int32_t
*
cache_batch_idx
;
};
struct
BatchModeKargs
...
...
@@ -210,12 +194,13 @@ struct FmhaFwdSplitKVKernel
EmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasMask
,
MaskKargs
,
EmptyKargs
<
1
>>
,
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
2
>>
,
std
::
conditional_t
<
k
HasDropout
,
BatchModeDropoutKargs
,
Empty
Kargs
<
3
>
>
std
::
conditional_t
<
k
IsPagedKV
,
PageBlockTableKargs
,
CacheBatchIdx
Kargs
>
{
const
int32_t
*
seqlen_k_ptr
;
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_lse_acc
;
};
struct
GroupModeKargs
...
...
@@ -226,12 +211,14 @@ struct FmhaFwdSplitKVKernel
AlibiKargs
,
EmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasMask
,
MaskKargs
,
EmptyKargs
<
1
>>
,
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
2
>>
,
std
::
conditional_t
<
kHasDropout
,
CommonDropoutKargs
,
EmptyKargs
<
3
>>
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
2
>>
{
const
int32_t
*
seqstart_q_ptr
;
const
int32_t
*
seqstart_k_ptr
;
const
int32_t
*
seqlen_k_ptr
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
};
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
GroupModeKargs
,
BatchModeKargs
>
;
...
...
@@ -242,48 +229,45 @@ struct FmhaFwdSplitKVKernel
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
void
*
rand_val_ptr
,
void
*
lse_acc_ptr
,
void
*
o_acc_ptr
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_k
,
ck_tile
::
index_t
seqlen_k
,
// only used if 'seqlen_k_ptr' is not specified
const
void
*
seqlen_k_ptr
,
// only used for (paged-) kvcache
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_head_q
,
ck_tile
::
index_t
nhead_ratio_qk
,
ck_tile
::
index_t
num_splits
,
const
void
*
block_table_ptr
,
ck_tile
::
index_t
batch_stride_block_table
,
ck_tile
::
index_t
page_block_size
,
const
void
*
cache_batch_idx
,
float
scale_s
,
float
scale_p
,
ck_tile
::
index_t
stride_q
,
ck_tile
::
index_t
stride_k
,
ck_tile
::
index_t
stride_v
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_o_acc
,
ck_tile
::
index_t
nhead_stride_q
,
ck_tile
::
index_t
nhead_stride_k
,
ck_tile
::
index_t
nhead_stride_v
,
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_lse_acc
,
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
batch_stride_q
,
ck_tile
::
index_t
batch_stride_k
,
ck_tile
::
index_t
batch_stride_v
,
ck_tile
::
index_t
batch_stride_bias
,
ck_tile
::
index_t
batch_stride_randval
,
ck_tile
::
index_t
batch_stride_lse_acc
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
ck_tile
::
index_t
mask_type
)
{
Kargs
kargs
{{
q_ptr
,
k_ptr
,
...
...
@@ -291,7 +275,6 @@ struct FmhaFwdSplitKVKernel
lse_acc_ptr
,
o_acc_ptr
,
batch
,
max_seqlen_q
,
seqlen_q
,
seqlen_k
,
hdim_q
,
...
...
@@ -313,17 +296,18 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for mask
{},
// placeholder for fp8_static_quant args
{},
// placeholder for dropout
{},
// placeholder for paged-block table or cache_batch_idx
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
),
batch_stride_q
,
batch_stride_k
,
batch_stride_v
,
batch_stride_lse_acc
};
batch_stride_v
};
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
...
...
@@ -347,14 +331,15 @@ struct FmhaFwdSplitKVKernel
{
kargs
.
scale_p
=
scale_p
;
}
if
constexpr
(
kHasDropout
)
if
constexpr
(
kIsPagedKV
)
{
kargs
.
block_table_ptr
=
reinterpret_cast
<
const
int32_t
*>
(
block_table_ptr
);
kargs
.
batch_stride_block_table
=
batch_stride_block_table
;
kargs
.
page_block_size
=
page_block_size
;
}
else
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
kargs
.
cache_batch_idx
=
reinterpret_cast
<
const
int32_t
*>
(
cache_batch_idx
);
}
return
kargs
;
...
...
@@ -366,11 +351,9 @@ struct FmhaFwdSplitKVKernel
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
void
*
rand_val_ptr
,
void
*
lse_acc_ptr
,
void
*
o_acc_ptr
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
max_seqlen_q
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_k_ptr
,
const
void
*
seqlen_k_ptr
,
...
...
@@ -385,24 +368,22 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
stride_k
,
ck_tile
::
index_t
stride_v
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_o_acc
,
ck_tile
::
index_t
nhead_stride_q
,
ck_tile
::
index_t
nhead_stride_k
,
ck_tile
::
index_t
nhead_stride_v
,
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_lse_acc
,
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
batch_stride_k
,
ck_tile
::
index_t
batch_stride_v
,
ck_tile
::
index_t
batch_stride_lse_acc
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
ck_tile
::
index_t
mask_type
)
{
Kargs
kargs
{{
q_ptr
,
k_ptr
,
...
...
@@ -410,9 +391,8 @@ struct FmhaFwdSplitKVKernel
lse_acc_ptr
,
o_acc_ptr
,
batch
,
max_seqlen_q
,
-
1
,
// seqlen will be updated by another pointer
-
1
,
//
-
1
,
// seqlen_q will be updated by another pointer
-
1
,
// seqlen_k will be updated by another pointer
hdim_q
,
hdim_v
,
num_head_q
,
...
...
@@ -432,16 +412,18 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for mask
{},
// placeholder for fp8_static_quant args
{},
// placeholder for dropout
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
)};
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
),
batch_stride_k
,
batch_stride_v
};
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
...
...
@@ -464,14 +446,6 @@ struct FmhaFwdSplitKVKernel
{
kargs
.
scale_p
=
scale_p
;
}
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
}
return
kargs
;
}
...
...
@@ -508,7 +482,6 @@ struct FmhaFwdSplitKVKernel
long_index_t
batch_offset_k
=
0
;
long_index_t
batch_offset_v
=
0
;
long_index_t
batch_offset_bias
=
0
;
long_index_t
batch_offset_randval
=
0
;
long_index_t
batch_offset_lse_acc
=
0
;
const
long_index_t
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
...
...
@@ -534,14 +507,9 @@ struct FmhaFwdSplitKVKernel
{
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
+
key_start
;
}
if
constexpr
(
kHasDropout
)
{
batch_offset_randval
=
query_start
*
kargs
.
stride_randval
;
}
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
kargs
.
seqlen_q
=
adjusted_seqstart_q_ptr
[
1
]
-
adjusted_seqstart_q_ptr
[
0
];
kargs
.
seqlen_q
=
kargs
.
seqstart_q_ptr
[
i_batch
+
1
]
-
kargs
.
seqstart_q_ptr
[
i_batch
];
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
...
...
@@ -556,24 +524,36 @@ struct FmhaFwdSplitKVKernel
}
else
{
const
auto
adjusted_seqstart_k_ptr
=
kargs
.
seqstart_k_ptr
+
i_batch
;
kargs
.
seqlen_k
=
adjusted_seqstart_k_ptr
[
1
]
-
adjusted_seqstart_k_ptr
[
0
];
kargs
.
seqlen_k
=
kargs
.
seqstart_k_ptr
[
i_batch
+
1
]
-
kargs
.
seqstart_k_ptr
[
i_batch
];
}
}
else
{
const
index_t
i_cache_batch
=
[
&
,
i_batch_
=
i_batch
]
{
if
constexpr
(
kIsPagedKV
)
{
return
i_batch_
;
}
else
{
return
(
kargs
.
cache_batch_idx
!=
nullptr
?
kargs
.
cache_batch_idx
[
i_batch_
]
:
i_batch_
);
}
}();
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_
cache_
batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_
cache_
batch
)
*
kargs
.
batch_stride_v
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
batch_offset_bias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_bias
;
}
if
constexpr
(
kHasDropout
)
if
(
kargs
.
seqlen_k_ptr
!=
nullptr
)
{
batch_offset_randval
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_randval
;
kargs
.
seqlen_k
=
kargs
.
seqlen_k_ptr
[
i_batch
];
}
}
...
...
@@ -589,6 +569,7 @@ struct FmhaFwdSplitKVKernel
reinterpret_cast
<
const
VDataType
*>
(
kargs
.
v_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_v
+
batch_offset_v
;
OaccDataType
*
o_acc_ptr
=
reinterpret_cast
<
OaccDataType
*>
(
kargs
.
o_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_o_acc
+
batch_offset_o_acc
+
i_split
*
kargs
.
split_stride_o_acc
;
...
...
@@ -616,10 +597,11 @@ struct FmhaFwdSplitKVKernel
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
}();
const
auto
k_dram
=
[
&
]()
{
const
auto
make_k_dram
=
[
&
](
const
KDataType
*
data
,
index_t
height
)
{
const
auto
k_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
k_ptr
,
make_tuple
(
kargs
.
seqlen_k
,
kargs
.
hdim_q
),
data
,
// will update this pointer if using paged-kvcache
make_tuple
(
height
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_k
,
1
),
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
1
>
{});
...
...
@@ -628,13 +610,24 @@ struct FmhaFwdSplitKVKernel
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
};
const
auto
k_dram
=
[
&
]()
{
if
constexpr
(
kIsPagedKV
)
{
return
make_k_dram
(
nullptr
,
kargs
.
page_block_size
);
}
else
{
return
make_k_dram
(
k_ptr
,
kargs
.
seqlen_k
);
}
}();
const
auto
v_dram
=
[
&
]()
{
const
auto
make_v_dram
=
[
&
](
const
VDataType
*
data
,
index_t
length
)
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
const
auto
v_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
v_ptr
,
make_tuple
(
kargs
.
seqlen_k
,
kargs
.
hdim_v
),
data
,
// will update this pointer if using paged-kvcache
make_tuple
(
length
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
stride_v
,
1
),
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
1
>
{});
...
...
@@ -642,7 +635,7 @@ struct FmhaFwdSplitKVKernel
const
auto
v_dram_transposed
=
transform_tensor_view
(
v_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_v
),
make_pass_through_transform
(
kargs
.
seqlen_k
)),
make_pass_through_transform
(
length
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
...
...
@@ -654,8 +647,8 @@ struct FmhaFwdSplitKVKernel
else
{
const
auto
v_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
v_ptr
,
make_tuple
(
kargs
.
hdim_v
,
kargs
.
seqlen_k
),
data
,
// will update this pointer if using paged-kvcache
make_tuple
(
kargs
.
hdim_v
,
length
),
make_tuple
(
kargs
.
stride_v
,
1
),
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
1
>
{});
...
...
@@ -665,6 +658,76 @@ struct FmhaFwdSplitKVKernel
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenK
>
{});
}
};
const
auto
v_dram
=
[
&
]()
{
if
constexpr
(
kIsPagedKV
)
{
return
make_v_dram
(
nullptr
,
kargs
.
page_block_size
);
}
else
{
return
make_v_dram
(
v_ptr
,
kargs
.
seqlen_k
);
}
}();
auto
k_page_block_navigator
=
[
&
,
i_batch_
=
i_batch
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
kIsPagedKV
)
{
const
auto
*
block_indices
=
reinterpret_cast
<
const
int32_t
*>
(
kargs
.
block_table_ptr
)
+
i_batch_
*
kargs
.
batch_stride_block_table
;
const
index_t
num_blocks
=
integer_divide_ceil
(
kargs
.
seqlen_k
,
kargs
.
page_block_size
);
const
long_index_t
fixed_offset
=
static_cast
<
long_index_t
>
(
i_nhead_
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_k
;
return
make_page_block_navigator
<
const
KDataType
,
0
>
(
kargs
.
k_ptr
,
kargs
.
batch_stride_k
,
fixed_offset
,
block_indices
,
num_blocks
,
kargs
.
page_block_size
,
k_dram
,
make_k_dram
(
nullptr
,
kargs
.
seqlen_k
-
(
num_blocks
-
1
)
*
kargs
.
page_block_size
));
}
else
{
return
make_page_block_navigator
(
k_dram
);
}
}();
auto
v_page_block_navigator
=
[
&
,
i_batch_
=
i_batch
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
kIsPagedKV
)
{
const
auto
*
block_indices
=
reinterpret_cast
<
const
int32_t
*>
(
kargs
.
block_table_ptr
)
+
i_batch_
*
kargs
.
batch_stride_block_table
;
const
index_t
num_blocks
=
integer_divide_ceil
(
kargs
.
seqlen_k
,
kargs
.
page_block_size
);
const
long_index_t
fixed_offset
=
static_cast
<
long_index_t
>
(
i_nhead_
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_v
;
return
make_page_block_navigator
<
const
VDataType
,
1
>
(
kargs
.
v_ptr
,
kargs
.
batch_stride_v
,
fixed_offset
,
block_indices
,
num_blocks
,
kargs
.
page_block_size
,
v_dram
,
make_v_dram
(
nullptr
,
kargs
.
seqlen_k
-
(
num_blocks
-
1
)
*
kargs
.
page_block_size
));
}
else
{
return
make_page_block_navigator
(
v_dram
);
}
}();
auto
q_dram_window
=
make_tile_window
(
...
...
@@ -678,13 +741,11 @@ struct FmhaFwdSplitKVKernel
}(),
{
i_m0
,
0
});
auto
k_dram_window
=
make_tile_window
(
k_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
{
0
,
0
});
auto
k_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
auto
v_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{});
auto
v_dram_window
=
make_tile_window
(
v_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
{
i_n1
,
0
});
/// FIXME: Before C++20, capturing structured binding variables are not supported. Remove
/// following copy capture of the 'i_nhead' if in C++20
const
auto
bias_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
...
...
@@ -741,62 +802,6 @@ struct FmhaFwdSplitKVKernel
return
make_tile_window
(
lse_acc_dram
,
lse_acc_dram_window_lengths
,
{
i_m0
});
}();
// dropout
float
rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint64_t
drop_seed
=
0
;
uint64_t
drop_offset
=
0
;
bool
is_store_randval
=
false
;
if
constexpr
(
kHasDropout
)
{
rp_undrop
=
kargs
.
rp_undrop
;
p_undrop_in_uint8_t
=
kargs
.
p_undrop_in_uint8_t
;
drop_seed
=
kargs
.
drop_seed
;
drop_offset
=
kargs
.
drop_offset
;
is_store_randval
=
kargs
.
is_store_randval
;
}
BlockDropout
dropout
(
i_batch
,
i_nhead
,
kargs
.
num_head_q
,
drop_seed
,
drop_offset
,
rp_undrop
,
p_undrop_in_uint8_t
,
is_store_randval
);
auto
randval_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
randval_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
if
constexpr
(
kHasDropout
)
{
RandValOutputDataType
*
rand_val_ptr
=
reinterpret_cast
<
RandValOutputDataType
*>
(
kargs
.
rand_val_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_randval
+
batch_offset_randval
;
const
auto
randval_dram
=
[
&
]()
{
const
auto
randval_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
rand_val_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
seqlen_k
),
make_tuple
(
kargs
.
stride_randval
,
1
),
number
<
1
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
randval_dram_naive
,
randval_dram_window_lengths
,
sequence
<
kPadSeqLenQ
,
kPadSeqLenK
>
{});
}();
return
make_tile_window
(
randval_dram
,
randval_dram_window_lengths
,
{
i_m0
,
0
});
}
else
{
return
make_null_tile_window
(
randval_dram_window_lengths
);
}
}();
FmhaMask
mask
=
[
&
]()
{
if
constexpr
(
kHasMask
)
return
ck_tile
::
make_generic_attention_mask_from_lr_window
<
FmhaMask
>
(
...
...
@@ -823,16 +828,16 @@ struct FmhaFwdSplitKVKernel
#endif
if
constexpr
(
kHasMask
)
{
return
make_alibi_from_lr_mask
<
SaccDataType
,
true
>
(
slope
,
kargs
.
window_size_left
,
kargs
.
window_size_right
,
kargs
.
seqlen_q
,
kargs
.
seqlen_k
,
kargs
.
mask_type
);
return
make_alibi_from_lr_mask
<
SaccDataType
,
true
,
32
>
(
slope
,
kargs
.
window_size_left
,
kargs
.
window_size_right
,
kargs
.
seqlen_q
,
kargs
.
seqlen_k
,
kargs
.
mask_type
);
}
else
{
return
Alibi
<
SaccDataType
,
true
>
{
return
Alibi
<
SaccDataType
,
true
,
32
>
{
slope
,
kargs
.
seqlen_q
,
kargs
.
seqlen_k
,
AlibiMode
::
FROM_BOTTOM_RIGHT
};
}
}
...
...
@@ -847,13 +852,14 @@ struct FmhaFwdSplitKVKernel
{
return
FmhaPipeline
{}(
q_dram_window
,
identity
{},
// q_element_func
k_dram_window
,
k_dram_window_lengths
,
k_page_block_navigator
,
identity
{},
// k_element_func
v_dram_window
,
v_dram_window_lengths
,
v_page_block_navigator
,
identity
{},
// v_element_func
bias_dram_window
,
identity
{},
// bias_element_func
randval_dram_window
,
lse_acc_dram_window
,
identity
{},
// lse_element_func
identity
{},
// s_acc_element_func
...
...
@@ -864,24 +870,23 @@ struct FmhaFwdSplitKVKernel
mask
,
position_encoding
,
kargs
.
scale_s
,
smem_ptr
,
dropout
);
smem_ptr
);
}
else
{
return
FmhaPipeline
{}(
q_dram_window
,
k_dram_window
,
v_dram_window
,
k_dram_window_lengths
,
k_page_block_navigator
,
v_dram_window_lengths
,
v_page_block_navigator
,
bias_dram_window
,
randval_dram_window
,
lse_acc_dram_window
,
kargs
.
num_splits
,
i_split_
,
mask
,
position_encoding
,
kargs
.
scale_s
,
smem_ptr
,
dropout
);
smem_ptr
);
}
}();
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp
0 → 100644
View file @
4885c38a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaFwdAppendKVPipelineDefaultPolicy
>
struct
BlockFmhaFwdAppendKVPipeline
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
QDataType
=
typename
Problem
::
QDataType
;
using
KDataType
=
typename
Problem
::
KDataType
;
using
VDataType
=
typename
Problem
::
VDataType
;
using
VLayout
=
typename
Problem
::
VLayout
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
Problem
::
kM0
;
static
constexpr
index_t
kN0
=
Problem
::
kN0
;
static
constexpr
index_t
kK0
=
Problem
::
kK0
;
static
constexpr
index_t
kN1
=
Problem
::
kN1
;
static
constexpr
auto
RotaryEnum
=
Problem
::
RotaryEnum
;
static
constexpr
bool
kIsPagedKV
=
Problem
::
kIsPagedKV
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static
constexpr
index_t
kAlignmentQ
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
return
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
else
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
return
Problem
::
kBlockPerCu
;
else
{
if
constexpr
(
kK0
<=
32
)
{
return
2
;
}
else
if
constexpr
(
kK0
<=
64
)
{
return
3
;
}
else
if
constexpr
(
kK0
<=
128
)
{
return
2
;
}
else
if
constexpr
(
kK0
<=
256
)
{
return
1
;
}
}
}();
template
<
typename
QDramBlockWindow
,
typename
KDramBlockWindow
,
typename
KPageBlockNavigator
,
typename
KnewDramBlockWindow
,
typename
VDramBlockWindow
,
typename
VPageBlockNavigator
,
typename
VnewDramBlockWindow
,
typename
QElementFunction
,
typename
KnewElementFunction
,
typename
VnewElementFunction
,
typename
QRotaryCosDramBlockWindow
,
typename
QRotarySinDramBlockWindow
,
typename
KnewRotaryCosDramBlockWindow
,
typename
KnewRotarySinDramBlockWindow
>
CK_TILE_HOST_DEVICE
auto
operator
()(
QDramBlockWindow
&
q_dram_block_window
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
KDramBlockWindow
&
k_dram_block_window
,
// N0*K0 tile
index_t
i_page_block_k
,
const
KPageBlockNavigator
&
k_page_block_navigator
,
const
KnewDramBlockWindow
&
knew_dram_block_window
,
// N0*K0 tile
const
KnewElementFunction
&
knew_element_func
,
VDramBlockWindow
&
v_dram_block_window
,
// N1*N0 tile
index_t
i_page_block_v
,
const
VPageBlockNavigator
&
v_page_block_navigator
,
const
VnewDramBlockWindow
&
vnew_dram_block_window
,
// N1*N0 tile
const
VnewElementFunction
&
vnew_element_func
,
const
QRotaryCosDramBlockWindow
q_rotary_cos_dram_block_window
,
const
QRotarySinDramBlockWindow
q_rotary_sin_dram_block_window
,
const
KnewRotaryCosDramBlockWindow
knew_rotary_cos_dram_block_window
,
const
KnewRotarySinDramBlockWindow
knew_rotary_sin_dram_block_window
,
index_t
rotary_dim
,
bool
skip_rotate_q
,
bool
skip_rotate_append_kv
)
const
{
if
(
!
skip_rotate_append_kv
)
{
// append Knew to K
auto
knew_window
=
make_tile_window
(
knew_dram_block_window
,
Policy
::
template
MakeKnewDramTileDistribution
<
Problem
>());
auto
knew_tile
=
[
&
]()
{
auto
knew
=
load_tile
(
knew_window
);
return
tile_elementwise_in
(
knew_element_func
,
knew
);
}();
// optionally apply rotary embedding to Knew
if
constexpr
(
RotaryEnum
!=
RotaryEmbeddingEnum
::
NONE
)
{
auto
rotary_cos_window
=
make_tile_window
(
knew_rotary_cos_dram_block_window
,
Policy
::
template
MakeRotaryCosSinTileDistribution
<
Problem
,
/*IsRotaryCosSinForQ=*/
false
>());
auto
rotary_sin_window
=
make_tile_window
(
knew_rotary_sin_dram_block_window
,
Policy
::
template
MakeRotaryCosSinTileDistribution
<
Problem
,
/*IsRotaryCosSinForQ=*/
false
>());
// We assume that each thread owns contiguous elements on head dimention. And we
// will use the distribution to enable/disable threads in order to override partial
// knew_tile content
auto
[
thread_start
,
thread_end
]
=
Policy
::
template
GetKnewThreadRangeAlongK
<
Problem
>();
ignore
=
thread_start
;
BlockRotaryEmbedding
<
RotaryEnum
>::
apply
(
knew_tile
,
knew_window
,
rotary_cos_window
,
rotary_sin_window
,
rotary_dim
,
thread_end
);
}
store_tile
(
k_dram_block_window
,
knew_tile
);
// write tile to another block if nesscary
if
constexpr
(
kIsPagedKV
)
{
if
(
k_page_block_navigator
.
is_cross_block
(
i_page_block_k
,
k_dram_block_window
))
{
k_page_block_navigator
.
move_to_block
(
i_page_block_k
,
k_dram_block_window
,
i_page_block_k
+
1
);
store_tile
(
k_dram_block_window
,
knew_tile
);
}
}
// append Vnew to V
auto
vnew_window
=
make_tile_window
(
vnew_dram_block_window
,
Policy
::
template
MakeVnewDramTileDistribution
<
Problem
>());
auto
vnew_tile
=
[
&
]()
{
auto
vnew
=
load_tile
(
vnew_window
);
return
tile_elementwise_in
(
vnew_element_func
,
vnew
);
}();
store_tile
(
v_dram_block_window
,
vnew_tile
);
// write tile to another block if nesscary
if
constexpr
(
kIsPagedKV
)
{
if
(
v_page_block_navigator
.
is_cross_block
(
i_page_block_v
,
v_dram_block_window
))
{
v_page_block_navigator
.
move_to_block
(
i_page_block_v
,
v_dram_block_window
,
i_page_block_v
+
1
);
store_tile
(
v_dram_block_window
,
vnew_tile
);
}
}
}
if
(
!
skip_rotate_q
)
{
// optionally apply rotary embedding to Q
if
constexpr
(
RotaryEnum
!=
RotaryEmbeddingEnum
::
NONE
)
{
auto
q_window
=
make_tile_window
(
q_dram_block_window
,
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
auto
q_tile
=
[
&
]()
{
auto
q
=
load_tile
(
q_window
);
return
tile_elementwise_in
(
q_element_func
,
q
);
}();
auto
rotary_cos_window
=
make_tile_window
(
q_rotary_cos_dram_block_window
,
Policy
::
template
MakeRotaryCosSinTileDistribution
<
Problem
,
/*IsRotaryCosSinForQ=*/
true
>());
auto
rotary_sin_window
=
make_tile_window
(
q_rotary_sin_dram_block_window
,
Policy
::
template
MakeRotaryCosSinTileDistribution
<
Problem
,
/*IsRotaryCosSinForQ=*/
true
>());
// We assume that each thread owns contiguous elements on head dimention. And we
// will use the distribution to enable/disable threads in order to override partial
// q_tile content
auto
[
thread_start
,
thread_end
]
=
Policy
::
template
GetQThreadRangeAlongK
<
Problem
>();
ignore
=
thread_start
;
BlockRotaryEmbedding
<
RotaryEnum
>::
apply
(
q_tile
,
q_window
,
rotary_cos_window
,
rotary_sin_window
,
rotary_dim
,
thread_end
);
store_tile
(
q_dram_block_window
,
q_tile
);
}
}
}
template
<
typename
QDramBlockWindow
,
typename
KDramBlockWindow
,
typename
KPageBlockNavigator
,
typename
KnewDramBlockWindow
,
typename
VDramBlockWindow
,
typename
VPageBlockNavigator
,
typename
VnewDramBlockWindow
,
typename
QRotaryCosDramBlockWindow
,
typename
QRotarySinDramBlockWindow
,
typename
KnewRotaryCosDramBlockWindow
,
typename
KnewRotarySinDramBlockWindow
>
CK_TILE_HOST_DEVICE
auto
operator
()(
QDramBlockWindow
&
q_dram_block_window
,
KDramBlockWindow
&
k_dram_block_window
,
index_t
i_page_block_k
,
const
KPageBlockNavigator
&
k_page_block_navigator
,
const
KnewDramBlockWindow
&
knew_dram_block_window
,
VDramBlockWindow
&
v_dram_block_window
,
index_t
i_page_block_v
,
const
VPageBlockNavigator
&
v_page_block_navigator
,
const
VnewDramBlockWindow
&
vnew_dram_block_window
,
const
QRotaryCosDramBlockWindow
&
q_rotary_cos_dram_block_window
,
const
QRotarySinDramBlockWindow
&
q_rotary_sin_dram_block_window
,
const
KnewRotaryCosDramBlockWindow
&
knew_rotary_cos_dram_block_window
,
const
KnewRotarySinDramBlockWindow
&
knew_rotary_sin_dram_block_window
,
index_t
rotary_dim
,
bool
skip_rotate_q
,
bool
skip_rotate_append_kv
)
const
{
return
operator
()(
q_dram_block_window
,
identity
{},
k_dram_block_window
,
i_page_block_k
,
k_page_block_navigator
,
knew_dram_block_window
,
identity
{},
v_dram_block_window
,
i_page_block_v
,
v_page_block_navigator
,
vnew_dram_block_window
,
identity
{},
q_rotary_cos_dram_block_window
,
q_rotary_sin_dram_block_window
,
knew_rotary_cos_dram_block_window
,
knew_rotary_sin_dram_block_window
,
rotary_dim
,
skip_rotate_q
,
skip_rotate_append_kv
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp
0 → 100644
View file @
4885c38a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
struct
BlockFmhaFwdAppendKVPipelineDefaultPolicy
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentQ
()
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
return
16
/
sizeof
(
QDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentK
()
{
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
return
16
/
sizeof
(
KDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentV
()
{
using
VLayout
=
remove_cvref_t
<
typename
Problem
::
VLayout
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
kN1
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
// TODO: not correct!
if
constexpr
(
total_pixels
>
4
)
return
4
;
else
return
2
;
}
else
{
return
16
/
sizeof
(
VDataType
);
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQNumElemsPerRead
()
{
using
DataType
=
typename
Problem
::
QDataType
;
if
constexpr
(
Problem
::
RotaryEnum
==
RotaryEmbeddingEnum
::
HALF_ROTATED
)
{
/// NOTICE: we might need to lower down this to support smaller rotary_dim
return
16
/
sizeof
(
DataType
);
}
else
{
return
16
/
sizeof
(
DataType
);
}
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
auto
GetQThreadRangeAlongK
()
{
static_assert
(
Problem
::
RotaryEnum
!=
RotaryEmbeddingEnum
::
NONE
);
if
constexpr
(
Problem
::
RotaryEnum
==
RotaryEmbeddingEnum
::
INTERLEAVED
)
{
constexpr
index_t
KPerThread
=
GetQNumElemsPerRead
<
Problem
>
();
static_assert
(
Problem
::
kK0
%
KPerThread
==
0
);
constexpr
index_t
KThreadPerBlock
=
Problem
::
kK0
/
KPerThread
;
index_t
start_pos
=
(
get_thread_id
()
%
KThreadPerBlock
)
*
KPerThread
;
return
make_tuple
(
start_pos
,
start_pos
+
KPerThread
);
}
else
{
constexpr
index_t
KPerThread
=
GetQNumElemsPerRead
<
Problem
>
();
static_assert
(
Problem
::
kK0
%
KPerThread
==
0
);
constexpr
index_t
KThreadPerBlock
=
Problem
::
kK0
/
KPerThread
;
index_t
start_pos
=
(
get_thread_id
()
%
KThreadPerBlock
)
*
KPerThread
;
return
make_tuple
(
start_pos
,
start_pos
+
KPerThread
);
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQDramTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
kK0
;
constexpr
index_t
KPerThread
=
GetQNumElemsPerRead
<
Problem
>
();
constexpr
index_t
KThreadPerBlock
=
kKPerBlock
/
KPerThread
;
constexpr
index_t
MThreadPerWarp
=
get_warp_size
()
/
KThreadPerBlock
;
constexpr
index_t
NumWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
MPerThread
=
kMPerBlock
/
(
NumWarps
*
MThreadPerWarp
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
MPerThread
,
NumWarps
,
MThreadPerWarp
>
,
sequence
<
KThreadPerBlock
,
KPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetKnewNumElemsPerRead
()
{
using
DataType
=
typename
Problem
::
KDataType
;
if
constexpr
(
Problem
::
RotaryEnum
==
RotaryEmbeddingEnum
::
HALF_ROTATED
)
{
/// NOTICE: we might need to lower down this to support smaller rotary_dim
return
16
/
sizeof
(
DataType
);
}
else
{
return
16
/
sizeof
(
DataType
);
}
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
auto
GetKnewThreadRangeAlongK
()
{
static_assert
(
Problem
::
RotaryEnum
!=
RotaryEmbeddingEnum
::
NONE
);
if
constexpr
(
Problem
::
RotaryEnum
==
RotaryEmbeddingEnum
::
INTERLEAVED
)
{
constexpr
index_t
KPerThread
=
GetKnewNumElemsPerRead
<
Problem
>
();
constexpr
index_t
KThreadPerBlock
=
Problem
::
kK0
/
KPerThread
;
index_t
start_pos
=
(
get_thread_id
()
%
KThreadPerBlock
)
*
KPerThread
;
return
make_tuple
(
start_pos
,
start_pos
+
KPerThread
);
}
else
{
constexpr
index_t
KPerThread
=
GetKnewNumElemsPerRead
<
Problem
>
();
constexpr
index_t
KThreadPerBlock
=
Problem
::
kK0
/
KPerThread
;
index_t
start_pos
=
(
get_thread_id
()
%
KThreadPerBlock
)
*
KPerThread
;
return
make_tuple
(
start_pos
,
start_pos
+
KPerThread
);
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKnewDramTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
kK0
;
constexpr
index_t
KPerThread
=
GetKnewNumElemsPerRead
<
Problem
>
();
constexpr
index_t
KThreadPerBlock
=
kKPerBlock
/
KPerThread
;
constexpr
index_t
NThreadPerWarp
=
get_warp_size
()
/
KThreadPerBlock
;
constexpr
index_t
NumWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
NPerThread
=
kNPerBlock
/
(
NumWarps
*
NThreadPerWarp
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
NPerThread
,
NumWarps
,
NThreadPerWarp
>
,
sequence
<
KThreadPerBlock
,
KPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackV
()
{
// TODO: this is for 3d layout
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
return
16
/
sizeof
(
VDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVnewDramTileDistribution
()
{
using
VLayout
=
remove_cvref_t
<
typename
Problem
::
VLayout
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
kN1
;
constexpr
index_t
kKPerBlock
=
Problem
::
kN0
;
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
constexpr
index_t
NPerThread
=
16
/
sizeof
(
VDataType
);
constexpr
index_t
NThreadPerBlock
=
kNPerBlock
/
NPerThread
;
constexpr
index_t
KThreadPerWarp
=
get_warp_size
()
/
NThreadPerBlock
;
constexpr
index_t
NumWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
KPerThread
=
kKPerBlock
/
(
NumWarps
*
KThreadPerWarp
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
NThreadPerBlock
,
NPerThread
>
,
sequence
<
KPerThread
,
NumWarps
,
KThreadPerWarp
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
0
>>
{});
}
else
{
constexpr
index_t
KPerThread
=
16
/
sizeof
(
VDataType
);
constexpr
index_t
KThreadPerBlock
=
kKPerBlock
/
KPerThread
;
constexpr
index_t
NThreadPerWarp
=
get_warp_size
()
/
KThreadPerBlock
;
constexpr
index_t
NumWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
NPerThread
=
kNPerBlock
/
(
NumWarps
*
NThreadPerWarp
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
NPerThread
,
NumWarps
,
NThreadPerWarp
>
,
sequence
<
KThreadPerBlock
,
KPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
}
template
<
typename
Problem
,
bool
IsRotaryCosSinForQ
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetRotaryCosSinTileSize
()
{
constexpr
index_t
height
=
(
IsRotaryCosSinForQ
?
Problem
::
kM0
:
Problem
::
kN0
);
if
constexpr
(
Problem
::
RotaryEnum
==
RotaryEmbeddingEnum
::
HALF_ROTATED
)
{
return
make_tuple
(
number
<
height
>
{},
number
<
Problem
::
kK0
>
{});
}
else
{
return
make_tuple
(
number
<
height
>
{},
number
<
Problem
::
kK0
/
2
>
{});
}
}
template
<
typename
Problem
,
bool
IsRotaryCosSinForQ
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRotaryCosSinTileDistribution
()
{
using
DataType
=
std
::
conditional_t
<
IsRotaryCosSinForQ
,
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
>
;
constexpr
auto
TileSize
=
GetRotaryCosSinTileSize
<
Problem
,
IsRotaryCosSinForQ
>
();
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
TileSize
[
number
<
0
>
{}];
constexpr
index_t
kKPerBlock
=
TileSize
[
number
<
1
>
{}];
constexpr
index_t
KPerThread
=
[]()
{
if
constexpr
(
Problem
::
RotaryEnum
==
RotaryEmbeddingEnum
::
HALF_ROTATED
)
{
/// NOTICE: we might need to lower down this to support smaller rotary_dim
return
16
/
sizeof
(
DataType
);
}
else
{
return
8
/
sizeof
(
DataType
);
}
}();
constexpr
index_t
KThreadPerBlock
=
kKPerBlock
/
KPerThread
;
constexpr
index_t
NThreadPerWarp
=
get_warp_size
()
/
KThreadPerBlock
;
constexpr
index_t
NumWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
NPerThread
=
kNPerBlock
/
(
NumWarps
*
NThreadPerWarp
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
NPerThread
,
NumWarps
,
NThreadPerWarp
>
,
sequence
<
KThreadPerBlock
,
KPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
View file @
4885c38a
...
...
@@ -6,7 +6,6 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
...
...
@@ -15,19 +14,18 @@ namespace ck_tile {
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy
>
struct
BlockFmhaFwdSplitKVPipelineQRKSVS
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
SaccDataType
=
remove_cvref_t
<
typename
Problem
::
SaccDataType
>
;
using
SMPLComputeDataType
=
remove_cvref_t
<
typename
Problem
::
SMPLComputeDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
RandValOutputDataType
=
remove_cvref_t
<
typename
Problem
::
RandValOutputDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
SaccDataType
=
remove_cvref_t
<
typename
Problem
::
SaccDataType
>
;
using
SMPLComputeDataType
=
remove_cvref_t
<
typename
Problem
::
SMPLComputeDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
...
...
@@ -49,8 +47,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
true
;
// always store LSE (acc)
static
constexpr
bool
k
HasDropout
=
false
;
// ignore this flag
static
constexpr
bool
kStoreLSE
=
true
;
// always store LSE (acc)
static
constexpr
bool
k
IsPagedKV
=
Problem
::
kIsPagedKV
;
static
constexpr
bool
kHasUnevenSplits
=
Problem
::
kHasUnevenSplits
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
...
...
@@ -106,10 +104,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
KDramBlockWindowLengths
,
typename
KPageBlockNavigator
,
typename
VDramBlockWindowLengths
,
typename
VPageBlockNavigator
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEaccDramBlockWindowTmp
,
typename
QElementFunction
,
typename
KElementFunction
,
...
...
@@ -123,13 +122,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
KDramBlockWindowLengths
&
k_dram_block_window_lengths
,
// N0*K0 tile
const
KPageBlockNavigator
&
k_page_block_navigator
,
const
KElementFunction
&
k_element_func
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
VDramBlockWindowLengths
&
v_dram_block_window_lengths
,
// N1*K1 tile
const
VPageBlockNavigator
&
v_page_block_navigator
,
const
VElementFunction
&
v_element_func
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasElementFunction
&
bias_element_func
,
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
LSEaccDramBlockWindowTmp
&
lse_acc_dram_window_tmp
,
// M0*1 tile
const
LSEaccElementFunction
&
lse_acc_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
...
...
@@ -140,20 +140,19 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
,
BlockDropout
&
dropout
)
const
void
*
smem_ptr
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
K
Dram
Block
WindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
V
Dram
Block
WindowTmp
::
DataType
>>
,
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
K
Page
Block
Navigator
::
DataType
>>
&&
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
V
Page
Block
Navigator
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindow
Tmp
{}.
get_window_l
engths
()
[
number
<
0
>
{}]
&&
kK0
==
KDramBlockWindow
Tmp
{}.
get_window_l
engths
()
[
number
<
1
>
{}]
&&
kN1
==
VDramBlockWindow
Tmp
{}.
get_window_l
engths
()
[
number
<
0
>
{}]
&&
kK1
==
VDramBlockWindow
Tmp
{}.
get_window_l
engths
()
[
number
<
1
>
{}]
&&
kN0
==
KDramBlockWindow
L
engths
{}
[
number
<
0
>
{}]
&&
kK0
==
KDramBlockWindow
L
engths
{}
[
number
<
1
>
{}]
&&
kN1
==
VDramBlockWindow
L
engths
{}
[
number
<
0
>
{}]
&&
kK1
==
VDramBlockWindow
L
engths
{}
[
number
<
1
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
...
...
@@ -213,12 +212,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{},
num_splits
,
i_split
);
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
// check early exit if masked and no work to do.
if
constexpr
(
FmhaMask
::
IsMasking
||
kHasUnevenSplits
)
{
if
(
num_total_loop
<=
0
)
const
index_t
original_num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
if
(
original_num_total_loop
<=
0
)
{
if
constexpr
(
kStoreLSE
)
{
...
...
@@ -237,26 +236,34 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
}
auto
k_dram_block_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_k_start
,
0
});
// make sure the first tile is completely located in page-block
const
index_t
adjusted_seqlen_k_start
=
[
&
,
seqlen_k_start_
=
seqlen_k_start
]
{
if
constexpr
(
kIsPagedKV
)
{
return
kN0
*
integer_divide_floor
(
seqlen_k_start_
,
kN0
);
}
else
{
return
seqlen_k_start_
;
}
}();
const
index_t
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
adjusted_seqlen_k_start
,
kN0
);
auto
[
i_page_block_k
,
k_dram_block_window
]
=
k_page_block_navigator
.
make_tile_window
(
k_dram_block_window_lengths
,
{
adjusted_seqlen_k_start
,
0
});
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
{
bias_origin
.
at
(
number
<
0
>
{}),
adjusted_
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
MakeRandvalDramWindow
<
decltype
(
gemm_0
)
>
(
randval_dram_block_window_tmp
,
seqlen_k_start
);
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_k_start
},
// TODO: hdim split?
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
auto
[
i_page_block_v
,
v_dram_window
]
=
v_page_block_navigator
.
make_tile_window
(
v_dram_block_window_lengths
,
{
0
,
adjusted_seqlen_k_start
},
// TODO: hdim split?
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
auto
q_tile
=
tile_elementwise_in
(
q_element_func
,
q
);
...
...
@@ -271,14 +278,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
{
// STAGE 1, QK gemm
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window
.
get_bottom_tensor_view
(),
k_dram_block_window
.
get_window_lengths
(),
k_dram_block_window
.
get_window_origin
(),
k_dram_block_window
,
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
auto
k_block_tile
=
load_tile
(
k_dram_window
);
{
// moving k_dram_window is an in-page-block operation, so there is
// no need to invoke k_page_block_navigator.move_tile_window() here.
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
clear_tile
(
s_acc
);
// initialize C
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
...
...
@@ -355,7 +362,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
const
auto
k_origin
=
k_page_block_navigator
.
to_global_window_origin
(
i_page_block_k
,
k_dram_block_window
.
get_window_origin
());
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
...
...
@@ -381,22 +389,32 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
move_tile_window
(
bias_dram_window
,
{
0
,
kN0
});
/// TODO: only check in last iteration without increasing code size
/// TODO: only check in
first/
last iteration without increasing code size
if
constexpr
(
kHasUnevenSplits
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
const
auto
k_origin
=
k_page_block_navigator
.
to_global_window_origin
(
i_page_block_k
,
k_dram_block_window
.
get_window_origin
());
set_tile_if
(
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
,
seqlen_k_end_
=
seqlen_k_end
](
auto
tile_idx
)
{
[
&
,
seqlen_k_start_
=
seqlen_k_start
,
seqlen_k_end_
=
seqlen_k_end
](
auto
tile_idx
)
{
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
seqlen_k_end_
<=
col
;
if
constexpr
(
kIsPagedKV
)
{
return
col
<
seqlen_k_start_
||
seqlen_k_end_
<=
col
;
}
else
{
return
seqlen_k_end_
<=
col
;
}
});
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
const
auto
k_origin
=
k_page_block_navigator
.
to_global_window_origin
(
i_page_block_k
,
k_dram_block_window
.
get_window_origin
());
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
...
...
@@ -501,12 +519,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
});
});
if
constexpr
(
kHasDropout
)
{
dropout
.
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>
(
smem_ptr
,
seqlen_k_start
+
i_total_loops
*
kN0
,
p_compute
,
randval_dram_window
);
}
block_sync_lds
();
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
...
...
@@ -522,7 +534,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v_prefetch
));
// store the prefetch
}
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
i_page_block_v
=
v_page_block_navigator
.
move_tile_window
(
i_page_block_v
,
v_dram_window
,
{
0
,
kK1
});
const
auto
p
=
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
...
...
@@ -530,8 +543,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
// STAGE 3, KV gemm
if
constexpr
(
k1_loops
>
1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
const
auto
v
=
load_tile
(
v_dram_window
);
// load next v
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
,
&
i_page_block_v_
=
i_page_block_v
,
&
v_dram_window_
=
v_dram_window
](
auto
i_k1
)
{
const
auto
v
=
load_tile
(
v_dram_window_
);
// load next v
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
...
...
@@ -552,11 +567,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v
));
// store next v
}
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
i_page_block_v_
=
v_page_block_navigator
.
move_tile_window
(
i_page_block_v_
,
v_dram_window_
,
{
0
,
kK1
});
});
}
// move K tile windows
move_tile_window
(
k_dram_block_window
,
{
kN0
,
0
});
i_page_block_k
=
k_page_block_navigator
.
move_tile_window
(
i_page_block_k
,
k_dram_block_window
,
{
kN0
,
0
});
// tail
{
block_sync_lds
();
...
...
@@ -618,36 +635,38 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
KDramBlockWindowLengths
,
typename
KPageBlockNavigator
,
typename
VDramBlockWindowLengths
,
typename
VPageBlockNavigator
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEaccDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowLengths
&
k_dram_block_window_lengths
,
// N0*K0 tile
const
KPageBlockNavigator
&
k_page_block_navigator
,
const
VDramBlockWindowLengths
&
v_dram_block_window_lengths
,
// N1*K1 tile
const
VPageBlockNavigator
&
v_page_block_navigator
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
// M0*N0 tile
LSEaccDramBlockWindowTmp
&
lse_acc_dram_block_window_tmp
,
// M0*1 tile
index_t
num_splits
,
index_t
i_split
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
,
BlockDropout
&
dropout
)
const
void
*
smem_ptr
)
const
{
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
k_dram_block_window_tmp
,
k_dram_block_window_lengths
,
k_page_block_navigator
,
identity
{},
v_dram_block_window_tmp
,
v_dram_block_window_lengths
,
v_page_block_navigator
,
identity
{},
bias_dram_block_window_tmp
,
identity
{},
randval_dram_block_window_tmp
,
lse_acc_dram_block_window_tmp
,
identity
{},
identity
{},
...
...
@@ -658,8 +677,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
mask
,
position_encoding
,
scale_s
,
smem_ptr
,
dropout
);
smem_ptr
);
}
};
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp
deleted
100644 → 0
View file @
cbf14ee1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaFwdSplitKVPipelineQRKSVSAsyncDefaultPolicy
>
struct
BlockFmhaFwdSplitKVPipelineQRKSVSAsync
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
SaccDataType
=
remove_cvref_t
<
typename
Problem
::
SaccDataType
>
;
using
SMPLComputeDataType
=
remove_cvref_t
<
typename
Problem
::
SMPLComputeDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
RandValOutputDataType
=
remove_cvref_t
<
typename
Problem
::
RandValOutputDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
static
constexpr
bool
kQLoadOnce
=
true
;
// if q_tile load whole block length (hdim) at once
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK0BlockLength
=
BlockFmhaShape
::
kK0BlockLength
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
static_assert
(
Problem
::
kPadSeqLenQ
==
true
&&
Problem
::
kPadHeadDimQ
==
true
&&
Problem
::
kPadHeadDimV
==
true
);
static
constexpr
bool
kPadSeqLenQ
=
true
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
true
;
// support multiple of vector(like 8x)
static
constexpr
bool
kPadHeadDimV
=
true
;
// support multiple of vector(like 8x)
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
true
;
// always store LSE (acc)
static
constexpr
bool
kHasDropout
=
false
;
// ignore this flag
static
constexpr
bool
kHasUnevenSplits
=
Problem
::
kHasUnevenSplits
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static
constexpr
index_t
kAlignmentQ
=
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
return
Policy
::
template
GetAlignmentV
<
Problem
>();
else
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
static
constexpr
index_t
kAlignmentO
=
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
#if CK_TILE_FMHA_FWD_FAST_EXP2
static
constexpr
auto
R_LOG2E
=
1.0
/
log2e_v
<
SaccDataType
>
;
#endif
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
return
Problem
::
kBlockPerCu
;
else
{
if
constexpr
(
kK0BlockLength
<=
32
)
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
&&
FmhaMask
::
IsMasking
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
kK0BlockLength
<=
64
)
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
2
;
else
return
3
;
}
else
if
constexpr
(
kK0BlockLength
<=
128
)
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
kK0BlockLength
<=
256
)
{
return
1
;
}
}
}();
static
constexpr
const
char
*
name
=
"qr_async"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEaccDramBlockWindowTmp
,
typename
QElementFunction
,
typename
KElementFunction
,
typename
VElementFunction
,
typename
BiasElementFunction
,
typename
LSEaccElementFunction
,
typename
SAccElementFunction
,
typename
PComputeElementFunction
,
typename
OAccElementFunction
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
KElementFunction
&
/*k_element_func*/
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
VElementFunction
&
v_element_func
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasElementFunction
&
bias_element_func
,
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
LSEaccDramBlockWindowTmp
&
lse_acc_dram_window_tmp
,
// M0*1 tile
const
LSEaccElementFunction
&
lse_acc_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
const
PComputeElementFunction
&
p_compute_element_func
,
const
OAccElementFunction
&
o_acc_element_func
,
index_t
num_splits
,
index_t
i_split
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
,
BlockDropout
&
dropout
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
VDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kN1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
constexpr
auto
LdsSeq
=
Policy
::
template
GetLdsBufferSequence
<
Problem
>();
// K tile in LDS
auto
k_lds_ptr
=
reinterpret_cast
<
KDataType
*>
(
smem_ptr
);
auto
k_lds_store
=
generate_tuple
(
[
&
](
auto
i_buf
)
{
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsStoreBlockDescriptor
<
Problem
>(
i_buf
)),
Policy
::
template
MakeKLdsStoreBlockDescriptor
<
Problem
>(
i_buf
).
get_lengths
(),
{
0
,
0
,
0
});
},
number
<
Policy
::
NumPrefetchK
>
{});
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
auto
k_lds_load
=
generate_tuple
(
[
&
](
auto
i_buf
)
{
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>(
i_buf
)),
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>(
i_buf
).
get_lengths
(),
{
0
,
0
});
},
number
<
Policy
::
NumPrefetchK
>
{});
#else
auto
k_lds_Load_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>());
auto
k_lds_load
=
make_tile_window
(
k_lds_Load_view
,
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
#endif
// V tile in LDS
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
VDataType
*>
(
smem_ptr
),
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>());
auto
v_lds_window
=
make_tile_window
(
v_lds
,
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
auto
q
=
decltype
(
load_tile
(
q_dram_window
)){};
set_tile
(
q
,
number
<
0
>
{});
// use per-dword clear to avoid scratch
load_tile_raw
(
q
,
q_dram_window
);
__builtin_amdgcn_sched_barrier
(
0
);
using
SaccBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
auto
s_acc
=
SaccBlockTileType
{};
// reduction function for softmax
const
auto
f_max
=
[](
auto
e0
,
auto
e1
)
{
return
max
(
e0
,
e1
);
};
const
auto
f_sum
=
[](
auto
e0
,
auto
e1
)
{
return
e0
+
e1
;
};
// infer Sacc, S, P, M, L, Oacc type
using
SBlockTileType
=
decltype
(
cast_tile
<
SMPLComputeDataType
>
(
s_acc
));
using
MLBlockTileType
=
decltype
(
block_tile_reduce
<
SMPLComputeDataType
>
(
SBlockTileType
{},
sequence
<
1
>
{},
f_max
,
SMPLComputeDataType
{
0
}));
using
OaccBlockTileType
=
decltype
(
gemm_1
.
MakeCBlockTile
());
// init Oacc, M, L
auto
o_acc
=
OaccBlockTileType
{};
auto
m
=
MLBlockTileType
{};
auto
l
=
MLBlockTileType
{};
clear_tile
(
o_acc
);
set_tile
(
m
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
clear_tile
(
l
);
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
q_origin
=
q_dram_window
.
get_window_origin
();
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{},
num_splits
,
i_split
);
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
// check early exit if masked and no work to do.
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
||
kHasUnevenSplits
)
{
if
(
num_total_loop
<=
0
)
{
if
constexpr
(
kStoreLSE
)
{
auto
lse_acc
=
make_static_distributed_tensor
<
LSEDataType
>
(
m
.
get_tile_distribution
());
set_tile
(
lse_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
store_tile
(
lse_acc_dram_window_tmp
,
tile_elementwise_in
(
lse_acc_element_func
,
lse_acc
));
}
buffer_load_fence
(
0
);
// rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it
return
o_acc
;
}
__builtin_amdgcn_sched_barrier
(
0
);
// make sure sched_barrier(0) for this check
}
auto
k_dram_block_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_k_start
,
0
});
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window
.
get_bottom_tensor_view
(),
k_dram_block_window
.
get_window_lengths
(),
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
MakeRandvalDramWindow
<
decltype
(
gemm_0
)
>
(
randval_dram_block_window_tmp
,
seqlen_k_start
);
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_k_start
},
// TODO: hdim split?
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
// prefetch K tile
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_fence
(
k_dram_window
.
get_num_access
(),
q
.
get_thread_buffer
());
(
void
)
q_element_func
;
// ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kK0BlockLength
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
1
<=
k0_loops
);
static_assert
(
1
<=
k1_loops
);
// main loop
do
{
// STAGE 1, QK gemm
clear_tile
(
s_acc
);
// initialize C
if
constexpr
(
k0_loops
>
1
)
{
static_for
<
0
,
k0_loops
-
1
,
1
>
{}([
&
](
auto
i_k0
)
{
async_load_tile_raw
(
k_lds_store
(
number
<
LdsSeq
.
at
(
number
<
i_k0
+
1
>
{})
>
{}),
k_dram_window
);
if
constexpr
(
i_k0
<
k0_loops
-
1
)
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
async_load_fence
(
k_dram_window
.
get_num_access
());
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
gemm_0
(
s_acc
,
get_slice_tile
(
q
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
k_lds_load
[
number
<
LdsSeq
.
at
(
number
<
i_k0
>
{})
>
{}]);
#else
get_slice_tile
(
k_lds_load
,
sequence
<
(
LdsSeq
.
at
(
number
<
i_k0
>
{}))
*
kN0
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
i_k0
>
{})
+
1
)
*
kN0
,
kK0
>
{}));
#endif
});
}
// TODO: this to fix a bug when loop smaller than 2,
// the following fence/barrier will be scheduled inside 1st loop
if
constexpr
(
k0_loops
<=
2
)
__builtin_amdgcn_sched_barrier
(
0
);
async_load_fence
();
__builtin_amdgcn_s_barrier
();
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
auto
v_buf
=
load_tile
(
v_dram_window
,
bool_constant
<
false
>
{});
__builtin_amdgcn_sched_barrier
(
0
);
{
// tail
gemm_0
(
s_acc
,
get_slice_tile
(
q
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kM0
,
k0_loops
*
kK0
>
{}),
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
k_lds_load
[
number
<
LdsSeq
.
at
(
number
<
k0_loops
-
1
>
{})
>
{}]);
#else
get_slice_tile
(
k_lds_load
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
-
1
>
{}))
*
kN0
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
-
1
>
{})
+
1
)
*
kN0
,
kK0
>
{}));
#endif
}
__builtin_amdgcn_sched_barrier
(
1
);
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x
+=
type_convert
<
SaccDataType
>
(
bias_element_func
(
y
));
#else
x
+=
log2e_v
<
SaccDataType
>
*
type_convert
<
SaccDataType
>
(
bias_element_func
(
y
));
#endif
},
s_acc
,
bias_tile
);
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
s_acc
(
i_j_idx
)
*=
scale_s
;
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
);
});
});
}
else
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
#endif
}
move_tile_window
(
bias_dram_window
,
{
0
,
kN0
});
/// TODO: only check in last iteration without increasing code size
if
constexpr
(
kHasUnevenSplits
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
set_tile_if
(
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
,
seqlen_k_end_
=
seqlen_k_end
](
auto
tile_idx
)
{
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
seqlen_k_end_
<=
col
;
});
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
set_tile_if
(
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
}
const
auto
s
=
cast_tile
<
SMPLComputeDataType
>
(
s_acc
);
// S{j}
auto
m_local
=
block_tile_reduce
<
SMPLComputeDataType
>
(
s
,
sequence
<
1
>
{},
f_max
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
// m_local = rowmax(S{j})
block_tile_reduce_sync
(
m_local
,
f_max
,
bool_constant
<
false
>
{});
const
auto
m_old
=
m
;
// m{j-1}
tile_elementwise_inout
(
[](
auto
&
e0
,
auto
e1
,
auto
e2
)
{
e0
=
max
(
e1
,
e2
);
},
m
,
m_old
,
m_local
);
// m{j}
auto
p_compute
=
make_static_distributed_tensor
<
SMPLComputeDataType
>
(
s
.
get_tile_distribution
());
// Pcompute{j}
__builtin_amdgcn_sched_barrier
(
0x7F
);
// store & prefetch next v, after the max reduction
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v_buf
);
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
>
{})
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
// store the prefetch
}
else
{
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
>
{})
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_buf
));
// store the prefetch
}
if
constexpr
(
k1_loops
>
1
)
{
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
// will have scratch if move this right after load_tile(v_dram)...
v_buf
=
load_tile
(
v_dram_window
,
bool_constant
<
false
>
{});
// load next v_buf
}
__builtin_amdgcn_sched_barrier
(
0
);
static
const
auto
get_validated_m
=
[](
SMPLComputeDataType
raw_m
)
{
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration. alibi does not have this problem
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
return
raw_m
==
-
numeric
<
SMPLComputeDataType
>::
infinity
()
?
type_convert
<
SMPLComputeDataType
>
(
0.
f
)
:
raw_m
;
}
else
{
return
raw_m
;
}
};
constexpr
auto
p_spans
=
decltype
(
p_compute
)
::
get_distributed_spans
();
sweep_tile_span
(
p_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto
row_max
=
scale_s
*
get_validated_m
(
m
[
i_idx
]);
#endif
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
p_compute
(
i_j_idx
)
=
exp2
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
else
{
p_compute
(
i_j_idx
)
=
exp2
(
scale_s
*
s
[
i_j_idx
]
-
row_max
);
}
#else
p_compute
(
i_j_idx
)
=
exp
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
#endif
});
});
auto
rowsum_p
=
block_tile_reduce
<
SMPLComputeDataType
>
(
p_compute
,
sequence
<
1
>
{},
f_sum
,
SMPLComputeDataType
{
0
});
// rowsum(Pcompute{j})
block_tile_reduce_sync
(
rowsum_p
,
f_sum
,
bool_constant
<
false
>
{});
// l{j}, Oacc{j}
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
return
exp2
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
else
{
auto
row_max
=
scale_s
*
get_validated_m
(
m
[
i_idx
]);
return
exp2
(
scale_s
*
m_old
[
i_idx
]
-
row_max
);
}
}();
#else
const
auto
tmp
=
exp
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
#endif
l
(
i_idx
)
=
tmp
*
l
[
i_idx
]
+
rowsum_p
[
i_idx
];
sweep_tile_span
(
o_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
if
constexpr
(
kHasDropout
)
{
auto
randval_ptr
=
reinterpret_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeKV
<
Problem
>();
dropout
.
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>
(
randval_ptr
,
seqlen_k_start
+
i_total_loops
*
kN0
,
p_compute
,
randval_dram_window
);
}
const
auto
p
=
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
// STAGE 3, KV gemm
if
constexpr
(
k1_loops
>
1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
if
constexpr
(
i_k1
!=
0
&&
i_k1
<
k1_loops
-
1
)
{
v_buf
=
load_tile
(
v_dram_window
,
bool_constant
<
false
>
{});
// load next v_buf
}
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
i_k1
*
kK1
>
{},
sequence
<
kM0
,
(
i_k1
+
1
)
*
kK1
>
{}),
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
>
{})
+
1
)
*
kN1
,
kK1
>
{}));
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v_buf
);
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
+
1
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
+
1
>
{})
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
// store the prefetch
}
else
{
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
+
1
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
+
1
>
{})
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_buf
));
// store next v_buf
}
if
constexpr
(
i_k1
<
k1_loops
-
1
)
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
});
}
i_total_loops
++
;
if
(
i_total_loops
<
num_total_loop
)
{
// move K tile windows
move_tile_window
(
k_dram_block_window
,
{
kN0
,
0
});
k_dram_window
=
make_tile_window
(
k_dram_block_window
.
get_bottom_tensor_view
(),
k_dram_block_window
.
get_window_lengths
(),
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
if
constexpr
(
k1_loops
>=
2
&&
LdsSeq
.
at
(
number
<
0
>
{})
==
LdsSeq
.
at
(
number
<
k0_loops
+
k1_loops
-
2
>
{}))
__builtin_amdgcn_s_barrier
();
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
}
// tail
{
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
(
k1_loops
-
1
)
*
kK1
>
{},
sequence
<
kM0
,
kN0
>
{}),
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
k1_loops
-
1
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
k1_loops
-
1
>
{})
+
1
)
*
kN1
,
kK1
>
{}));
}
}
while
(
i_total_loops
<
num_total_loop
);
// store lse acc
if
constexpr
(
kStoreLSE
)
{
auto
lse_acc
=
make_static_distributed_tensor
<
LSEDataType
>
(
m
.
get_tile_distribution
());
constexpr
auto
lse_acc_spans
=
decltype
(
lse_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
lse_acc_spans
[
number
<
0
>
{}],
[
&
,
m_
=
m
,
l_
=
l
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
lse_acc
(
i_idx
)
=
m_
[
i_idx
]
*
R_LOG2E
+
log
(
l_
[
i_idx
]);
}
else
{
lse_acc
(
i_idx
)
=
m_
[
i_idx
]
*
scale_s
*
R_LOG2E
+
log
(
l_
[
i_idx
]);
}
#else
lse_acc
(
i_idx
)
=
m_
[
i_idx
]
+
log
(
l_
[
i_idx
]);
#endif
});
store_tile
(
lse_acc_dram_window_tmp
,
tile_elementwise_in
(
lse_acc_element_func
,
lse_acc
));
}
// finally, O
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
FmhaMask
::
IsMasking
)
{
return
l
[
i_idx
]
==
0.
f
?
0.
f
:
1
/
l
[
i_idx
];
}
else
return
1
/
l
[
i_idx
];
}();
sweep_tile_span
(
o_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
o_acc
=
tile_elementwise_in
(
o_acc_element_func
,
o_acc
);
return
o_acc
;
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEaccDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
// M0*N0 tile
LSEaccDramBlockWindowTmp
&
lse_acc_dram_block_window_tmp
,
// M0*1 tile
index_t
num_splits
,
index_t
i_split
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
,
BlockDropout
&
dropout
)
const
{
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
k_dram_block_window_tmp
,
identity
{},
v_dram_block_window_tmp
,
identity
{},
bias_dram_block_window_tmp
,
identity
{},
randval_dram_block_window_tmp
,
lse_acc_dram_block_window_tmp
,
identity
{},
identity
{},
identity
{},
identity
{},
num_splits
,
i_split
,
mask
,
position_encoding
,
scale_s
,
smem_ptr
,
dropout
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp
deleted
100644 → 0
View file @
cbf14ee1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
using
BlockFmhaFwdSplitKVPipelineQRKSVSAsyncDefaultPolicy
=
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopyK = */
true
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
3
,
/* NumPrefetchV = */
3
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
View file @
4885c38a
...
...
@@ -54,38 +54,50 @@ struct BlockFmhaPipelineProblem
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
template
<
typename
QDataType
,
typename
KDataType
,
typename
VDataType
,
typename
SaccDataType
,
typename
SMPLComputeDataType
,
typename
BiasDataType
,
typename
RandValOutputDataType
,
typename
LSEDataType
,
typename
PDataType
,
typename
OaccDataType
,
typename
ODataType
,
typename
BlockFmhaShape
,
bool
kIsGroupMode
,
typename
FmhaMask
,
typename
Traits
>
struct
BlockFmhaFwdSplitKVPipelineProblem
:
BlockFmhaPipelineProblem
<
QDataType
,
KDataType
,
VDataType
,
SaccDataType
,
SMPLComputeDataType
,
BiasDataType
,
RandValOutputDataType
,
LSEDataType
,
PDataType
,
OaccDataType
,
ODataType
,
BlockFmhaShape
,
kIsGroupMode
,
FmhaMask
,
Traits
>
template
<
typename
QDataType_
,
typename
KDataType_
,
typename
VDataType_
,
typename
SaccDataType_
,
typename
SMPLComputeDataType_
,
typename
BiasDataType_
,
typename
LSEDataType_
,
typename
PDataType_
,
typename
OaccDataType_
,
typename
ODataType_
,
typename
BlockFmhaShape_
,
bool
kIsGroupMode_
,
typename
FmhaMask_
,
typename
Traits_
>
struct
BlockFmhaFwdSplitKVPipelineProblem
{
static
constexpr
bool
kHasUnevenSplits
=
kIsGroupMode
||
Traits
::
kHasUnevenSplits
;
using
QDataType
=
remove_cvref_t
<
QDataType_
>
;
using
KDataType
=
remove_cvref_t
<
KDataType_
>
;
using
VDataType
=
remove_cvref_t
<
VDataType_
>
;
using
SaccDataType
=
remove_cvref_t
<
SaccDataType_
>
;
using
SMPLComputeDataType
=
remove_cvref_t
<
SMPLComputeDataType_
>
;
using
BiasDataType
=
remove_cvref_t
<
BiasDataType_
>
;
using
LSEDataType
=
remove_cvref_t
<
LSEDataType_
>
;
using
PDataType
=
remove_cvref_t
<
PDataType_
>
;
using
OaccDataType
=
remove_cvref_t
<
OaccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
BlockFmhaShape_
>
;
using
FmhaMask
=
remove_cvref_t
<
FmhaMask_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
index_t
kBlockSize
=
BlockFmhaShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
// attributes from traits
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Traits
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Traits
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Traits
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Traits
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Traits
::
kStoreLSE
;
static
constexpr
bool
kDoFp8StaticQuant
=
Traits
::
kDoFp8StaticQuant
;
static
constexpr
bool
kIsPagedKV
=
Traits
::
kIsPagedKV
;
static
constexpr
bool
kHasUnevenSplits
=
kIsGroupMode
||
Traits
::
kHasUnevenSplits
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
template
<
typename
LSEDataType_
,
...
...
@@ -119,4 +131,44 @@ struct BlockFmhaSplitKVCombinePipelineProblem
static
constexpr
index_t
kMaxSplits
=
Traits
::
kMaxSplits
;
};
template
<
typename
QDataType_
,
typename
KDataType_
,
typename
VDataType_
,
index_t
kM0_
,
index_t
kN0_
,
index_t
kK0_
,
index_t
kN1_
,
bool
kIsVLayoutRowMajor_
,
RotaryEmbeddingEnum
RotaryEnum_
,
bool
kIsPagedKV_
,
typename
Traits_
>
struct
BlockFmhaFwdAppendKVPipelineProblem
{
using
QDataType
=
remove_cvref_t
<
QDataType_
>
;
using
KDataType
=
remove_cvref_t
<
KDataType_
>
;
using
VDataType
=
remove_cvref_t
<
VDataType_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
index_t
kBlockSize
=
256
;
static
constexpr
index_t
kM0
=
kM0_
;
static
constexpr
index_t
kN0
=
kN0_
;
static
constexpr
index_t
kK0
=
kK0_
;
static
constexpr
index_t
kN1
=
kN1_
;
using
VLayout
=
std
::
conditional_t
<
kIsVLayoutRowMajor_
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
static
constexpr
auto
RotaryEnum
=
RotaryEnum_
;
static
constexpr
bool
kIsPagedKV
=
kIsPagedKV_
;
// attributes from traits
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Traits
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Traits
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Traits
::
kPadHeadDimV
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
}
// namespace ck_tile
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment