Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4947639c
Commit
4947639c
authored
Jun 19, 2024
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
17cf8179
d39c3f5d
Changes
150
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2401 additions
and
173 deletions
+2401
-173
include/ck_tile/core/tensor/buffer_view.hpp
include/ck_tile/core/tensor/buffer_view.hpp
+7
-6
include/ck_tile/core/tensor/store_tile.hpp
include/ck_tile/core/tensor/store_tile.hpp
+1
-1
include/ck_tile/core/tensor/tensor_view.hpp
include/ck_tile/core/tensor/tensor_view.hpp
+26
-4
include/ck_tile/core/tensor/tile_distribution.hpp
include/ck_tile/core/tensor/tile_distribution.hpp
+1
-0
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+60
-0
include/ck_tile/core/tensor/update_tile.hpp
include/ck_tile/core/tensor/update_tile.hpp
+55
-0
include/ck_tile/core/utility/philox_rand.hpp
include/ck_tile/core/utility/philox_rand.hpp
+89
-0
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+2
-0
include/ck_tile/host/device_memory.hpp
include/ck_tile/host/device_memory.hpp
+41
-18
include/ck_tile/host/host_tensor.hpp
include/ck_tile/host/host_tensor.hpp
+32
-4
include/ck_tile/host/kernel_launch.hpp
include/ck_tile/host/kernel_launch.hpp
+71
-131
include/ck_tile/host/reference/reference_batched_dropout.hpp
include/ck_tile/host/reference/reference_batched_dropout.hpp
+33
-0
include/ck_tile/host/stream_config.hpp
include/ck_tile/host/stream_config.hpp
+17
-0
include/ck_tile/host/timer.hpp
include/ck_tile/host/timer.hpp
+79
-0
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+14
-0
include/ck_tile/ops/fmha/block/block_dropout.hpp
include/ck_tile/ops/fmha/block/block_dropout.hpp
+329
-0
include/ck_tile/ops/fmha/block/block_masking.hpp
include/ck_tile/ops/fmha/block/block_masking.hpp
+66
-6
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
+3
-3
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
+1421
-0
include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
...ude/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
+54
-0
No files found.
include/ck_tile/core/tensor/buffer_view.hpp
View file @
4947639c
...
...
@@ -6,6 +6,7 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
...
...
@@ -507,10 +508,10 @@ struct buffer_view<address_space_enum::global,
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
if
constexpr
(
use_amd_buffer_addressing
)
{
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_add
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
}
...
...
@@ -518,7 +519,7 @@ struct buffer_view<address_space_enum::global,
{
if
(
is_valid_element
)
{
atomic_add
<
X
>
(
c_style_pointer_cast
<
X
*
>
(
&
p_data_
[
i
]
)
,
x
);
atomic_add
_g
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
&
p_data_
[
i
],
x
);
}
}
}
...
...
@@ -547,16 +548,16 @@ struct buffer_view<address_space_enum::global,
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
if
constexpr
(
use_amd_buffer_addressing
)
{
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_max
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
}
else
if
(
is_valid_element
)
{
atomic_max
<
X
>
(
c_style_pointer_cast
<
X
*
>
(
&
p_data_
[
i
]
)
,
x
);
atomic_max
_g
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
&
p_data_
[
i
],
x
);
}
}
...
...
include/ck_tile/core/tensor/store_tile.hpp
View file @
4947639c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/core/tensor/tensor_view.hpp
View file @
4947639c
...
...
@@ -16,7 +16,9 @@
namespace
ck_tile
{
template
<
typename
BufferView_
,
typename
TensorDesc_
>
template
<
typename
BufferView_
,
typename
TensorDesc_
,
memory_operation_enum
DstInMemOp_
=
memory_operation_enum
::
set
>
struct
tensor_view
{
using
buffer_view
=
remove_reference_t
<
BufferView_
>
;
...
...
@@ -24,6 +26,7 @@ struct tensor_view
using
TensorDesc
=
remove_cvref_t
<
TensorDesc_
>
;
using
TensorIndex
=
array
<
index_t
,
TensorDesc
::
get_num_of_top_dimension
()
>
;
using
TensorCoord
=
decltype
(
make_tensor_coordinate
(
TensorDesc
{},
TensorIndex
{}));
static
constexpr
auto
DstInMemOp
=
DstInMemOp_
;
CK_TILE_HOST_DEVICE
constexpr
tensor_view
()
=
default
;
...
...
@@ -140,6 +143,23 @@ struct tensor_view
x
);
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
update_vectorized_elements
(
const
TensorCoord
&
coord
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
buf_
.
template
update
<
DstInMemOp
,
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
x
);
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"tensor_view{"
);
...
...
@@ -178,6 +198,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p,
}
template
<
address_space_enum
BufferAddressSpace
=
address_space_enum
::
generic
,
memory_operation_enum
DstInMemOp
=
memory_operation_enum
::
set
,
typename
DataType
,
typename
...
Lengths
,
typename
...
Strides
,
...
...
@@ -198,7 +219,7 @@ make_naive_tensor_view(DataType* p,
auto
buffer_view
=
make_buffer_view
<
BufferAddressSpace
>
(
p
,
desc
.
get_element_space_size
());
return
tensor_view
<
decltype
(
buffer_view
),
decltype
(
desc
)
>
{
buffer_view
,
desc
};
return
tensor_view
<
decltype
(
buffer_view
),
decltype
(
desc
)
,
DstInMemOp
>
{
buffer_view
,
desc
};
}
template
<
address_space_enum
BufferAddressSpace
=
address_space_enum
::
generic
,
...
...
@@ -232,8 +253,9 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView& ol
NewLowerDimensionOldVisibleIdss
{},
NewUpperDimensionNewVisibleIdss
{});
return
tensor_view
<
typename
OldTensorView
::
buffer_view
,
remove_cvref_t
<
decltype
(
new_desc
)
>>
{
old_tensor_view
.
buf_
,
new_desc
};
return
tensor_view
<
typename
OldTensorView
::
buffer_view
,
remove_cvref_t
<
decltype
(
new_desc
)
>
,
remove_cvref_t
<
OldTensorView
>::
DstInMemOp
>
{
old_tensor_view
.
buf_
,
new_desc
};
}
template
<
typename
TensorView
,
...
...
include/ck_tile/core/tensor/tile_distribution.hpp
View file @
4947639c
...
...
@@ -9,6 +9,7 @@
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/meta_data_buffer.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/utility/functional.hpp"
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
4947639c
...
...
@@ -594,6 +594,66 @@ struct tile_window_with_static_distribution
});
}
template
<
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
update
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
Traits
=
load_store_traits
;
using
vector_t
=
typename
Traits
::
vector_t
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
// read from distributed tensor
vector_t
vec_value
;
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// write into bottom tensor
get_bottom_tensor_view
().
template
update_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
});
});
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin
...
...
include/ck_tile/core/tensor/update_tile.hpp
0 → 100644
View file @
4947639c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
DataType_
>
CK_TILE_DEVICE
void
update_tile
(
tile_window_with_static_lengths
<
BottomTensorView_
,
WindowLengths_
>&
tile_window_tmp
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
{
using
DataType
=
remove_cvref_t
<
typename
BottomTensorView_
::
DataType
>
;
using
TileDstr
=
remove_cvref_t
<
TileDistribution_
>
;
static_assert
(
std
::
is_same_v
<
remove_cvref_t
<
DataType_
>
,
DataType
>
,
"wrong!"
);
constexpr
auto
tile_dstr
=
TileDstr
{};
auto
tile_window
=
make_tile_window
(
tile_window_tmp
.
get_bottom_tensor_view
(),
tile_window_tmp
.
get_window_lengths
(),
tile_window_tmp
.
get_window_origin
(),
tile_dstr
);
tile_window
.
update
(
dstr_tensor
);
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
typename
DataType_
>
CK_TILE_DEVICE
void
update_tile
(
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
{
tile_window
.
update
(
dstr_tensor
);
}
}
// namespace ck_tile
include/ck_tile/core/utility/philox_rand.hpp
0 → 100644
View file @
4947639c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace
ck_tile
{
// Reference: https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/philox.cuh
class
philox
{
public:
CK_TILE_HOST_DEVICE
philox
(
unsigned
long
long
seed_
,
unsigned
long
long
offset_
)
:
seed
(
reinterpret_cast
<
const
uint2
&>
(
seed_
))
{
ull2
*
tmp
=
reinterpret_cast
<
ull2
*>
(
&
counter
);
tmp
->
x
=
offset_
;
}
CK_TILE_HOST_DEVICE
uint4
get_philox_4x32
(
const
unsigned
long
long
subsequence
)
const
{
uint4
counter_
=
counter
;
ull2
*
tmp
=
reinterpret_cast
<
ull2
*>
(
&
counter_
);
tmp
->
y
=
subsequence
;
uint2
key_
=
seed
;
// 7-round philox
#pragma unroll
for
(
int
i
=
0
;
i
<
6
;
i
++
)
{
counter_
=
philox_single_round
(
counter_
,
key_
);
key_
.
x
+=
kPhilox10A
;
key_
.
y
+=
kPhilox10B
;
}
uint4
output
=
philox_single_round
(
counter_
,
key_
);
return
output
;
}
CK_TILE_HOST_DEVICE
void
get_random_16x8
(
uint8_t
*
out
,
const
unsigned
long
long
subsequence
)
const
{
uint4
tmp_ph
;
tmp_ph
=
get_philox_4x32
(
subsequence
);
uint32_t
*
out_tmp
=
reinterpret_cast
<
uint32_t
*>
(
&
out
[
0
]);
out_tmp
[
0
]
=
tmp_ph
.
x
;
out_tmp
[
1
]
=
tmp_ph
.
y
;
out_tmp
[
2
]
=
tmp_ph
.
z
;
out_tmp
[
3
]
=
tmp_ph
.
w
;
}
private:
struct
ull2
{
uint64_t
x
;
uint64_t
y
;
};
uint4
counter
;
const
uint2
seed
;
CK_TILE_HOST_DEVICE
uint2
mulhilo32
(
const
unsigned
int
a
,
const
unsigned
int
b
)
const
{
uint2
*
res
;
unsigned
long
long
tmp
;
tmp
=
static_cast
<
unsigned
long
long
>
(
a
)
*
b
;
res
=
reinterpret_cast
<
uint2
*>
(
&
tmp
);
return
*
res
;
}
CK_TILE_HOST_DEVICE
uint4
philox_single_round
(
const
uint4
ctr
,
const
uint2
key
)
const
{
uint2
res0
=
mulhilo32
(
kPhiloxSA
,
ctr
.
x
);
uint2
res1
=
mulhilo32
(
kPhiloxSB
,
ctr
.
z
);
uint4
ret
=
{
res1
.
y
^
ctr
.
y
^
key
.
x
,
res1
.
x
,
res0
.
y
^
ctr
.
w
^
key
.
y
,
res0
.
x
};
return
ret
;
}
static
const
unsigned
long
kPhilox10A
=
0x9E3779B9
;
static
const
unsigned
long
kPhilox10B
=
0xBB67AE85
;
static
const
unsigned
long
kPhiloxSA
=
0xD2511F53
;
static
const
unsigned
long
kPhiloxSB
=
0xCD9E8D57
;
};
}
// namespace ck_tile
include/ck_tile/host.hpp
View file @
4947639c
...
...
@@ -11,6 +11,7 @@
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/ranges.hpp"
#include "ck_tile/host/reference/reference_batched_dropout.hpp"
#include "ck_tile/host/reference/reference_batched_elementwise.hpp"
#include "ck_tile/host/reference/reference_batched_gemm.hpp"
#include "ck_tile/host/reference/reference_batched_masking.hpp"
...
...
@@ -20,3 +21,4 @@
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp"
include/ck_tile/host/device_memory.hpp
View file @
4947639c
...
...
@@ -27,7 +27,14 @@ struct DeviceMem
DeviceMem
()
:
mpDeviceBuf
(
nullptr
),
mMemSize
(
0
)
{}
DeviceMem
(
std
::
size_t
mem_size
)
:
mMemSize
(
mem_size
)
{
HIP_CHECK_ERROR
(
hipMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
));
if
(
mMemSize
!=
0
)
{
HIP_CHECK_ERROR
(
hipMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
));
}
else
{
mpDeviceBuf
=
nullptr
;
}
}
void
Realloc
(
std
::
size_t
mem_size
)
{
...
...
@@ -36,7 +43,14 @@ struct DeviceMem
HIP_CHECK_ERROR
(
hipFree
(
mpDeviceBuf
));
}
mMemSize
=
mem_size
;
HIP_CHECK_ERROR
(
hipMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
));
if
(
mMemSize
!=
0
)
{
HIP_CHECK_ERROR
(
hipMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
));
}
else
{
mpDeviceBuf
=
nullptr
;
}
}
void
*
GetDeviceBuffer
()
const
{
return
mpDeviceBuf
;
}
std
::
size_t
GetBufferSize
()
const
{
return
mMemSize
;
}
...
...
@@ -47,15 +61,18 @@ struct DeviceMem
HIP_CHECK_ERROR
(
hipMemcpy
(
mpDeviceBuf
,
const_cast
<
void
*>
(
p
),
mMemSize
,
hipMemcpyHostToDevice
));
}
else
{
throw
std
::
runtime_error
(
"ToDevice with an empty pointer"
);
}
//
else
//
{
//
throw std::runtime_error("ToDevice with an empty pointer");
//
}
}
void
ToDevice
(
const
void
*
p
,
const
std
::
size_t
cpySize
)
const
{
HIP_CHECK_ERROR
(
hipMemcpy
(
mpDeviceBuf
,
const_cast
<
void
*>
(
p
),
cpySize
,
hipMemcpyHostToDevice
));
if
(
mpDeviceBuf
)
{
HIP_CHECK_ERROR
(
hipMemcpy
(
mpDeviceBuf
,
const_cast
<
void
*>
(
p
),
cpySize
,
hipMemcpyHostToDevice
));
}
}
void
FromDevice
(
void
*
p
)
const
{
...
...
@@ -63,14 +80,17 @@ struct DeviceMem
{
HIP_CHECK_ERROR
(
hipMemcpy
(
p
,
mpDeviceBuf
,
mMemSize
,
hipMemcpyDeviceToHost
));
}
else
{
throw
std
::
runtime_error
(
"FromDevice with an empty pointer"
);
}
//
else
//
{
//
throw std::runtime_error("FromDevice with an empty pointer");
//
}
}
void
FromDevice
(
void
*
p
,
const
std
::
size_t
cpySize
)
const
{
HIP_CHECK_ERROR
(
hipMemcpy
(
p
,
mpDeviceBuf
,
cpySize
,
hipMemcpyDeviceToHost
));
if
(
mpDeviceBuf
)
{
HIP_CHECK_ERROR
(
hipMemcpy
(
p
,
mpDeviceBuf
,
cpySize
,
hipMemcpyDeviceToHost
));
}
}
void
SetZero
()
const
{
...
...
@@ -82,13 +102,16 @@ struct DeviceMem
template
<
typename
T
>
void
SetValue
(
T
x
)
const
{
if
(
m
MemSize
%
sizeof
(
T
)
!=
0
)
if
(
m
pDeviceBuf
)
{
throw
std
::
runtime_error
(
"wrong! not entire DeviceMem will be set"
);
}
if
(
mMemSize
%
sizeof
(
T
)
!=
0
)
{
throw
std
::
runtime_error
(
"wrong! not entire DeviceMem will be set"
);
}
// TODO: call a gpu kernel to set the value (?)
set_buffer_value
<
T
><<<
1
,
1024
>>>
(
static_cast
<
T
*>
(
mpDeviceBuf
),
x
,
mMemSize
/
sizeof
(
T
));
// TODO: call a gpu kernel to set the value (?)
set_buffer_value
<
T
><<<
1
,
1024
>>>
(
static_cast
<
T
*>
(
mpDeviceBuf
),
x
,
mMemSize
/
sizeof
(
T
));
}
}
~
DeviceMem
()
{
...
...
include/ck_tile/host/host_tensor.hpp
View file @
4947639c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -156,7 +156,7 @@ struct HostTensorDescriptor
}
const
std
::
vector
<
std
::
size_t
>&
get_lengths
()
const
{
return
mLens
;
}
const
std
::
vector
<
std
::
size_t
>&
G
et
S
trides
()
const
{
return
mStrides
;
}
const
std
::
vector
<
std
::
size_t
>&
g
et
_s
trides
()
const
{
return
mStrides
;
}
template
<
typename
...
Is
>
std
::
size_t
GetOffsetFromMultiIndex
(
Is
...
is
)
const
...
...
@@ -188,7 +188,7 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old
for
(
std
::
size_t
i
=
0
;
i
<
a
.
get_num_of_dimension
();
i
++
)
{
new_lengths
[
i
]
=
a
.
get_lengths
()[
new2old
[
i
]];
new_strides
[
i
]
=
a
.
G
et
S
trides
()[
new2old
[
i
]];
new_strides
[
i
]
=
a
.
g
et
_s
trides
()[
new2old
[
i
]];
}
return
HostTensorDescriptor
(
new_lengths
,
new_strides
);
...
...
@@ -327,7 +327,7 @@ struct HostTensor
decltype
(
auto
)
get_lengths
()
const
{
return
mDesc
.
get_lengths
();
}
decltype
(
auto
)
G
et
S
trides
()
const
{
return
mDesc
.
G
et
S
trides
();
}
decltype
(
auto
)
g
et
_s
trides
()
const
{
return
mDesc
.
g
et
_s
trides
();
}
std
::
size_t
get_num_of_dimension
()
const
{
return
mDesc
.
get_num_of_dimension
();
}
...
...
@@ -481,6 +481,34 @@ struct HostTensor
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
}
HostTensor
<
T
>
transpose
(
std
::
vector
<
size_t
>
axes
=
{})
const
{
if
(
axes
.
empty
())
{
axes
.
resize
(
this
->
get_num_of_dimension
());
std
::
iota
(
axes
.
rbegin
(),
axes
.
rend
(),
0
);
}
if
(
axes
.
size
()
!=
mDesc
.
get_num_of_dimension
())
{
throw
std
::
runtime_error
(
"HostTensor::transpose(): size of axes must match tensor dimension"
);
}
std
::
vector
<
size_t
>
tlengths
,
tstrides
;
for
(
const
auto
&
axis
:
axes
)
{
tlengths
.
push_back
(
get_lengths
()[
axis
]);
tstrides
.
push_back
(
get_strides
()[
axis
]);
}
HostTensor
<
T
>
ret
(
*
this
);
ret
.
mDesc
=
HostTensorDescriptor
(
tlengths
,
tstrides
);
return
ret
;
}
HostTensor
<
T
>
transpose
(
std
::
vector
<
size_t
>
axes
=
{})
{
return
const_cast
<
HostTensor
<
T
>
const
*>
(
this
)
->
transpose
(
axes
);
}
typename
Data
::
iterator
begin
()
{
return
mData
.
begin
();
}
typename
Data
::
iterator
end
()
{
return
mData
.
end
();
}
...
...
include/ck_tile/host/kernel_launch.hpp
View file @
4947639c
...
...
@@ -6,6 +6,7 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/timer.hpp"
#include <hip/hip_runtime.h>
#include <cstddef>
...
...
@@ -14,153 +15,92 @@ template <int MaxThreadPerBlock, int MinBlockPerCu, typename Kernel, typename...
#if CK_TILE_USE_LAUNCH_BOUNDS
__launch_bounds__
(
MaxThreadPerBlock
,
MinBlockPerCu
)
#endif
__global__
void
kentry
(
Kernel
f
,
Args
...
args
)
__global__
void
kentry
(
Args
...
args
)
{
f
(
args
...);
Kernel
{}
(
args
...);
}
template
<
typename
...
Args
,
typename
F
>
CK_TILE_HOST
float
launch_and_time_kernel
(
const
stream_config
&
s
,
F
kernel
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
Args
...
args
)
//
// return a anonymous functor(lambda) to be called later
// the KernelImpl should be a class without non-static data member, or let's say
// can be instantiate with "KernelImpl{}"
//
// the "static __device__ operator()(some_arg)" is the entry point of KernelImpl
//
template
<
int
MaxThreadPerBlock
=
CK_TILE_MAX_THREAD_PER_BLOCK
,
int
MinBlockPerCu
=
CK_TILE_MIN_BLOCK_PER_CU
,
typename
KernelImpl
,
typename
...
Args
>
CK_TILE_HOST
auto
make_kernel
(
KernelImpl
/*f*/
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
Args
...
args
)
{
#if CK_TILE_TIME_KERNEL
if
(
s
.
time_kernel_
)
{
// warm up
for
(
int
i
=
0
;
i
<
s
.
cold_niters_
;
++
i
)
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
s
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
}
const
int
nrepeat
=
s
.
nrepeat_
;
hipEvent_t
start
,
stop
;
HIP_CHECK_ERROR
(
hipEventCreate
(
&
start
));
HIP_CHECK_ERROR
(
hipEventCreate
(
&
stop
));
HIP_CHECK_ERROR
(
hipDeviceSynchronize
());
HIP_CHECK_ERROR
(
hipEventRecord
(
start
,
s
.
stream_id_
));
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
s
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
}
HIP_CHECK_ERROR
(
hipEventRecord
(
stop
,
s
.
stream_id_
));
HIP_CHECK_ERROR
(
hipEventSynchronize
(
stop
));
float
total_time
=
0
;
HIP_CHECK_ERROR
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
const
auto
kernel
=
kentry
<
MaxThreadPerBlock
,
MinBlockPerCu
,
KernelImpl
,
Args
...
>
;
return
total_time
/
nrepeat
;
}
else
{
return
[
=
](
const
stream_config
&
s
)
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
s
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
return
0
;
}
#else
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
s
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
return
0
;
#endif
};
}
template
<
typename
...
Args
,
typename
F
,
typename
PreProcessFunc
>
CK_TILE_HOST
float
launch_and_time_kernel_with_preprocess
(
const
stream_config
&
s
,
PreProcessFunc
preprocess
,
F
kernel
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
Args
...
args
)
// clang-format off
/*
* launch_kernel()
*
* this is the function to launch arbitrary number of kernels with optional timer(selected by stream_config)
* the callables should have signature as "operator()(const stream_config& s){ ... }" to call
*
* the simplest way is pass in a lambda function, with "[=](const stream_config& s){ call_your_kernel_here() }"
* as signature, for the callable (pay attention to the capture list)
*
* e.g.
* ck_tile::launch_kernel(s,
* [=](const stream_config& s){ hipMemset(ptr, 0, size) },
* [=](const stream_config& s){ some_kernel<<<grids, blocks>>>(arg); }
* );
*
* if you use ck_tile kernel, or similiar to this style (structure with "static __device__ operator()(...){}")
* you can pass your kernel to ck_tile::make_kernel(), which will create a anonymous functor for you,
* then pass it to ck_tile::launch_kernel()
*
* e.g.
* ck_tile::launch_kernel(s,
* ck_tile::make_kernel<T0, B0>(kernel_0{}, grids0, blocks0, 0, kargs0),
* ck_tile::make_kernel<T0, B1>(kernel_1{}, grids1, blocks1, 0, kargs1),
* ...);
**/
// clang-format on
template
<
typename
...
Callables
>
CK_TILE_HOST
float
launch_kernel
(
const
stream_config
&
s
,
Callables
...
callables
)
{
#if CK_TILE_TIME_KERNEL
if
(
s
.
time_kernel_
)
{
#if CK_TILE_DEBUG_LOG
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
__func__
,
grid_dim
.
x
,
grid_dim
.
y
,
grid_dim
.
z
,
block_dim
.
x
,
block_dim
.
y
,
block_dim
.
z
);
printf
(
"Warm up 1 time
\n
"
);
#endif
// warm up
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
s
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
const
int
nrepeat
=
10
;
#if CK_TILE_DEBUG_LOG
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
#endif
hipEvent_t
start
,
stop
;
HIP_CHECK_ERROR
(
hipEventCreate
(
&
start
));
HIP_CHECK_ERROR
(
hipEventCreate
(
&
stop
));
HIP_CHECK_ERROR
(
hipDeviceSynchronize
());
HIP_CHECK_ERROR
(
hipEventRecord
(
start
,
s
.
stream_id_
));
// clang-format off
if
(
!
s
.
time_kernel_
)
{
(
callables
(
s
),...);
hip_check_error
(
hipGetLastError
());
return
0
;
}
if
(
s
.
is_gpu_timer_
)
{
gpu_timer
timer
{};
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
s
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
}
// warmup
for
(
int
i
=
0
;
i
<
s
.
cold_niters_
;
i
++
)
{
(
callables
(
s
),...);
}
hip_check_error
(
hipGetLastError
());
HIP_CHECK_ERROR
(
hipEventRecord
(
stop
,
s
.
stream_id_
));
HIP_CHECK_ERROR
(
hipEventSynchronize
(
stop
));
timer
.
start
(
s
.
stream_id_
);
for
(
int
i
=
0
;
i
<
s
.
nrepeat_
;
i
++
)
{
(
callables
(
s
),...);
}
hip_check_error
(
hipGetLastError
());
timer
.
stop
(
s
.
stream_id_
);
float
total_time
=
0
;
return
timer
.
duration
()
/
s
.
nrepeat_
;
}
else
{
cpu_timer
timer
{};
HIP_CHECK_ERROR
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
// warmup
for
(
int
i
=
0
;
i
<
s
.
cold_niters_
;
i
++
)
{
(
callables
(
s
),...);
}
hip_check_error
(
hipGetLastError
());
return
total_time
/
nrepeat
;
}
else
{
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
s
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
timer
.
start
(
s
.
stream_id_
);
for
(
int
i
=
0
;
i
<
s
.
nrepeat_
;
i
++
)
{
(
callables
(
s
),...);
}
hip_check_error
(
hipGetLastError
());
timer
.
stop
(
s
.
stream_id_
);
return
0
;
return
timer
.
duration
()
/
s
.
nrepeat_
;
}
#else
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
s
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
return
0
;
#endif
// clang-format on
}
template
<
int
MaxThreadPerBlock
=
CK_TILE_MAX_THREAD_PER_BLOCK
,
int
MinBlockPerCu
=
CK_TILE_MIN_BLOCK_PER_CU
,
typename
KernelImpl
,
typename
...
Args
>
CK_TILE_HOST
float
launch_kernel
(
const
stream_config
&
s
,
KernelImpl
kernel_impl
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
dynamic_smem_byte
,
Args
...
args
)
{
const
auto
kernel
=
kentry
<
MaxThreadPerBlock
,
MinBlockPerCu
,
KernelImpl
,
Args
...
>
;
return
launch_and_time_kernel
(
s
,
kernel
,
grid_dim
,
block_dim
,
dynamic_smem_byte
,
kernel_impl
,
args
...);
}
}
// namespace ck_tile
include/ck_tile/host/reference/reference_batched_dropout.hpp
0 → 100644
View file @
4947639c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace
ck_tile
{
template
<
typename
DataType
,
typename
RandValOutputDataType
>
CK_TILE_HOST
void
reference_batched_dropout
(
HostTensor
<
DataType
>&
in_out_b_m_n
,
const
HostTensor
<
RandValOutputDataType
>&
randval_b_m_n
,
const
uint8_t
&
p_undrop_in_uint8_t
,
const
float
scale
)
{
const
int
N
=
in_out_b_m_n
.
mDesc
.
get_lengths
()[
2
];
auto
f
=
[
&
](
auto
batch
,
auto
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
float
tmp
=
ck_tile
::
type_convert
<
float
>
(
in_out_b_m_n
(
batch
,
m
,
n
))
*
scale
;
in_out_b_m_n
(
batch
,
m
,
n
)
=
randval_b_m_n
(
batch
,
m
,
n
)
<=
p_undrop_in_uint8_t
?
ck_tile
::
type_convert
<
DataType
>
(
tmp
)
:
DataType
(
0
);
}
};
make_ParallelTensorFunctor
(
f
,
randval_b_m_n
.
mDesc
.
get_lengths
()[
0
],
randval_b_m_n
.
mDesc
.
get_lengths
()[
1
])(
std
::
thread
::
hardware_concurrency
());
}
}
// namespace ck_tile
include/ck_tile/host/stream_config.hpp
View file @
4947639c
...
...
@@ -6,6 +6,22 @@
#include <hip/hip_runtime.h>
namespace
ck_tile
{
/*
* construct this structure with behavior as:
*
* // create stream config with default stream(NULL), and not timing the kernel
* stream_config s = stream_config{};
*
* // create stream config with _some_stream_id_, and not timing the kernel
* stream_config s = stream_config{_some_stream_id_};
*
* // create stream config with _some_stream_id_, and benchmark with warmup/repeat as default
* stream_config s = stream_config{_some_stream_id_, true};
*
* // create stream config with _some_stream_id_, and benchmark using cpu timer
* stream_config s = stream_config{_some_stream_id_, true, 0, 3, 10, false};
**/
struct
stream_config
{
hipStream_t
stream_id_
=
nullptr
;
...
...
@@ -13,5 +29,6 @@ struct stream_config
int
log_level_
=
0
;
int
cold_niters_
=
3
;
int
nrepeat_
=
10
;
bool
is_gpu_timer_
=
true
;
// keep compatible
};
}
// namespace ck_tile
include/ck_tile/host/timer.hpp
0 → 100644
View file @
4947639c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <hip/hip_runtime.h>
#include <cstddef>
#include <chrono>
namespace
ck_tile
{
struct
gpu_timer
{
CK_TILE_HOST
gpu_timer
()
{
HIP_CHECK_ERROR
(
hipEventCreate
(
&
start_evt
));
HIP_CHECK_ERROR
(
hipEventCreate
(
&
stop_evt
));
}
CK_TILE_HOST
~
gpu_timer
()
noexcept
(
false
)
{
HIP_CHECK_ERROR
(
hipEventDestroy
(
start_evt
));
HIP_CHECK_ERROR
(
hipEventDestroy
(
stop_evt
));
}
CK_TILE_HOST
void
start
(
const
hipStream_t
&
s
)
{
HIP_CHECK_ERROR
(
hipDeviceSynchronize
());
HIP_CHECK_ERROR
(
hipEventRecord
(
start_evt
,
s
));
}
CK_TILE_HOST
void
stop
(
const
hipStream_t
&
s
)
{
HIP_CHECK_ERROR
(
hipEventRecord
(
stop_evt
,
s
));
HIP_CHECK_ERROR
(
hipEventSynchronize
(
stop_evt
));
}
// return in ms
CK_TILE_HOST
float
duration
()
const
{
float
ms
=
0
;
HIP_CHECK_ERROR
(
hipEventElapsedTime
(
&
ms
,
start_evt
,
stop_evt
));
return
ms
;
}
private:
hipEvent_t
start_evt
,
stop_evt
;
};
struct
cpu_timer
{
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
CK_TILE_HOST
void
start
(
const
hipStream_t
&
)
{
HIP_CHECK_ERROR
(
hipDeviceSynchronize
());
start_tick
=
std
::
chrono
::
high_resolution_clock
::
now
();
}
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
CK_TILE_HOST
void
stop
(
const
hipStream_t
&
)
{
HIP_CHECK_ERROR
(
hipDeviceSynchronize
());
stop_tick
=
std
::
chrono
::
high_resolution_clock
::
now
();
}
// return in ms
CK_TILE_HOST
float
duration
()
const
{
double
sec
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
duration
<
double
>>
(
stop_tick
-
start_tick
)
.
count
();
return
static_cast
<
float
>
(
sec
*
1e3
);
}
private:
std
::
chrono
::
time_point
<
std
::
chrono
::
high_resolution_clock
>
start_tick
;
std
::
chrono
::
time_point
<
std
::
chrono
::
high_resolution_clock
>
stop_tick
;
};
}
// namespace ck_tile
include/ck_tile/ops/fmha.hpp
View file @
4947639c
...
...
@@ -4,10 +4,24 @@
#pragma once
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
...
...
include/ck_tile/ops/fmha/block/block_dropout.hpp
0 → 100644
View file @
4947639c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace
ck_tile
{
struct
BlockDropout
{
CK_TILE_HOST_DEVICE
BlockDropout
(
index_t
i_batch
,
index_t
i_head
,
index_t
nheads
,
unsigned
long
long
seed
,
unsigned
long
long
offset
,
float
rp_undrop_
,
uint8_t
p_undrop_in_uint8_t_
,
bool
is_store_randval_
)
:
ph
(
seed
,
offset
+
(
i_batch
*
nheads
+
i_head
)
*
get_warp_size
()
+
get_lane_id
()),
rp_undrop
(
rp_undrop_
),
p_undrop_in_uint8_t
(
p_undrop_in_uint8_t_
),
is_store_randval
(
is_store_randval_
)
{
}
template
<
typename
BlockGemm
,
bool
IsFwd
=
true
,
typename
RandValDramBlockWindowTmp
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandvalDramWindow
(
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
index_t
seqlen_qk_start
)
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
const
auto
block_origin
=
randval_dram_block_window_tmp
.
get_window_origin
();
auto
randval_dram_window
=
[
&
]()
{
if
constexpr
(
IsFwd
)
{
return
make_tile_window
(
randval_dram_block_window_tmp
.
get_bottom_tensor_view
(),
ck_tile
::
make_tuple
(
number
<
kMPerStep
>
{},
number
<
kNPerStep
>
{}),
{
block_origin
.
at
(
number
<
0
>
{}),
seqlen_qk_start
});
// M/N
}
else
{
return
make_tile_window
(
randval_dram_block_window_tmp
.
get_bottom_tensor_view
(),
ck_tile
::
make_tuple
(
number
<
kMPerStep
>
{},
number
<
kNPerStep
>
{}),
{
seqlen_qk_start
,
block_origin
.
at
(
number
<
1
>
{})});
// M/N
}
}();
return
randval_dram_window
;
}
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandValLdsBlockDescriptor
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kNPerStep
=
WG
::
kN
;
constexpr
index_t
kN1
=
8
;
constexpr
index_t
kN0
=
kNPerStep
/
kN1
;
constexpr
auto
randval_lds_block_desc_0
=
make_naive_tensor_descriptor
(
ck_tile
::
make_tuple
(
number
<
kN0
>
{},
number
<
kMPerStep
>
{},
number
<
kN1
>
{}),
ck_tile
::
make_tuple
(
number
<
(
kMPerStep
+
1
)
*
kN1
>
{},
number
<
kN1
>
{},
number
<
1
>
{}),
number
<
kN1
>
{},
number
<
1
>
{});
constexpr
auto
randval_lds_block_desc
=
transform_tensor_descriptor
(
randval_lds_block_desc_0
,
ck_tile
::
make_tuple
(
make_pass_through_transform
(
number
<
kMPerStep
>
{}),
make_merge_transform
(
ck_tile
::
make_tuple
(
number
<
kN0
>
{},
number
<
kN1
>
{}))),
ck_tile
::
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
ck_tile
::
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
randval_lds_block_desc
;
}
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandValTileDistribution
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
1
;
constexpr
index_t
NIterPerWarp
=
1
;
constexpr
auto
randval_block_outer_part_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
// Use Bwd WarpGemm to ensure that Fwd's random values are consistent with Bwd.
constexpr
auto
randval_block_inner_part_dstr_encoding
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
BlockGemm
::
ADataType
,
half_t
>
&&
std
::
is_same_v
<
typename
BlockGemm
::
BDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
BlockGemm
::
CDataType
,
float
>
)
{
return
typename
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
::
CWarpDstrEncoding
{};
}
else
{
return
typename
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
::
CWarpDstrEncoding
{};
}
}();
constexpr
auto
randval_block_part_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
randval_block_outer_part_dstr_encoding
,
randval_block_inner_part_dstr_encoding
);
return
make_static_tile_distribution
(
randval_block_part_dstr_encode
);
}
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandValLdsShuffleTileDistribution
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
1
;
constexpr
index_t
NIterPerWarp
=
1
;
constexpr
auto
randval_block_outer_part_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
randval_block_part_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
randval_block_outer_part_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
return
make_static_tile_distribution
(
randval_block_part_dstr_encode
);
}
template
<
typename
BlockGemm
,
typename
PComputeDataType
,
typename
RandValOutputDataType
,
typename
PComputeWindow
,
typename
RandValDramWindow
>
CK_TILE_HOST_DEVICE
void
Run
(
void
*
randval_ptr
,
const
index_t
start_n0_idx
,
PComputeWindow
&
p_compute
,
RandValDramWindow
&
randval_dram_window
)
const
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
using
BlockGemmShape
=
remove_cvref_t
<
typename
BlockGemm
::
BlockGemmShape
>
;
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
// randval tile in LDS
auto
randval_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
uint8_t
*>
(
randval_ptr
),
MakeRandValLdsBlockDescriptor
<
BlockGemm
>
());
auto
randval_lds_window
=
make_tile_window
(
randval_lds
,
MakeRandValLdsBlockDescriptor
<
BlockGemm
>
().
get_lengths
(),
{
0
,
0
});
// register distribute
auto
randval_dist_generated
=
make_static_distributed_tensor
<
uint8_t
>
(
MakeRandValTileDistribution
<
BlockGemm
>
());
static_assert
(
randval_dist_generated
.
kThreadElementSpaceSize
==
16
);
auto
randval_lds_read_window
=
make_tile_window
(
randval_lds_window
.
get_bottom_tensor_view
(),
randval_lds_window
.
get_window_lengths
(),
randval_lds_window
.
get_window_origin
(),
MakeRandValLdsShuffleTileDistribution
<
BlockGemm
>
());
const
int
start_m0_idx
=
randval_dram_window
.
get_window_origin
().
at
(
number
<
0
>
{});
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
int
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
(
i_m0
*
MWarp
)
+
get_warp_id
();
int
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
i_n0
;
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
// generate random number
uint8_t
random_uint8_t
[
16
];
ph
.
get_random_16x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
));
constexpr
auto
randval_dist_generated_spans
=
decltype
(
randval_dist_generated
)
::
get_distributed_spans
();
int
i_random_idx
=
0
;
sweep_tile_span
(
randval_dist_generated_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_dist_generated_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
randval_dist_generated
(
i_j_idx
)
=
random_uint8_t
[
i_random_idx
++
];
});
});
// save to LDS
store_tile
(
randval_lds_window
,
randval_dist_generated
);
block_sync_lds
();
// read from LDS to register
auto
randval
=
load_tile
(
randval_lds_read_window
);
constexpr
auto
randval_spans
=
decltype
(
randval
)
::
get_distributed_spans
();
sweep_tile_span
(
randval_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
p_idx0
=
tile_distributed_index
<
i_m0
>
{};
constexpr
auto
p_idx1
=
tile_distributed_index
<
i_n0
,
idx1
.
impl_
.
at
(
1
),
idx1
.
impl_
.
at
(
2
)
>
{};
constexpr
auto
p_idx
=
ck_tile
::
make_tuple
(
p_idx0
,
p_idx1
);
constexpr
auto
r_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
p_compute
(
p_idx
)
=
randval
[
r_idx
]
<=
p_undrop_in_uint8_t
?
p_compute
[
p_idx
]
*
rp_undrop
:
PComputeDataType
(
0
);
});
});
// save to Global
if
(
is_store_randval
)
{
const
auto
randval_store
=
cast_tile
<
RandValOutputDataType
>
(
randval
);
store_tile
(
randval_dram_window
,
randval_store
);
move_tile_window
(
randval_dram_window
,
{
0
,
kNPerStep
});
}
});
if
(
is_store_randval
)
{
move_tile_window
(
randval_dram_window
,
{
kMPerStep
,
-
kNPerBlock
});
}
});
if
(
is_store_randval
)
{
move_tile_window
(
randval_dram_window
,
{
-
kMPerBlock
,
kNPerBlock
});
}
}
template
<
typename
BlockGemm
,
typename
RandValOutputDataType
,
typename
PComputeWindow
,
typename
RandValDramWindow
>
CK_TILE_HOST_DEVICE
void
Run
(
const
index_t
start_m0_idx
,
PComputeWindow
&
p_compute
,
RandValDramWindow
&
randval_dram_window
)
const
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
using
BlockGemmShape
=
remove_cvref_t
<
typename
BlockGemm
::
BlockGemmShape
>
;
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
// register distribute
auto
randval
=
make_static_distributed_tensor
<
uint8_t
>
(
MakeRandValTileDistribution
<
BlockGemm
>
());
static_assert
(
randval
.
kThreadElementSpaceSize
==
16
);
const
int
start_n0_idx
=
randval_dram_window
.
get_window_origin
().
at
(
number
<
1
>
{});
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
int
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
i_m0
;
int
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
(
i_n0
*
NWarp
)
+
get_warp_id
();
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
// generate random number
uint8_t
random_uint8_t
[
16
];
ph
.
get_random_16x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
));
constexpr
auto
randval_spans
=
decltype
(
randval
)
::
get_distributed_spans
();
int
i_random_idx
=
0
;
sweep_tile_span
(
randval_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
r_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
randval
(
r_idx
)
=
random_uint8_t
[
i_random_idx
++
];
constexpr
auto
p_idx0
=
tile_distributed_index
<
i_m0
,
idx0
.
impl_
.
at
(
1
),
idx0
.
impl_
.
at
(
2
)
>
{};
constexpr
auto
p_idx1
=
tile_distributed_index
<
i_n0
>
{};
constexpr
auto
p_idx
=
ck_tile
::
make_tuple
(
p_idx0
,
p_idx1
);
p_compute
(
p_idx
)
=
randval
[
r_idx
]
<=
p_undrop_in_uint8_t
?
p_compute
[
p_idx
]
:
-
p_compute
[
p_idx
];
});
});
// save to Global
if
(
is_store_randval
)
{
const
auto
randval_store
=
cast_tile
<
RandValOutputDataType
>
(
randval
);
store_tile
(
randval_dram_window
,
randval_store
);
move_tile_window
(
randval_dram_window
,
{
kMPerStep
,
0
});
}
});
if
(
is_store_randval
)
{
move_tile_window
(
randval_dram_window
,
{
-
kMPerBlock
,
kNPerStep
});
}
});
if
(
is_store_randval
)
{
move_tile_window
(
randval_dram_window
,
{
kMPerBlock
,
-
kNPerBlock
});
}
}
ck_tile
::
philox
ph
;
const
float
rp_undrop
;
const
uint8_t
p_undrop_in_uint8_t
;
const
bool
is_store_randval
;
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/block/block_masking.hpp
View file @
4947639c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -141,6 +141,36 @@ struct GenericAttentionMask
}
}
// to get the loop length along Y axis, return index:[start, end), end-start=length
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
// TODO: y_end still could be negative, so end-start could be negative(need check)
template
<
index_t
YTile
,
index_t
XTile
>
CK_TILE_HOST_DEVICE
constexpr
auto
GetTileRangeAlongY
(
index_t
i_x
,
number
<
YTile
>
,
number
<
XTile
>
)
const
{
if
constexpr
(
!
IsMasking
)
{
return
ck_tile
::
make_tuple
(
0
,
y_total
);
}
else
{
// get the tile start/end range assum we loop over along Y tile by tile
index_t
y_start
=
[
&
]()
{
index_t
tmp
=
max
(
-
x
+
i_x
+
1
,
0
);
return
(
tmp
/
YTile
)
*
YTile
;
// round to tile aligned
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t
y_end
=
[
&
]()
{
index_t
tmp
=
min
(
i_x
+
XTile
-
1
+
y
,
y_total
);
return
((
tmp
+
YTile
-
1
)
/
YTile
)
*
YTile
;
}();
return
ck_tile
::
make_tuple
(
y_start
,
y_end
);
}
}
// per-pixel check if out-of-bound, if true, need mask a value(like -INF)
CK_TILE_HOST_DEVICE
constexpr
auto
IsOutOfBound
(
index_t
i_y
,
index_t
i_x
)
const
{
...
...
@@ -160,14 +190,14 @@ struct GenericAttentionMask
}
else
{
return
i_x
>=
x_end
;
return
i_x
>=
x_end
||
i_y
>=
y_total
;
}
}
}
// if current tile is at the edge, means need per-pixel mask check.
// otherwise no need to check per-pixel
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX()
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX
/Y
()
// can be used as a fast-path to decide if do per-pixel check or not
template
<
index_t
TileHeight
,
index_t
TileWidth
>
CK_TILE_HOST_DEVICE
constexpr
auto
...
...
@@ -269,6 +299,36 @@ struct SimplifiedGenericAttentionMask
}
}
// to get the loop length along Y axis, return index:[start, end), end-start=length
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
// TODO: y_end still could be negative, so end-start could be negative(need check)
template
<
index_t
YTile
,
index_t
XTile
>
CK_TILE_HOST_DEVICE
constexpr
auto
GetTileRangeAlongY
(
index_t
i_x
,
number
<
YTile
>
,
number
<
XTile
>
)
const
{
if
constexpr
(
!
IsMasking
)
{
return
ck_tile
::
make_tuple
(
0
,
y_total
);
}
else
{
// get the tile start/end range assum we loop over along Y tile by tile
index_t
y_start
=
[
&
]()
{
index_t
tmp
=
max
(
-
x
+
i_x
+
1
,
0
);
return
(
tmp
/
YTile
)
*
YTile
;
// round to tile aligned
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t
y_end
=
[
&
]()
{
index_t
tmp
=
min
(
i_x
+
XTile
-
1
+
y
,
y_total
);
return
((
tmp
+
YTile
-
1
)
/
YTile
)
*
YTile
;
}();
return
ck_tile
::
make_tuple
(
y_start
,
y_end
);
}
}
// per-pixel check if out-of-bound, if true, need mask a value(like -INF)
CK_TILE_HOST_DEVICE
constexpr
auto
IsOutOfBound
(
index_t
i_y
,
index_t
i_x
)
const
{
...
...
@@ -283,13 +343,13 @@ struct SimplifiedGenericAttentionMask
index_t
x_start
=
-
y
+
i_y
+
1
;
// this could be negative, but it's fine
index_t
x_end
=
min
(
i_y
+
x
,
x_total
);
// need min in case x is padded
return
i_x
<
x_start
||
i_x
>=
x_end
;
return
i_x
<
x_start
||
i_x
>=
x_end
||
i_y
>=
y_total
;
}
}
// if current tile is at the edge, means need per-pixel mask check.
// otherwise no need to check per-pixel
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX()
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX
/Y
()
// can be used as a fast-path to decide if do per-pixel check or not
template
<
index_t
TileHeight
,
index_t
TileWidth
>
CK_TILE_HOST_DEVICE
constexpr
auto
...
...
@@ -361,6 +421,6 @@ make_generic_attention_mask_from_lr_window(index_t left_size,
{
auto
r
=
make_generic_attention_mask_coordinates_from_lr_window
(
left_size
,
right_size
,
y_total
,
x_total
,
is_top_left
);
return
MaskType
{
r
.
at
(
ck_tile
::
number
<
0
>
{}),
r
.
at
(
ck_tile
::
number
<
1
>
{}),
y_total
,
x_total
};
return
MaskType
{
r
.
at
(
number
<
0
>
{}),
r
.
at
(
number
<
1
>
{}),
y_total
,
x_total
};
}
}
// namespace ck_tile
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
View file @
4947639c
...
...
@@ -23,13 +23,13 @@ VERTICAL:
[0] 1 2 3 4 5
[0] 1 2 3 4 5
TOP_LEFT:
TOP_LEFT
(but negative)
:
[0] 1 2 3 4 5
1 [0] 1 2 3 4
2 1 [0] 1 2 3
3 2 1 [0] 1 2
FROM_BOTTOM_RIGHT:
FROM_BOTTOM_RIGHT
(but negative)
:
2 1 [0] 1 2 3
3 2 1 [0] 1 2
4 3 2 1 [0] 1
...
...
@@ -54,7 +54,7 @@ struct Alibi
index_t
x_total_
,
AlibiMode
mode_
=
AlibiMode
::
VERTICAL
)
{
slope
=
mode_
==
AlibiMode
::
VERTICAL
?
slope_
:
-
slope
;
slope
=
mode_
==
AlibiMode
::
VERTICAL
?
slope_
:
-
slope
_
;
shift_left_up
=
[
&
]()
{
if
(
RowMajor
)
...
...
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
0 → 100644
View file @
4947639c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string>
#include <type_traits>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
// dV[seqlen_k, hdim_v] = P^T[seqlen_k, seqlen_q] @ dO^T[hdim_v, seqlen_q]
// dP[seqlen_q, seqlen_k] = dO[seqlen_q, hdim_v] @ V[seqlen_k, hdim_v]
// D[seqlen_q] = rowsum(dO[seqlen_q, hdim_v] * O[seqlen_q, hdim_v])
// dS''[seqlen_q, seqlen_k] = P[seqlen_q, seqlen_k] * (dP[seqlen_q, seqlen_k] - D[seqlen_q])
// dBias[seqlen_q, seqlen_k] = dS'[seqlen_q, seqlen_k] = dS''[seqlen_q, seqlen_k]
// dK[seqlen_k, hdim_q] = dS'^T[seqlen_k, seqlen_q] @ Q^T[hdim_q, seqlen_q] * Scale[1]
// dQ[seqlen_q, hdim_q] = dS'[seqlen_q, seqlen_k] @ K^T[hdim_q, seqlen_k] * Scale[1]
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
KGradEpiloguePipeline_
,
typename
VGradEpiloguePipeline_
>
struct
FmhaBwdDQDKDVKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
ck_tile
::
remove_cvref_t
<
FmhaPipeline_
>
;
using
KGradEpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
KGradEpiloguePipeline_
>
;
using
VGradEpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
VGradEpiloguePipeline_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaPipeline
::
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kBlockPerCu
=
FmhaPipeline
::
kBlockPerCu
;
using
QDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
QDataType
>
;
using
KDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
KDataType
>
;
using
VDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
VDataType
>
;
using
BiasDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
BiasDataType
>
;
using
GemmDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
GemmDataType
>
;
using
LSEDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
LSEDataType
>
;
using
AccDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
AccDataType
>
;
using
DDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
DDataType
>
;
using
RandValOutputDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
RandValOutputDataType
>
;
using
OGradDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
OGradDataType
>
;
using
QGradDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
QGradDataType
>
;
using
KGradDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
KGradDataType
>
;
using
VGradDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
VGradDataType
>
;
using
BiasGradDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
BiasGradDataType
>
;
static
constexpr
bool
kIsGroupMode
=
FmhaPipeline
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
FmhaPipeline
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
FmhaPipeline
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
FmhaPipeline
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
FmhaPipeline
::
kHasBiasGrad
;
static
constexpr
bool
kHasDropout
=
FmhaPipeline
::
kHasDropout
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
ck_tile
::
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
// clang-format on
CK_TILE_HOST
static
std
::
string
GetName
()
{
// sync with generate.py
// clang-format off
using
bfs
=
typename
FmhaPipeline
::
BlockFmhaShape
;
using
gbr
=
typename
bfs
::
Gemm0BlockWarps
;
using
gwt
=
typename
bfs
::
Gemm0WarpTile
;
#define _SS_ std::string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
std
::
string
n
;
if
(
kPadSeqLenQ
)
n
+=
"s"
;
if
(
kPadSeqLenK
)
n
+=
"sk"
;
if
(
kPadHeadDimQ
)
n
+=
"d"
;
if
(
kPadHeadDimV
)
n
+=
"dv"
;
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
_SS_
(
"fmha_bwd_d"
)
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
+
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"x"
+
_TS_
(
bfs
::
kVHeaddim
)
+
"_"
+
"r"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
)
+
_SS_
(
FmhaPipeline
::
name
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasBiasGrad
?
"_dbias"
:
""
)
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
);
#undef _SS_
#undef _TS_
// clang-format on
}
template
<
ck_tile
::
index_t
I
>
// to avoid duplicated base class prblem, introduce an template
// arg
struct
FmhaBwdEmptyKargs
{
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct
FmhaBwdCommonKargs
{
const
void
*
q_ptr
;
const
void
*
k_ptr
;
const
void
*
v_ptr
;
const
void
*
lse_ptr
;
const
void
*
do_ptr
;
const
void
*
d_ptr
;
void
*
dq_ptr
;
void
*
dk_ptr
;
void
*
dv_ptr
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_k
;
ck_tile
::
index_t
hdim_q
;
ck_tile
::
index_t
hdim_v
;
// for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
// if this param is larger than 1, indicate MQA/GQA case
ck_tile
::
index_t
num_head_q
;
ck_tile
::
index_t
nhead_ratio_qk
;
float
raw_scale
;
#if CK_TILE_FMHA_FWD_FAST_EXP2
float
scale
;
#endif
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_do
;
ck_tile
::
index_t
stride_dk
;
ck_tile
::
index_t
stride_dv
;
ck_tile
::
index_t
nhead_stride_q
;
ck_tile
::
index_t
nhead_stride_k
;
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_do
;
ck_tile
::
index_t
nhead_stride_lsed
;
ck_tile
::
index_t
batch_stride_lsed
;
};
struct
FmhaBwdCommonBiasKargs
{
const
void
*
bias_ptr
=
nullptr
;
ck_tile
::
index_t
stride_bias
=
0
;
ck_tile
::
index_t
nhead_stride_bias
=
0
;
};
struct
FmhaBwdBatchModeBiasKargs
:
FmhaBwdCommonBiasKargs
{
ck_tile
::
index_t
batch_stride_bias
=
0
;
};
struct
FmhaBwdAlibiKargs
{
// alibi is batch*nhead*1, no matter in batch/group mode, they are the same
const
void
*
alibi_slope_ptr
;
ck_tile
::
index_t
alibi_slope_stride
;
// stride in batch, or 0 for all batch share same slope
};
struct
FmhaBwdCommonBiasGradKargs
{
void
*
dbias_ptr
=
nullptr
;
ck_tile
::
index_t
stride_dbias
=
0
;
ck_tile
::
index_t
nhead_stride_dbias
=
0
;
};
struct
FmhaBwdBatchModeBiasGradKargs
:
FmhaBwdCommonBiasGradKargs
{
ck_tile
::
index_t
batch_stride_dbias
=
0
;
};
struct
FmhaBwdMaskKargs
{
ck_tile
::
index_t
window_size_left
,
window_size_right
;
ck_tile
::
GenericAttentionMaskEnum
mask_type
;
};
struct
FmhaBwdCommonDropoutKargs
{
void
init_dropout
(
const
float
p_drop
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
,
const
float
raw_scale
)
{
float
p_undrop
=
1.0
-
p_drop
;
p_undrop_in_uint8_t
=
uint8_t
(
std
::
floor
(
p_undrop
*
std
::
numeric_limits
<
uint8_t
>::
max
()));
rp_undrop
=
1.0
/
p_undrop
;
scale_rp_undrop
=
rp_undrop
*
raw_scale
;
drop_seed
=
std
::
get
<
0
>
(
drop_seed_offset
);
drop_offset
=
std
::
get
<
1
>
(
drop_seed_offset
);
}
float
rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
bool
is_store_randval
=
false
;
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
ck_tile
::
index_t
stride_randval
=
0
;
ck_tile
::
index_t
nhead_stride_randval
=
0
;
};
struct
FmhaBwdBatchModeDropoutKargs
:
FmhaBwdCommonDropoutKargs
{
ck_tile
::
index_t
batch_stride_randval
=
0
;
};
struct
FmhaBwdBatchModeKargs
:
FmhaBwdCommonKargs
,
std
::
conditional_t
<
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
,
FmhaBwdBatchModeBiasKargs
,
std
::
conditional_t
<
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
,
FmhaBwdAlibiKargs
,
FmhaBwdEmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasBiasGrad
,
FmhaBwdBatchModeBiasGradKargs
,
FmhaBwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kHasMask
,
FmhaBwdMaskKargs
,
FmhaBwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kHasDropout
,
FmhaBwdBatchModeDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
{
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_dk
;
ck_tile
::
index_t
batch_stride_dv
;
};
struct
FmhaBwdGroupModeKargs
:
FmhaBwdCommonKargs
,
std
::
conditional_t
<
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
,
FmhaBwdCommonBiasKargs
,
std
::
conditional_t
<
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
,
FmhaBwdAlibiKargs
,
FmhaBwdEmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasBiasGrad
,
FmhaBwdCommonBiasGradKargs
,
FmhaBwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kHasMask
,
FmhaBwdMaskKargs
,
FmhaBwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kHasDropout
,
FmhaBwdCommonDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
{
const
int32_t
*
seqstart_q_ptr
;
const
int32_t
*
seqstart_k_ptr
;
const
int32_t
*
seqlen_k_ptr
;
};
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
FmhaBwdGroupModeKargs
,
FmhaBwdBatchModeKargs
>
;
template
<
bool
Cond
=
!
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
q_ptr
,
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
const
void
*
lse_ptr
,
const
void
*
do_ptr
,
const
void
*
d_ptr
,
void
*
rand_val_ptr
,
void
*
dq_ptr
,
void
*
dk_ptr
,
void
*
dv_ptr
,
void
*
dbias_ptr
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_k
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_head_q
,
ck_tile
::
index_t
nhead_ratio_qk
,
float
scale
,
ck_tile
::
index_t
stride_q
,
ck_tile
::
index_t
stride_k
,
ck_tile
::
index_t
stride_v
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_dk
,
ck_tile
::
index_t
stride_dv
,
ck_tile
::
index_t
stride_dbias
,
ck_tile
::
index_t
nhead_stride_q
,
ck_tile
::
index_t
nhead_stride_k
,
ck_tile
::
index_t
nhead_stride_v
,
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
batch_stride_q
,
ck_tile
::
index_t
batch_stride_k
,
ck_tile
::
index_t
batch_stride_v
,
ck_tile
::
index_t
batch_stride_bias
,
ck_tile
::
index_t
batch_stride_randval
,
ck_tile
::
index_t
batch_stride_do
,
ck_tile
::
index_t
batch_stride_lsed
,
ck_tile
::
index_t
batch_stride_dk
,
ck_tile
::
index_t
batch_stride_dv
,
ck_tile
::
index_t
batch_stride_dbias
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
k_ptr
,
v_ptr
,
lse_ptr
,
do_ptr
,
d_ptr
,
dq_ptr
,
dk_ptr
,
dv_ptr
,
seqlen_q
,
seqlen_k
,
hdim_q
,
hdim_v
,
num_head_q
,
nhead_ratio_qk
,
scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast
<
float
>
(
scale
*
ck_tile
::
log2e_v
<>
),
#endif
stride_q
,
stride_k
,
stride_v
,
stride_do
,
stride_dk
,
stride_dv
,
nhead_stride_q
,
nhead_stride_k
,
nhead_stride_v
,
nhead_stride_do
,
nhead_stride_lsed
,
batch_stride_lsed
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for dbias
{},
// placeholder for mask
{},
// placeholder for dropout
batch_stride_q
,
batch_stride_k
,
batch_stride_v
,
batch_stride_do
,
batch_stride_dk
,
batch_stride_dv
};
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
kargs
.
bias_ptr
=
bias_ptr
;
kargs
.
stride_bias
=
stride_bias
;
kargs
.
nhead_stride_bias
=
nhead_stride_bias
;
kargs
.
batch_stride_bias
=
batch_stride_bias
;
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
kargs
.
alibi_slope_ptr
=
bias_ptr
;
kargs
.
alibi_slope_stride
=
stride_bias
;
}
if
constexpr
(
kHasBiasGrad
)
{
kargs
.
dbias_ptr
=
dbias_ptr
;
kargs
.
stride_dbias
=
stride_dbias
;
kargs
.
nhead_stride_dbias
=
nhead_stride_dbias
;
kargs
.
batch_stride_dbias
=
batch_stride_dbias
;
}
if
constexpr
(
kHasMask
)
{
kargs
.
window_size_left
=
window_size_left
;
kargs
.
window_size_right
=
window_size_right
;
kargs
.
mask_type
=
static_cast
<
ck_tile
::
GenericAttentionMaskEnum
>
(
mask_type
);
}
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
}
return
kargs
;
}
template
<
bool
Cond
=
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
q_ptr
,
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
const
void
*
lse_ptr
,
const
void
*
do_ptr
,
const
void
*
d_ptr
,
void
*
rand_val_ptr
,
void
*
dq_ptr
,
void
*
dk_ptr
,
void
*
dv_ptr
,
void
*
dbias_ptr
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_k_ptr
,
const
void
*
seqlen_k_ptr
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_head_q
,
ck_tile
::
index_t
nhead_ratio_qk
,
float
scale
,
ck_tile
::
index_t
stride_q
,
ck_tile
::
index_t
stride_k
,
ck_tile
::
index_t
stride_v
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_dk
,
ck_tile
::
index_t
stride_dv
,
ck_tile
::
index_t
stride_dbias
,
ck_tile
::
index_t
nhead_stride_q
,
ck_tile
::
index_t
nhead_stride_k
,
ck_tile
::
index_t
nhead_stride_v
,
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
batch_stride_lsed
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
k_ptr
,
v_ptr
,
lse_ptr
,
do_ptr
,
d_ptr
,
dq_ptr
,
dk_ptr
,
dv_ptr
,
-
1
,
// seqlen will be updated by another pointer
-
1
,
//
hdim_q
,
hdim_v
,
num_head_q
,
nhead_ratio_qk
,
scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast
<
float
>
(
scale
*
ck_tile
::
log2e_v
<>
),
#endif
stride_q
,
stride_k
,
stride_v
,
stride_do
,
stride_dk
,
stride_dv
,
nhead_stride_q
,
nhead_stride_k
,
nhead_stride_v
,
nhead_stride_do
,
nhead_stride_lsed
,
batch_stride_lsed
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for dbias
{},
// placeholder for mask
{},
// placeholder for dropout
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
)};
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
kargs
.
bias_ptr
=
bias_ptr
;
kargs
.
stride_bias
=
stride_bias
;
kargs
.
nhead_stride_bias
=
nhead_stride_bias
;
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
kargs
.
alibi_slope_ptr
=
bias_ptr
;
kargs
.
alibi_slope_stride
=
stride_bias
;
}
if
constexpr
(
kHasBiasGrad
)
{
kargs
.
dbias_ptr
=
dbias_ptr
;
kargs
.
stride_dbias
=
stride_dbias
;
kargs
.
nhead_stride_dbias
=
nhead_stride_dbias
;
}
if
constexpr
(
kHasMask
)
{
kargs
.
window_size_left
=
window_size_left
;
kargs
.
window_size_right
=
window_size_right
;
kargs
.
mask_type
=
static_cast
<
ck_tile
::
GenericAttentionMaskEnum
>
(
mask_type
);
}
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
}
return
kargs
;
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_k_
)
{
return
TilePartitioner
::
GridSize
(
batch_size_
,
nhead_
,
seqlen_k_
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
ck_tile
::
max
(
FmhaPipeline
::
GetSmemSize
(),
KGradEpiloguePipeline
::
GetSmemSize
(),
VGradEpiloguePipeline
::
GetSmemSize
());
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
// divide problem
const
auto
[
i_tile_n
,
i_nhead
,
i_batch
]
=
TilePartitioner
{}(
kargs
.
seqlen_k
);
const
index_t
i_n0
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN0
);
long_index_t
batch_offset_q
=
0
;
long_index_t
batch_offset_k
=
0
;
long_index_t
batch_offset_v
=
0
;
long_index_t
batch_offset_bias
=
0
;
long_index_t
batch_offset_randval
=
0
;
long_index_t
batch_offset_do
=
0
;
long_index_t
batch_offset_lsed
=
0
;
long_index_t
batch_offset_dk
=
0
;
long_index_t
batch_offset_dv
=
0
;
long_index_t
batch_offset_dbias
=
0
;
if
constexpr
(
kIsGroupMode
)
{
// get starting offset for each batch
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
const
long_index_t
key_start
=
kargs
.
seqstart_k_ptr
[
i_batch
];
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
batch_offset_do
=
query_start
*
kargs
.
stride_do
;
batch_offset_lsed
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lsed
;
batch_offset_dk
=
key_start
*
kargs
.
stride_dk
;
batch_offset_dv
=
key_start
*
kargs
.
stride_dv
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
;
}
if
constexpr
(
kHasBiasGrad
)
{
batch_offset_dbias
=
query_start
*
kargs
.
stride_dbias
;
}
else
{
batch_offset_dbias
=
key_start
;
}
if
constexpr
(
kHasDropout
)
{
batch_offset_randval
=
query_start
*
kargs
.
stride_randval
;
}
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
kargs
.
seqlen_q
=
adjusted_seqstart_q_ptr
[
1
]
-
adjusted_seqstart_q_ptr
[
0
];
if
(
kargs
.
seqlen_k_ptr
!=
nullptr
)
{
kargs
.
seqlen_k
=
kargs
.
seqlen_k_ptr
[
i_batch
];
}
else
{
const
auto
adjusted_seqstart_k_ptr
=
kargs
.
seqstart_k_ptr
+
i_batch
;
kargs
.
seqlen_k
=
adjusted_seqstart_k_ptr
[
1
]
-
adjusted_seqstart_k_ptr
[
0
];
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if
(
kargs
.
seqlen_k
<=
i_n0
)
{
return
;
}
}
else
{
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_do
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_do
;
batch_offset_lsed
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lsed
;
batch_offset_dk
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dk
;
batch_offset_dv
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dv
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
batch_offset_bias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_bias
;
}
if
constexpr
(
kHasBiasGrad
)
{
batch_offset_dbias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dbias
;
}
if
constexpr
(
kHasDropout
)
{
batch_offset_randval
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_randval
;
}
}
// for simplicity, batch stride we just modify the pointer
const
QDataType
*
q_ptr
=
reinterpret_cast
<
const
QDataType
*>
(
kargs
.
q_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_q
+
batch_offset_q
;
const
KDataType
*
k_ptr
=
reinterpret_cast
<
const
KDataType
*>
(
kargs
.
k_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_k
+
batch_offset_k
;
const
VDataType
*
v_ptr
=
reinterpret_cast
<
const
VDataType
*>
(
kargs
.
v_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_v
+
batch_offset_v
;
const
LSEDataType
*
lse_ptr
=
reinterpret_cast
<
const
LSEDataType
*>
(
kargs
.
lse_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_lsed
+
batch_offset_lsed
;
const
DDataType
*
d_ptr
=
reinterpret_cast
<
const
DDataType
*>
(
kargs
.
d_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_lsed
+
batch_offset_lsed
;
const
OGradDataType
*
do_ptr
=
reinterpret_cast
<
const
OGradDataType
*>
(
kargs
.
do_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_do
+
batch_offset_do
;
QGradDataType
*
dq_ptr
=
reinterpret_cast
<
QGradDataType
*>
(
kargs
.
dq_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_q
+
batch_offset_q
;
KGradDataType
*
dk_ptr
=
reinterpret_cast
<
KGradDataType
*>
(
kargs
.
dk_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_k
+
batch_offset_dk
;
VGradDataType
*
dv_ptr
=
reinterpret_cast
<
VGradDataType
*>
(
kargs
.
dv_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_v
+
batch_offset_dv
;
// Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window
const
auto
q_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
q_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQ
>
{},
number
<
1
>
{});
const
auto
q_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
{
return
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
else
{
return
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
}();
const
auto
qt_dram_naive
=
transform_tensor_view
(
q_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_q
),
make_pass_through_transform
(
kargs
.
seqlen_q
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
qt_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQTLoadOnce
)
{
return
pad_tensor_view
(
qt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenQ
>
{});
}
else
{
return
pad_tensor_view
(
qt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK3
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenQ
>
{});
}
}();
const
auto
k_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
k_ptr
,
make_tuple
(
kargs
.
seqlen_k
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_k
,
1
),
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
1
>
{});
const
auto
k_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKLoadOnce
)
{
return
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}
else
{
return
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}
}();
const
auto
kt_dram_naive
=
transform_tensor_view
(
k_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_q
),
make_pass_through_transform
(
kargs
.
seqlen_k
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
kt_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKTLoadOnce
)
{
return
pad_tensor_view
(
kt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kN0
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenK
>
{});
}
else
{
return
pad_tensor_view
(
kt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK4
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenK
>
{});
}
}();
const
auto
v_dram
=
[
&
]()
{
const
auto
v_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
v_ptr
,
make_tuple
(
kargs
.
seqlen_k
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
stride_v
,
1
),
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
1
>
{});
if
constexpr
(
FmhaPipeline
::
kVLoadOnce
)
{
return
pad_tensor_view
(
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimV
>
{});
}
else
{
return
pad_tensor_view
(
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK2
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimV
>
{});
}
}();
const
auto
lse_dram
=
[
&
]()
{
const
auto
lse_dram_naive
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
lse_ptr
,
make_tuple
(
kargs
.
seqlen_q
),
number
<
1
>
{});
return
pad_tensor_view
(
lse_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{}),
sequence
<
kPadSeqLenQ
>
{});
}();
const
auto
d_dram
=
[
&
]()
{
const
auto
d_dram_naive
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
d_ptr
,
make_tuple
(
kargs
.
seqlen_q
),
number
<
1
>
{});
return
pad_tensor_view
(
d_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{}),
sequence
<
kPadSeqLenQ
>
{});
}();
const
auto
do_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
do_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
stride_do
,
1
),
number
<
FmhaPipeline
::
kAlignmentOGrad
>
{},
number
<
1
>
{});
const
auto
do_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradLoadOnce
)
{
return
pad_tensor_view
(
do_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
}
else
{
return
pad_tensor_view
(
do_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK2
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
}
}();
const
auto
dot_dram_naive
=
transform_tensor_view
(
do_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_v
),
make_pass_through_transform
(
kargs
.
seqlen_q
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
dot_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradTLoadOnce
)
{
return
pad_tensor_view
(
dot_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenQ
>
{});
}
else
{
return
pad_tensor_view
(
dot_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenQ
>
{});
}
}();
auto
dq_dram
=
[
&
]()
{
const
auto
dq_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
,
memory_operation_enum
::
atomic_add
>
(
dq_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
auto
q_dram_window
=
make_tile_window
(
q_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
{
0
,
0
});
auto
qt_dram_window
=
make_tile_window
(
qt_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQTLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK3
>
{});
}(),
{
0
,
0
});
auto
k_dram_window
=
make_tile_window
(
k_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
{
i_n0
,
0
});
auto
kt_dram_window
=
make_tile_window
(
kt_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKTLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK4
>
{});
}(),
{
0
,
i_n0
});
auto
v_dram_window
=
make_tile_window
(
v_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kVLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK2
>
{});
}(),
{
i_n0
,
0
});
auto
do_dram_window
=
make_tile_window
(
do_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK2
>
{});
}(),
{
0
,
0
});
auto
dot_dram_window
=
make_tile_window
(
dot_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradTLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kK1
>
{});
}(),
{
0
,
0
});
auto
dq_dram_window
=
make_tile_window
(
dq_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
0
,
0
});
auto
lse_dram_window
=
make_tile_window
(
lse_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{}),
{
0
});
auto
d_dram_window
=
make_tile_window
(
d_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{}),
{
0
});
/// FIXME: Before C++20, capturing structured binding variables are not supported. Remove
/// following copy capture of the 'i_nhead' if in C++20
constexpr
auto
bias_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
const
auto
bias_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
const
BiasDataType
*
bias_ptr
=
reinterpret_cast
<
const
BiasDataType
*>
(
kargs
.
bias_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_bias
+
batch_offset_bias
;
const
auto
bias_dram
=
[
&
]()
{
const
auto
bias_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
bias_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
seqlen_k
),
make_tuple
(
kargs
.
stride_bias
,
1
),
number
<
FmhaPipeline
::
kAlignmentBias
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
bias_dram_naive
,
bias_dram_window_lengths
,
sequence
<
kPadSeqLenQ
,
kPadSeqLenK
>
{});
}();
return
make_tile_window
(
bias_dram
,
bias_dram_window_lengths
,
{
0
,
i_n0
});
}
else
{
return
make_null_tile_window
(
bias_dram_window_lengths
);
}
}();
auto
dbias_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
kHasBiasGrad
)
{
BiasGradDataType
*
dbias_ptr
=
reinterpret_cast
<
BiasGradDataType
*>
(
kargs
.
dbias_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_dbias
+
batch_offset_dbias
;
auto
dbias_dram
=
[
&
]()
{
const
auto
dbias_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dbias_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
seqlen_k
),
make_tuple
(
kargs
.
stride_dbias
,
1
),
number
<
FmhaPipeline
::
kAlignmentBias
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dbias_dram_naive
,
bias_dram_window_lengths
,
sequence
<
kPadSeqLenQ
,
kPadSeqLenK
>
{});
}();
return
make_tile_window
(
dbias_dram
,
bias_dram_window_lengths
,
{
0
,
i_n0
});
}
else
{
return
make_null_tile_window
(
bias_dram_window_lengths
);
}
}();
// WA i_batch capture structure binding before c++20
auto
position_encoding
=
[
&
,
i_batch_
=
i_batch
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
// data loading, shared by entire wg
// TODO: how to use s_read?
AccDataType
slope
=
*
(
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
alibi_slope_ptr
)
+
i_batch_
*
kargs
.
alibi_slope_stride
+
i_nhead_
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
slope
*=
ck_tile
::
log2e_v
<>
;
#endif
if
constexpr
(
kHasMask
)
{
return
make_alibi_from_lr_mask
<
AccDataType
,
false
>
(
slope
,
kargs
.
window_size_left
,
kargs
.
window_size_right
,
kargs
.
seqlen_q
,
kargs
.
seqlen_k
,
kargs
.
mask_type
);
}
else
{
return
Alibi
<
AccDataType
,
false
>
{
slope
,
kargs
.
seqlen_q
,
kargs
.
seqlen_k
,
AlibiMode
::
FROM_BOTTOM_RIGHT
};
}
}
else
{
return
EmptyPositionEncoding
<
AccDataType
>
{};
}
}();
// dropout
float
rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint64_t
drop_seed
=
0
;
uint64_t
drop_offset
=
0
;
bool
is_store_randval
=
false
;
if
constexpr
(
kHasDropout
)
{
rp_undrop
=
kargs
.
rp_undrop
;
scale_rp_undrop
=
kargs
.
scale_rp_undrop
;
p_undrop_in_uint8_t
=
kargs
.
p_undrop_in_uint8_t
;
drop_seed
=
kargs
.
drop_seed
;
drop_offset
=
kargs
.
drop_offset
;
is_store_randval
=
kargs
.
is_store_randval
;
}
BlockDropout
dropout
(
i_batch
,
i_nhead
,
kargs
.
num_head_q
,
drop_seed
,
drop_offset
,
rp_undrop
,
p_undrop_in_uint8_t
,
is_store_randval
);
auto
randval_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
randval_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
if
constexpr
(
kHasDropout
)
{
RandValOutputDataType
*
rand_val_ptr
=
reinterpret_cast
<
RandValOutputDataType
*>
(
kargs
.
rand_val_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_randval
+
batch_offset_randval
;
const
auto
randval_dram
=
[
&
]()
{
const
auto
randval_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
rand_val_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
seqlen_k
),
make_tuple
(
kargs
.
stride_randval
,
1
),
number
<
1
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
randval_dram_naive
,
randval_dram_window_lengths
,
sequence
<
kPadSeqLenQ
,
kPadSeqLenK
>
{});
}();
return
make_tile_window
(
randval_dram
,
randval_dram_window_lengths
,
{
0
,
i_n0
});
}
else
{
return
make_null_tile_window
(
randval_dram_window_lengths
);
}
}();
FmhaMask
mask
=
[
&
]()
{
if
constexpr
(
kHasMask
)
return
ck_tile
::
make_generic_attention_mask_from_lr_window
<
FmhaMask
>
(
kargs
.
window_size_left
,
kargs
.
window_size_right
,
kargs
.
seqlen_q
,
kargs
.
seqlen_k
,
kargs
.
mask_type
==
GenericAttentionMaskEnum
::
MASK_FROM_TOP_LEFT
);
else
return
FmhaMask
{
kargs
.
seqlen_q
,
kargs
.
seqlen_k
};
}();
auto
[
dk_acc_tile
,
dv_acc_tile
]
=
FmhaPipeline
{}(
q_dram_window
,
qt_dram_window
,
k_dram_window
,
kt_dram_window
,
v_dram_window
,
bias_dram_window
,
randval_dram_window
,
do_dram_window
,
dot_dram_window
,
lse_dram_window
,
d_dram_window
,
dq_dram_window
,
dbias_dram_window
,
mask
,
position_encoding
,
kargs
.
raw_scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
kargs
.
scale
,
#endif
rp_undrop
,
scale_rp_undrop
,
smem_ptr
,
dropout
);
auto
dk_dram
=
[
&
]()
{
const
auto
dk_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dk_ptr
,
make_tuple
(
kargs
.
seqlen_k
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_dk
,
1
),
number
<
FmhaPipeline
::
kAlignmentKGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dk_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}();
auto
dv_dram
=
[
&
]()
{
const
auto
dv_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dv_ptr
,
make_tuple
(
kargs
.
seqlen_k
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
stride_dv
,
1
),
number
<
FmhaPipeline
::
kAlignmentVGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dv_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimV
>
{});
}();
auto
dk_dram_window
=
make_tile_window
(
dk_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
i_n0
,
0
});
auto
dv_dram_window
=
make_tile_window
(
dv_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
{
i_n0
,
0
});
KGradEpiloguePipeline
{}(
dk_dram_window
,
dk_acc_tile
);
VGradEpiloguePipeline
{}(
dv_dram_window
,
dv_acc_tile
);
}
};
template
<
typename
TilePartitioner_
,
typename
FmhaBwdOGradDotO_
>
struct
FmhaBwdOGradDotOKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaBwdOGradDotO
=
ck_tile
::
remove_cvref_t
<
FmhaBwdOGradDotO_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaBwdOGradDotO
::
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kBlockPerCu
=
FmhaBwdOGradDotO
::
kBlockPerCu
;
static
constexpr
ck_tile
::
index_t
kM0
=
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kVHeaddim
=
FmhaBwdOGradDotO
::
kVHeaddim
;
using
DDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaBwdOGradDotO
::
DDataType
>
;
using
ODataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaBwdOGradDotO
::
ODataType
>
;
using
OGradDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaBwdOGradDotO
::
OGradDataType
>
;
static
constexpr
bool
kIsGroupMode
=
FmhaBwdOGradDotO
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
FmhaBwdOGradDotO
::
kPadSeqLenQ
;
static
constexpr
bool
kPadHeadDimV
=
FmhaBwdOGradDotO
::
kPadHeadDimV
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
ck_tile
::
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
// clang-format on
CK_TILE_HOST
static
std
::
string
GetName
()
{
// sync with generate.py
// clang-format off
#define _SS_ std::string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
std
::
string
n
;
if
(
kPadSeqLenQ
)
n
+=
"s"
;
if
(
kPadHeadDimV
)
n
+=
"dv"
;
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
_SS_
(
"fmha_bwd_dot_do_o_d"
)
+
_TS_
(
kVHeaddim
)
+
"_"
+
_SS_
(
t2s
<
ODataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
+
(
"o"
+
_TS_
(
kBlockPerCu
))
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
);
#undef _SS_
#undef _TS_
// clang-format on
}
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct
FmhaBwdOGradDotOCommonKargs
{
const
void
*
o_ptr
;
const
void
*
do_ptr
;
void
*
d_ptr
;
float
p_undrop
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
stride_do
;
ck_tile
::
index_t
stride_o
;
ck_tile
::
index_t
nhead_stride_do
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
nhead_stride_d
;
ck_tile
::
index_t
batch_stride_d
;
};
struct
FmhaBwdOGradDotOBatchModeKargs
:
FmhaBwdOGradDotOCommonKargs
{
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_o
;
};
struct
FmhaBwdOGradDotOGroupModeKargs
:
FmhaBwdOGradDotOCommonKargs
{
const
int32_t
*
seqstart_q_ptr
;
};
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
FmhaBwdOGradDotOGroupModeKargs
,
FmhaBwdOGradDotOBatchModeKargs
>
;
template
<
bool
Cond
=
!
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
o_ptr
,
const
void
*
do_ptr
,
void
*
d_ptr
,
float
p_undrop
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_o
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
nhead_stride_d
,
ck_tile
::
index_t
batch_stride_do
,
ck_tile
::
index_t
batch_stride_o
,
ck_tile
::
index_t
batch_stride_d
)
{
Kargs
kargs
{{
o_ptr
,
do_ptr
,
d_ptr
,
p_undrop
,
seqlen_q
,
hdim_v
,
stride_do
,
stride_o
,
nhead_stride_do
,
nhead_stride_o
,
nhead_stride_d
,
batch_stride_d
},
batch_stride_do
,
batch_stride_o
};
return
kargs
;
}
template
<
bool
Cond
=
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
o_ptr
,
const
void
*
do_ptr
,
void
*
d_ptr
,
float
p_undrop
,
const
void
*
seqstart_q_ptr
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_o
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
nhead_stride_d
,
ck_tile
::
index_t
batch_stride_d
)
{
Kargs
kargs
{{
o_ptr
,
do_ptr
,
d_ptr
,
p_undrop
,
-
1
,
// seqlen will be updated by another pointer
hdim_v
,
stride_do
,
stride_o
,
nhead_stride_do
,
nhead_stride_o
,
nhead_stride_d
,
batch_stride_d
},
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
)};
return
kargs
;
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
)
{
return
TilePartitioner
::
GridSize
(
batch_size_
,
nhead_
,
seqlen_q_
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
0
;
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// divide problem
const
auto
[
i_tile_m
,
i_nhead
,
i_batch
]
=
TilePartitioner
{}(
kargs
.
seqlen_q
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
kM0
);
long_index_t
batch_offset_o
=
0
;
long_index_t
batch_offset_do
=
0
;
long_index_t
batch_offset_d
=
0
;
if
constexpr
(
kIsGroupMode
)
{
// get starting offset for each batch
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
batch_offset_o
=
query_start
*
kargs
.
stride_o
;
batch_offset_do
=
query_start
*
kargs
.
stride_do
;
batch_offset_d
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_d
;
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
kargs
.
seqlen_q
=
adjusted_seqstart_q_ptr
[
1
]
-
adjusted_seqstart_q_ptr
[
0
];
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if
(
kargs
.
seqlen_q
<=
i_m0
)
{
return
;
}
}
else
{
batch_offset_o
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o
;
batch_offset_do
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_do
;
batch_offset_d
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_d
;
}
// for simplicity, batch stride we just modify the pointer
const
ODataType
*
o_ptr
=
reinterpret_cast
<
const
ODataType
*>
(
kargs
.
o_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_o
+
batch_offset_o
;
const
OGradDataType
*
do_ptr
=
reinterpret_cast
<
const
OGradDataType
*>
(
kargs
.
do_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_do
+
batch_offset_do
;
DDataType
*
d_ptr
=
reinterpret_cast
<
DDataType
*>
(
kargs
.
d_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_d
+
batch_offset_d
;
// O/dO/D DRAM and DRAM window
const
auto
o_dram
=
[
&
]()
{
auto
o_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
o_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
stride_o
,
1
),
number
<
FmhaBwdOGradDotO
::
kAlignmentO
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
o_dram_naive
,
make_tuple
(
number
<
kM0
>
{},
number
<
kVHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
}();
const
auto
do_dram
=
[
&
]()
{
auto
do_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
do_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
stride_do
,
1
),
number
<
FmhaBwdOGradDotO
::
kAlignmentOGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
do_dram_naive
,
make_tuple
(
number
<
kM0
>
{},
number
<
kVHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
}();
auto
d_dram
=
[
&
]()
{
const
auto
d_dram_naive
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
d_ptr
,
make_tuple
(
kargs
.
seqlen_q
),
number
<
1
>
{});
return
pad_tensor_view
(
d_dram_naive
,
make_tuple
(
number
<
kM0
>
{}),
sequence
<
kPadSeqLenQ
>
{});
}();
auto
o_dram_window
=
make_tile_window
(
o_dram
,
make_tuple
(
number
<
kM0
>
{},
number
<
kVHeaddim
>
{}),
{
i_m0
,
0
});
auto
do_dram_window
=
make_tile_window
(
do_dram
,
make_tuple
(
number
<
kM0
>
{},
number
<
kVHeaddim
>
{}),
{
i_m0
,
0
});
auto
d_dram_window
=
make_tile_window
(
d_dram
,
make_tuple
(
number
<
kM0
>
{}),
{
i_m0
});
FmhaBwdOGradDotO
{}(
o_dram_window
,
do_dram_window
,
d_dram_window
,
kargs
.
p_undrop
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
0 → 100644
View file @
4947639c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
BlockFmhaShape_
>
struct
FmhaBwdTilePartitioner
{
using
BlockFmhaShape
=
ck_tile
::
remove_cvref_t
<
BlockFmhaShape_
>
;
static
constexpr
ck_tile
::
index_t
kN0
=
BlockFmhaShape
::
kN0
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_k_
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_k_
,
kN0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_k*/
)
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
};
template
<
ck_tile
::
index_t
kBlockSize
>
struct
FmhaBwdOGradDotOTilePartitioner
{
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kBlockSize
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
)
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment