Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
72c9f129
Commit
72c9f129
authored
Sep 20, 2024
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
241c261f
ded0d83d
Changes
235
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2293 additions
and
445 deletions
+2293
-445
include/ck_tile/core/numeric/vector_type.hpp
include/ck_tile/core/numeric/vector_type.hpp
+9
-0
include/ck_tile/core/tensor/tile_distribution.hpp
include/ck_tile/core/tensor/tile_distribution.hpp
+3
-2
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+52
-1
include/ck_tile/core/utility/philox_rand.hpp
include/ck_tile/core/utility/philox_rand.hpp
+33
-0
include/ck_tile/core/utility/type_traits.hpp
include/ck_tile/core/utility/type_traits.hpp
+17
-0
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+1
-0
include/ck_tile/host/host_tensor.hpp
include/ck_tile/host/host_tensor.hpp
+9
-0
include/ck_tile/host/kernel_launch.hpp
include/ck_tile/host/kernel_launch.hpp
+5
-5
include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp
...reference/reference_batched_rotary_position_embedding.hpp
+73
-0
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+9
-10
include/ck_tile/ops/fmha/block/block_dropout.hpp
include/ck_tile/ops/fmha/block/block_dropout.hpp
+377
-16
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
+19
-3
include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp
include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp
+108
-0
include/ck_tile/ops/fmha/block/page_block_navigator.hpp
include/ck_tile/ops/fmha/block/page_block_navigator.hpp
+279
-0
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
+555
-332
include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
...ude/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
+0
-54
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp
+679
-0
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp
...le/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp
+42
-0
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+2
-4
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
..._tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
+21
-18
No files found.
include/ck_tile/core/numeric/vector_type.hpp
View file @
72c9f129
...
@@ -117,6 +117,15 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
...
@@ -117,6 +117,15 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
using
int32x32_t
=
int32_t
__attribute__
((
ext_vector_type
(
32
)));
using
int32x32_t
=
int32_t
__attribute__
((
ext_vector_type
(
32
)));
using
int32x64_t
=
int32_t
__attribute__
((
ext_vector_type
(
64
)));
using
int32x64_t
=
int32_t
__attribute__
((
ext_vector_type
(
64
)));
// u32
// using uint32_t = ...
using
uint32x2_t
=
uint32_t
__attribute__
((
ext_vector_type
(
2
)));
using
uint32x4_t
=
uint32_t
__attribute__
((
ext_vector_type
(
4
)));
using
uint32x8_t
=
uint32_t
__attribute__
((
ext_vector_type
(
8
)));
using
uint32x16_t
=
uint32_t
__attribute__
((
ext_vector_type
(
16
)));
using
uint32x32_t
=
uint32_t
__attribute__
((
ext_vector_type
(
32
)));
using
uint32x64_t
=
uint32_t
__attribute__
((
ext_vector_type
(
64
)));
// i16
// i16
// using int16_t = ...
// using int16_t = ...
using
int16x2_t
=
int16_t
__attribute__
((
ext_vector_type
(
2
)));
using
int16x2_t
=
int16_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
include/ck_tile/core/tensor/tile_distribution.hpp
View file @
72c9f129
...
@@ -746,8 +746,9 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
...
@@ -746,8 +746,9 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
return
make_tuple
(
return
make_tuple
(
make_static_tile_distribution
(
make_static_tile_distribution
(
tile_distribution_encoding
<
typename
Encoding
::
RsLengths
,
tile_distribution_encoding
<
typename
Encoding
::
RsLengths
,
decltype
(
sliced_h_lengths
),
// only need to change the
remove_cvref_t
<
decltype
(
sliced_h_lengths
)
>
,
// only need to
// h_lengths type
// change the
// h_lengths type
typename
Encoding
::
Ps2RHssMajor
,
typename
Encoding
::
Ps2RHssMajor
,
typename
Encoding
::
Ps2RHssMinor
,
typename
Encoding
::
Ps2RHssMinor
,
typename
Encoding
::
Ys2RHsMajor
,
typename
Encoding
::
Ys2RHsMajor
,
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
72c9f129
...
@@ -214,6 +214,12 @@ struct tile_window_with_static_distribution
...
@@ -214,6 +214,12 @@ struct tile_window_with_static_distribution
CK_TILE_DEVICE
constexpr
auto
get_window_origin
()
const
{
return
window_origin_
;
}
CK_TILE_DEVICE
constexpr
auto
get_window_origin
()
const
{
return
window_origin_
;
}
CK_TILE_DEVICE
constexpr
void
set_bottom_tensor_view_data_ptr
(
typename
BottomTensorView
::
DataType
*
data
)
{
bottom_tensor_view_
.
buf_
.
p_data_
=
data
;
}
// move thread's window adaptor coordinate and bottom tensor coordinate
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
CK_TILE_DEVICE
void
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
CK_TILE_DEVICE
void
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
...
@@ -393,7 +399,8 @@ struct tile_window_with_static_distribution
...
@@ -393,7 +399,8 @@ struct tile_window_with_static_distribution
bottom_tensor_thread_coord
,
bottom_tensor_thread_coord
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
oob_conditional_check
>
{},
pre_nop_
);
pre_nop_
);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
asm
volatile
(
asm
volatile
(
""
);
// this is starting from rocm-6.2, but same sympton, reuse this flag
""
);
// this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
#endif
...
@@ -843,6 +850,17 @@ struct tile_window_with_static_lengths
...
@@ -843,6 +850,17 @@ struct tile_window_with_static_lengths
CK_TILE_DEVICE
constexpr
auto
get_window_origin
()
const
{
return
window_origin_
;
}
CK_TILE_DEVICE
constexpr
auto
get_window_origin
()
const
{
return
window_origin_
;
}
CK_TILE_DEVICE
void
set_window_origin
(
const
BottomTensorIndex
&
new_window_origin
)
{
window_origin_
=
new_window_origin
;
}
CK_TILE_DEVICE
constexpr
void
set_bottom_tensor_view_data_ptr
(
typename
BottomTensorView
::
DataType
*
data
)
{
bottom_tensor_view_
.
buf_
.
p_data_
=
data
;
}
// move window-origin
// move window-origin
CK_TILE_DEVICE
void
move
(
const
BottomTensorIndex
&
step
)
{
window_origin_
+=
step
;
}
CK_TILE_DEVICE
void
move
(
const
BottomTensorIndex
&
step
)
{
window_origin_
+=
step
;
}
...
@@ -871,6 +889,39 @@ make_tile_window(const TensorView_& tensor_view,
...
@@ -871,6 +889,39 @@ make_tile_window(const TensorView_& tensor_view,
tensor_view
,
window_lengths
,
origin
};
tensor_view
,
window_lengths
,
origin
};
}
}
// duplicate tile window and replace its origin
template
<
typename
TensorView
,
typename
WindowLengths
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window
(
const
tile_window_with_static_lengths
<
TensorView
,
WindowLengths
>&
tile_window
,
const
multi_index
<
TensorView
::
get_num_of_dimension
()
>&
origin
)
{
return
tile_window_with_static_lengths
<
TensorView
,
WindowLengths
>
{
tile_window
.
get_bottom_tensor_view
(),
tile_window
.
get_window_lengths
(),
origin
};
}
template
<
typename
TensorView
,
typename
WindowLengths
,
typename
StaticTileDistribution
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window
(
const
tile_window_with_static_lengths
<
TensorView
,
WindowLengths
>&
tile_window
,
const
multi_index
<
TensorView
::
get_num_of_dimension
()
>&
origin
,
const
StaticTileDistribution
&
tile_distribution
)
{
return
make_tile_window
(
tile_window
.
get_bottom_tensor_view
(),
tile_window
.
get_window_lengths
(),
origin
,
tile_distribution
);
}
template
<
typename
TensorView
,
typename
WindowLengths
,
typename
StaticTileDistribution
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window
(
const
tile_window_with_static_lengths
<
TensorView
,
WindowLengths
>&
tile_window
,
const
StaticTileDistribution
&
tile_distribution
)
{
return
make_tile_window
(
tile_window
.
get_bottom_tensor_view
(),
tile_window
.
get_window_lengths
(),
tile_window
.
get_window_origin
(),
tile_distribution
);
}
template
<
typename
TensorView_
,
typename
WindowLengths_
>
template
<
typename
TensorView_
,
typename
WindowLengths_
>
CK_TILE_DEVICE
void
move_tile_window
(
CK_TILE_DEVICE
void
move_tile_window
(
tile_window_with_static_lengths
<
TensorView_
,
WindowLengths_
>&
window
,
tile_window_with_static_lengths
<
TensorView_
,
WindowLengths_
>&
window
,
...
...
include/ck_tile/core/utility/philox_rand.hpp
View file @
72c9f129
...
@@ -53,6 +53,39 @@ class philox
...
@@ -53,6 +53,39 @@ class philox
out_tmp
[
3
]
=
tmp_ph
.
w
;
out_tmp
[
3
]
=
tmp_ph
.
w
;
}
}
CK_TILE_HOST_DEVICE
void
get_random_8x8
(
uint8_t
*
out
,
const
unsigned
long
long
subsequence
,
const
index_t
start_idx
)
const
{
uint4
tmp_ph
;
tmp_ph
=
get_philox_4x32
(
subsequence
);
uint32x4_t
tmp
;
tmp
[
0
]
=
tmp_ph
.
x
;
tmp
[
1
]
=
tmp_ph
.
y
;
tmp
[
2
]
=
tmp_ph
.
z
;
tmp
[
3
]
=
tmp_ph
.
w
;
uint32_t
*
out_tmp
=
reinterpret_cast
<
uint32_t
*>
(
&
out
[
0
]);
out_tmp
[
0
]
=
tmp
[
start_idx
];
out_tmp
[
1
]
=
tmp
[
start_idx
+
2
];
}
CK_TILE_HOST_DEVICE
void
get_random_4x8
(
uint8_t
*
out
,
const
unsigned
long
long
subsequence
,
const
index_t
start_idx
)
const
{
uint4
tmp_ph
;
tmp_ph
=
get_philox_4x32
(
subsequence
);
uint32x4_t
tmp
;
tmp
[
0
]
=
tmp_ph
.
x
;
tmp
[
1
]
=
tmp_ph
.
y
;
tmp
[
2
]
=
tmp_ph
.
z
;
tmp
[
3
]
=
tmp_ph
.
w
;
uint32_t
*
out_tmp
=
reinterpret_cast
<
uint32_t
*>
(
&
out
[
0
]);
out_tmp
[
0
]
=
tmp
[
start_idx
];
}
private:
private:
struct
ull2
struct
ull2
{
{
...
...
include/ck_tile/core/utility/type_traits.hpp
View file @
72c9f129
...
@@ -22,6 +22,23 @@ using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
...
@@ -22,6 +22,23 @@ using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
template
<
typename
T
>
template
<
typename
T
>
using
remove_pointer_t
=
typename
std
::
remove_pointer
<
T
>::
type
;
using
remove_pointer_t
=
typename
std
::
remove_pointer
<
T
>::
type
;
template
<
typename
From
,
typename
To
>
struct
copy_const
{
static_assert
(
!
std
::
is_const_v
<
From
>
);
using
type
=
To
;
};
template
<
typename
From
,
typename
To
>
struct
copy_const
<
const
From
,
To
>
{
using
type
=
std
::
add_const_t
<
typename
copy_const
<
From
,
To
>::
type
>
;
};
template
<
typename
From
,
typename
To
>
using
copy_const_t
=
typename
copy_const
<
From
,
To
>::
type
;
namespace
detail
{
namespace
detail
{
template
<
class
Default
,
class
AlwaysVoid
,
template
<
class
...
>
class
Op
,
class
...
Args
>
template
<
class
Default
,
class
AlwaysVoid
,
template
<
class
...
>
class
Op
,
class
...
Args
>
struct
detector
struct
detector
...
...
include/ck_tile/host.hpp
View file @
72c9f129
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "ck_tile/host/reference/reference_batched_elementwise.hpp"
#include "ck_tile/host/reference/reference_batched_elementwise.hpp"
#include "ck_tile/host/reference/reference_batched_gemm.hpp"
#include "ck_tile/host/reference/reference_batched_gemm.hpp"
#include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
...
...
include/ck_tile/host/host_tensor.hpp
View file @
72c9f129
...
@@ -155,7 +155,12 @@ struct HostTensorDescriptor
...
@@ -155,7 +155,12 @@ struct HostTensorDescriptor
return
space
;
return
space
;
}
}
std
::
size_t
get_length
(
std
::
size_t
dim
)
const
{
return
mLens
[
dim
];
}
const
std
::
vector
<
std
::
size_t
>&
get_lengths
()
const
{
return
mLens
;
}
const
std
::
vector
<
std
::
size_t
>&
get_lengths
()
const
{
return
mLens
;
}
std
::
size_t
get_stride
(
std
::
size_t
dim
)
const
{
return
mStrides
[
dim
];
}
const
std
::
vector
<
std
::
size_t
>&
get_strides
()
const
{
return
mStrides
;
}
const
std
::
vector
<
std
::
size_t
>&
get_strides
()
const
{
return
mStrides
;
}
template
<
typename
...
Is
>
template
<
typename
...
Is
>
...
@@ -325,8 +330,12 @@ struct HostTensor
...
@@ -325,8 +330,12 @@ struct HostTensor
{
{
}
}
std
::
size_t
get_length
(
std
::
size_t
dim
)
const
{
return
mDesc
.
get_length
(
dim
);
}
decltype
(
auto
)
get_lengths
()
const
{
return
mDesc
.
get_lengths
();
}
decltype
(
auto
)
get_lengths
()
const
{
return
mDesc
.
get_lengths
();
}
std
::
size_t
get_stride
(
std
::
size_t
dim
)
const
{
return
mDesc
.
get_stride
(
dim
);
}
decltype
(
auto
)
get_strides
()
const
{
return
mDesc
.
get_strides
();
}
decltype
(
auto
)
get_strides
()
const
{
return
mDesc
.
get_strides
();
}
std
::
size_t
get_num_of_dimension
()
const
{
return
mDesc
.
get_num_of_dimension
();
}
std
::
size_t
get_num_of_dimension
()
const
{
return
mDesc
.
get_num_of_dimension
();
}
...
...
include/ck_tile/host/kernel_launch.hpp
View file @
72c9f129
...
@@ -73,17 +73,17 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables)
...
@@ -73,17 +73,17 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables)
{
{
// clang-format off
// clang-format off
if
(
!
s
.
time_kernel_
)
{
if
(
!
s
.
time_kernel_
)
{
(
callables
(
s
),...);
hip_check_error
(
hipGetLastError
());
(
callables
(
s
),...);
HIP_CHECK_ERROR
(
hipGetLastError
());
return
0
;
return
0
;
}
}
if
(
s
.
is_gpu_timer_
)
{
if
(
s
.
is_gpu_timer_
)
{
gpu_timer
timer
{};
gpu_timer
timer
{};
// warmup
// warmup
for
(
int
i
=
0
;
i
<
s
.
cold_niters_
;
i
++
)
{
(
callables
(
s
),...);
}
hip_check_error
(
hipGetLastError
());
for
(
int
i
=
0
;
i
<
s
.
cold_niters_
;
i
++
)
{
(
callables
(
s
),...);
}
HIP_CHECK_ERROR
(
hipGetLastError
());
timer
.
start
(
s
.
stream_id_
);
timer
.
start
(
s
.
stream_id_
);
for
(
int
i
=
0
;
i
<
s
.
nrepeat_
;
i
++
)
{
(
callables
(
s
),...);
}
hip_check_error
(
hipGetLastError
());
for
(
int
i
=
0
;
i
<
s
.
nrepeat_
;
i
++
)
{
(
callables
(
s
),...);
}
HIP_CHECK_ERROR
(
hipGetLastError
());
timer
.
stop
(
s
.
stream_id_
);
timer
.
stop
(
s
.
stream_id_
);
return
timer
.
duration
()
/
s
.
nrepeat_
;
return
timer
.
duration
()
/
s
.
nrepeat_
;
...
@@ -92,10 +92,10 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables)
...
@@ -92,10 +92,10 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables)
cpu_timer
timer
{};
cpu_timer
timer
{};
// warmup
// warmup
for
(
int
i
=
0
;
i
<
s
.
cold_niters_
;
i
++
)
{
(
callables
(
s
),...);
}
hip_check_error
(
hipGetLastError
());
for
(
int
i
=
0
;
i
<
s
.
cold_niters_
;
i
++
)
{
(
callables
(
s
),...);
}
HIP_CHECK_ERROR
(
hipGetLastError
());
timer
.
start
(
s
.
stream_id_
);
timer
.
start
(
s
.
stream_id_
);
for
(
int
i
=
0
;
i
<
s
.
nrepeat_
;
i
++
)
{
(
callables
(
s
),...);
}
hip_check_error
(
hipGetLastError
());
for
(
int
i
=
0
;
i
<
s
.
nrepeat_
;
i
++
)
{
(
callables
(
s
),...);
}
HIP_CHECK_ERROR
(
hipGetLastError
());
timer
.
stop
(
s
.
stream_id_
);
timer
.
stop
(
s
.
stream_id_
);
return
timer
.
duration
()
/
s
.
nrepeat_
;
return
timer
.
duration
()
/
s
.
nrepeat_
;
...
...
include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp
0 → 100644
View file @
72c9f129
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <cassert>
#include <thread>
namespace
ck_tile
{
template
<
typename
DataType
,
typename
ComputeDataType
=
float
>
CK_TILE_HOST
void
reference_batched_rotary_position_embedding
(
const
HostTensor
<
DataType
>&
input_bsd
,
const
HostTensor
<
DataType
>&
cos_sd
,
const
HostTensor
<
DataType
>&
sin_sd
,
bool
interleaved
,
HostTensor
<
DataType
>&
output_bsd
,
bool
use_1_row_sin_cos
=
false
)
{
assert
(
cos_sd
.
get_num_of_dimension
()
==
2
&&
sin_sd
.
get_num_of_dimension
()
==
2
);
assert
(
cos_sd
.
get_length
(
0
)
==
sin_sd
.
get_length
(
0
)
&&
cos_sd
.
get_length
(
1
)
==
sin_sd
.
get_length
(
1
));
const
index_t
rotary_dim
=
cos_sd
.
get_length
(
1
)
*
2
;
assert
(
static_cast
<
std
::
size_t
>
(
rotary_dim
)
<=
input_bsd
.
get_length
(
2
));
output_bsd
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
const
index_t
i_d
=
i
[
2
];
if
(
rotary_dim
<=
i_d
)
{
self
(
i
)
=
input_bsd
(
i
);
return
;
}
assert
(
i_d
<
rotary_dim
);
const
index_t
i_s
=
i
[
1
];
const
index_t
i_s_cos_sin
=
(
use_1_row_sin_cos
?
0
:
i_s
);
const
ComputeDataType
cos
=
type_convert
<
ComputeDataType
>
(
interleaved
?
cos_sd
(
i_s_cos_sin
,
i_d
/
2
)
:
cos_sd
(
i_s_cos_sin
,
i_d
%
cos_sd
.
get_length
(
1
)));
const
ComputeDataType
sin
=
type_convert
<
ComputeDataType
>
(
interleaved
?
sin_sd
(
i_s_cos_sin
,
i_d
/
2
)
:
sin_sd
(
i_s_cos_sin
,
i_d
%
sin_sd
.
get_length
(
1
)));
const
ComputeDataType
half_rotated_input
=
[
&
]
{
const
index_t
i_b
=
i
[
0
];
if
(
interleaved
)
{
const
bool
is_even
=
(
i_d
%
2
==
0
);
const
index_t
pos
=
i_d
+
(
is_even
?
1
:
-
1
);
const
ComputeDataType
sign
=
(
is_even
?
-
1
:
1
);
return
sign
*
type_convert
<
ComputeDataType
>
(
input_bsd
(
i_b
,
i_s
,
pos
));
}
else
{
const
index_t
half_rdim
=
(
rotary_dim
/
2
);
const
index_t
pos
=
(
i_d
+
half_rdim
)
%
rotary_dim
;
const
ComputeDataType
sign
=
(
pos
<
half_rdim
?
1
:
-
1
);
return
sign
*
type_convert
<
ComputeDataType
>
(
input_bsd
(
i_b
,
i_s
,
pos
));
}
}();
ComputeDataType
result
=
type_convert
<
ComputeDataType
>
(
input_bsd
(
i
))
*
cos
+
half_rotated_input
*
sin
;
self
(
i
)
=
type_convert
<
DataType
>
(
result
);
});
}
}
// namespace ck_tile
include/ck_tile/ops/fmha.hpp
View file @
72c9f129
...
@@ -7,30 +7,29 @@
...
@@ -7,30 +7,29 @@
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
#include "ck_tile/ops/fmha/block/page_block_navigator.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
...
...
include/ck_tile/ops/fmha/block/block_dropout.hpp
View file @
72c9f129
...
@@ -286,11 +286,226 @@ struct BlockDropout
...
@@ -286,11 +286,226 @@ struct BlockDropout
});
});
}
}
ck_tile
::
philox
ph
;
const
float
rp_undrop
;
const
uint8_t
p_undrop_in_uint8_t
;
const
bool
is_store_randval
;
};
template
<
bool
IsDropout_
,
bool
IsWG32_
,
bool
IsStoreRandval_
>
struct
BlockDropoutBwd
;
template
<
bool
IsWG32_
,
bool
IsStoreRandval_
>
struct
BlockDropoutBwd
<
false
,
IsWG32_
,
IsStoreRandval_
>
{
static
constexpr
bool
IsDropout
=
false
;
static
constexpr
bool
IsStoreRandval
=
IsStoreRandval_
;
template
<
typename
BlockGemm
,
bool
IsFwd
=
true
,
typename
RandValDramBlockWindowTmp
>
__host__
__device__
static
constexpr
auto
MakeRandvalDramWindow
(
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
index_t
seqlen_qk_start
)
{
(
void
)
randval_dram_block_window_tmp
;
(
void
)
seqlen_qk_start
;
return
make_null_tile_window
(
make_tuple
(
number
<
0
>
{},
number
<
0
>
{}));
}
};
template
<
bool
IsWG32_
,
bool
IsStoreRandval_
>
struct
BlockDropoutBwd
<
true
,
IsWG32_
,
IsStoreRandval_
>
{
static
constexpr
bool
IsDropout
=
true
;
// true: 32*32 warp gemm
// false: 16*16 warp gemm
static
constexpr
bool
IsWG32
=
IsWG32_
;
static
constexpr
bool
IsStoreRandval
=
IsStoreRandval_
;
CK_TILE_HOST_DEVICE
BlockDropoutBwd
(
index_t
i_batch
,
index_t
i_head
,
index_t
nheads
,
unsigned
long
long
seed
,
unsigned
long
long
offset
,
float
rp_undrop_
,
uint8_t
p_undrop_in_uint8_t_
)
:
ph
(
seed
,
offset
+
(
i_batch
*
nheads
+
i_head
)
*
get_warp_size
()
+
(
IsWG32
?
get_lane_id
()
:
((
get_lane_id
()
&
47
)
+
((
get_warp_id
()
&
1
)
<<
4
)))),
rp_undrop
(
rp_undrop_
),
p_undrop_in_uint8_t
(
p_undrop_in_uint8_t_
)
{
}
template
<
typename
BlockGemm
,
bool
IsFwd
=
true
,
typename
RandValDramBlockWindowTmp
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandvalDramWindow
(
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
index_t
seqlen_qk_start
)
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
BlockGemmShape
=
remove_cvref_t
<
typename
BlockGemm
::
BlockGemmShape
>
;
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
bool
MBwdWG16MultiIterCheck
=
(
!
IsFwd
)
&&
(
!
IsWG32
)
&&
(
kMPerBlock
>
16
);
constexpr
index_t
kMPerStep
=
[
&
]()
{
if
constexpr
(
MBwdWG16MultiIterCheck
)
{
return
MWarp
*
WG
::
kM
*
2
;
}
else
{
return
MWarp
*
WG
::
kM
;
}
}();
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
const
auto
block_origin
=
randval_dram_block_window_tmp
.
get_window_origin
();
auto
randval_dram_window
=
[
&
]()
{
if
constexpr
(
IsFwd
)
{
return
make_tile_window
(
randval_dram_block_window_tmp
.
get_bottom_tensor_view
(),
ck_tile
::
make_tuple
(
number
<
kMPerStep
>
{},
number
<
kNPerStep
>
{}),
{
block_origin
.
at
(
number
<
0
>
{}),
seqlen_qk_start
});
// M/N
}
else
{
return
make_tile_window
(
randval_dram_block_window_tmp
.
get_bottom_tensor_view
(),
ck_tile
::
make_tuple
(
number
<
kMPerStep
>
{},
number
<
kNPerStep
>
{}),
{
seqlen_qk_start
,
block_origin
.
at
(
number
<
1
>
{})});
// M/N
}
}();
return
randval_dram_window
;
}
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandValLdsBlockDescriptor
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kNPerStep
=
WG
::
kN
;
constexpr
index_t
kN1
=
8
;
constexpr
index_t
kN0
=
kNPerStep
/
kN1
;
constexpr
auto
randval_lds_block_desc_0
=
make_naive_tensor_descriptor
(
ck_tile
::
make_tuple
(
number
<
kN0
>
{},
number
<
kMPerStep
>
{},
number
<
kN1
>
{}),
ck_tile
::
make_tuple
(
number
<
(
kMPerStep
+
1
)
*
kN1
>
{},
number
<
kN1
>
{},
number
<
1
>
{}),
number
<
kN1
>
{},
number
<
1
>
{});
constexpr
auto
randval_lds_block_desc
=
transform_tensor_descriptor
(
randval_lds_block_desc_0
,
ck_tile
::
make_tuple
(
make_pass_through_transform
(
number
<
kMPerStep
>
{}),
make_merge_transform
(
ck_tile
::
make_tuple
(
number
<
kN0
>
{},
number
<
kN1
>
{}))),
ck_tile
::
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
ck_tile
::
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
randval_lds_block_desc
;
}
template
<
typename
BlockGemm
,
bool
IsFwd
=
true
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandValTileDistribution
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
BlockGemmShape
=
remove_cvref_t
<
typename
BlockGemm
::
BlockGemmShape
>
;
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
bool
MBwdWG16MultiIterCheck
=
(
!
IsFwd
)
&&
(
!
IsWG32
)
&&
(
kMPerBlock
>
16
);
constexpr
index_t
MIterPerWarp
=
[
&
]()
{
if
constexpr
(
MBwdWG16MultiIterCheck
)
{
return
2
;
}
else
{
return
1
;
}
}();
constexpr
index_t
NIterPerWarp
=
1
;
constexpr
auto
randval_block_outer_part_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
// Use Bwd WarpGemm to ensure that Fwd's random values are consistent with Bwd.
// except headdim256.
constexpr
auto
randval_block_inner_part_dstr_encoding
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
BlockGemm
::
ADataType
,
half_t
>
&&
std
::
is_same_v
<
typename
BlockGemm
::
BDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
BlockGemm
::
CDataType
,
float
>
)
{
if
constexpr
(
IsWG32
)
return
typename
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
::
CWarpDstrEncoding
{};
else
return
typename
WarpGemmMfmaF16F16F32M16N16K16
::
CWarpDstrEncoding
{};
}
else
{
if
constexpr
(
IsWG32
)
return
typename
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
::
CWarpDstrEncoding
{};
else
return
typename
WarpGemmMfmaBf16Bf16F32M16N16K16
::
CWarpDstrEncoding
{};
}
}();
constexpr
auto
randval_block_part_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
randval_block_outer_part_dstr_encoding
,
randval_block_inner_part_dstr_encoding
);
return
make_static_tile_distribution
(
randval_block_part_dstr_encode
);
}
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandValLdsShuffleTileDistribution
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
1
;
constexpr
index_t
NIterPerWarp
=
1
;
constexpr
auto
randval_block_outer_part_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
randval_block_part_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
randval_block_outer_part_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
return
make_static_tile_distribution
(
randval_block_part_dstr_encode
);
}
template
<
typename
BlockGemm
,
template
<
typename
BlockGemm
,
typename
PComputeDataType
,
typename
RandValOutputDataType
,
typename
RandValOutputDataType
,
typename
PComputeWindow
,
typename
PComputeWindow
,
typename
RandValDramWindow
>
typename
RandValDramWindow
>
CK_TILE_HOST_DEVICE
void
Run
(
const
index_t
start_m0_idx
,
CK_TILE_HOST_DEVICE
void
Run
(
void
*
randval_ptr
,
const
index_t
start_m0_idx
,
const
index_t
start_n0_idx
,
PComputeWindow
&
p_compute
,
PComputeWindow
&
p_compute
,
RandValDramWindow
&
randval_dram_window
)
const
RandValDramWindow
&
randval_dram_window
)
const
{
{
...
@@ -305,30 +520,177 @@ struct BlockDropout
...
@@ -305,30 +520,177 @@ struct BlockDropout
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
// randval tile in LDS
auto
randval_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
uint8_t
*>
(
randval_ptr
),
MakeRandValLdsBlockDescriptor
<
BlockGemm
>
());
auto
randval_lds_window
=
make_tile_window
(
randval_lds
,
MakeRandValLdsBlockDescriptor
<
BlockGemm
>
().
get_lengths
(),
{
0
,
0
});
// register distribute
// register distribute
auto
randval
=
auto
randval
_dist_generated
=
make_static_distributed_tensor
<
uint8_t
>
(
MakeRandValTileDistribution
<
BlockGemm
>
());
make_static_distributed_tensor
<
uint8_t
>
(
MakeRandValTileDistribution
<
BlockGemm
>
());
static_assert
(
randval
.
kThreadElementSpaceSize
==
16
);
static_assert
(
randval
_dist_generated
.
kThreadElementSpaceSize
==
16
);
const
int
start_n0_idx
=
randval_dram_window
.
get_window_origin
().
at
(
number
<
1
>
{});
auto
randval_lds_read_window
=
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
make_tile_window
(
randval_lds_window
.
get_bottom_tensor_view
(),
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
randval_lds_window
.
get_window_lengths
(),
int
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
i_m0
;
randval_lds_window
.
get_window_origin
(),
int
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
(
i_n0
*
NWarp
)
+
get_warp_id
();
MakeRandValLdsShuffleTileDistribution
<
BlockGemm
>
());
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
int
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
(
i_m0
*
MWarp
)
+
get_warp_id
();
int
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
i_n0
;
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
// generate random number
// generate random number
uint8_t
random_uint8_t
[
16
];
uint8_t
random_uint8_t
[
16
];
ph
.
get_random_16x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
));
ph
.
get_random_16x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
));
constexpr
auto
randval_dist_generated_spans
=
decltype
(
randval_dist_generated
)
::
get_distributed_spans
();
int
i_random_idx
=
0
;
sweep_tile_span
(
randval_dist_generated_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_dist_generated_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
randval_dist_generated
(
i_j_idx
)
=
random_uint8_t
[
i_random_idx
++
];
});
});
// save to LDS
store_tile
(
randval_lds_window
,
randval_dist_generated
);
block_sync_lds
();
// read from LDS to register
auto
randval
=
load_tile
(
randval_lds_read_window
);
constexpr
auto
randval_spans
=
decltype
(
randval
)
::
get_distributed_spans
();
constexpr
auto
randval_spans
=
decltype
(
randval
)
::
get_distributed_spans
();
int
i_random_idx
=
0
;
sweep_tile_span
(
randval_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
randval_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
p_idx0
=
tile_distributed_index
<
i_m0
>
{};
constexpr
auto
p_idx1
=
tile_distributed_index
<
i_n0
,
idx1
.
impl_
.
at
(
1
),
idx1
.
impl_
.
at
(
2
)
>
{};
constexpr
auto
p_idx
=
ck_tile
::
make_tuple
(
p_idx0
,
p_idx1
);
constexpr
auto
r_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
constexpr
auto
r_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
randval
(
r_idx
)
=
random_uint8_t
[
i_random_idx
++
];
p_compute
(
p_idx
)
=
randval
[
r_idx
]
<=
p_undrop_in_uint8_t
constexpr
auto
p_idx0
=
?
p_compute
[
p_idx
]
*
rp_undrop
tile_distributed_index
<
i_m0
,
idx0
.
impl_
.
at
(
1
),
idx0
.
impl_
.
at
(
2
)
>
{};
:
PComputeDataType
(
0
);
});
});
// save to Global
if
constexpr
(
IsStoreRandval
)
{
const
auto
randval_store
=
cast_tile
<
RandValOutputDataType
>
(
randval
);
store_tile
(
randval_dram_window
,
randval_store
);
move_tile_window
(
randval_dram_window
,
{
0
,
kNPerStep
});
}
});
if
constexpr
(
IsStoreRandval
)
{
move_tile_window
(
randval_dram_window
,
{
kMPerStep
,
-
kNPerBlock
});
}
});
if
constexpr
(
IsStoreRandval
)
{
move_tile_window
(
randval_dram_window
,
{
-
kMPerBlock
,
kNPerBlock
});
}
}
template
<
typename
BlockGemm
,
typename
RandValOutputDataType
,
typename
PComputeWindow
,
typename
RandValDramWindow
>
CK_TILE_HOST_DEVICE
void
Run
(
const
index_t
start_m0_idx
,
const
index_t
start_n0_idx
,
PComputeWindow
&
p_compute
,
RandValDramWindow
&
randval_dram_window
)
const
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
using
BlockGemmShape
=
remove_cvref_t
<
typename
BlockGemm
::
BlockGemmShape
>
;
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
constexpr
bool
MBwdWG16MultiIterCheck
=
(
!
IsWG32
)
&&
(
kMPerBlock
>
16
);
constexpr
bool
MBwdWG16SingleIterCheck
=
(
!
IsWG32
)
&&
(
kMPerBlock
==
16
);
constexpr
index_t
kMPerStep
=
[
&
]()
{
if
constexpr
(
MBwdWG16MultiIterCheck
)
{
return
MWarp
*
WG
::
kM
*
2
;
}
else
{
return
MWarp
*
WG
::
kM
;
}
}();
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
// register distribute
auto
randval
=
make_static_distributed_tensor
<
uint8_t
>
(
MakeRandValTileDistribution
<
BlockGemm
,
false
>
());
if
constexpr
(
IsWG32
)
static_assert
(
randval
.
kThreadElementSpaceSize
==
16
);
else
static_assert
(
randval
.
kThreadElementSpaceSize
==
4
||
randval
.
kThreadElementSpaceSize
==
8
);
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
int
block_row_start
,
block_col_start
;
if
constexpr
(
IsWG32
)
{
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
i_m0
;
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
(
i_n0
*
NWarp
)
+
get_warp_id
();
}
else
{
block_row_start
=
start_m0_idx
/
32
+
i_m0
;
block_col_start
=
(
start_n0_idx
/
32
)
+
get_warp_id
()
/
2
+
i_n0
*
2
;
}
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
// generate random number
uint8_t
*
random_uint8_t_
;
if
constexpr
(
MBwdWG16SingleIterCheck
)
{
uint8_t
random_uint8_t
[
4
];
// m0t0 ~m0t15/m0t32~m0t47: 0
// m0t16~m0t31/m0t48~m0t63: 1
// m1t0 ~m1t15/m1t32~m1t47: 2
// m1t16~m1t31/m1t48~m1t63: 3
const
index_t
start_idx
=
((
get_lane_id
()
>>
4
)
&
1
)
+
(((
start_m0_idx
>>
4
)
&
1
)
<<
1
);
ph
.
get_random_4x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
),
start_idx
);
random_uint8_t_
=
random_uint8_t
;
}
else
if
constexpr
(
MBwdWG16MultiIterCheck
)
{
uint8_t
random_uint8_t
[
8
];
// t0 ~t15/t32~t47: 0
// t16~t31/t48~t63: 1
const
index_t
start_idx
=
(
get_lane_id
()
>>
4
)
&
1
;
ph
.
get_random_8x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
),
start_idx
);
random_uint8_t_
=
random_uint8_t
;
}
else
{
uint8_t
random_uint8_t
[
16
];
ph
.
get_random_16x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
));
random_uint8_t_
=
random_uint8_t
;
}
constexpr
auto
randval_spans
=
decltype
(
randval
)
::
get_distributed_spans
();
int
i_random_idx
=
0
;
sweep_tile_span
(
randval_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
r_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
randval
(
r_idx
)
=
random_uint8_t_
[
i_random_idx
++
];
constexpr
auto
p_idx0
=
tile_distributed_index
<
i_m0
+
idx0
.
impl_
.
at
(
0
),
idx0
.
impl_
.
at
(
1
),
idx0
.
impl_
.
at
(
2
)
>
{};
constexpr
auto
p_idx1
=
tile_distributed_index
<
i_n0
>
{};
constexpr
auto
p_idx1
=
tile_distributed_index
<
i_n0
>
{};
constexpr
auto
p_idx
=
ck_tile
::
make_tuple
(
p_idx0
,
p_idx1
);
constexpr
auto
p_idx
=
ck_tile
::
make_tuple
(
p_idx0
,
p_idx1
);
p_compute
(
p_idx
)
=
randval
[
r_idx
]
<=
p_undrop_in_uint8_t
p_compute
(
p_idx
)
=
randval
[
r_idx
]
<=
p_undrop_in_uint8_t
...
@@ -337,19 +699,19 @@ struct BlockDropout
...
@@ -337,19 +699,19 @@ struct BlockDropout
});
});
});
});
// save to Global
// save to Global
if
(
is_s
tore
_r
andval
)
if
constexpr
(
IsS
tore
R
andval
)
{
{
const
auto
randval_store
=
cast_tile
<
RandValOutputDataType
>
(
randval
);
const
auto
randval_store
=
cast_tile
<
RandValOutputDataType
>
(
randval
);
store_tile
(
randval_dram_window
,
randval_store
);
store_tile
(
randval_dram_window
,
randval_store
);
move_tile_window
(
randval_dram_window
,
{
kMPerStep
,
0
});
move_tile_window
(
randval_dram_window
,
{
kMPerStep
,
0
});
}
}
});
});
if
(
is_s
tore
_r
andval
)
if
constexpr
(
IsS
tore
R
andval
)
{
{
move_tile_window
(
randval_dram_window
,
{
-
kMPerBlock
,
kNPerStep
});
move_tile_window
(
randval_dram_window
,
{
-
kMPerBlock
,
kNPerStep
});
}
}
});
});
if
(
is_s
tore
_r
andval
)
if
constexpr
(
IsS
tore
R
andval
)
{
{
move_tile_window
(
randval_dram_window
,
{
kMPerBlock
,
-
kNPerBlock
});
move_tile_window
(
randval_dram_window
,
{
kMPerBlock
,
-
kNPerBlock
});
}
}
...
@@ -358,7 +720,6 @@ struct BlockDropout
...
@@ -358,7 +720,6 @@ struct BlockDropout
ck_tile
::
philox
ph
;
ck_tile
::
philox
ph
;
const
float
rp_undrop
;
const
float
rp_undrop
;
const
uint8_t
p_undrop_in_uint8_t
;
const
uint8_t
p_undrop_in_uint8_t
;
const
bool
is_store_randval
;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
View file @
72c9f129
...
@@ -43,9 +43,12 @@ enum struct AlibiMode
...
@@ -43,9 +43,12 @@ enum struct AlibiMode
FROM_BOTTOM_RIGHT
=
2
,
FROM_BOTTOM_RIGHT
=
2
,
};
};
template
<
typename
DataType
,
bool
RowMajor
=
true
>
template
<
typename
DataType
,
bool
RowMajor
=
true
,
unsigned
LogMaxSadOprndSize
=
16
>
struct
Alibi
struct
Alibi
{
{
static_assert
(
1
<=
LogMaxSadOprndSize
&&
LogMaxSadOprndSize
<=
32
,
"for LogMaxSadOprndSize <= 16, we use SAD uint16_t, otherwise, use SAD uint32_t"
);
// RowMajor here means if pixel within the same thread are along the row, or col
// RowMajor here means if pixel within the same thread are along the row, or col
// this may impact the performance of update(), while the result are the same.
// this may impact the performance of update(), while the result are the same.
// e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false
// e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false
...
@@ -79,6 +82,19 @@ struct Alibi
...
@@ -79,6 +82,19 @@ struct Alibi
mode
=
mode_
;
mode
=
mode_
;
}
}
CK_TILE_HOST
uint32_t
sad
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
{
return
sad_u32
(
x
,
y
,
acc
);
}
CK_TILE_DEVICE
uint32_t
sad
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
{
if
constexpr
(
LogMaxSadOprndSize
<=
16
)
{
return
sad_u16
(
static_cast
<
uint16_t
>
(
x
),
static_cast
<
uint16_t
>
(
y
),
static_cast
<
uint16_t
>
(
acc
));
}
return
sad_u32
(
x
,
y
,
acc
);
}
CK_TILE_HOST_DEVICE
void
update
(
DataType
&
pixel
,
index_t
row_idx
,
index_t
col_idx
)
CK_TILE_HOST_DEVICE
void
update
(
DataType
&
pixel
,
index_t
row_idx
,
index_t
col_idx
)
{
{
if
constexpr
(
RowMajor
)
if
constexpr
(
RowMajor
)
...
@@ -128,7 +144,7 @@ struct EmptyPositionEncoding
...
@@ -128,7 +144,7 @@ struct EmptyPositionEncoding
// can convert from the FA style left/right to our generic coordinate
// can convert from the FA style left/right to our generic coordinate
// if left_size < 0 && right_size = 0, it is normal causal mask
// if left_size < 0 && right_size = 0, it is normal causal mask
// local is left_size >=0 or right_size >=0
// local is left_size >=0 or right_size >=0
template
<
typename
DataType
,
bool
RowMajor
=
true
>
template
<
typename
DataType
,
bool
RowMajor
=
true
,
unsigned
LogMaxSadOprndSize
=
16
>
CK_TILE_HOST_DEVICE
auto
make_alibi_from_lr_mask
(
DataType
slope
,
CK_TILE_HOST_DEVICE
auto
make_alibi_from_lr_mask
(
DataType
slope
,
index_t
window_left_size
,
index_t
window_left_size
,
index_t
window_right_size
,
index_t
window_right_size
,
...
@@ -142,7 +158,7 @@ CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope,
...
@@ -142,7 +158,7 @@ CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope,
AlibiMode
alibi_mode
=
AlibiMode
alibi_mode
=
is_causal
?
AlibiMode
::
VERTICAL
is_causal
?
AlibiMode
::
VERTICAL
:
static_cast
<
AlibiMode
>
(
mask_enum
)
/*either top-left or bottom-right*/
;
:
static_cast
<
AlibiMode
>
(
mask_enum
)
/*either top-left or bottom-right*/
;
return
Alibi
<
DataType
,
RowMajor
>
{
slope
,
y_total
,
x_total
,
alibi_mode
};
return
Alibi
<
DataType
,
RowMajor
,
LogMaxSadOprndSize
>
{
slope
,
y_total
,
x_total
,
alibi_mode
};
}
}
// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
...
...
include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp
0 → 100644
View file @
72c9f129
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
namespace
ck_tile
{
// This class is used for codegen pattern matching
enum
class
RotaryEmbeddingEnum
{
NONE
=
0
,
INTERLEAVED
=
1
,
// combine dimensions 0 & 1, 2 & 3, etc
HALF_ROTATED
=
2
,
// combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1, etc
};
template
<
RotaryEmbeddingEnum
>
struct
RotaryEmbeddingEnumToStr
;
template
<
>
struct
RotaryEmbeddingEnumToStr
<
RotaryEmbeddingEnum
::
NONE
>
{
static
constexpr
const
char
*
name
=
""
;
};
template
<
>
struct
RotaryEmbeddingEnumToStr
<
RotaryEmbeddingEnum
::
INTERLEAVED
>
{
static
constexpr
const
char
*
name
=
"inter"
;
};
template
<
>
struct
RotaryEmbeddingEnumToStr
<
RotaryEmbeddingEnum
::
HALF_ROTATED
>
{
static
constexpr
const
char
*
name
=
"half"
;
};
template
<
RotaryEmbeddingEnum
RotaryEnum
,
typename
ComputeDataType
=
float
>
struct
BlockRotaryEmbedding
{
template
<
typename
DistributedTensor
,
typename
OtherDramBlockWindow
,
typename
RotaryCosDramBlockWindow
,
typename
RotarySinDramBlockWindow
>
CK_TILE_HOST_DEVICE
static
void
apply
(
DistributedTensor
&
tile
,
OtherDramBlockWindow
other_window
,
RotaryCosDramBlockWindow
rotary_cos_window
,
RotarySinDramBlockWindow
rotary_sin_window
,
index_t
rotary_dim
,
index_t
thread_end
)
{
using
DataType
=
typename
remove_cvref_t
<
DistributedTensor
>::
DataType
;
if
constexpr
(
RotaryEnum
==
RotaryEmbeddingEnum
::
INTERLEAVED
)
{
auto
rotary_cos_tile
=
load_tile
(
rotary_cos_window
);
auto
rotary_sin_tile
=
load_tile
(
rotary_sin_window
);
if
(
thread_end
<=
rotary_dim
)
{
constexpr
index_t
thread_buffer_size
=
decltype
(
tile
.
thread_buf_
)
::
size
();
static_for
<
0
,
thread_buffer_size
,
2
>
{}([
&
](
auto
idx
)
{
const
auto
left
=
type_convert
<
ComputeDataType
>
(
tile
.
thread_buf_
[
idx
]);
const
auto
right
=
type_convert
<
ComputeDataType
>
(
tile
.
thread_buf_
[
idx
+
1
]);
const
auto
cos
=
type_convert
<
ComputeDataType
>
(
rotary_cos_tile
.
thread_buf_
[
idx
/
2
]);
const
auto
sin
=
type_convert
<
ComputeDataType
>
(
rotary_sin_tile
.
thread_buf_
[
idx
/
2
]);
tile
.
thread_buf_
[
idx
]
=
type_convert
<
DataType
>
(
left
*
cos
-
right
*
sin
);
tile
.
thread_buf_
[
idx
+
1
]
=
type_convert
<
DataType
>
(
right
*
cos
+
left
*
sin
);
});
}
}
else
if
constexpr
(
RotaryEnum
==
RotaryEmbeddingEnum
::
HALF_ROTATED
)
{
if
(
thread_end
<=
rotary_dim
)
{
const
bool
is_left
=
(
thread_end
<=
(
rotary_dim
/
2
));
move_tile_window
(
other_window
,
{
0
,
is_left
?
rotary_dim
/
2
:
-
(
rotary_dim
/
2
)});
auto
other_tile
=
load_tile
(
other_window
);
move_tile_window
(
rotary_cos_window
,
{
0
,
is_left
?
0
:
-
(
rotary_dim
/
2
)});
auto
rotary_cos_tile
=
load_tile
(
rotary_cos_window
);
move_tile_window
(
rotary_sin_window
,
{
0
,
is_left
?
0
:
-
(
rotary_dim
/
2
)});
auto
rotary_sin_tile
=
load_tile
(
rotary_sin_window
);
constexpr
index_t
thread_buffer_size
=
decltype
(
tile
.
thread_buf_
)
::
size
();
static_for
<
0
,
thread_buffer_size
,
1
>
{}([
&
](
auto
idx
)
{
const
auto
curr
=
type_convert
<
ComputeDataType
>
(
tile
.
thread_buf_
[
idx
]);
const
auto
other
=
type_convert
<
ComputeDataType
>
(
other_tile
.
thread_buf_
[
idx
]);
const
auto
cos
=
type_convert
<
ComputeDataType
>
(
rotary_cos_tile
.
thread_buf_
[
idx
]);
const
auto
sin
=
type_convert
<
ComputeDataType
>
(
rotary_sin_tile
.
thread_buf_
[
idx
]);
tile
.
thread_buf_
[
idx
]
=
type_convert
<
DataType
>
(
curr
*
cos
+
other
*
(
is_left
?
-
sin
:
sin
));
});
}
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/block/page_block_navigator.hpp
0 → 100644
View file @
72c9f129
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
namespace
ck_tile
{
// assume that we have only 1 page-block/tensor view
template
<
typename
TensorView
>
struct
TrivialPageBlockNavigator
{
using
DataType
=
typename
TensorView
::
DataType
;
using
WindowOrigin
=
multi_index
<
2
>
;
CK_TILE_HOST_DEVICE
constexpr
TrivialPageBlockNavigator
(
const
TensorView
&
tensor_view_
)
:
tensor_view
(
tensor_view_
)
{
}
template
<
typename
WindowLengths
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_tile_window
(
const
WindowLengths
&
window_lengths
,
const
WindowOrigin
&
window_origin
)
const
{
return
make_tuple
(
/*block_index=*/
0
,
ck_tile
::
make_tile_window
(
tensor_view
,
window_lengths
,
window_origin
));
}
template
<
typename
WindowLengths
,
typename
TileDistribution
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_tile_window
(
const
WindowLengths
&
window_lengths
,
const
WindowOrigin
&
window_origin
,
const
TileDistribution
&
tile_distribution
)
const
{
return
make_tuple
(
/*block_index=*/
0
,
ck_tile
::
make_tile_window
(
tensor_view
,
window_lengths
,
window_origin
,
tile_distribution
));
}
template
<
typename
TileWindow
>
CK_TILE_HOST_DEVICE
static
index_t
move_tile_window
(
index_t
/*block_index*/
,
TileWindow
&
tile_window
,
const
typename
remove_cvref_t
<
TileWindow
>::
BottomTensorIndex
&
step
)
{
ck_tile
::
move_tile_window
(
tile_window
,
step
);
return
/*block_index=*/
0
;
}
CK_TILE_HOST_DEVICE
static
constexpr
WindowOrigin
to_local_window_origin
(
const
WindowOrigin
&
global_window_origin
)
{
return
global_window_origin
;
}
CK_TILE_HOST_DEVICE
static
constexpr
WindowOrigin
to_global_window_origin
(
index_t
/*block_index*/
,
const
WindowOrigin
&
local_window_origin
)
{
return
local_window_origin
;
}
private:
TensorView
tensor_view
;
};
// default page-block navigator, assume that tensor view size is same as page-block size or smaller
// if tile window on last page-block
template
<
typename
DataType_
,
index_t
VirtualDim
,
typename
TensorView
>
struct
PageBlockNavigator
{
using
DataType
=
DataType_
;
static_assert
(
std
::
is_same_v
<
DataType
,
typename
TensorView
::
DataType
>
);
static_assert
(
VirtualDim
==
0
||
VirtualDim
==
1
,
"only support 2d tile window"
);
using
WindowOrigin
=
multi_index
<
2
>
;
CK_TILE_HOST_DEVICE
constexpr
PageBlockNavigator
(
copy_const_t
<
DataType
,
void
>*
physical_blocks_
,
long_index_t
block_stride_
,
long_index_t
fixed_offset_
,
const
int32_t
*
physical_block_indices_
,
index_t
num_blocks_
,
index_t
page_block_size_
,
const
TensorView
&
complete_view_
,
const
TensorView
&
last_view_
)
:
physical_blocks
(
reinterpret_cast
<
DataType
*>
(
physical_blocks_
)),
block_stride
(
block_stride_
),
fixed_offset
(
fixed_offset_
),
physical_block_indices
(
physical_block_indices_
),
num_blocks
(
num_blocks_
),
page_block_size
(
page_block_size_
),
complete_view
(
complete_view_
),
last_view
(
last_view_
)
{
}
template
<
typename
WindowLengths
>
CK_TILE_HOST_DEVICE
auto
make_tile_window
(
const
WindowLengths
&
window_lengths
,
const
WindowOrigin
&
window_origin
)
const
{
const
index_t
block_index
=
get_block_index
(
window_origin
);
const
WindowOrigin
local_window_origin
=
to_local_window_origin
(
window_origin
);
auto
new_tile_window
=
ck_tile
::
make_tile_window
(
is_last_block
(
block_index
)
?
last_view
:
complete_view
,
window_lengths
,
local_window_origin
);
new_tile_window
.
set_bottom_tensor_view_data_ptr
(
get_block_ptr
(
block_index
));
return
make_tuple
(
block_index
,
new_tile_window
);
}
template
<
typename
WindowLengths
,
typename
TileDistribution
>
CK_TILE_HOST_DEVICE
auto
make_tile_window
(
const
WindowLengths
&
window_lengths
,
const
WindowOrigin
&
window_origin
,
const
TileDistribution
&
tile_distribution
)
const
{
const
index_t
block_index
=
get_block_index
(
window_origin
);
const
WindowOrigin
local_window_origin
=
to_local_window_origin
(
window_origin
);
auto
new_tile_window
=
ck_tile
::
make_tile_window
(
is_last_block
(
block_index
)
?
last_view
:
complete_view
,
window_lengths
,
local_window_origin
,
tile_distribution
);
new_tile_window
.
set_bottom_tensor_view_data_ptr
(
get_block_ptr
(
block_index
));
return
make_tuple
(
block_index
,
new_tile_window
);
}
template
<
typename
TileWindow
>
CK_TILE_HOST_DEVICE
index_t
move_tile_window
(
index_t
block_index
,
TileWindow
&
tile_window
,
const
typename
remove_cvref_t
<
TileWindow
>::
BottomTensorIndex
&
step
)
const
{
ck_tile
::
move_tile_window
(
tile_window
,
step
);
const
WindowOrigin
global_window_origin
=
to_global_window_origin
(
block_index
,
tile_window
.
get_window_origin
());
const
WindowOrigin
local_window_origin
=
to_local_window_origin
(
global_window_origin
);
const
index_t
new_block_index
=
get_block_index
(
global_window_origin
);
/// TODO: only update necessary attributes
tile_window
.
bottom_tensor_view_
.
desc_
=
(
is_last_block
(
new_block_index
)
?
last_view
:
complete_view
).
get_tensor_descriptor
();
tile_window
.
set_window_origin
(
local_window_origin
);
tile_window
.
set_bottom_tensor_view_data_ptr
(
get_block_ptr
(
new_block_index
));
return
new_block_index
;
}
CK_TILE_HOST_DEVICE
bool
is_last_block
(
index_t
block_index
)
const
{
return
block_index
==
num_blocks
-
1
;
}
template
<
typename
TileWindow
>
CK_TILE_HOST_DEVICE
bool
is_cross_block
(
index_t
block_index
,
const
TileWindow
&
tile_window
)
const
{
const
index_t
origin
=
tile_window
.
get_window_origin
().
at
(
number
<
VirtualDim
>
{});
const
index_t
length
=
tile_window
.
get_window_lengths
().
at
(
number
<
VirtualDim
>
{});
return
(
block_index
<
num_blocks
-
1
)
&&
(
page_block_size
<
origin
+
length
);
}
template
<
typename
TileWindow
>
CK_TILE_HOST_DEVICE
void
move_to_block
(
index_t
block_index
,
TileWindow
&
tile_window
,
index_t
new_block_index
)
const
{
const
multi_index
<
2
>
step
=
[
&
]()
{
const
index_t
origin_diff
=
(
block_index
-
new_block_index
)
*
page_block_size
;
if
constexpr
(
VirtualDim
==
0
)
{
return
make_multi_index
(
origin_diff
,
0
);
}
else
{
return
make_multi_index
(
0
,
origin_diff
);
}
}();
/// TODO: only update necessary attributes
tile_window
.
bottom_tensor_view_
.
desc_
=
(
is_last_block
(
new_block_index
)
?
last_view
:
complete_view
).
get_tensor_descriptor
();
tile_window
.
set_window_origin
(
tile_window
.
get_window_origin
()
+
step
);
tile_window
.
set_bottom_tensor_view_data_ptr
(
get_block_ptr
(
new_block_index
));
}
CK_TILE_HOST_DEVICE
WindowOrigin
to_local_window_origin
(
const
WindowOrigin
&
global_window_origin
)
const
{
if
constexpr
(
VirtualDim
==
0
)
{
const
index_t
length
=
global_window_origin
.
at
(
number
<
0
>
{});
const
index_t
num_complete_blocks
=
integer_divide_floor
(
length
,
page_block_size
);
return
make_multi_index
(
length
-
page_block_size
*
num_complete_blocks
,
global_window_origin
.
at
(
number
<
1
>
{}));
}
else
{
const
index_t
length
=
global_window_origin
.
at
(
number
<
1
>
{});
const
index_t
num_complete_blocks
=
integer_divide_floor
(
length
,
page_block_size
);
return
make_multi_index
(
global_window_origin
.
at
(
number
<
0
>
{}),
length
-
page_block_size
*
num_complete_blocks
);
}
}
CK_TILE_HOST_DEVICE
WindowOrigin
to_global_window_origin
(
index_t
block_index
,
const
WindowOrigin
&
local_window_origin
)
const
{
if
constexpr
(
VirtualDim
==
0
)
{
return
make_multi_index
(
block_index
*
page_block_size
+
local_window_origin
.
at
(
number
<
0
>
{}),
local_window_origin
.
at
(
number
<
1
>
{}));
}
else
{
return
make_multi_index
(
local_window_origin
.
at
(
number
<
0
>
{}),
block_index
*
page_block_size
+
local_window_origin
.
at
(
number
<
1
>
{}));
}
}
private:
CK_TILE_HOST_DEVICE
DataType
*
get_block_ptr
(
index_t
block_index
)
const
{
return
physical_blocks
+
physical_block_indices
[
block_index
]
*
block_stride
+
fixed_offset
;
}
CK_TILE_HOST_DEVICE
int32_t
get_block_index
(
const
WindowOrigin
&
global_window_origin
)
const
{
return
integer_divide_floor
(
global_window_origin
.
at
(
number
<
VirtualDim
>
{}),
page_block_size
);
}
DataType
*
physical_blocks
;
long_index_t
block_stride
;
long_index_t
fixed_offset
;
const
int32_t
*
physical_block_indices
;
index_t
num_blocks
;
index_t
page_block_size
;
TensorView
complete_view
;
TensorView
last_view
;
};
template
<
typename
TensorView
>
CK_TILE_HOST_DEVICE
auto
make_page_block_navigator
(
const
TensorView
&
tensor_view
)
{
return
TrivialPageBlockNavigator
<
TensorView
>
(
tensor_view
);
}
template
<
typename
DataType
,
index_t
VirtualDim
,
typename
TensorView
>
CK_TILE_HOST_DEVICE
auto
make_page_block_navigator
(
copy_const_t
<
DataType
,
void
>*
physical_blocks
,
long_index_t
block_stride
,
long_index_t
fixed_offset
,
const
int32_t
*
physical_block_indices
,
index_t
num_blocks
,
index_t
page_block_size
,
const
TensorView
&
complete_view
,
const
TensorView
&
last_view
)
{
return
PageBlockNavigator
<
DataType
,
VirtualDim
,
TensorView
>
(
physical_blocks
,
block_stride
,
fixed_offset
,
physical_block_indices
,
num_blocks
,
page_block_size
,
complete_view
,
last_view
);
}
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
View file @
72c9f129
...
@@ -23,13 +23,9 @@
...
@@ -23,13 +23,9 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
template
<
typename
FmhaPipeline_
,
typename
KGradEpiloguePipeline_
,
typename
VGradEpiloguePipeline_
>
typename
FmhaPipeline_
,
typename
KGradEpiloguePipeline_
,
typename
VGradEpiloguePipeline_
>
struct
FmhaBwdDQDKDVKernel
struct
FmhaBwdDQDKDVKernel
{
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
ck_tile
::
remove_cvref_t
<
FmhaPipeline_
>
;
using
FmhaPipeline
=
ck_tile
::
remove_cvref_t
<
FmhaPipeline_
>
;
using
KGradEpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
KGradEpiloguePipeline_
>
;
using
KGradEpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
KGradEpiloguePipeline_
>
;
using
VGradEpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
VGradEpiloguePipeline_
>
;
using
VGradEpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
VGradEpiloguePipeline_
>
;
...
@@ -59,9 +55,12 @@ struct FmhaBwdDQDKDVKernel
...
@@ -59,9 +55,12 @@ struct FmhaBwdDQDKDVKernel
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
FmhaPipeline
::
kHasBiasGrad
;
static
constexpr
bool
kHasBiasGrad
=
FmhaPipeline
::
kHasBiasGrad
;
static
constexpr
bool
kHasDropout
=
FmhaPipeline
::
kHasDropout
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
using
FmhaDropout
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaDropout
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
static
constexpr
bool
kHasDropout
=
FmhaDropout
::
IsDropout
;
static
constexpr
bool
kIsStoreRandval
=
FmhaDropout
::
IsStoreRandval
;
static
constexpr
bool
kIsDeterministic
=
FmhaPipeline
::
kIsDeterministic
;
// clang-format off
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
typename
T
>
struct
t2s
;
...
@@ -73,9 +72,12 @@ struct FmhaBwdDQDKDVKernel
...
@@ -73,9 +72,12 @@ struct FmhaBwdDQDKDVKernel
{
{
// sync with generate.py
// sync with generate.py
// clang-format off
// clang-format off
using
bfs
=
typename
FmhaPipeline
::
BlockFmhaShape
;
using
bfs
=
typename
FmhaPipeline
::
BlockFmhaShape
;
using
gbr
=
typename
bfs
::
Gemm0BlockWarps
;
using
gbr0
=
typename
bfs
::
Gemm0BlockWarps
;
using
gwt
=
typename
bfs
::
Gemm0WarpTile
;
using
gbr1
=
typename
bfs
::
Gemm1BlockWarps
;
using
gbr4
=
typename
bfs
::
Gemm4BlockWarps
;
using
gwt0
=
typename
bfs
::
Gemm0WarpTile
;
using
gwt1
=
typename
bfs
::
Gemm1WarpTile
;
#define _SS_ std::string
#define _SS_ std::string
#define _TS_ std::to_string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
auto
pn
=
[
&
]
()
{
...
@@ -88,13 +90,17 @@ struct FmhaBwdDQDKDVKernel
...
@@ -88,13 +90,17 @@ struct FmhaBwdDQDKDVKernel
return
return
_SS_
(
"fmha_bwd_d"
)
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
_SS_
(
"fmha_bwd_d"
)
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
+
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kK2
)
+
"x"
+
_TS_
(
bfs
::
kK3
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"x"
+
_TS_
(
bfs
::
kVHeaddim
)
+
"_"
+
_TS_
(
bfs
::
kK4
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"x"
+
_TS_
(
bfs
::
kVHeaddim
)
+
"_"
+
"r"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
gbr0
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr0
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr0
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
gbr1
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr1
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr1
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
gbr4
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr4
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr4
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt0
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt0
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt0
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt1
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt1
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt1
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
)
+
_SS_
(
FmhaPipeline
::
name
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
)
+
_SS_
(
FmhaPipeline
::
name
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasBiasGrad
?
"_dbias"
:
""
)
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
);
(
kHasBiasGrad
?
"_dbias"
:
""
)
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kIsStoreRandval
?
"_storerandval"
:
""
)
+
(
kIsDeterministic
?
"_deterministic"
:
""
);
#undef _SS_
#undef _SS_
#undef _TS_
#undef _TS_
// clang-format on
// clang-format on
...
@@ -117,7 +123,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -117,7 +123,7 @@ struct FmhaBwdDQDKDVKernel
const
void
*
lse_ptr
;
const
void
*
lse_ptr
;
const
void
*
do_ptr
;
const
void
*
do_ptr
;
const
void
*
d_ptr
;
const
void
*
d_ptr
;
void
*
dq_ptr
;
void
*
dq_
acc_
ptr
;
void
*
dk_ptr
;
void
*
dk_ptr
;
void
*
dv_ptr
;
void
*
dv_ptr
;
...
@@ -131,14 +137,13 @@ struct FmhaBwdDQDKDVKernel
...
@@ -131,14 +137,13 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
num_head_q
;
ck_tile
::
index_t
num_head_q
;
ck_tile
::
index_t
nhead_ratio_qk
;
ck_tile
::
index_t
nhead_ratio_qk
;
float
raw_scale
;
float
raw_scale
;
#if CK_TILE_FMHA_FWD_FAST_EXP2
float
scale
;
float
scale
;
#endif
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_do
;
ck_tile
::
index_t
stride_do
;
ck_tile
::
index_t
stride_dq_acc
;
ck_tile
::
index_t
stride_dk
;
ck_tile
::
index_t
stride_dk
;
ck_tile
::
index_t
stride_dv
;
ck_tile
::
index_t
stride_dv
;
...
@@ -147,8 +152,9 @@ struct FmhaBwdDQDKDVKernel
...
@@ -147,8 +152,9 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_do
;
ck_tile
::
index_t
nhead_stride_do
;
ck_tile
::
index_t
nhead_stride_lsed
;
ck_tile
::
index_t
nhead_stride_lsed
;
ck_tile
::
index_t
nhead_stride_dq_acc
;
ck_tile
::
index_t
batch_stride_lsed
;
ck_tile
::
index_t
nhead_stride_dk
;
ck_tile
::
index_t
nhead_stride_dv
;
};
};
struct
FmhaBwdCommonBiasKargs
struct
FmhaBwdCommonBiasKargs
...
@@ -206,7 +212,6 @@ struct FmhaBwdDQDKDVKernel
...
@@ -206,7 +212,6 @@ struct FmhaBwdDQDKDVKernel
float
rp_undrop
=
1
;
float
rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
bool
is_store_randval
=
false
;
uint64_t
drop_seed
=
1
;
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
void
*
rand_val_ptr
=
nullptr
;
...
@@ -218,6 +223,10 @@ struct FmhaBwdDQDKDVKernel
...
@@ -218,6 +223,10 @@ struct FmhaBwdDQDKDVKernel
{
{
ck_tile
::
index_t
batch_stride_randval
=
0
;
ck_tile
::
index_t
batch_stride_randval
=
0
;
};
};
struct
FmhaBwdDeterministicKargs
{
ck_tile
::
index_t
split_stride_dq_acc
=
0
;
};
struct
FmhaBwdBatchModeKargs
struct
FmhaBwdBatchModeKargs
:
FmhaBwdCommonKargs
,
:
FmhaBwdCommonKargs
,
...
@@ -228,12 +237,15 @@ struct FmhaBwdDQDKDVKernel
...
@@ -228,12 +237,15 @@ struct FmhaBwdDQDKDVKernel
FmhaBwdEmptyKargs
<
0
>>>
,
FmhaBwdEmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasBiasGrad
,
FmhaBwdBatchModeBiasGradKargs
,
FmhaBwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kHasBiasGrad
,
FmhaBwdBatchModeBiasGradKargs
,
FmhaBwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kHasMask
,
FmhaBwdMaskKargs
,
FmhaBwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kHasMask
,
FmhaBwdMaskKargs
,
FmhaBwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kHasDropout
,
FmhaBwdBatchModeDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
std
::
conditional_t
<
kHasDropout
,
FmhaBwdBatchModeDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
,
std
::
conditional_t
<
kIsDeterministic
,
FmhaBwdDeterministicKargs
,
FmhaBwdEmptyKargs
<
4
>>
{
{
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_lsed
;
ck_tile
::
index_t
batch_stride_dq_acc
;
ck_tile
::
index_t
batch_stride_dk
;
ck_tile
::
index_t
batch_stride_dk
;
ck_tile
::
index_t
batch_stride_dv
;
ck_tile
::
index_t
batch_stride_dv
;
};
};
...
@@ -247,7 +259,8 @@ struct FmhaBwdDQDKDVKernel
...
@@ -247,7 +259,8 @@ struct FmhaBwdDQDKDVKernel
FmhaBwdEmptyKargs
<
0
>>>
,
FmhaBwdEmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasBiasGrad
,
FmhaBwdCommonBiasGradKargs
,
FmhaBwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kHasBiasGrad
,
FmhaBwdCommonBiasGradKargs
,
FmhaBwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kHasMask
,
FmhaBwdMaskKargs
,
FmhaBwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kHasMask
,
FmhaBwdMaskKargs
,
FmhaBwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kHasDropout
,
FmhaBwdCommonDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
std
::
conditional_t
<
kHasDropout
,
FmhaBwdCommonDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
,
std
::
conditional_t
<
kIsDeterministic
,
FmhaBwdDeterministicKargs
,
FmhaBwdEmptyKargs
<
4
>>
{
{
const
int32_t
*
seqstart_q_ptr
;
const
int32_t
*
seqstart_q_ptr
;
const
int32_t
*
seqstart_k_ptr
;
const
int32_t
*
seqstart_k_ptr
;
...
@@ -266,10 +279,10 @@ struct FmhaBwdDQDKDVKernel
...
@@ -266,10 +279,10 @@ struct FmhaBwdDQDKDVKernel
const
void
*
do_ptr
,
const
void
*
do_ptr
,
const
void
*
d_ptr
,
const
void
*
d_ptr
,
void
*
rand_val_ptr
,
void
*
rand_val_ptr
,
void
*
dq_ptr
,
void
*
dk_ptr
,
void
*
dk_ptr
,
void
*
dv_ptr
,
void
*
dv_ptr
,
void
*
dbias_ptr
,
void
*
dbias_ptr
,
void
*
dq_acc_ptr
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_k
,
ck_tile
::
index_t
seqlen_k
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_q
,
...
@@ -283,6 +296,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -283,6 +296,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
stride_dk
,
ck_tile
::
index_t
stride_dk
,
ck_tile
::
index_t
stride_dv
,
ck_tile
::
index_t
stride_dv
,
ck_tile
::
index_t
stride_dbias
,
ck_tile
::
index_t
stride_dbias
,
...
@@ -293,6 +307,9 @@ struct FmhaBwdDQDKDVKernel
...
@@ -293,6 +307,9 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dk
,
ck_tile
::
index_t
nhead_stride_dv
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
batch_stride_q
,
ck_tile
::
index_t
batch_stride_q
,
ck_tile
::
index_t
batch_stride_k
,
ck_tile
::
index_t
batch_stride_k
,
...
@@ -301,14 +318,15 @@ struct FmhaBwdDQDKDVKernel
...
@@ -301,14 +318,15 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
batch_stride_randval
,
ck_tile
::
index_t
batch_stride_randval
,
ck_tile
::
index_t
batch_stride_do
,
ck_tile
::
index_t
batch_stride_do
,
ck_tile
::
index_t
batch_stride_lsed
,
ck_tile
::
index_t
batch_stride_lsed
,
ck_tile
::
index_t
batch_stride_dq_acc
,
ck_tile
::
index_t
batch_stride_dk
,
ck_tile
::
index_t
batch_stride_dk
,
ck_tile
::
index_t
batch_stride_dv
,
ck_tile
::
index_t
batch_stride_dv
,
ck_tile
::
index_t
batch_stride_dbias
,
ck_tile
::
index_t
batch_stride_dbias
,
ck_tile
::
index_t
split_stride_dq_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
{
Kargs
kargs
{{
q_ptr
,
Kargs
kargs
{{
q_ptr
,
...
@@ -317,7 +335,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -317,7 +335,7 @@ struct FmhaBwdDQDKDVKernel
lse_ptr
,
lse_ptr
,
do_ptr
,
do_ptr
,
d_ptr
,
d_ptr
,
dq_ptr
,
dq_
acc_
ptr
,
dk_ptr
,
dk_ptr
,
dv_ptr
,
dv_ptr
,
seqlen_q
,
seqlen_q
,
...
@@ -327,13 +345,12 @@ struct FmhaBwdDQDKDVKernel
...
@@ -327,13 +345,12 @@ struct FmhaBwdDQDKDVKernel
num_head_q
,
num_head_q
,
nhead_ratio_qk
,
nhead_ratio_qk
,
scale
,
scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast
<
float
>
(
scale
*
ck_tile
::
log2e_v
<>
),
static_cast
<
float
>
(
scale
*
ck_tile
::
log2e_v
<>
),
#endif
stride_q
,
stride_q
,
stride_k
,
stride_k
,
stride_v
,
stride_v
,
stride_do
,
stride_do
,
stride_dq_acc
,
stride_dk
,
stride_dk
,
stride_dv
,
stride_dv
,
nhead_stride_q
,
nhead_stride_q
,
...
@@ -341,15 +358,20 @@ struct FmhaBwdDQDKDVKernel
...
@@ -341,15 +358,20 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_v
,
nhead_stride_v
,
nhead_stride_do
,
nhead_stride_do
,
nhead_stride_lsed
,
nhead_stride_lsed
,
batch_stride_lsed
},
// args for common karg
nhead_stride_dq_acc
,
{},
// placeholder for bias
nhead_stride_dk
,
{},
// placeholder for dbias
nhead_stride_dv
},
// args for common karg
{},
// placeholder for mask
{},
// placeholder for bias
{},
// placeholder for dropout
{},
// placeholder for dbias
{},
// placeholder for mask
{},
// placeholder for dropout
{},
// placeholder for deterministic
batch_stride_q
,
batch_stride_q
,
batch_stride_k
,
batch_stride_k
,
batch_stride_v
,
batch_stride_v
,
batch_stride_do
,
batch_stride_do
,
batch_stride_lsed
,
batch_stride_dq_acc
,
batch_stride_dk
,
batch_stride_dk
,
batch_stride_dv
};
batch_stride_dv
};
...
@@ -384,11 +406,18 @@ struct FmhaBwdDQDKDVKernel
...
@@ -384,11 +406,18 @@ struct FmhaBwdDQDKDVKernel
if
constexpr
(
kHasDropout
)
if
constexpr
(
kHasDropout
)
{
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
kargs
.
rand_val_ptr
=
rand_val_ptr
;
if
constexpr
(
kIsStoreRandval
)
kargs
.
stride_randval
=
stride_randval
;
{
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
}
}
if
constexpr
(
kIsDeterministic
)
{
kargs
.
split_stride_dq_acc
=
split_stride_dq_acc
;
}
}
return
kargs
;
return
kargs
;
...
@@ -404,10 +433,10 @@ struct FmhaBwdDQDKDVKernel
...
@@ -404,10 +433,10 @@ struct FmhaBwdDQDKDVKernel
const
void
*
do_ptr
,
const
void
*
do_ptr
,
const
void
*
d_ptr
,
const
void
*
d_ptr
,
void
*
rand_val_ptr
,
void
*
rand_val_ptr
,
void
*
dq_ptr
,
void
*
dk_ptr
,
void
*
dk_ptr
,
void
*
dv_ptr
,
void
*
dv_ptr
,
void
*
dbias_ptr
,
void
*
dbias_ptr
,
void
*
dq_acc_ptr
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_k_ptr
,
const
void
*
seqstart_k_ptr
,
const
void
*
seqlen_k_ptr
,
const
void
*
seqlen_k_ptr
,
...
@@ -422,6 +451,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -422,6 +451,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
stride_dk
,
ck_tile
::
index_t
stride_dk
,
ck_tile
::
index_t
stride_dv
,
ck_tile
::
index_t
stride_dv
,
ck_tile
::
index_t
stride_dbias
,
ck_tile
::
index_t
stride_dbias
,
...
@@ -432,13 +462,15 @@ struct FmhaBwdDQDKDVKernel
...
@@ -432,13 +462,15 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dk
,
ck_tile
::
index_t
nhead_stride_dv
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
batch
_stride_
lsed
,
ck_tile
::
index_t
split
_stride_
dq_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
{
Kargs
kargs
{{
q_ptr
,
Kargs
kargs
{{
q_ptr
,
...
@@ -447,7 +479,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -447,7 +479,7 @@ struct FmhaBwdDQDKDVKernel
lse_ptr
,
lse_ptr
,
do_ptr
,
do_ptr
,
d_ptr
,
d_ptr
,
dq_ptr
,
dq_
acc_
ptr
,
dk_ptr
,
dk_ptr
,
dv_ptr
,
dv_ptr
,
-
1
,
// seqlen will be updated by another pointer
-
1
,
// seqlen will be updated by another pointer
...
@@ -457,13 +489,12 @@ struct FmhaBwdDQDKDVKernel
...
@@ -457,13 +489,12 @@ struct FmhaBwdDQDKDVKernel
num_head_q
,
num_head_q
,
nhead_ratio_qk
,
nhead_ratio_qk
,
scale
,
scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast
<
float
>
(
scale
*
ck_tile
::
log2e_v
<>
),
static_cast
<
float
>
(
scale
*
ck_tile
::
log2e_v
<>
),
#endif
stride_q
,
stride_q
,
stride_k
,
stride_k
,
stride_v
,
stride_v
,
stride_do
,
stride_do
,
stride_dq_acc
,
stride_dk
,
stride_dk
,
stride_dv
,
stride_dv
,
nhead_stride_q
,
nhead_stride_q
,
...
@@ -471,11 +502,14 @@ struct FmhaBwdDQDKDVKernel
...
@@ -471,11 +502,14 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_v
,
nhead_stride_v
,
nhead_stride_do
,
nhead_stride_do
,
nhead_stride_lsed
,
nhead_stride_lsed
,
batch_stride_lsed
},
// args for common karg
nhead_stride_dq_acc
,
{},
// placeholder for bias
nhead_stride_dk
,
{},
// placeholder for dbias
nhead_stride_dv
},
// args for common karg
{},
// placeholder for mask
{},
// placeholder for bias
{},
// placeholder for dropout
{},
// placeholder for dbias
{},
// placeholder for mask
{},
// placeholder for dropout
{},
// placeholder for deterministic
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
)};
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
)};
...
@@ -506,10 +540,16 @@ struct FmhaBwdDQDKDVKernel
...
@@ -506,10 +540,16 @@ struct FmhaBwdDQDKDVKernel
if
constexpr
(
kHasDropout
)
if
constexpr
(
kHasDropout
)
{
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
kargs
.
rand_val_ptr
=
rand_val_ptr
;
if
constexpr
(
kIsStoreRandval
)
kargs
.
stride_randval
=
stride_randval
;
{
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
is_store_randval
=
s_randval
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
}
}
if
constexpr
(
kIsDeterministic
)
{
kargs
.
split_stride_dq_acc
=
split_stride_dq_acc
;
}
}
return
kargs
;
return
kargs
;
...
@@ -518,7 +558,17 @@ struct FmhaBwdDQDKDVKernel
...
@@ -518,7 +558,17 @@ struct FmhaBwdDQDKDVKernel
CK_TILE_HOST
static
constexpr
auto
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_k_
)
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_k_
)
{
{
return
TilePartitioner
::
GridSize
(
batch_size_
,
nhead_
,
seqlen_k_
);
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_k_
,
FmhaPipeline
::
kN0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
()
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
@@ -536,7 +586,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -536,7 +586,7 @@ struct FmhaBwdDQDKDVKernel
__shared__
char
smem_ptr
[
GetSmemSize
()];
__shared__
char
smem_ptr
[
GetSmemSize
()];
// divide problem
// divide problem
const
auto
[
i_tile_n
,
i_nhead
,
i_batch
]
=
Tile
Partitioner
{}(
kargs
.
seqlen_k
);
const
auto
[
i_tile_n
,
i_nhead
,
i_batch
]
=
Get
Tile
Index
(
);
const
index_t
i_n0
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN0
);
const
index_t
i_n0
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN0
);
...
@@ -547,6 +597,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -547,6 +597,7 @@ struct FmhaBwdDQDKDVKernel
long_index_t
batch_offset_randval
=
0
;
long_index_t
batch_offset_randval
=
0
;
long_index_t
batch_offset_do
=
0
;
long_index_t
batch_offset_do
=
0
;
long_index_t
batch_offset_lsed
=
0
;
long_index_t
batch_offset_lsed
=
0
;
long_index_t
batch_offset_dq_acc
=
0
;
long_index_t
batch_offset_dk
=
0
;
long_index_t
batch_offset_dk
=
0
;
long_index_t
batch_offset_dv
=
0
;
long_index_t
batch_offset_dv
=
0
;
long_index_t
batch_offset_dbias
=
0
;
long_index_t
batch_offset_dbias
=
0
;
...
@@ -557,13 +608,14 @@ struct FmhaBwdDQDKDVKernel
...
@@ -557,13 +608,14 @@ struct FmhaBwdDQDKDVKernel
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
const
long_index_t
key_start
=
kargs
.
seqstart_k_ptr
[
i_batch
];
const
long_index_t
key_start
=
kargs
.
seqstart_k_ptr
[
i_batch
];
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
batch_offset_do
=
query_start
*
kargs
.
stride_do
;
batch_offset_do
=
query_start
*
kargs
.
stride_do
;
batch_offset_lsed
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lsed
;
batch_offset_lsed
=
query_start
;
batch_offset_dk
=
key_start
*
kargs
.
stride_dk
;
batch_offset_dq_acc
=
query_start
*
kargs
.
stride_dq_acc
;
batch_offset_dv
=
key_start
*
kargs
.
stride_dv
;
batch_offset_dk
=
key_start
*
kargs
.
stride_dk
;
batch_offset_dv
=
key_start
*
kargs
.
stride_dv
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
;
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
;
...
@@ -576,7 +628,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -576,7 +628,7 @@ struct FmhaBwdDQDKDVKernel
{
{
batch_offset_dbias
=
key_start
;
batch_offset_dbias
=
key_start
;
}
}
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
{
batch_offset_randval
=
query_start
*
kargs
.
stride_randval
;
batch_offset_randval
=
query_start
*
kargs
.
stride_randval
;
}
}
...
@@ -603,13 +655,14 @@ struct FmhaBwdDQDKDVKernel
...
@@ -603,13 +655,14 @@ struct FmhaBwdDQDKDVKernel
}
}
else
else
{
{
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_do
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_do
;
batch_offset_do
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_do
;
batch_offset_lsed
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lsed
;
batch_offset_lsed
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lsed
;
batch_offset_dk
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dk
;
batch_offset_dq_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dq_acc
;
batch_offset_dv
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dv
;
batch_offset_dk
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dk
;
batch_offset_dv
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dv
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
batch_offset_bias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_bias
;
batch_offset_bias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_bias
;
...
@@ -618,7 +671,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -618,7 +671,7 @@ struct FmhaBwdDQDKDVKernel
{
{
batch_offset_dbias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dbias
;
batch_offset_dbias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dbias
;
}
}
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
{
batch_offset_randval
=
batch_offset_randval
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_randval
;
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_randval
;
...
@@ -646,14 +699,11 @@ struct FmhaBwdDQDKDVKernel
...
@@ -646,14 +699,11 @@ struct FmhaBwdDQDKDVKernel
const
OGradDataType
*
do_ptr
=
reinterpret_cast
<
const
OGradDataType
*>
(
kargs
.
do_ptr
)
+
const
OGradDataType
*
do_ptr
=
reinterpret_cast
<
const
OGradDataType
*>
(
kargs
.
do_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_do
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_do
+
batch_offset_do
;
batch_offset_do
;
QGradDataType
*
dq_ptr
=
reinterpret_cast
<
QGradDataType
*>
(
kargs
.
dq_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_q
+
batch_offset_q
;
KGradDataType
*
dk_ptr
=
reinterpret_cast
<
KGradDataType
*>
(
kargs
.
dk_ptr
)
+
KGradDataType
*
dk_ptr
=
reinterpret_cast
<
KGradDataType
*>
(
kargs
.
dk_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_k
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_
d
k
+
batch_offset_dk
;
batch_offset_dk
;
VGradDataType
*
dv_ptr
=
reinterpret_cast
<
VGradDataType
*>
(
kargs
.
dv_ptr
)
+
VGradDataType
*
dv_ptr
=
reinterpret_cast
<
VGradDataType
*>
(
kargs
.
dv_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_v
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_
d
v
+
batch_offset_dv
;
batch_offset_dv
;
// Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window
// Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window
...
@@ -663,45 +713,10 @@ struct FmhaBwdDQDKDVKernel
...
@@ -663,45 +713,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple
(
kargs
.
stride_q
,
1
),
make_tuple
(
kargs
.
stride_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQ
>
{},
number
<
FmhaPipeline
::
kAlignmentQ
>
{},
number
<
1
>
{});
number
<
1
>
{});
const
auto
q_dram
=
[
&
]()
{
const
auto
q_dram
=
pad_tensor_view
(
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
q_dram_naive
,
{
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
return
pad_tensor_view
(
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
else
{
return
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
}();
const
auto
qt_dram_naive
=
transform_tensor_view
(
q_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_q
),
make_pass_through_transform
(
kargs
.
seqlen_q
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
qt_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQTLoadOnce
)
{
return
pad_tensor_view
(
qt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenQ
>
{});
}
else
{
return
pad_tensor_view
(
qt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK3
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenQ
>
{});
}
}();
const
auto
k_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
k_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
k_ptr
,
k_ptr
,
...
@@ -709,45 +724,10 @@ struct FmhaBwdDQDKDVKernel
...
@@ -709,45 +724,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple
(
kargs
.
stride_k
,
1
),
make_tuple
(
kargs
.
stride_k
,
1
),
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
1
>
{});
number
<
1
>
{});
const
auto
k_dram
=
[
&
]()
{
const
auto
k_dram
=
pad_tensor_view
(
if
constexpr
(
FmhaPipeline
::
kKLoadOnce
)
k_dram_naive
,
{
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
return
pad_tensor_view
(
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}
else
{
return
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}
}();
const
auto
kt_dram_naive
=
transform_tensor_view
(
k_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_q
),
make_pass_through_transform
(
kargs
.
seqlen_k
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
kt_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKTLoadOnce
)
{
return
pad_tensor_view
(
kt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kN0
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenK
>
{});
}
else
{
return
pad_tensor_view
(
kt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK4
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenK
>
{});
}
}();
const
auto
v_dram
=
[
&
]()
{
const
auto
v_dram
=
[
&
]()
{
const
auto
v_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
v_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
...
@@ -756,20 +736,10 @@ struct FmhaBwdDQDKDVKernel
...
@@ -756,20 +736,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple
(
kargs
.
stride_v
,
1
),
make_tuple
(
kargs
.
stride_v
,
1
),
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
1
>
{});
number
<
1
>
{});
if
constexpr
(
FmhaPipeline
::
kVLoadOnce
)
return
pad_tensor_view
(
{
v_dram_naive
,
return
pad_tensor_view
(
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
v_dram_naive
,
sequence
<
kPadSeqLenK
,
kPadHeadDimV
>
{});
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimV
>
{});
}
else
{
return
pad_tensor_view
(
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK2
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimV
>
{});
}
}();
}();
const
auto
lse_dram
=
[
&
]()
{
const
auto
lse_dram
=
[
&
]()
{
...
@@ -792,145 +762,89 @@ struct FmhaBwdDQDKDVKernel
...
@@ -792,145 +762,89 @@ struct FmhaBwdDQDKDVKernel
make_tuple
(
kargs
.
stride_do
,
1
),
make_tuple
(
kargs
.
stride_do
,
1
),
number
<
FmhaPipeline
::
kAlignmentOGrad
>
{},
number
<
FmhaPipeline
::
kAlignmentOGrad
>
{},
number
<
1
>
{});
number
<
1
>
{});
const
auto
do_dram
=
[
&
]()
{
const
auto
do_dram
=
pad_tensor_view
(
if
constexpr
(
FmhaPipeline
::
kOGradLoadOnce
)
do_dram_naive
,
{
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
return
pad_tensor_view
(
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
do_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
}
else
{
return
pad_tensor_view
(
do_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK2
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
}
}();
const
auto
dot_dram_naive
=
transform_tensor_view
(
do_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_v
),
make_pass_through_transform
(
kargs
.
seqlen_q
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
dot_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradTLoadOnce
)
{
return
pad_tensor_view
(
dot_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenQ
>
{});
}
else
{
return
pad_tensor_view
(
dot_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenQ
>
{});
}
}();
auto
dq_dram
=
[
&
]()
{
const
auto
dq_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
,
memory_operation_enum
::
atomic_add
>
(
dq_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
auto
q_dram_window
=
make_tile_window
(
auto
q_dram_window
=
make_tile_window
(
q_dram
,
q_dram
,
[
&
]()
{
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
{
0
,
0
});
{
0
,
0
});
auto
qt_dram_window
=
make_tile_window
(
qt_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQTLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK3
>
{});
}(),
{
0
,
0
});
auto
k_dram_window
=
make_tile_window
(
auto
k_dram_window
=
make_tile_window
(
k_dram
,
k_dram
,
[
&
]()
{
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
if
constexpr
(
FmhaPipeline
::
kKLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
{
i_n0
,
0
});
{
i_n0
,
0
});
auto
kt_dram_window
=
make_tile_window
(
kt_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKTLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK4
>
{});
}(),
{
0
,
i_n0
});
auto
v_dram_window
=
make_tile_window
(
auto
v_dram_window
=
make_tile_window
(
v_dram
,
v_dram
,
[
&
]()
{
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
if
constexpr
(
FmhaPipeline
::
kVLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK2
>
{});
}(),
{
i_n0
,
0
});
{
i_n0
,
0
});
auto
do_dram_window
=
make_tile_window
(
auto
do_dram_window
=
make_tile_window
(
do_dram
,
do_dram
,
[
&
]()
{
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
if
constexpr
(
FmhaPipeline
::
kOGradLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK2
>
{});
}(),
{
0
,
0
});
{
0
,
0
});
auto
dot_dram_window
=
auto
dq_dram_window
=
[
&
,
i_tile_n_
=
i_tile_n
,
i_nhead_
=
i_nhead
]()
{
make_tile_window
(
dot_dram
,
if
constexpr
(
kIsDeterministic
)
[
&
]()
{
{
if
constexpr
(
FmhaPipeline
::
kOGradTLoadOnce
)
AccDataType
*
dq_acc_ptr
=
return
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
reinterpret_cast
<
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
number
<
FmhaPipeline
::
kM0
>
{});
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_dq_acc
+
else
static_cast
<
long_index_t
>
(
i_tile_n_
)
*
kargs
.
split_stride_dq_acc
+
return
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
batch_offset_dq_acc
;
number
<
FmhaPipeline
::
kK1
>
{});
}(),
auto
dq_acc_dram
=
[
&
]()
{
{
0
,
0
});
const
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
auto
dq_dram_window
=
make_tile_window
(
dq_acc_ptr
,
dq_dram
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
make_tuple
(
kargs
.
stride_dq_acc
,
1
),
{
0
,
0
});
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_acc_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
return
make_tile_window
(
dq_acc_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
0
,
0
});
}
else
{
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_dq_acc
+
batch_offset_dq_acc
;
auto
dq_acc_dram
=
[
&
]()
{
const
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
,
memory_operation_enum
::
atomic_add
>
(
dq_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_dq_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_acc_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
return
make_tile_window
(
dq_acc_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
0
,
0
});
}
}();
auto
lse_dram_window
=
auto
lse_dram_window
=
make_tile_window
(
lse_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{}),
{
0
});
make_tile_window
(
lse_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{}),
{
0
});
...
@@ -1008,9 +922,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -1008,9 +922,7 @@ struct FmhaBwdDQDKDVKernel
// TODO: how to use s_read?
// TODO: how to use s_read?
AccDataType
slope
=
*
(
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
alibi_slope_ptr
)
+
AccDataType
slope
=
*
(
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
alibi_slope_ptr
)
+
i_batch_
*
kargs
.
alibi_slope_stride
+
i_nhead_
);
i_batch_
*
kargs
.
alibi_slope_stride
+
i_nhead_
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
slope
*=
ck_tile
::
log2e_v
<>
;
slope
*=
ck_tile
::
log2e_v
<>
;
#endif
if
constexpr
(
kHasMask
)
if
constexpr
(
kHasMask
)
{
{
return
make_alibi_from_lr_mask
<
AccDataType
,
false
>
(
slope
,
return
make_alibi_from_lr_mask
<
AccDataType
,
false
>
(
slope
,
...
@@ -1033,35 +945,34 @@ struct FmhaBwdDQDKDVKernel
...
@@ -1033,35 +945,34 @@ struct FmhaBwdDQDKDVKernel
}();
}();
// dropout
// dropout
float
rp_undrop
=
1
;
float
rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint64_t
drop_seed
=
0
;
uint64_t
drop_offset
=
0
;
bool
is_store_randval
=
false
;
if
constexpr
(
kHasDropout
)
if
constexpr
(
kHasDropout
)
{
{
rp_undrop
=
kargs
.
rp_undrop
;
rp_undrop
=
kargs
.
rp_undrop
;
scale_rp_undrop
=
kargs
.
scale_rp_undrop
;
scale_rp_undrop
=
kargs
.
scale_rp_undrop
;
p_undrop_in_uint8_t
=
kargs
.
p_undrop_in_uint8_t
;
drop_seed
=
kargs
.
drop_seed
;
drop_offset
=
kargs
.
drop_offset
;
is_store_randval
=
kargs
.
is_store_randval
;
}
}
BlockDropout
dropout
(
i_batch
,
auto
dropout
=
[
&
,
i_nhead_
=
i_nhead
,
i_batch_
=
i_batch
]()
{
i_nhead
,
if
constexpr
(
kHasDropout
)
kargs
.
num_head_q
,
{
drop_seed
,
return
FmhaDropout
{
i_batch_
,
drop_offset
,
i_nhead_
,
rp_undrop
,
kargs
.
num_head_q
,
p_undrop_in_uint8_t
,
kargs
.
drop_seed
,
is_store_randval
);
kargs
.
drop_offset
,
kargs
.
rp_undrop
,
kargs
.
p_undrop_in_uint8_t
};
}
else
{
return
FmhaDropout
{};
};
}();
auto
randval_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
auto
randval_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
randval_dram_window_lengths
=
constexpr
auto
randval_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
{
RandValOutputDataType
*
rand_val_ptr
=
RandValOutputDataType
*
rand_val_ptr
=
reinterpret_cast
<
RandValOutputDataType
*>
(
kargs
.
rand_val_ptr
)
+
reinterpret_cast
<
RandValOutputDataType
*>
(
kargs
.
rand_val_ptr
)
+
...
@@ -1103,14 +1014,11 @@ struct FmhaBwdDQDKDVKernel
...
@@ -1103,14 +1014,11 @@ struct FmhaBwdDQDKDVKernel
}();
}();
auto
[
dk_acc_tile
,
dv_acc_tile
]
=
FmhaPipeline
{}(
q_dram_window
,
auto
[
dk_acc_tile
,
dv_acc_tile
]
=
FmhaPipeline
{}(
q_dram_window
,
qt_dram_window
,
k_dram_window
,
k_dram_window
,
kt_dram_window
,
v_dram_window
,
v_dram_window
,
bias_dram_window
,
bias_dram_window
,
randval_dram_window
,
randval_dram_window
,
do_dram_window
,
do_dram_window
,
dot_dram_window
,
lse_dram_window
,
lse_dram_window
,
d_dram_window
,
d_dram_window
,
dq_dram_window
,
dq_dram_window
,
...
@@ -1118,9 +1026,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -1118,9 +1026,7 @@ struct FmhaBwdDQDKDVKernel
mask
,
mask
,
position_encoding
,
position_encoding
,
kargs
.
raw_scale
,
kargs
.
raw_scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
kargs
.
scale
,
kargs
.
scale
,
#endif
rp_undrop
,
rp_undrop
,
scale_rp_undrop
,
scale_rp_undrop
,
smem_ptr
,
smem_ptr
,
...
@@ -1169,10 +1075,9 @@ struct FmhaBwdDQDKDVKernel
...
@@ -1169,10 +1075,9 @@ struct FmhaBwdDQDKDVKernel
}
}
};
};
template
<
typename
TilePartitioner_
,
typename
FmhaBwdOGradDotO_
>
template
<
typename
FmhaBwdOGradDotO_
>
struct
FmhaBwdOGradDotOKernel
struct
FmhaBwdOGradDotOKernel
{
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaBwdOGradDotO
=
ck_tile
::
remove_cvref_t
<
FmhaBwdOGradDotO_
>
;
using
FmhaBwdOGradDotO
=
ck_tile
::
remove_cvref_t
<
FmhaBwdOGradDotO_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaBwdOGradDotO
::
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaBwdOGradDotO
::
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kBlockPerCu
=
FmhaBwdOGradDotO
::
kBlockPerCu
;
static
constexpr
ck_tile
::
index_t
kBlockPerCu
=
FmhaBwdOGradDotO
::
kBlockPerCu
;
...
@@ -1234,13 +1139,13 @@ struct FmhaBwdOGradDotOKernel
...
@@ -1234,13 +1139,13 @@ struct FmhaBwdOGradDotOKernel
ck_tile
::
index_t
nhead_stride_do
;
ck_tile
::
index_t
nhead_stride_do
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
nhead_stride_d
;
ck_tile
::
index_t
nhead_stride_d
;
ck_tile
::
index_t
batch_stride_d
;
};
};
struct
FmhaBwdOGradDotOBatchModeKargs
:
FmhaBwdOGradDotOCommonKargs
struct
FmhaBwdOGradDotOBatchModeKargs
:
FmhaBwdOGradDotOCommonKargs
{
{
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
batch_stride_d
;
};
};
struct
FmhaBwdOGradDotOGroupModeKargs
:
FmhaBwdOGradDotOCommonKargs
struct
FmhaBwdOGradDotOGroupModeKargs
:
FmhaBwdOGradDotOCommonKargs
...
@@ -1278,10 +1183,10 @@ struct FmhaBwdOGradDotOKernel
...
@@ -1278,10 +1183,10 @@ struct FmhaBwdOGradDotOKernel
stride_o
,
stride_o
,
nhead_stride_do
,
nhead_stride_do
,
nhead_stride_o
,
nhead_stride_o
,
nhead_stride_d
,
nhead_stride_d
},
batch_stride_d
},
batch_stride_do
,
batch_stride_do
,
batch_stride_o
};
batch_stride_o
,
batch_stride_d
};
return
kargs
;
return
kargs
;
}
}
...
@@ -1298,8 +1203,7 @@ struct FmhaBwdOGradDotOKernel
...
@@ -1298,8 +1203,7 @@ struct FmhaBwdOGradDotOKernel
ck_tile
::
index_t
stride_o
,
ck_tile
::
index_t
stride_o
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
nhead_stride_d
,
ck_tile
::
index_t
nhead_stride_d
)
ck_tile
::
index_t
batch_stride_d
)
{
{
Kargs
kargs
{{
o_ptr
,
Kargs
kargs
{{
o_ptr
,
do_ptr
,
do_ptr
,
...
@@ -1311,8 +1215,7 @@ struct FmhaBwdOGradDotOKernel
...
@@ -1311,8 +1215,7 @@ struct FmhaBwdOGradDotOKernel
stride_o
,
stride_o
,
nhead_stride_do
,
nhead_stride_do
,
nhead_stride_o
,
nhead_stride_o
,
nhead_stride_d
,
nhead_stride_d
},
batch_stride_d
},
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
)};
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
)};
return
kargs
;
return
kargs
;
...
@@ -1321,7 +1224,16 @@ struct FmhaBwdOGradDotOKernel
...
@@ -1321,7 +1224,16 @@ struct FmhaBwdOGradDotOKernel
CK_TILE_HOST
static
constexpr
auto
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
)
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
)
{
{
return
TilePartitioner
::
GridSize
(
batch_size_
,
nhead_
,
seqlen_q_
);
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
()
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
@@ -1331,7 +1243,7 @@ struct FmhaBwdOGradDotOKernel
...
@@ -1331,7 +1243,7 @@ struct FmhaBwdOGradDotOKernel
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
{
// divide problem
// divide problem
const
auto
[
i_tile_m
,
i_nhead
,
i_batch
]
=
Tile
Partitioner
{}(
kargs
.
seqlen_q
);
const
auto
[
i_tile_m
,
i_nhead
,
i_batch
]
=
Get
Tile
Index
(
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
kM0
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
kM0
);
...
@@ -1346,7 +1258,7 @@ struct FmhaBwdOGradDotOKernel
...
@@ -1346,7 +1258,7 @@ struct FmhaBwdOGradDotOKernel
batch_offset_o
=
query_start
*
kargs
.
stride_o
;
batch_offset_o
=
query_start
*
kargs
.
stride_o
;
batch_offset_do
=
query_start
*
kargs
.
stride_do
;
batch_offset_do
=
query_start
*
kargs
.
stride_do
;
batch_offset_d
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_d
;
batch_offset_d
=
query_start
;
// get real # queries & # keys under group mode
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
...
@@ -1418,4 +1330,315 @@ struct FmhaBwdOGradDotOKernel
...
@@ -1418,4 +1330,315 @@ struct FmhaBwdOGradDotOKernel
}
}
};
};
template
<
typename
FmhaBwdConvertQGrad_
>
struct
FmhaBwdConvertQGradKernel
{
using
FmhaBwdConvertQGrad
=
ck_tile
::
remove_cvref_t
<
FmhaBwdConvertQGrad_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaBwdConvertQGrad
::
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kBlockPerCu
=
FmhaBwdConvertQGrad
::
kBlockPerCu
;
static
constexpr
ck_tile
::
index_t
kM0
=
FmhaBwdConvertQGrad
::
kM0
;
static
constexpr
ck_tile
::
index_t
kN0
=
FmhaBwdConvertQGrad
::
kN0
;
static
constexpr
ck_tile
::
index_t
kQKHeaddim
=
FmhaBwdConvertQGrad
::
kQKHeaddim
;
using
AccDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaBwdConvertQGrad
::
AccDataType
>
;
using
QGradDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaBwdConvertQGrad
::
QGradDataType
>
;
static
constexpr
bool
kIsGroupMode
=
FmhaBwdConvertQGrad
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
FmhaBwdConvertQGrad
::
kPadSeqLenQ
;
static
constexpr
bool
kPadHeadDimQ
=
FmhaBwdConvertQGrad
::
kPadHeadDimQ
;
static
constexpr
bool
kIsDeterministic
=
FmhaBwdConvertQGrad
::
kIsDeterministic
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
ck_tile
::
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
// clang-format on
CK_TILE_HOST
static
std
::
string
GetName
()
{
// sync with generate.py
// clang-format off
#define _SS_ std::string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
std
::
string
n
;
if
(
kPadSeqLenQ
)
n
+=
"s"
;
if
(
kPadHeadDimQ
)
n
+=
"d"
;
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
_SS_
(
"fmha_bwd_convert_dq_d"
)
+
_TS_
(
kQKHeaddim
)
+
"_"
+
_SS_
(
t2s
<
QGradDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
(
kIsDeterministic
?
"_deterministic"
:
""
)
+
"_"
+
(
"o"
+
_TS_
(
kBlockPerCu
))
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
);
#undef _SS_
#undef _TS_
// clang-format on
}
// to avoid duplicated base class prblem, introduce an template arg
template
<
ck_tile
::
index_t
I
>
struct
FmhaBwdConvertQGradEmptyKargs
{
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct
FmhaBwdConvertQGradCommonKargs
{
const
void
*
dq_acc_ptr
;
void
*
dq_ptr
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_k
;
ck_tile
::
index_t
hdim_q
;
ck_tile
::
index_t
stride_dq
;
ck_tile
::
index_t
stride_dq_acc
;
ck_tile
::
index_t
nhead_stride_dq
;
ck_tile
::
index_t
nhead_stride_dq_acc
;
};
struct
FmhaBwdConvertQGradDeterministicKargs
{
ck_tile
::
index_t
split_stride_dq_acc
=
0
;
};
struct
FmhaBwdConvertQGradBatchModeKargs
:
FmhaBwdConvertQGradCommonKargs
,
std
::
conditional_t
<
kIsDeterministic
,
FmhaBwdConvertQGradDeterministicKargs
,
FmhaBwdConvertQGradEmptyKargs
<
0
>>
{
ck_tile
::
index_t
batch_stride_dq
;
ck_tile
::
index_t
batch_stride_dq_acc
;
};
struct
FmhaBwdConvertQGradGroupModeKargs
:
FmhaBwdConvertQGradCommonKargs
,
std
::
conditional_t
<
kIsDeterministic
,
FmhaBwdConvertQGradDeterministicKargs
,
FmhaBwdConvertQGradEmptyKargs
<
0
>>
{
const
int32_t
*
seqstart_q_ptr
;
const
int32_t
*
seqstart_k_ptr
;
};
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
FmhaBwdConvertQGradGroupModeKargs
,
FmhaBwdConvertQGradBatchModeKargs
>
;
template
<
bool
Cond
=
!
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
dq_acc_ptr
,
void
*
dq_ptr
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_k
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
stride_dq
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dq
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
batch_stride_dq
,
ck_tile
::
index_t
batch_stride_dq_acc
,
ck_tile
::
index_t
split_stride_dq_acc
)
{
Kargs
kargs
{{
dq_acc_ptr
,
dq_ptr
,
seqlen_q
,
seqlen_k
,
hdim_q
,
stride_dq
,
stride_dq_acc
,
nhead_stride_dq
,
nhead_stride_dq_acc
},
{},
batch_stride_dq
,
batch_stride_dq_acc
};
if
constexpr
(
kIsDeterministic
)
{
kargs
.
split_stride_dq_acc
=
split_stride_dq_acc
;
}
return
kargs
;
}
template
<
bool
Cond
=
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
dq_acc_ptr
,
void
*
dq_ptr
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_k_ptr
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
stride_dq
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dq
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
split_stride_dq_acc
)
{
Kargs
kargs
{{
dq_acc_ptr
,
dq_ptr
,
-
1
,
// seqlen will be updated by another pointer
-
1
,
//
hdim_q
,
stride_dq
,
stride_dq_acc
,
nhead_stride_dq
,
nhead_stride_dq_acc
},
{},
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
)};
if
constexpr
(
kIsDeterministic
)
{
kargs
.
split_stride_dq_acc
=
split_stride_dq_acc
;
}
return
kargs
;
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
)
{
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
()
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
0
;
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// divide problem
const
auto
[
i_tile_m
,
i_nhead
,
i_batch
]
=
GetTileIndex
();
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
kM0
);
long_index_t
batch_offset_dq
=
0
;
long_index_t
batch_offset_dq_acc
=
0
;
if
constexpr
(
kIsGroupMode
)
{
// get starting offset for each batch
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
batch_offset_dq
=
query_start
*
kargs
.
stride_dq
;
batch_offset_dq_acc
=
query_start
*
kargs
.
stride_dq_acc
;
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
kargs
.
seqlen_q
=
adjusted_seqstart_q_ptr
[
1
]
-
adjusted_seqstart_q_ptr
[
0
];
if
constexpr
(
kIsDeterministic
)
{
const
auto
adjusted_seqstart_k_ptr
=
kargs
.
seqstart_k_ptr
+
i_batch
;
kargs
.
seqlen_k
=
adjusted_seqstart_k_ptr
[
1
]
-
adjusted_seqstart_k_ptr
[
0
];
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if
(
kargs
.
seqlen_q
<=
i_m0
)
{
return
;
}
}
else
{
batch_offset_dq
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dq
;
batch_offset_dq_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dq_acc
;
}
// for simplicity, batch stride we just modify the pointer
QGradDataType
*
dq_ptr
=
reinterpret_cast
<
QGradDataType
*>
(
kargs
.
dq_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_dq
+
batch_offset_dq
;
// dQAcc/dQ DRAM and DRAM window
const
auto
dq_acc_dram
=
[
&
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
kIsDeterministic
)
{
const
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
(
kargs
.
nhead_stride_dq_acc
)
+
batch_offset_dq_acc
;
const
index_t
nsplits
=
ck_tile
::
integer_divide_ceil
(
kargs
.
seqlen_k
,
kN0
);
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_acc_ptr
,
make_tuple
(
nsplits
,
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
split_stride_dq_acc
,
kargs
.
stride_dq_acc
,
1
),
number
<
FmhaBwdConvertQGrad
::
kAlignmentQGradAcc
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_acc_dram_naive
,
make_tuple
(
number
<
1
>
{},
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
sequence
<
false
,
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
else
{
const
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
(
kargs
.
nhead_stride_dq_acc
)
+
batch_offset_dq_acc
;
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_dq_acc
,
1
),
number
<
FmhaBwdConvertQGrad
::
kAlignmentQGradAcc
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_acc_dram_naive
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
}();
auto
dq_dram
=
[
&
]()
{
auto
dq_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_dq
,
1
),
number
<
FmhaBwdConvertQGrad
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_dram_naive
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
auto
dq_acc_dram_window
=
[
&
]()
{
if
constexpr
(
kIsDeterministic
)
{
return
make_tile_window
(
dq_acc_dram
,
make_tuple
(
number
<
1
>
{},
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
{
0
,
i_m0
,
0
});
}
else
{
return
make_tile_window
(
dq_acc_dram
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
{
i_m0
,
0
});
}
}();
auto
dq_dram_window
=
make_tile_window
(
dq_dram
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
{
i_m0
,
0
});
if
constexpr
(
kIsDeterministic
)
{
const
index_t
nsplits
=
ck_tile
::
integer_divide_ceil
(
kargs
.
seqlen_k
,
kN0
);
FmhaBwdConvertQGrad
{}(
dq_acc_dram_window
,
dq_dram_window
,
nsplits
);
}
else
{
FmhaBwdConvertQGrad
{}(
dq_acc_dram_window
,
dq_dram_window
);
}
}
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
deleted
100644 → 0
View file @
241c261f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
BlockFmhaShape_
>
struct
FmhaBwdTilePartitioner
{
using
BlockFmhaShape
=
ck_tile
::
remove_cvref_t
<
BlockFmhaShape_
>
;
static
constexpr
ck_tile
::
index_t
kN0
=
BlockFmhaShape
::
kN0
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_k_
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_k_
,
kN0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_k*/
)
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
};
template
<
ck_tile
::
index_t
kBlockSize
>
struct
FmhaBwdOGradDotOTilePartitioner
{
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kBlockSize
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
)
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp
0 → 100644
View file @
72c9f129
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
>
struct
FmhaFwdAppendKVKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
ck_tile
::
remove_cvref_t
<
FmhaPipeline_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaPipeline
::
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kBlockPerCu
=
FmhaPipeline
::
kBlockPerCu
;
static_assert
(
kBlockPerCu
>
0
);
static
constexpr
ck_tile
::
index_t
kBlockPerCuInput
=
FmhaPipeline
::
Problem
::
kBlockPerCu
;
using
QDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
QDataType
>
;
using
KDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
KDataType
>
;
using
VDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
VDataType
>
;
using
VLayout
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
VLayout
>
;
static
constexpr
bool
kApplyRoPE
=
FmhaPipeline
::
RotaryEnum
!=
RotaryEmbeddingEnum
::
NONE
;
static
constexpr
bool
kIsPagedKV
=
FmhaPipeline
::
kIsPagedKV
;
static
constexpr
bool
kPadSeqLenQ
=
FmhaPipeline
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
FmhaPipeline
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
FmhaPipeline
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
// clang-format on
__host__
static
std
::
string
GetName
()
{
// sync with generate.py
// clang-format off
#define _SS_ std::string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
std
::
string
n
;
if
(
kPadSeqLenQ
)
n
+=
"s"
;
if
(
kPadSeqLenK
)
n
+=
"sk"
;
if
(
kPadHeadDimQ
)
n
+=
"d"
;
if
(
kPadHeadDimV
)
n
+=
"dv"
;
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
_SS_
(
"fmha_fwd_appendkv_d"
)
+
_TS_
(
FmhaPipeline
::
kK0
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
"_"
"b"
+
_TS_
(
FmhaPipeline
::
kM0
)
+
"x"
+
_TS_
(
FmhaPipeline
::
kN0
)
+
"x"
+
_TS_
(
FmhaPipeline
::
kK0
)
+
"x"
+
_TS_
(
FmhaPipeline
::
kN1
)
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
!
kApplyRoPE
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
RotaryEmbeddingEnumToStr
<
FmhaPipeline
::
RotaryEnum
>::
name
))
+
(
kIsPagedKV
?
"_pagedkv"
:
""
);
#undef _SS_
#undef _TS_
// clang-format on
}
template
<
ck_tile
::
index_t
I
>
// to avoid duplicated base class prblem, introduce an template
// arg
struct
EmptyKargs
{
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct
BasicKargs
{
void
*
q_ptr
;
void
*
k_ptr
;
const
void
*
knew_ptr
;
void
*
v_ptr
;
const
void
*
vnew_ptr
;
const
int32_t
*
seqlen_k_ptr
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_k
;
ck_tile
::
index_t
seqlen_knew
;
ck_tile
::
index_t
hdim_q
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
num_head_q
;
// for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
// if this param is larger than 1, indicate MQA/GQA case
ck_tile
::
index_t
nhead_ratio_qk
;
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_knew
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_vnew
;
ck_tile
::
index_t
nhead_stride_q
;
ck_tile
::
index_t
nhead_stride_k
;
ck_tile
::
index_t
nhead_stride_knew
;
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_vnew
;
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_knew
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_vnew
;
};
struct
RoPEKargs
{
const
void
*
rotary_cos_ptr
;
const
void
*
rotary_sin_ptr
;
ck_tile
::
index_t
rotary_dim
;
bool
has_mask
;
};
struct
PageBlockTableKargs
{
const
int32_t
*
block_table_ptr
;
ck_tile
::
index_t
batch_stride_block_table
;
ck_tile
::
index_t
page_block_size
;
};
struct
CacheBatchIdxKargs
{
const
int32_t
*
cache_batch_idx
;
};
struct
Kargs
:
BasicKargs
,
std
::
conditional_t
<
kApplyRoPE
,
RoPEKargs
,
EmptyKargs
<
0
>>
,
std
::
conditional_t
<
kIsPagedKV
,
PageBlockTableKargs
,
CacheBatchIdxKargs
>
{
};
__host__
static
constexpr
Kargs
MakeKargs
(
void
*
q_ptr
,
void
*
k_ptr
,
const
void
*
knew_ptr
,
void
*
v_ptr
,
const
void
*
vnew_ptr
,
ck_tile
::
index_t
seqlen_q
,
const
void
*
seqlen_k_ptr
,
ck_tile
::
index_t
seqlen_knew
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_head_q
,
ck_tile
::
index_t
nhead_ratio_qk
,
const
void
*
rotary_cos_ptr
,
const
void
*
rotary_sin_ptr
,
ck_tile
::
index_t
rotary_dim
,
bool
has_mask
,
const
void
*
block_table_ptr
,
ck_tile
::
index_t
batch_stride_block_table
,
ck_tile
::
index_t
page_block_size
,
const
void
*
cache_batch_idx
,
ck_tile
::
index_t
stride_q
,
ck_tile
::
index_t
stride_k
,
ck_tile
::
index_t
stride_knew
,
ck_tile
::
index_t
stride_v
,
ck_tile
::
index_t
stride_vnew
,
ck_tile
::
index_t
nhead_stride_q
,
ck_tile
::
index_t
nhead_stride_k
,
ck_tile
::
index_t
nhead_stride_knew
,
ck_tile
::
index_t
nhead_stride_v
,
ck_tile
::
index_t
nhead_stride_vnew
,
ck_tile
::
index_t
batch_stride_q
,
ck_tile
::
index_t
batch_stride_k
,
ck_tile
::
index_t
batch_stride_knew
,
ck_tile
::
index_t
batch_stride_v
,
ck_tile
::
index_t
batch_stride_vnew
)
{
Kargs
kargs
{
{
q_ptr
,
k_ptr
,
knew_ptr
,
v_ptr
,
vnew_ptr
,
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
),
seqlen_q
,
-
1
,
// seqlen_k will be updated by content of seqlen_k_ptr
seqlen_knew
,
hdim_q
,
hdim_v
,
num_head_q
,
nhead_ratio_qk
,
stride_q
,
stride_k
,
stride_knew
,
stride_v
,
stride_vnew
,
nhead_stride_q
,
nhead_stride_k
,
nhead_stride_knew
,
nhead_stride_v
,
nhead_stride_vnew
,
batch_stride_q
,
batch_stride_k
,
batch_stride_knew
,
batch_stride_v
,
batch_stride_vnew
},
// args for common karg
{},
// placeholder for rope
{}
// placeholder for paged-block table or cache_batch_idx
};
if
constexpr
(
kApplyRoPE
)
{
kargs
.
rotary_cos_ptr
=
rotary_cos_ptr
;
kargs
.
rotary_sin_ptr
=
rotary_sin_ptr
;
kargs
.
rotary_dim
=
rotary_dim
;
kargs
.
has_mask
=
has_mask
;
}
if
constexpr
(
kIsPagedKV
)
{
kargs
.
block_table_ptr
=
reinterpret_cast
<
const
int32_t
*>
(
block_table_ptr
);
kargs
.
batch_stride_block_table
=
batch_stride_block_table
;
kargs
.
page_block_size
=
page_block_size
;
}
else
{
kargs
.
cache_batch_idx
=
reinterpret_cast
<
const
int32_t
*>
(
cache_batch_idx
);
}
return
kargs
;
}
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_knew
)
{
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
seqlen_q
,
seqlen_knew
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// divide problem
const
auto
[
i_tile
,
i_nhead
,
i_batch
]
=
TilePartitioner
{}();
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile
*
FmhaPipeline
::
kM0
);
const
index_t
i_n0
=
__builtin_amdgcn_readfirstlane
(
i_tile
*
FmhaPipeline
::
kN0
);
const
index_t
i_cache_batch
=
[
&
,
i_batch_
=
i_batch
]
{
if
constexpr
(
kIsPagedKV
)
{
return
i_batch_
;
}
else
{
return
(
kargs
.
cache_batch_idx
!=
nullptr
?
kargs
.
cache_batch_idx
[
i_batch_
]
:
i_batch_
);
}
}();
const
long_index_t
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
const
long_index_t
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_cache_batch
)
*
kargs
.
batch_stride_k
;
const
long_index_t
batch_offset_knew
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_knew
;
const
long_index_t
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_cache_batch
)
*
kargs
.
batch_stride_v
;
const
long_index_t
batch_offset_vnew
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_vnew
;
kargs
.
seqlen_k
=
kargs
.
seqlen_k_ptr
[
i_batch
];
// for simplicity, batch stride we just modify the pointer
QDataType
*
q_ptr
=
reinterpret_cast
<
QDataType
*>
(
kargs
.
q_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_q
+
batch_offset_q
;
KDataType
*
k_ptr
=
reinterpret_cast
<
KDataType
*>
(
kargs
.
k_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_k
+
batch_offset_k
;
const
KDataType
*
knew_ptr
=
reinterpret_cast
<
const
KDataType
*>
(
kargs
.
knew_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_knew
+
batch_offset_knew
;
VDataType
*
v_ptr
=
reinterpret_cast
<
VDataType
*>
(
kargs
.
v_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_v
+
batch_offset_v
;
const
VDataType
*
vnew_ptr
=
reinterpret_cast
<
const
VDataType
*>
(
kargs
.
vnew_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_vnew
+
batch_offset_vnew
;
// Q/K/V DRAM and DRAM window
const
auto
q_dram
=
[
&
]()
{
const
auto
q_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
q_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQ
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
const
auto
make_k_dram
=
[
&
](
KDataType
*
data
,
index_t
height
)
{
const
auto
k_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
data
,
// will update this pointer if using paged-kvcache
make_tuple
(
height
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_k
,
1
),
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
};
const
auto
k_dram
=
[
&
]()
{
if
constexpr
(
kIsPagedKV
)
{
return
make_k_dram
(
nullptr
,
kargs
.
page_block_size
);
}
else
{
return
make_k_dram
(
k_ptr
,
kargs
.
seqlen_k
+
kargs
.
seqlen_knew
);
}
}();
const
auto
knew_dram
=
[
&
]()
{
const
auto
knew_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
knew_ptr
,
make_tuple
(
kargs
.
seqlen_knew
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_knew
,
1
),
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
knew_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}();
const
auto
make_v_dram
=
[
&
](
VDataType
*
data
,
index_t
length
)
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
const
auto
v_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
data
,
// will update this pointer if using paged-kvcache
make_tuple
(
length
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
stride_v
,
1
),
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
1
>
{});
const
auto
v_dram_transposed
=
transform_tensor_view
(
v_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_v
),
make_pass_through_transform
(
length
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
pad_tensor_view
(
v_dram_transposed
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kN0
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenK
>
{});
}
else
{
const
auto
v_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
data
,
// will update this pointer if using paged-kvcache
make_tuple
(
kargs
.
hdim_v
,
length
),
make_tuple
(
kargs
.
stride_v
,
1
),
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kN0
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenK
>
{});
}
};
const
auto
v_dram
=
[
&
]()
{
if
constexpr
(
kIsPagedKV
)
{
return
make_v_dram
(
nullptr
,
kargs
.
page_block_size
);
}
else
{
return
make_v_dram
(
v_ptr
,
kargs
.
seqlen_k
+
kargs
.
seqlen_knew
);
}
}();
const
auto
vnew_dram
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
const
auto
vnew_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
vnew_ptr
,
make_tuple
(
kargs
.
seqlen_knew
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
stride_vnew
,
1
),
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
1
>
{});
const
auto
vnew_dram_transposed
=
transform_tensor_view
(
vnew_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_v
),
make_pass_through_transform
(
kargs
.
seqlen_knew
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
pad_tensor_view
(
vnew_dram_transposed
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kN0
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenK
>
{});
}
else
{
const
auto
vnew_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
vnew_ptr
,
make_tuple
(
kargs
.
hdim_v
,
kargs
.
seqlen_knew
),
make_tuple
(
kargs
.
stride_vnew
,
1
),
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
vnew_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kN0
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenK
>
{});
}
}();
constexpr
auto
q_rotary_cos_sin_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
/
2
>
{});
const
auto
q_rotary_cos_dram_window
=
[
&
]()
{
if
constexpr
(
kApplyRoPE
)
{
const
auto
rotary_cos_dram_native
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
reinterpret_cast
<
const
QDataType
*>
(
kargs
.
rotary_cos_ptr
)
+
kargs
.
seqlen_k
*
(
kargs
.
rotary_dim
/
2
),
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
rotary_dim
/
2
),
make_tuple
(
kargs
.
has_mask
*
(
kargs
.
rotary_dim
/
2
),
1
),
number
<
8
>
{},
number
<
1
>
{});
const
auto
rotary_cos_dram
=
[
&
]()
{
return
pad_tensor_view
(
rotary_cos_dram_native
,
q_rotary_cos_sin_dram_window_lengths
,
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
return
make_tile_window
(
rotary_cos_dram
,
q_rotary_cos_sin_dram_window_lengths
,
{
i_m0
,
0
});
}
else
{
return
make_null_tile_window
(
q_rotary_cos_sin_dram_window_lengths
);
}
}();
const
auto
q_rotary_sin_dram_window
=
[
&
]()
{
if
constexpr
(
kApplyRoPE
)
{
const
auto
rotary_sin_dram_native
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
reinterpret_cast
<
const
QDataType
*>
(
kargs
.
rotary_sin_ptr
)
+
kargs
.
seqlen_k
*
(
kargs
.
rotary_dim
/
2
),
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
rotary_dim
/
2
),
make_tuple
(
kargs
.
has_mask
*
(
kargs
.
rotary_dim
/
2
),
1
),
number
<
8
>
{},
number
<
1
>
{});
const
auto
rotary_sin_dram
=
[
&
]()
{
return
pad_tensor_view
(
rotary_sin_dram_native
,
q_rotary_cos_sin_dram_window_lengths
,
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
return
make_tile_window
(
rotary_sin_dram
,
q_rotary_cos_sin_dram_window_lengths
,
{
i_m0
,
0
});
}
else
{
return
make_null_tile_window
(
q_rotary_cos_sin_dram_window_lengths
);
}
}();
constexpr
auto
knew_rotary_cos_sin_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
/
2
>
{});
const
auto
knew_rotary_cos_dram_window
=
[
&
]()
{
if
constexpr
(
kApplyRoPE
)
{
const
auto
rotary_cos_dram_native
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
reinterpret_cast
<
const
KDataType
*>
(
kargs
.
rotary_cos_ptr
)
+
kargs
.
seqlen_k
*
(
kargs
.
rotary_dim
/
2
),
make_tuple
(
kargs
.
seqlen_knew
,
kargs
.
rotary_dim
/
2
),
make_tuple
(
kargs
.
rotary_dim
/
2
,
1
),
number
<
8
>
{},
number
<
1
>
{});
const
auto
rotary_cos_dram
=
[
&
]()
{
return
pad_tensor_view
(
rotary_cos_dram_native
,
knew_rotary_cos_sin_dram_window_lengths
,
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}();
return
make_tile_window
(
rotary_cos_dram
,
knew_rotary_cos_sin_dram_window_lengths
,
{
i_n0
,
0
});
}
else
{
return
make_null_tile_window
(
knew_rotary_cos_sin_dram_window_lengths
);
}
}();
const
auto
knew_rotary_sin_dram_window
=
[
&
]()
{
if
constexpr
(
kApplyRoPE
)
{
const
auto
rotary_sin_dram_native
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
reinterpret_cast
<
const
KDataType
*>
(
kargs
.
rotary_sin_ptr
)
+
kargs
.
seqlen_k
*
(
kargs
.
rotary_dim
/
2
),
make_tuple
(
kargs
.
seqlen_knew
,
kargs
.
rotary_dim
/
2
),
make_tuple
(
kargs
.
rotary_dim
/
2
,
1
),
number
<
8
>
{},
number
<
1
>
{});
const
auto
rotary_sin_dram
=
[
&
]()
{
return
pad_tensor_view
(
rotary_sin_dram_native
,
knew_rotary_cos_sin_dram_window_lengths
,
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}();
return
make_tile_window
(
rotary_sin_dram
,
knew_rotary_cos_sin_dram_window_lengths
,
{
i_n0
,
0
});
}
else
{
return
make_null_tile_window
(
knew_rotary_cos_sin_dram_window_lengths
);
}
}();
auto
k_page_block_navigator
=
[
&
,
i_batch_
=
i_batch
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
kIsPagedKV
)
{
const
auto
*
block_indices
=
reinterpret_cast
<
const
int32_t
*>
(
kargs
.
block_table_ptr
)
+
i_batch_
*
kargs
.
batch_stride_block_table
;
const
index_t
num_blocks
=
integer_divide_ceil
(
kargs
.
seqlen_k
+
kargs
.
seqlen_knew
,
kargs
.
page_block_size
);
const
long_index_t
fixed_offset
=
static_cast
<
long_index_t
>
(
i_nhead_
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_k
;
return
make_page_block_navigator
<
KDataType
,
0
>
(
kargs
.
k_ptr
,
kargs
.
batch_stride_k
,
fixed_offset
,
block_indices
,
num_blocks
,
kargs
.
page_block_size
,
k_dram
,
make_k_dram
(
nullptr
,
(
kargs
.
seqlen_k
+
kargs
.
seqlen_knew
)
-
(
num_blocks
-
1
)
*
kargs
.
page_block_size
));
}
else
{
return
make_page_block_navigator
(
k_dram
);
}
}();
auto
v_page_block_navigator
=
[
&
,
i_batch_
=
i_batch
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
kIsPagedKV
)
{
const
auto
*
block_indices
=
reinterpret_cast
<
const
int32_t
*>
(
kargs
.
block_table_ptr
)
+
i_batch_
*
kargs
.
batch_stride_block_table
;
const
index_t
num_blocks
=
integer_divide_ceil
(
kargs
.
seqlen_k
+
kargs
.
seqlen_knew
,
kargs
.
page_block_size
);
const
long_index_t
fixed_offset
=
static_cast
<
long_index_t
>
(
i_nhead_
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_v
;
return
make_page_block_navigator
<
VDataType
,
1
>
(
kargs
.
v_ptr
,
kargs
.
batch_stride_v
,
fixed_offset
,
block_indices
,
num_blocks
,
kargs
.
page_block_size
,
v_dram
,
make_v_dram
(
nullptr
,
(
kargs
.
seqlen_k
+
kargs
.
seqlen_knew
)
-
(
num_blocks
-
1
)
*
kargs
.
page_block_size
));
}
else
{
return
make_page_block_navigator
(
v_dram
);
}
}();
auto
q_dram_window
=
make_tile_window
(
q_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
{
i_m0
,
0
});
const
bool
skip_append_kv
=
kargs
.
seqlen_knew
<=
i_n0
;
// window origin = (0, 0) if no work to do for current block
auto
[
i_page_block_k
,
k_dram_window
]
=
k_page_block_navigator
.
make_tile_window
(
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
{
!
skip_append_kv
*
(
kargs
.
seqlen_k
+
i_n0
),
0
});
auto
knew_dram_window
=
make_tile_window
(
knew_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
{
i_n0
,
0
});
// window origin = (0, 0) if no work to do for current block
auto
[
i_page_block_v
,
v_dram_window
]
=
v_page_block_navigator
.
make_tile_window
(
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kN0
>
{}),
{
0
,
!
skip_append_kv
*
(
kargs
.
seqlen_k
+
i_n0
)});
auto
vnew_dram_window
=
make_tile_window
(
vnew_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kN0
>
{}),
{
0
,
i_n0
});
if
constexpr
(
kApplyRoPE
)
{
FmhaPipeline
{}(
q_dram_window
,
k_dram_window
,
i_page_block_k
,
k_page_block_navigator
,
knew_dram_window
,
v_dram_window
,
i_page_block_v
,
v_page_block_navigator
,
vnew_dram_window
,
q_rotary_cos_dram_window
,
q_rotary_sin_dram_window
,
knew_rotary_cos_dram_window
,
knew_rotary_sin_dram_window
,
kargs
.
rotary_dim
,
kargs
.
seqlen_q
<=
i_m0
,
skip_append_kv
);
}
else
{
FmhaPipeline
{}(
q_dram_window
,
k_dram_window
,
i_page_block_k
,
k_page_block_navigator
,
knew_dram_window
,
v_dram_window
,
i_page_block_v
,
v_page_block_navigator
,
vnew_dram_window
,
q_rotary_cos_dram_window
,
q_rotary_sin_dram_window
,
knew_rotary_cos_dram_window
,
knew_rotary_sin_dram_window
,
0
,
// rotary_dim not used
kargs
.
seqlen_q
<=
i_m0
,
skip_append_kv
);
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp
0 → 100644
View file @
72c9f129
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
index_t
kM0_
,
index_t
kN0_
,
index_t
kK0_
,
index_t
kN1_
>
struct
FmhaFwdAppendKVTilePartitioner
{
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kN0
=
kN0_
;
static
constexpr
ck_tile
::
index_t
kK0
=
kK0_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
static_assert
(
kK0
==
kN1
);
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_knew
)
{
// TODO: this may need tuning
return
dim3
(
std
::
max
(
ck_tile
::
integer_divide_ceil
(
seqlen_q
,
kM0
),
ck_tile
::
integer_divide_ceil
(
seqlen_knew
,
kN0
)),
nhead
,
batch_size
);
}
CK_TILE_DEVICE
auto
operator
()()
{
const
index_t
i_tile
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_tile
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
72c9f129
...
@@ -86,7 +86,7 @@ struct FmhaFwdKernel
...
@@ -86,7 +86,7 @@ struct FmhaFwdKernel
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
#undef _SS_
#undef _SS_
#undef _TS_
#undef _TS_
...
@@ -387,7 +387,6 @@ struct FmhaFwdKernel
...
@@ -387,7 +387,6 @@ struct FmhaFwdKernel
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
batch_stride_lse
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
ck_tile
::
index_t
mask_type
,
...
@@ -448,7 +447,6 @@ struct FmhaFwdKernel
...
@@ -448,7 +447,6 @@ struct FmhaFwdKernel
{
{
kargs
.
lse_ptr
=
lse_ptr
;
kargs
.
lse_ptr
=
lse_ptr
;
kargs
.
nhead_stride_lse
=
nhead_stride_lse
;
kargs
.
nhead_stride_lse
=
nhead_stride_lse
;
kargs
.
batch_stride_lse
=
batch_stride_lse
;
}
}
if
constexpr
(
kDoFp8StaticQuant
)
if
constexpr
(
kDoFp8StaticQuant
)
{
{
...
@@ -524,7 +522,7 @@ struct FmhaFwdKernel
...
@@ -524,7 +522,7 @@ struct FmhaFwdKernel
}
}
if
constexpr
(
kStoreLSE
)
if
constexpr
(
kStoreLSE
)
{
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
batch_offset_lse
=
query_start
;
}
}
if
constexpr
(
kHasDropout
)
if
constexpr
(
kHasDropout
)
{
{
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
View file @
72c9f129
...
@@ -55,7 +55,7 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -55,7 +55,7 @@ struct FmhaFwdSplitKVCombineKernel
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
_SS_
(
FmhaPipeline
::
name
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
#undef _SS_
#undef _SS_
#undef _TS_
#undef _TS_
...
@@ -91,7 +91,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -91,7 +91,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
...
@@ -116,6 +115,7 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -116,6 +115,7 @@ struct FmhaFwdSplitKVCombineKernel
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
1
>>
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
1
>>
{
{
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
batch_stride_lse_acc
;
};
};
struct
GroupModeKargs
struct
GroupModeKargs
...
@@ -166,13 +166,13 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -166,13 +166,13 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o_acc
,
nhead_stride_o
,
nhead_stride_o
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
split_stride_o_acc
},
// args for common karg
{},
// placeholder for lse
{},
// placeholder for lse
{},
// placeholder for fp8_static_quant args
{},
// placeholder for fp8_static_quant args
batch_stride_o
};
batch_stride_o
,
batch_stride_lse_acc
};
if
constexpr
(
kStoreLSE
)
if
constexpr
(
kStoreLSE
)
{
{
...
@@ -206,9 +206,7 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -206,9 +206,7 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
batch_stride_lse_acc
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
batch_stride_lse
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
)
ck_tile
::
index_t
split_stride_o_acc
)
{
{
...
@@ -225,7 +223,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -225,7 +223,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o_acc
,
nhead_stride_o
,
nhead_stride_o
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
split_stride_o_acc
},
// args for common karg
...
@@ -237,7 +234,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -237,7 +234,6 @@ struct FmhaFwdSplitKVCombineKernel
{
{
kargs
.
lse_ptr
=
lse_ptr
;
kargs
.
lse_ptr
=
lse_ptr
;
kargs
.
nhead_stride_lse
=
nhead_stride_lse
;
kargs
.
nhead_stride_lse
=
nhead_stride_lse
;
kargs
.
batch_stride_lse
=
batch_stride_lse
;
}
}
if
constexpr
(
kDoFp8StaticQuant
)
if
constexpr
(
kDoFp8StaticQuant
)
{
{
...
@@ -274,24 +270,25 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -274,24 +270,25 @@ struct FmhaFwdSplitKVCombineKernel
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
const
long_index_t
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
const
long_index_t
batch_offset_o_acc
=
const
long_index_t
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
long_index_t
batch_offset_lse
=
0
;
long_index_t
batch_offset_o
=
0
;
if
constexpr
(
kStoreLSE
)
long_index_t
batch_offset_lse_acc
=
0
;
{
long_index_t
batch_offset_lse
=
0
;
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
long_index_t
batch_offset_o
=
0
;
}
if
constexpr
(
kIsGroupMode
)
if
constexpr
(
kIsGroupMode
)
{
{
// get starting offset for each batch
// get starting offset for each batch
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
batch_offset_o
=
query_start
*
kargs
.
row_stride_o
;
batch_offset_o
=
query_start
*
kargs
.
row_stride_o
;
batch_offset_lse_acc
=
query_start
;
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse
=
query_start
;
}
// get real # queries & # keys under group mode
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
...
@@ -306,7 +303,13 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -306,7 +303,13 @@ struct FmhaFwdSplitKVCombineKernel
}
}
else
else
{
{
batch_offset_o
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o
;
batch_offset_o
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
}
}
}
// for simplicity, batch stride we just modify the pointer
// for simplicity, batch stride we just modify the pointer
...
...
Prev
1
2
3
4
5
6
7
8
9
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment