Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ea5be216
Commit
ea5be216
authored
Aug 23, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
e2eb0418
25935b57
Changes
168
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1557 additions
and
684 deletions
+1557
-684
include/ck/host_utility/device_prop.hpp
include/ck/host_utility/device_prop.hpp
+6
-0
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
...ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
+43
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp
...pu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp
+282
-99
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
...operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
+142
-162
include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp
...tensor_operation/gpu/device/impl/device_reduce_common.hpp
+3
-3
include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp
...or_operation/gpu/device/impl/device_reduce_multiblock.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp
...or_operation/gpu/device/impl/device_reduce_threadwise.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp
...tion/gpu/device/impl/device_reduce_threadwise_multi_d.hpp
+1
-1
include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp
...operation/gpu/element/combined_element_wise_operation.hpp
+3
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+15
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
+18
-10
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+47
-8
include/ck/utility/dynamic_buffer.hpp
include/ck/utility/dynamic_buffer.hpp
+4
-2
include/ck_tile/core/algorithm/coordinate_transform.hpp
include/ck_tile/core/algorithm/coordinate_transform.hpp
+9
-33
include/ck_tile/core/numeric/vector_type.hpp
include/ck_tile/core/numeric/vector_type.hpp
+9
-0
include/ck_tile/core/tensor/tile_distribution.hpp
include/ck_tile/core/tensor/tile_distribution.hpp
+3
-2
include/ck_tile/core/utility/philox_rand.hpp
include/ck_tile/core/utility/philox_rand.hpp
+33
-0
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+3
-8
include/ck_tile/ops/fmha/block/block_dropout.hpp
include/ck_tile/ops/fmha/block/block_dropout.hpp
+377
-16
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
+555
-332
No files found.
include/ck/host_utility/device_prop.hpp
View file @
ea5be216
...
...
@@ -65,6 +65,12 @@ inline bool is_lds_direct_load_supported()
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
;
}
inline
bool
is_bf16_atomic_supported
()
{
return
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
;
}
inline
bool
is_gfx101_supported
()
{
return
ck
::
get_device_name
()
==
"gfx1010"
||
ck
::
get_device_name
()
==
"gfx1011"
||
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
View file @
ea5be216
...
...
@@ -53,6 +53,49 @@ struct DeviceGemmMultipleD : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
// GEMM:
// input : A[M, K], B[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
struct
DeviceGemmMultipleDSplitK
:
public
BaseOperator
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
ck
::
index_t
StrideE
,
ck
::
index_t
KBatch
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp
View file @
ea5be216
...
...
@@ -69,7 +69,7 @@ template <typename ALayout,
typename
ComputeTypeB
=
ComputeTypeA
,
typename
LDSTypeA
=
ComputeTypeA
,
typename
LDSTypeB
=
ComputeTypeB
>
struct
DeviceGemmMultiD_Xdl_CShuffle_V3
:
public
DeviceGemmMultipleD
<
ALayout
,
struct
DeviceGemmMultiD_Xdl_CShuffle_V3
:
public
DeviceGemmMultipleD
SplitK
<
ALayout
,
BLayout
,
DsLayout
,
CLayout
,
...
...
@@ -192,15 +192,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// rotating mem
rotating_mem
.
Next
();
// clear c mem
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
if
(
arg_
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
arg_
.
p_c_grid
,
hipGetErrorString
(
hipMemsetAsync
(
arg_
.
p_c_grid
,
0
,
arg_
.
M
*
arg_
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
}
};
ave_time
=
ck
::
utility
::
launch_and_time_kernel_with_preprocess
<
false
>
(
...
...
@@ -234,9 +230,19 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
||
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
if
(
arg
.
KBatch
>
1
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
kernel_gemm_xdl_cshuffle_v3
_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
...
...
@@ -246,11 +252,124 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number could be One to Seven
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
if
(
arg
.
KBatch
>
1
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
One
>
;
Run
(
kernel
);
}
else
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Full
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Full
>
;
Run
(
kernel
);
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Two
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
3
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Three
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Three
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Four
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Four
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
5
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Five
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Five
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Six
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
7
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Seven
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Seven
>
;
Run
(
kernel
);
}
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
kernel_gemm_xdl_cshuffle_v3
_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -261,7 +380,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
TailNumber
::
Full
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
kernel_gemm_xdl_cshuffle_v3
_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -273,8 +392,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -288,8 +407,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Three
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -303,8 +422,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Four
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -318,8 +437,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Five
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -332,8 +451,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -347,8 +466,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Seven
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -361,12 +480,35 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number could be Odd or Even
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
if
(
arg
.
KBatch
>
1
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -375,8 +517,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -387,11 +529,35 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
}
else
{
if
(
arg
.
KBatch
>
1
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
kernel_gemm_xdl_cshuffle_v3
_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -401,7 +567,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
kernel_gemm_xdl_cshuffle_v3
_multi_d
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -416,9 +582,19 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number always 1
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
if
(
arg
.
KBatch
>
1
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
kernel_gemm_xdl_cshuffle_v3
_multi_d
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
...
...
@@ -451,6 +627,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
return
false
;
}
if
(
!
is_bf16_atomic_supported
()
&&
std
::
is_same_v
<
CDataType
,
ck
::
bhalf_t
>
&&
arg
.
KBatch
>
1
)
{
return
false
;
}
if
((
arg
.
K
%
AK1
!=
0
||
arg
.
K
%
BK1
!=
0
)
&&
!
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
...
...
@@ -479,6 +660,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
index_t
StrideB
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideC
,
index_t
KBatch
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
...
...
@@ -494,7 +676,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
StrideB
,
StrideDs
,
StrideC
,
1
,
KBatch
,
a_element_op
,
b_element_op
,
c_element_op
};
...
...
@@ -514,6 +696,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
index_t
StrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideC
,
index_t
KBatch
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
override
...
...
@@ -529,7 +712,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
StrideB
,
StrideDs
,
StrideC
,
1
,
KBatch
,
a_element_op
,
b_element_op
,
c_element_op
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
View file @
ea5be216
...
...
@@ -168,15 +168,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// rotating mem
rotating_mem
.
Next
();
// clear c mem
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
if
(
arg_
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
arg_
.
p_c_grid
,
hipGetErrorString
(
hipMemsetAsync
(
arg_
.
p_c_grid
,
0
,
arg_
.
M
*
arg_
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
}
};
ave_time
=
ck
::
utility
::
launch_and_time_kernel_with_preprocess
<
false
>
(
...
...
@@ -189,15 +185,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
arg_
);
}
else
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
if
(
arg
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
arg
.
p_c_grid
,
0
,
arg
.
M
*
arg
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
}
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
);
...
...
@@ -214,8 +207,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
if
(
arg
.
KBatch
>
1
)
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
...
...
@@ -224,7 +215,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
else
{
const
auto
kernel
=
...
...
@@ -239,13 +229,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
if
(
arg
.
KBatch
>
1
)
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
...
...
@@ -255,8 +243,8 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
else
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Full
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
...
...
@@ -266,8 +254,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
...
...
@@ -326,8 +313,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
...
...
@@ -354,7 +340,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
}
}
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
...
...
@@ -472,8 +457,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
if
(
arg
.
KBatch
>
1
)
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
...
...
@@ -496,7 +479,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
Run
(
kernel
);
}
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
...
...
@@ -524,13 +506,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
else
{
if
(
arg
.
KBatch
>
1
)
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
...
...
@@ -539,8 +519,8 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
...
...
@@ -548,7 +528,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
Run
(
kernel
);
}
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
...
...
@@ -579,10 +558,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// Tail number always 1
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
if
(
arg
.
KBatch
>
1
)
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
...
...
@@ -591,7 +567,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
else
{
const
auto
kernel
=
...
...
@@ -628,6 +603,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return
false
;
}
if
(
!
is_bf16_atomic_supported
()
&&
std
::
is_same_v
<
CDataType
,
ck
::
bhalf_t
>
&&
arg
.
KBatch
>
1
)
{
return
false
;
}
if
((
arg
.
K
%
AK1
!=
0
||
arg
.
K
%
BK1
!=
0
)
&&
!
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
...
...
include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp
View file @
ea5be216
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -19,7 +19,7 @@ namespace device {
template
<
index_t
Rank
,
int
NumReduceDim
>
std
::
pair
<
long_index_t
,
long_index_t
>
get_2d_lengths
(
const
std
::
vector
<
index_t
>&
inLengths
)
{
static_assert
(
Rank
<=
6
,
"bigger Rank size not supported!"
);
static_assert
(
Rank
<=
12
,
"bigger Rank size not supported!"
);
long_index_t
invariant_total_length
=
1
;
long_index_t
reduce_total_length
=
1
;
...
...
@@ -38,7 +38,7 @@ std::pair<long_index_t, long_index_t> get_2d_lengths(const std::vector<index_t>&
template
<
index_t
Rank
,
int
NumReduceDim
>
std
::
pair
<
long_index_t
,
long_index_t
>
get_2d_lengths
(
const
std
::
array
<
index_t
,
Rank
>&
inLengths
)
{
static_assert
(
Rank
<=
6
,
"bigger Rank size not supported!"
);
static_assert
(
Rank
<=
12
,
"bigger Rank size not supported!"
);
long_index_t
invariant_total_length
=
1
;
long_index_t
reduce_total_length
=
1
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp
View file @
ea5be216
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -51,7 +51,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InDataType,
PropagateNan
,
OutputIndex
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
Rank
<=
12
,
"Bigger Rank size is not supported!"
);
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
"Invalid thread cluster size assignments!"
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp
View file @
ea5be216
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -47,7 +47,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InDataType,
OutputIndex
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
Rank
<=
12
,
"Bigger Rank size is not supported!"
);
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
))
&&
...
...
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp
View file @
ea5be216
...
...
@@ -45,7 +45,7 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduceMultiD<InDataType,
OutElementwiseOperation
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
Rank
<=
12
,
"Bigger Rank size is not supported!"
);
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
))
&&
...
...
include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp
View file @
ea5be216
...
...
@@ -3,7 +3,6 @@
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
...
...
@@ -107,6 +106,9 @@ struct TrinaryWithUnaryCombinedOp
UnaryOp2
unary_op2_
{};
};
using
ScaleScalePass
=
UnaryCombinedOp
<
Scale
,
Scale
,
PassThrough
>
;
using
ScaleScaleRelu
=
UnaryCombinedOp
<
Scale
,
Scale
,
Relu
>
;
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
ea5be216
...
...
@@ -417,6 +417,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
}();
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
...
...
@@ -454,6 +461,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
// not pad M or N
return c_grid_desc_mraw_nraw;
}
#endif
}
struct
Problem
...
...
@@ -953,7 +961,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
&&
!
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
))
{
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
{
...
...
@@ -970,7 +979,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
&&
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
))
{
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
{
...
...
@@ -1105,7 +1115,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
if
constexpr
(
!
(
is_same
<
remove_cvref_t
<
CDataType
>
,
half_t
>::
value
||
is_same
<
remove_cvref_t
<
CDataType
>
,
float
>::
value
))
is_same
<
remove_cvref_t
<
CDataType
>
,
float
>::
value
||
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
||
is_same
<
remove_cvref_t
<
CDataType
>
,
int32_t
>::
value
))
{
if
(
!
karg
.
IsReduceAdd
())
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
View file @
ea5be216
...
...
@@ -36,10 +36,9 @@ __global__ void
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3
(
typename
GridwiseGemm
::
Argument
karg
)
kernel_gemm_xdl_cshuffle_v3
_multi_d
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
);
...
...
@@ -56,7 +55,7 @@ __global__ void
karg
.
c_element_op
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
template
<
typename
GridwiseGemm
,
...
...
@@ -69,10 +68,9 @@ __global__ void
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_2lds
(
typename
GridwiseGemm
::
Argument
karg
)
kernel_gemm_xdl_cshuffle_v3_
multi_d_
2lds
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -93,7 +91,7 @@ __global__ void
karg
.
c_element_op
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
template
<
typename
ALayout
,
...
...
@@ -454,6 +452,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
}
}();
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
...
...
@@ -491,6 +496,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
// not pad M or N
return c_grid_desc_mraw_nraw;
}
#endif
}
__host__
__device__
static
auto
MakeDsGridDescriptor_M_N
(
...
...
@@ -1016,7 +1022,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
&&
!
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
))
{
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
{
...
...
@@ -1033,7 +1040,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
&&
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
))
{
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
{
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
ea5be216
...
...
@@ -562,6 +562,34 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_addr_offset
);
}
template
<
typename
T
,
index_t
N
>
__device__
void
amd_global_atomic_add_impl
(
const
typename
vector_type
<
T
,
N
>::
type
src_thread_data
,
T
*
addr
)
{
static_assert
((
is_same
<
T
,
bhalf_t
>::
value
&&
(
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
2
||
N
==
4
||
N
==
8
)),
"wrong! not implemented"
);
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
{
vector_type
<
half_t
,
N
>
tmp
{
src_thread_data
};
static_for
<
0
,
N
/
2
,
1
>
{}([
&
](
auto
i
)
{
__builtin_amdgcn_global_atomic_fadd_v2f16
(
bit_cast
<
half2_t
*>
(
addr
)
+
i
,
tmp
.
template
AsType
<
half2_t
>()[
i
]);
});
}
#if defined(__gfx942__)
else
if
constexpr
(
is_same
<
T
,
bhalf_t
>::
value
)
{
vector_type
<
bhalf_t
,
N
>
tmp
{
src_thread_data
};
static_for
<
0
,
N
/
2
,
1
>
{}([
&
](
auto
i
)
{
__builtin_amdgcn_global_atomic_fadd_v2bf16
(
bit_cast
<
bhalf2_t
*>
(
addr
)
+
i
,
tmp
.
template
AsType
<
bhalf2_t
>()[
i
]);
});
}
#endif
}
template
<
typename
T
,
index_t
N
>
__device__
void
amd_buffer_atomic_add_impl
(
const
typename
vector_type
<
T
,
N
>::
type
src_thread_data
,
int32x4_t
dst_wave_buffer_resource
,
...
...
@@ -907,6 +935,16 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
using
scalar_t
=
typename
scalar_type
<
vector_t
>::
type
;
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
if
constexpr
(
is_same
<
T
,
bhalf_t
>::
value
)
{
if
(
dst_thread_element_valid
)
{
amd_global_atomic_add_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
p_dst_wave
+
dst_thread_element_offset
);
}
}
else
{
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x80000000
;
...
...
@@ -919,6 +957,7 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
#endif
}
}
// buffer_atomic_max requires:
...
...
include/ck/utility/dynamic_buffer.hpp
View file @
ea5be216
...
...
@@ -358,13 +358,15 @@ struct DynamicBuffer
bool
constexpr
use_amd_buffer_addressing
=
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
int32_t
>
||
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
float
>
||
(
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
half_t
>
&&
scalar_per_x_vector
%
2
==
0
);
(
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
half_t
>
&&
scalar_per_x_vector
%
2
==
0
)
||
(
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
bhalf_t
>
&&
scalar_per_x_vector
%
2
==
0
);
#elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
bool
constexpr
use_amd_buffer_addressing
=
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
int32_t
>
;
#elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool
constexpr
use_amd_buffer_addressing
=
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
float
>
||
(
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
half_t
>
&&
scalar_per_x_vector
%
2
==
0
);
(
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
half_t
>
&&
scalar_per_x_vector
%
2
==
0
)
||
(
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
bhalf_t
>
&&
scalar_per_x_vector
%
2
==
0
);
#else
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
...
...
include/ck_tile/core/algorithm/coordinate_transform.hpp
View file @
ea5be216
...
...
@@ -1341,7 +1341,7 @@ struct modulo : public base_transform<1, 1>
};
// 2D XOR, NOTE: "xor" is a keyword
template
<
typename
LowLengths
,
typename
RightShift
>
template
<
typename
LowLengths
>
struct
xor_t
:
public
base_transform
<
2
,
2
>
{
static
constexpr
auto
type_enum
=
coord_transform_enum
::
xor_t
;
...
...
@@ -1352,15 +1352,10 @@ struct xor_t : public base_transform<2, 2>
using
UpLengths
=
LowLengths
;
UpLengths
up_lengths_
;
RightShift
right_shift_
;
CK_TILE_HOST_DEVICE
constexpr
xor_t
()
:
up_lengths_
{}
,
right_shift_
{}
{}
CK_TILE_HOST_DEVICE
constexpr
xor_t
()
:
up_lengths_
{}
{}
CK_TILE_HOST_DEVICE
constexpr
xor_t
(
const
LowLengths
&
low_lengths
,
const
RightShift
&
right_shift
)
:
up_lengths_
{
low_lengths
},
right_shift_
{
right_shift
}
{
}
CK_TILE_HOST_DEVICE
constexpr
xor_t
(
const
LowLengths
&
low_lengths
)
:
up_lengths_
{
low_lengths
}
{}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_type_enum
()
{
...
...
@@ -1378,13 +1373,8 @@ struct xor_t : public base_transform<2, 2>
idx_low
(
number
<
0
>
{})
=
idx_up
[
number
<
0
>
{}];
const
auto
idx_low_1_tmp
=
(
idx_up
[
number
<
1
>
{}]
-
idx_up
[
number
<
0
>
{}]
*
right_shift_
)
%
up_lengths_
[
number
<
1
>
{}];
const
auto
idx_low_1
=
(
idx_low_1_tmp
>=
0
)
?
idx_low_1_tmp
:
up_lengths_
[
number
<
1
>
{}]
+
idx_low_1_tmp
;
idx_low
(
number
<
1
>
{})
=
idx_low_1
;
idx_low
(
number
<
1
>
{})
=
idx_up
[
number
<
1
>
{}]
^
(
idx_up
[
number
<
0
>
{}]
%
up_lengths_
[
number
<
1
>
{}]);
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
...
...
@@ -1419,8 +1409,7 @@ struct xor_t : public base_transform<2, 2>
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
&&
ck_tile
::
is_known_at_compile_time
<
RightShift
>::
value
;
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
;
}
// MUST be static function
...
...
@@ -1432,14 +1421,6 @@ struct xor_t : public base_transform<2, 2>
array
<
index_t
,
2
>
up_vector_lengths
=
low_vector_lengths
;
array
<
index_t
,
2
>
up_vector_strides
=
low_vector_strides
;
if
constexpr
(
ck_tile
::
is_known_at_compile_time
<
RightShift
>::
value
)
{
if
(
low_vector_lengths
[
1
]
!=
-
1
)
{
up_vector_lengths
(
1
)
=
gcd
(
low_vector_lengths
[
1
],
abs
(
right_shift_
));
}
}
return
make_tuple
(
up_vector_lengths
,
up_vector_strides
);
}
...
...
@@ -1452,10 +1433,6 @@ struct xor_t : public base_transform<2, 2>
print
(
up_lengths_
);
printf
(
", "
);
//
printf
(
"right_shift_: "
);
print
(
right_shift_
);
printf
(
"}"
);
}
};
...
...
@@ -1655,11 +1632,10 @@ CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus,
return
modulo
<
Modulus
,
UpLength
>
{
modulus
,
up_length
};
}
template
<
typename
LowLengths
,
typename
RightShift
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_xor_transform
(
const
LowLengths
&
low_lengths
,
const
RightShift
&
right_shift
)
template
<
typename
LowLengths
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_xor_transform
(
const
LowLengths
&
low_lengths
)
{
return
xor_t
<
LowLengths
,
RightShift
>
{
low_lengths
,
right_shift
};
return
xor_t
<
LowLengths
>
{
low_lengths
};
}
template
<
typename
LowLength
,
typename
OffsetLength
>
...
...
include/ck_tile/core/numeric/vector_type.hpp
View file @
ea5be216
...
...
@@ -117,6 +117,15 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
using
int32x32_t
=
int32_t
__attribute__
((
ext_vector_type
(
32
)));
using
int32x64_t
=
int32_t
__attribute__
((
ext_vector_type
(
64
)));
// u32
// using uint32_t = ...
using
uint32x2_t
=
uint32_t
__attribute__
((
ext_vector_type
(
2
)));
using
uint32x4_t
=
uint32_t
__attribute__
((
ext_vector_type
(
4
)));
using
uint32x8_t
=
uint32_t
__attribute__
((
ext_vector_type
(
8
)));
using
uint32x16_t
=
uint32_t
__attribute__
((
ext_vector_type
(
16
)));
using
uint32x32_t
=
uint32_t
__attribute__
((
ext_vector_type
(
32
)));
using
uint32x64_t
=
uint32_t
__attribute__
((
ext_vector_type
(
64
)));
// i16
// using int16_t = ...
using
int16x2_t
=
int16_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
include/ck_tile/core/tensor/tile_distribution.hpp
View file @
ea5be216
...
...
@@ -746,7 +746,8 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
return
make_tuple
(
make_static_tile_distribution
(
tile_distribution_encoding
<
typename
Encoding
::
RsLengths
,
decltype
(
sliced_h_lengths
),
// only need to change the
remove_cvref_t
<
decltype
(
sliced_h_lengths
)
>
,
// only need to
// change the
// h_lengths type
typename
Encoding
::
Ps2RHssMajor
,
typename
Encoding
::
Ps2RHssMinor
,
...
...
include/ck_tile/core/utility/philox_rand.hpp
View file @
ea5be216
...
...
@@ -53,6 +53,39 @@ class philox
out_tmp
[
3
]
=
tmp_ph
.
w
;
}
CK_TILE_HOST_DEVICE
void
get_random_8x8
(
uint8_t
*
out
,
const
unsigned
long
long
subsequence
,
const
index_t
start_idx
)
const
{
uint4
tmp_ph
;
tmp_ph
=
get_philox_4x32
(
subsequence
);
uint32x4_t
tmp
;
tmp
[
0
]
=
tmp_ph
.
x
;
tmp
[
1
]
=
tmp_ph
.
y
;
tmp
[
2
]
=
tmp_ph
.
z
;
tmp
[
3
]
=
tmp_ph
.
w
;
uint32_t
*
out_tmp
=
reinterpret_cast
<
uint32_t
*>
(
&
out
[
0
]);
out_tmp
[
0
]
=
tmp
[
start_idx
];
out_tmp
[
1
]
=
tmp
[
start_idx
+
2
];
}
CK_TILE_HOST_DEVICE
void
get_random_4x8
(
uint8_t
*
out
,
const
unsigned
long
long
subsequence
,
const
index_t
start_idx
)
const
{
uint4
tmp_ph
;
tmp_ph
=
get_philox_4x32
(
subsequence
);
uint32x4_t
tmp
;
tmp
[
0
]
=
tmp_ph
.
x
;
tmp
[
1
]
=
tmp_ph
.
y
;
tmp
[
2
]
=
tmp_ph
.
z
;
tmp
[
3
]
=
tmp_ph
.
w
;
uint32_t
*
out_tmp
=
reinterpret_cast
<
uint32_t
*>
(
&
out
[
0
]);
out_tmp
[
0
]
=
tmp
[
start_idx
];
}
private:
struct
ull2
{
...
...
include/ck_tile/ops/fmha.hpp
View file @
ea5be216
...
...
@@ -8,21 +8,16 @@
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
...
...
include/ck_tile/ops/fmha/block/block_dropout.hpp
View file @
ea5be216
...
...
@@ -286,11 +286,226 @@ struct BlockDropout
});
}
ck_tile
::
philox
ph
;
const
float
rp_undrop
;
const
uint8_t
p_undrop_in_uint8_t
;
const
bool
is_store_randval
;
};
template
<
bool
IsDropout_
,
bool
IsWG32_
,
bool
IsStoreRandval_
>
struct
BlockDropoutBwd
;
template
<
bool
IsWG32_
,
bool
IsStoreRandval_
>
struct
BlockDropoutBwd
<
false
,
IsWG32_
,
IsStoreRandval_
>
{
static
constexpr
bool
IsDropout
=
false
;
static
constexpr
bool
IsStoreRandval
=
IsStoreRandval_
;
template
<
typename
BlockGemm
,
bool
IsFwd
=
true
,
typename
RandValDramBlockWindowTmp
>
__host__
__device__
static
constexpr
auto
MakeRandvalDramWindow
(
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
index_t
seqlen_qk_start
)
{
(
void
)
randval_dram_block_window_tmp
;
(
void
)
seqlen_qk_start
;
return
make_null_tile_window
(
make_tuple
(
number
<
0
>
{},
number
<
0
>
{}));
}
};
template
<
bool
IsWG32_
,
bool
IsStoreRandval_
>
struct
BlockDropoutBwd
<
true
,
IsWG32_
,
IsStoreRandval_
>
{
static
constexpr
bool
IsDropout
=
true
;
// true: 32*32 warp gemm
// false: 16*16 warp gemm
static
constexpr
bool
IsWG32
=
IsWG32_
;
static
constexpr
bool
IsStoreRandval
=
IsStoreRandval_
;
CK_TILE_HOST_DEVICE
BlockDropoutBwd
(
index_t
i_batch
,
index_t
i_head
,
index_t
nheads
,
unsigned
long
long
seed
,
unsigned
long
long
offset
,
float
rp_undrop_
,
uint8_t
p_undrop_in_uint8_t_
)
:
ph
(
seed
,
offset
+
(
i_batch
*
nheads
+
i_head
)
*
get_warp_size
()
+
(
IsWG32
?
get_lane_id
()
:
((
get_lane_id
()
&
47
)
+
((
get_warp_id
()
&
1
)
<<
4
)))),
rp_undrop
(
rp_undrop_
),
p_undrop_in_uint8_t
(
p_undrop_in_uint8_t_
)
{
}
template
<
typename
BlockGemm
,
bool
IsFwd
=
true
,
typename
RandValDramBlockWindowTmp
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandvalDramWindow
(
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
index_t
seqlen_qk_start
)
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
BlockGemmShape
=
remove_cvref_t
<
typename
BlockGemm
::
BlockGemmShape
>
;
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
bool
MBwdWG16MultiIterCheck
=
(
!
IsFwd
)
&&
(
!
IsWG32
)
&&
(
kMPerBlock
>
16
);
constexpr
index_t
kMPerStep
=
[
&
]()
{
if
constexpr
(
MBwdWG16MultiIterCheck
)
{
return
MWarp
*
WG
::
kM
*
2
;
}
else
{
return
MWarp
*
WG
::
kM
;
}
}();
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
const
auto
block_origin
=
randval_dram_block_window_tmp
.
get_window_origin
();
auto
randval_dram_window
=
[
&
]()
{
if
constexpr
(
IsFwd
)
{
return
make_tile_window
(
randval_dram_block_window_tmp
.
get_bottom_tensor_view
(),
ck_tile
::
make_tuple
(
number
<
kMPerStep
>
{},
number
<
kNPerStep
>
{}),
{
block_origin
.
at
(
number
<
0
>
{}),
seqlen_qk_start
});
// M/N
}
else
{
return
make_tile_window
(
randval_dram_block_window_tmp
.
get_bottom_tensor_view
(),
ck_tile
::
make_tuple
(
number
<
kMPerStep
>
{},
number
<
kNPerStep
>
{}),
{
seqlen_qk_start
,
block_origin
.
at
(
number
<
1
>
{})});
// M/N
}
}();
return
randval_dram_window
;
}
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandValLdsBlockDescriptor
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kNPerStep
=
WG
::
kN
;
constexpr
index_t
kN1
=
8
;
constexpr
index_t
kN0
=
kNPerStep
/
kN1
;
constexpr
auto
randval_lds_block_desc_0
=
make_naive_tensor_descriptor
(
ck_tile
::
make_tuple
(
number
<
kN0
>
{},
number
<
kMPerStep
>
{},
number
<
kN1
>
{}),
ck_tile
::
make_tuple
(
number
<
(
kMPerStep
+
1
)
*
kN1
>
{},
number
<
kN1
>
{},
number
<
1
>
{}),
number
<
kN1
>
{},
number
<
1
>
{});
constexpr
auto
randval_lds_block_desc
=
transform_tensor_descriptor
(
randval_lds_block_desc_0
,
ck_tile
::
make_tuple
(
make_pass_through_transform
(
number
<
kMPerStep
>
{}),
make_merge_transform
(
ck_tile
::
make_tuple
(
number
<
kN0
>
{},
number
<
kN1
>
{}))),
ck_tile
::
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
ck_tile
::
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
randval_lds_block_desc
;
}
template
<
typename
BlockGemm
,
bool
IsFwd
=
true
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandValTileDistribution
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
BlockGemmShape
=
remove_cvref_t
<
typename
BlockGemm
::
BlockGemmShape
>
;
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
bool
MBwdWG16MultiIterCheck
=
(
!
IsFwd
)
&&
(
!
IsWG32
)
&&
(
kMPerBlock
>
16
);
constexpr
index_t
MIterPerWarp
=
[
&
]()
{
if
constexpr
(
MBwdWG16MultiIterCheck
)
{
return
2
;
}
else
{
return
1
;
}
}();
constexpr
index_t
NIterPerWarp
=
1
;
constexpr
auto
randval_block_outer_part_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
// Use Bwd WarpGemm to ensure that Fwd's random values are consistent with Bwd.
// except headdim256.
constexpr
auto
randval_block_inner_part_dstr_encoding
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
BlockGemm
::
ADataType
,
half_t
>
&&
std
::
is_same_v
<
typename
BlockGemm
::
BDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
BlockGemm
::
CDataType
,
float
>
)
{
if
constexpr
(
IsWG32
)
return
typename
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
::
CWarpDstrEncoding
{};
else
return
typename
WarpGemmMfmaF16F16F32M16N16K16
::
CWarpDstrEncoding
{};
}
else
{
if
constexpr
(
IsWG32
)
return
typename
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
::
CWarpDstrEncoding
{};
else
return
typename
WarpGemmMfmaBf16Bf16F32M16N16K16
::
CWarpDstrEncoding
{};
}
}();
constexpr
auto
randval_block_part_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
randval_block_outer_part_dstr_encoding
,
randval_block_inner_part_dstr_encoding
);
return
make_static_tile_distribution
(
randval_block_part_dstr_encode
);
}
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandValLdsShuffleTileDistribution
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
1
;
constexpr
index_t
NIterPerWarp
=
1
;
constexpr
auto
randval_block_outer_part_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
randval_block_part_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
randval_block_outer_part_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
return
make_static_tile_distribution
(
randval_block_part_dstr_encode
);
}
template
<
typename
BlockGemm
,
typename
PComputeDataType
,
typename
RandValOutputDataType
,
typename
PComputeWindow
,
typename
RandValDramWindow
>
CK_TILE_HOST_DEVICE
void
Run
(
const
index_t
start_m0_idx
,
CK_TILE_HOST_DEVICE
void
Run
(
void
*
randval_ptr
,
const
index_t
start_m0_idx
,
const
index_t
start_n0_idx
,
PComputeWindow
&
p_compute
,
RandValDramWindow
&
randval_dram_window
)
const
{
...
...
@@ -305,30 +520,177 @@ struct BlockDropout
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
// randval tile in LDS
auto
randval_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
uint8_t
*>
(
randval_ptr
),
MakeRandValLdsBlockDescriptor
<
BlockGemm
>
());
auto
randval_lds_window
=
make_tile_window
(
randval_lds
,
MakeRandValLdsBlockDescriptor
<
BlockGemm
>
().
get_lengths
(),
{
0
,
0
});
// register distribute
auto
randval
=
auto
randval
_dist_generated
=
make_static_distributed_tensor
<
uint8_t
>
(
MakeRandValTileDistribution
<
BlockGemm
>
());
static_assert
(
randval_dist_generated
.
kThreadElementSpaceSize
==
16
);
auto
randval_lds_read_window
=
make_tile_window
(
randval_lds_window
.
get_bottom_tensor_view
(),
randval_lds_window
.
get_window_lengths
(),
randval_lds_window
.
get_window_origin
(),
MakeRandValLdsShuffleTileDistribution
<
BlockGemm
>
());
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
int
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
(
i_m0
*
MWarp
)
+
get_warp_id
();
int
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
i_n0
;
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
// generate random number
uint8_t
random_uint8_t
[
16
];
ph
.
get_random_16x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
));
constexpr
auto
randval_dist_generated_spans
=
decltype
(
randval_dist_generated
)
::
get_distributed_spans
();
int
i_random_idx
=
0
;
sweep_tile_span
(
randval_dist_generated_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_dist_generated_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
randval_dist_generated
(
i_j_idx
)
=
random_uint8_t
[
i_random_idx
++
];
});
});
// save to LDS
store_tile
(
randval_lds_window
,
randval_dist_generated
);
block_sync_lds
();
// read from LDS to register
auto
randval
=
load_tile
(
randval_lds_read_window
);
constexpr
auto
randval_spans
=
decltype
(
randval
)
::
get_distributed_spans
();
sweep_tile_span
(
randval_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
p_idx0
=
tile_distributed_index
<
i_m0
>
{};
constexpr
auto
p_idx1
=
tile_distributed_index
<
i_n0
,
idx1
.
impl_
.
at
(
1
),
idx1
.
impl_
.
at
(
2
)
>
{};
constexpr
auto
p_idx
=
ck_tile
::
make_tuple
(
p_idx0
,
p_idx1
);
constexpr
auto
r_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
p_compute
(
p_idx
)
=
randval
[
r_idx
]
<=
p_undrop_in_uint8_t
?
p_compute
[
p_idx
]
*
rp_undrop
:
PComputeDataType
(
0
);
});
});
// save to Global
if
constexpr
(
IsStoreRandval
)
{
const
auto
randval_store
=
cast_tile
<
RandValOutputDataType
>
(
randval
);
store_tile
(
randval_dram_window
,
randval_store
);
move_tile_window
(
randval_dram_window
,
{
0
,
kNPerStep
});
}
});
if
constexpr
(
IsStoreRandval
)
{
move_tile_window
(
randval_dram_window
,
{
kMPerStep
,
-
kNPerBlock
});
}
});
if
constexpr
(
IsStoreRandval
)
{
move_tile_window
(
randval_dram_window
,
{
-
kMPerBlock
,
kNPerBlock
});
}
}
template
<
typename
BlockGemm
,
typename
RandValOutputDataType
,
typename
PComputeWindow
,
typename
RandValDramWindow
>
CK_TILE_HOST_DEVICE
void
Run
(
const
index_t
start_m0_idx
,
const
index_t
start_n0_idx
,
PComputeWindow
&
p_compute
,
RandValDramWindow
&
randval_dram_window
)
const
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
using
BlockGemmShape
=
remove_cvref_t
<
typename
BlockGemm
::
BlockGemmShape
>
;
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
constexpr
bool
MBwdWG16MultiIterCheck
=
(
!
IsWG32
)
&&
(
kMPerBlock
>
16
);
constexpr
bool
MBwdWG16SingleIterCheck
=
(
!
IsWG32
)
&&
(
kMPerBlock
==
16
);
constexpr
index_t
kMPerStep
=
[
&
]()
{
if
constexpr
(
MBwdWG16MultiIterCheck
)
{
return
MWarp
*
WG
::
kM
*
2
;
}
else
{
return
MWarp
*
WG
::
kM
;
}
}();
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
// register distribute
auto
randval
=
make_static_distributed_tensor
<
uint8_t
>
(
MakeRandValTileDistribution
<
BlockGemm
,
false
>
());
if
constexpr
(
IsWG32
)
static_assert
(
randval
.
kThreadElementSpaceSize
==
16
);
else
static_assert
(
randval
.
kThreadElementSpaceSize
==
4
||
randval
.
kThreadElementSpaceSize
==
8
);
const
int
start_n0_idx
=
randval_dram_window
.
get_window_origin
().
at
(
number
<
1
>
{});
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
int
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
i_m0
;
int
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
(
i_n0
*
NWarp
)
+
get_warp_id
();
int
block_row_start
,
block_col_start
;
if
constexpr
(
IsWG32
)
{
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
i_m0
;
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
(
i_n0
*
NWarp
)
+
get_warp_id
();
}
else
{
block_row_start
=
start_m0_idx
/
32
+
i_m0
;
block_col_start
=
(
start_n0_idx
/
32
)
+
get_warp_id
()
/
2
+
i_n0
*
2
;
}
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
// generate random number
uint8_t
*
random_uint8_t_
;
if
constexpr
(
MBwdWG16SingleIterCheck
)
{
uint8_t
random_uint8_t
[
4
];
// m0t0 ~m0t15/m0t32~m0t47: 0
// m0t16~m0t31/m0t48~m0t63: 1
// m1t0 ~m1t15/m1t32~m1t47: 2
// m1t16~m1t31/m1t48~m1t63: 3
const
index_t
start_idx
=
((
get_lane_id
()
>>
4
)
&
1
)
+
(((
start_m0_idx
>>
4
)
&
1
)
<<
1
);
ph
.
get_random_4x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
),
start_idx
);
random_uint8_t_
=
random_uint8_t
;
}
else
if
constexpr
(
MBwdWG16MultiIterCheck
)
{
uint8_t
random_uint8_t
[
8
];
// t0 ~t15/t32~t47: 0
// t16~t31/t48~t63: 1
const
index_t
start_idx
=
(
get_lane_id
()
>>
4
)
&
1
;
ph
.
get_random_8x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
),
start_idx
);
random_uint8_t_
=
random_uint8_t
;
}
else
{
uint8_t
random_uint8_t
[
16
];
ph
.
get_random_16x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
));
ph
.
get_random_16x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
));
random_uint8_t_
=
random_uint8_t
;
}
constexpr
auto
randval_spans
=
decltype
(
randval
)
::
get_distributed_spans
();
int
i_random_idx
=
0
;
sweep_tile_span
(
randval_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
r_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
randval
(
r_idx
)
=
random_uint8_t
[
i_random_idx
++
];
constexpr
auto
p_idx0
=
tile_distributed_index
<
i_m0
,
idx0
.
impl_
.
at
(
1
),
idx0
.
impl_
.
at
(
2
)
>
{};
randval
(
r_idx
)
=
random_uint8_t_
[
i_random_idx
++
];
constexpr
auto
p_idx0
=
tile_distributed_index
<
i_m0
+
idx0
.
impl_
.
at
(
0
),
idx0
.
impl_
.
at
(
1
),
idx0
.
impl_
.
at
(
2
)
>
{};
constexpr
auto
p_idx1
=
tile_distributed_index
<
i_n0
>
{};
constexpr
auto
p_idx
=
ck_tile
::
make_tuple
(
p_idx0
,
p_idx1
);
p_compute
(
p_idx
)
=
randval
[
r_idx
]
<=
p_undrop_in_uint8_t
...
...
@@ -337,19 +699,19 @@ struct BlockDropout
});
});
// save to Global
if
(
is_s
tore
_r
andval
)
if
constexpr
(
IsS
tore
R
andval
)
{
const
auto
randval_store
=
cast_tile
<
RandValOutputDataType
>
(
randval
);
store_tile
(
randval_dram_window
,
randval_store
);
move_tile_window
(
randval_dram_window
,
{
kMPerStep
,
0
});
}
});
if
(
is_s
tore
_r
andval
)
if
constexpr
(
IsS
tore
R
andval
)
{
move_tile_window
(
randval_dram_window
,
{
-
kMPerBlock
,
kNPerStep
});
}
});
if
(
is_s
tore
_r
andval
)
if
constexpr
(
IsS
tore
R
andval
)
{
move_tile_window
(
randval_dram_window
,
{
kMPerBlock
,
-
kNPerBlock
});
}
...
...
@@ -358,7 +720,6 @@ struct BlockDropout
ck_tile
::
philox
ph
;
const
float
rp_undrop
;
const
uint8_t
p_undrop_in_uint8_t
;
const
bool
is_store_randval
;
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
View file @
ea5be216
...
...
@@ -23,13 +23,9 @@
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
KGradEpiloguePipeline_
,
typename
VGradEpiloguePipeline_
>
template
<
typename
FmhaPipeline_
,
typename
KGradEpiloguePipeline_
,
typename
VGradEpiloguePipeline_
>
struct
FmhaBwdDQDKDVKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
ck_tile
::
remove_cvref_t
<
FmhaPipeline_
>
;
using
KGradEpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
KGradEpiloguePipeline_
>
;
using
VGradEpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
VGradEpiloguePipeline_
>
;
...
...
@@ -59,9 +55,12 @@ struct FmhaBwdDQDKDVKernel
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
FmhaPipeline
::
kHasBiasGrad
;
static
constexpr
bool
kHasDropout
=
FmhaPipeline
::
kHasDropout
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
using
FmhaDropout
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaDropout
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
static
constexpr
bool
kHasDropout
=
FmhaDropout
::
IsDropout
;
static
constexpr
bool
kIsStoreRandval
=
FmhaDropout
::
IsStoreRandval
;
static
constexpr
bool
kIsDeterministic
=
FmhaPipeline
::
kIsDeterministic
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
...
...
@@ -74,8 +73,11 @@ struct FmhaBwdDQDKDVKernel
// sync with generate.py
// clang-format off
using
bfs
=
typename
FmhaPipeline
::
BlockFmhaShape
;
using
gbr
=
typename
bfs
::
Gemm0BlockWarps
;
using
gwt
=
typename
bfs
::
Gemm0WarpTile
;
using
gbr0
=
typename
bfs
::
Gemm0BlockWarps
;
using
gbr1
=
typename
bfs
::
Gemm1BlockWarps
;
using
gbr4
=
typename
bfs
::
Gemm4BlockWarps
;
using
gwt0
=
typename
bfs
::
Gemm0WarpTile
;
using
gwt1
=
typename
bfs
::
Gemm1WarpTile
;
#define _SS_ std::string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
...
...
@@ -88,13 +90,17 @@ struct FmhaBwdDQDKDVKernel
return
_SS_
(
"fmha_bwd_d"
)
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
+
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"x"
+
_TS_
(
bfs
::
kVHeaddim
)
+
"_"
+
"r"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kK2
)
+
"x"
+
_TS_
(
bfs
::
kK3
)
+
"x"
+
_TS_
(
bfs
::
kK4
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"x"
+
_TS_
(
bfs
::
kVHeaddim
)
+
"_"
+
"r"
+
_TS_
(
gbr0
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr0
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr0
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
gbr1
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr1
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr1
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
gbr4
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr4
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr4
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt0
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt0
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt0
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt1
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt1
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt1
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
)
+
_SS_
(
FmhaPipeline
::
name
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasBiasGrad
?
"_dbias"
:
""
)
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
);
(
kHasBiasGrad
?
"_dbias"
:
""
)
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kIsStoreRandval
?
"_storerandval"
:
""
)
+
(
kIsDeterministic
?
"_deterministic"
:
""
);
#undef _SS_
#undef _TS_
// clang-format on
...
...
@@ -117,7 +123,7 @@ struct FmhaBwdDQDKDVKernel
const
void
*
lse_ptr
;
const
void
*
do_ptr
;
const
void
*
d_ptr
;
void
*
dq_ptr
;
void
*
dq_
acc_
ptr
;
void
*
dk_ptr
;
void
*
dv_ptr
;
...
...
@@ -131,14 +137,13 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
num_head_q
;
ck_tile
::
index_t
nhead_ratio_qk
;
float
raw_scale
;
#if CK_TILE_FMHA_FWD_FAST_EXP2
float
scale
;
#endif
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_do
;
ck_tile
::
index_t
stride_dq_acc
;
ck_tile
::
index_t
stride_dk
;
ck_tile
::
index_t
stride_dv
;
...
...
@@ -147,8 +152,9 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_do
;
ck_tile
::
index_t
nhead_stride_lsed
;
ck_tile
::
index_t
batch_stride_lsed
;
ck_tile
::
index_t
nhead_stride_dq_acc
;
ck_tile
::
index_t
nhead_stride_dk
;
ck_tile
::
index_t
nhead_stride_dv
;
};
struct
FmhaBwdCommonBiasKargs
...
...
@@ -206,7 +212,6 @@ struct FmhaBwdDQDKDVKernel
float
rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
bool
is_store_randval
=
false
;
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
...
...
@@ -218,6 +223,10 @@ struct FmhaBwdDQDKDVKernel
{
ck_tile
::
index_t
batch_stride_randval
=
0
;
};
struct
FmhaBwdDeterministicKargs
{
ck_tile
::
index_t
split_stride_dq_acc
=
0
;
};
struct
FmhaBwdBatchModeKargs
:
FmhaBwdCommonKargs
,
...
...
@@ -228,12 +237,15 @@ struct FmhaBwdDQDKDVKernel
FmhaBwdEmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasBiasGrad
,
FmhaBwdBatchModeBiasGradKargs
,
FmhaBwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kHasMask
,
FmhaBwdMaskKargs
,
FmhaBwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kHasDropout
,
FmhaBwdBatchModeDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
std
::
conditional_t
<
kHasDropout
,
FmhaBwdBatchModeDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
,
std
::
conditional_t
<
kIsDeterministic
,
FmhaBwdDeterministicKargs
,
FmhaBwdEmptyKargs
<
4
>>
{
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_lsed
;
ck_tile
::
index_t
batch_stride_dq_acc
;
ck_tile
::
index_t
batch_stride_dk
;
ck_tile
::
index_t
batch_stride_dv
;
};
...
...
@@ -247,7 +259,8 @@ struct FmhaBwdDQDKDVKernel
FmhaBwdEmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasBiasGrad
,
FmhaBwdCommonBiasGradKargs
,
FmhaBwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kHasMask
,
FmhaBwdMaskKargs
,
FmhaBwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kHasDropout
,
FmhaBwdCommonDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
std
::
conditional_t
<
kHasDropout
,
FmhaBwdCommonDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
,
std
::
conditional_t
<
kIsDeterministic
,
FmhaBwdDeterministicKargs
,
FmhaBwdEmptyKargs
<
4
>>
{
const
int32_t
*
seqstart_q_ptr
;
const
int32_t
*
seqstart_k_ptr
;
...
...
@@ -266,10 +279,10 @@ struct FmhaBwdDQDKDVKernel
const
void
*
do_ptr
,
const
void
*
d_ptr
,
void
*
rand_val_ptr
,
void
*
dq_ptr
,
void
*
dk_ptr
,
void
*
dv_ptr
,
void
*
dbias_ptr
,
void
*
dq_acc_ptr
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_k
,
ck_tile
::
index_t
hdim_q
,
...
...
@@ -283,6 +296,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
stride_dk
,
ck_tile
::
index_t
stride_dv
,
ck_tile
::
index_t
stride_dbias
,
...
...
@@ -293,6 +307,9 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dk
,
ck_tile
::
index_t
nhead_stride_dv
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
batch_stride_q
,
ck_tile
::
index_t
batch_stride_k
,
...
...
@@ -301,14 +318,15 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
batch_stride_randval
,
ck_tile
::
index_t
batch_stride_do
,
ck_tile
::
index_t
batch_stride_lsed
,
ck_tile
::
index_t
batch_stride_dq_acc
,
ck_tile
::
index_t
batch_stride_dk
,
ck_tile
::
index_t
batch_stride_dv
,
ck_tile
::
index_t
batch_stride_dbias
,
ck_tile
::
index_t
split_stride_dq_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
...
...
@@ -317,7 +335,7 @@ struct FmhaBwdDQDKDVKernel
lse_ptr
,
do_ptr
,
d_ptr
,
dq_ptr
,
dq_
acc_
ptr
,
dk_ptr
,
dv_ptr
,
seqlen_q
,
...
...
@@ -327,13 +345,12 @@ struct FmhaBwdDQDKDVKernel
num_head_q
,
nhead_ratio_qk
,
scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast
<
float
>
(
scale
*
ck_tile
::
log2e_v
<>
),
#endif
stride_q
,
stride_k
,
stride_v
,
stride_do
,
stride_dq_acc
,
stride_dk
,
stride_dv
,
nhead_stride_q
,
...
...
@@ -341,15 +358,20 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_v
,
nhead_stride_do
,
nhead_stride_lsed
,
batch_stride_lsed
},
// args for common karg
nhead_stride_dq_acc
,
nhead_stride_dk
,
nhead_stride_dv
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for dbias
{},
// placeholder for mask
{},
// placeholder for dropout
{},
// placeholder for deterministic
batch_stride_q
,
batch_stride_k
,
batch_stride_v
,
batch_stride_do
,
batch_stride_lsed
,
batch_stride_dq_acc
,
batch_stride_dk
,
batch_stride_dv
};
...
...
@@ -384,11 +406,18 @@ struct FmhaBwdDQDKDVKernel
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
if
constexpr
(
kIsStoreRandval
)
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
}
}
if
constexpr
(
kIsDeterministic
)
{
kargs
.
split_stride_dq_acc
=
split_stride_dq_acc
;
}
return
kargs
;
...
...
@@ -404,10 +433,10 @@ struct FmhaBwdDQDKDVKernel
const
void
*
do_ptr
,
const
void
*
d_ptr
,
void
*
rand_val_ptr
,
void
*
dq_ptr
,
void
*
dk_ptr
,
void
*
dv_ptr
,
void
*
dbias_ptr
,
void
*
dq_acc_ptr
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_k_ptr
,
const
void
*
seqlen_k_ptr
,
...
...
@@ -422,6 +451,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
stride_dk
,
ck_tile
::
index_t
stride_dv
,
ck_tile
::
index_t
stride_dbias
,
...
...
@@ -432,13 +462,15 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dk
,
ck_tile
::
index_t
nhead_stride_dv
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
batch
_stride_
lsed
,
ck_tile
::
index_t
split
_stride_
dq_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
...
...
@@ -447,7 +479,7 @@ struct FmhaBwdDQDKDVKernel
lse_ptr
,
do_ptr
,
d_ptr
,
dq_ptr
,
dq_
acc_
ptr
,
dk_ptr
,
dv_ptr
,
-
1
,
// seqlen will be updated by another pointer
...
...
@@ -457,13 +489,12 @@ struct FmhaBwdDQDKDVKernel
num_head_q
,
nhead_ratio_qk
,
scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast
<
float
>
(
scale
*
ck_tile
::
log2e_v
<>
),
#endif
stride_q
,
stride_k
,
stride_v
,
stride_do
,
stride_dq_acc
,
stride_dk
,
stride_dv
,
nhead_stride_q
,
...
...
@@ -471,11 +502,14 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_v
,
nhead_stride_do
,
nhead_stride_lsed
,
batch_stride_lsed
},
// args for common karg
nhead_stride_dq_acc
,
nhead_stride_dk
,
nhead_stride_dv
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for dbias
{},
// placeholder for mask
{},
// placeholder for dropout
{},
// placeholder for deterministic
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
)};
...
...
@@ -506,10 +540,16 @@ struct FmhaBwdDQDKDVKernel
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
if
constexpr
(
kIsStoreRandval
)
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
}
}
if
constexpr
(
kIsDeterministic
)
{
kargs
.
split_stride_dq_acc
=
split_stride_dq_acc
;
}
return
kargs
;
...
...
@@ -518,7 +558,17 @@ struct FmhaBwdDQDKDVKernel
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_k_
)
{
return
TilePartitioner
::
GridSize
(
batch_size_
,
nhead_
,
seqlen_k_
);
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_k_
,
FmhaPipeline
::
kN0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
()
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -536,7 +586,7 @@ struct FmhaBwdDQDKDVKernel
__shared__
char
smem_ptr
[
GetSmemSize
()];
// divide problem
const
auto
[
i_tile_n
,
i_nhead
,
i_batch
]
=
Tile
Partitioner
{}(
kargs
.
seqlen_k
);
const
auto
[
i_tile_n
,
i_nhead
,
i_batch
]
=
Get
Tile
Index
(
);
const
index_t
i_n0
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN0
);
...
...
@@ -547,6 +597,7 @@ struct FmhaBwdDQDKDVKernel
long_index_t
batch_offset_randval
=
0
;
long_index_t
batch_offset_do
=
0
;
long_index_t
batch_offset_lsed
=
0
;
long_index_t
batch_offset_dq_acc
=
0
;
long_index_t
batch_offset_dk
=
0
;
long_index_t
batch_offset_dv
=
0
;
long_index_t
batch_offset_dbias
=
0
;
...
...
@@ -561,7 +612,8 @@ struct FmhaBwdDQDKDVKernel
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
batch_offset_do
=
query_start
*
kargs
.
stride_do
;
batch_offset_lsed
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lsed
;
batch_offset_lsed
=
query_start
;
batch_offset_dq_acc
=
query_start
*
kargs
.
stride_dq_acc
;
batch_offset_dk
=
key_start
*
kargs
.
stride_dk
;
batch_offset_dv
=
key_start
*
kargs
.
stride_dv
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
...
...
@@ -576,7 +628,7 @@ struct FmhaBwdDQDKDVKernel
{
batch_offset_dbias
=
key_start
;
}
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
batch_offset_randval
=
query_start
*
kargs
.
stride_randval
;
}
...
...
@@ -608,6 +660,7 @@ struct FmhaBwdDQDKDVKernel
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_do
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_do
;
batch_offset_lsed
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lsed
;
batch_offset_dq_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dq_acc
;
batch_offset_dk
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dk
;
batch_offset_dv
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dv
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
...
...
@@ -618,7 +671,7 @@ struct FmhaBwdDQDKDVKernel
{
batch_offset_dbias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dbias
;
}
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
batch_offset_randval
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_randval
;
...
...
@@ -646,14 +699,11 @@ struct FmhaBwdDQDKDVKernel
const
OGradDataType
*
do_ptr
=
reinterpret_cast
<
const
OGradDataType
*>
(
kargs
.
do_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_do
+
batch_offset_do
;
QGradDataType
*
dq_ptr
=
reinterpret_cast
<
QGradDataType
*>
(
kargs
.
dq_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_q
+
batch_offset_q
;
KGradDataType
*
dk_ptr
=
reinterpret_cast
<
KGradDataType
*>
(
kargs
.
dk_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_k
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_
d
k
+
batch_offset_dk
;
VGradDataType
*
dv_ptr
=
reinterpret_cast
<
VGradDataType
*>
(
kargs
.
dv_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_v
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_
d
v
+
batch_offset_dv
;
// Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window
...
...
@@ -663,45 +713,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple
(
kargs
.
stride_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQ
>
{},
number
<
1
>
{});
const
auto
q_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
{
return
pad_tensor_view
(
const
auto
q_dram
=
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
else
{
return
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
}();
const
auto
qt_dram_naive
=
transform_tensor_view
(
q_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_q
),
make_pass_through_transform
(
kargs
.
seqlen_q
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
qt_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQTLoadOnce
)
{
return
pad_tensor_view
(
qt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenQ
>
{});
}
else
{
return
pad_tensor_view
(
qt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK3
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenQ
>
{});
}
}();
const
auto
k_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
k_ptr
,
...
...
@@ -709,45 +724,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple
(
kargs
.
stride_k
,
1
),
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
1
>
{});
const
auto
k_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKLoadOnce
)
{
return
pad_tensor_view
(
const
auto
k_dram
=
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}
else
{
return
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}
}();
const
auto
kt_dram_naive
=
transform_tensor_view
(
k_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_q
),
make_pass_through_transform
(
kargs
.
seqlen_k
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
kt_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKTLoadOnce
)
{
return
pad_tensor_view
(
kt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kN0
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenK
>
{});
}
else
{
return
pad_tensor_view
(
kt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK4
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenK
>
{});
}
}();
const
auto
v_dram
=
[
&
]()
{
const
auto
v_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
...
...
@@ -756,20 +736,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple
(
kargs
.
stride_v
,
1
),
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
1
>
{});
if
constexpr
(
FmhaPipeline
::
kVLoadOnce
)
{
return
pad_tensor_view
(
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimV
>
{});
}
else
{
return
pad_tensor_view
(
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK2
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimV
>
{});
}
}();
const
auto
lse_dram
=
[
&
]()
{
...
...
@@ -792,145 +762,89 @@ struct FmhaBwdDQDKDVKernel
make_tuple
(
kargs
.
stride_do
,
1
),
number
<
FmhaPipeline
::
kAlignmentOGrad
>
{},
number
<
1
>
{});
const
auto
do_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradLoadOnce
)
{
return
pad_tensor_view
(
const
auto
do_dram
=
pad_tensor_view
(
do_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
}
else
{
return
pad_tensor_view
(
do_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK2
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
}
}();
const
auto
dot_dram_naive
=
transform_tensor_view
(
do_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_v
),
make_pass_through_transform
(
kargs
.
seqlen_q
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
dot_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradTLoadOnce
)
{
return
pad_tensor_view
(
dot_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenQ
>
{});
}
else
{
return
pad_tensor_view
(
dot_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenQ
>
{});
}
}();
auto
dq_dram
=
[
&
]()
{
const
auto
dq_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
,
memory_operation_enum
::
atomic_add
>
(
dq_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
auto
q_dram_window
=
make_tile_window
(
q_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
{
0
,
0
});
auto
qt_dram_window
=
make_tile_window
(
qt_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQTLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK3
>
{});
}(),
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
0
,
0
});
auto
k_dram_window
=
make_tile_window
(
k_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
i_n0
,
0
});
auto
kt_dram_window
=
make_tile_window
(
kt_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKTLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK4
>
{});
}(),
{
0
,
i_n0
});
auto
v_dram_window
=
make_tile_window
(
v_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kVLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK2
>
{});
}(),
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
{
i_n0
,
0
});
auto
do_dram_window
=
make_tile_window
(
do_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK2
>
{});
}(),
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
{
0
,
0
});
auto
dot_dram_window
=
make_tile_window
(
dot_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradTLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kK1
>
{});
}(),
auto
dq_dram_window
=
[
&
,
i_tile_n_
=
i_tile_n
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
kIsDeterministic
)
{
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_dq_acc
+
static_cast
<
long_index_t
>
(
i_tile_n_
)
*
kargs
.
split_stride_dq_acc
+
batch_offset_dq_acc
;
auto
dq_acc_dram
=
[
&
]()
{
const
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_dq_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_acc_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
return
make_tile_window
(
dq_acc_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
0
,
0
});
}
else
{
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_dq_acc
+
batch_offset_dq_acc
;
auto
dq_acc_dram
=
[
&
]()
{
const
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
,
memory_operation_enum
::
atomic_add
>
(
dq_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_dq_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
auto
dq_dram_window
=
make_tile_window
(
dq_dram
,
return
pad_tensor_view
(
dq_acc_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
return
make_tile_window
(
dq_acc_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
0
,
0
});
}
}();
auto
lse_dram_window
=
make_tile_window
(
lse_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{}),
{
0
});
...
...
@@ -1008,9 +922,7 @@ struct FmhaBwdDQDKDVKernel
// TODO: how to use s_read?
AccDataType
slope
=
*
(
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
alibi_slope_ptr
)
+
i_batch_
*
kargs
.
alibi_slope_stride
+
i_nhead_
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
slope
*=
ck_tile
::
log2e_v
<>
;
#endif
if
constexpr
(
kHasMask
)
{
return
make_alibi_from_lr_mask
<
AccDataType
,
false
>
(
slope
,
...
...
@@ -1035,33 +947,32 @@ struct FmhaBwdDQDKDVKernel
// dropout
float
rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint64_t
drop_seed
=
0
;
uint64_t
drop_offset
=
0
;
bool
is_store_randval
=
false
;
if
constexpr
(
kHasDropout
)
{
rp_undrop
=
kargs
.
rp_undrop
;
scale_rp_undrop
=
kargs
.
scale_rp_undrop
;
p_undrop_in_uint8_t
=
kargs
.
p_undrop_in_uint8_t
;
drop_seed
=
kargs
.
drop_seed
;
drop_offset
=
kargs
.
drop_offset
;
is_store_randval
=
kargs
.
is_store_randval
;
}
BlockDropout
dropout
(
i_batch
,
i_nhead
,
auto
dropout
=
[
&
,
i_nhead_
=
i_nhead
,
i_batch_
=
i_batch
]()
{
if
constexpr
(
kHasDropout
)
{
return
FmhaDropout
{
i_batch_
,
i_nhead_
,
kargs
.
num_head_q
,
drop_seed
,
drop_offset
,
rp_undrop
,
p_undrop_in_uint8_t
,
is_store_randval
);
kargs
.
drop_seed
,
kargs
.
drop_offset
,
kargs
.
rp_undrop
,
kargs
.
p_undrop_in_uint8_t
};
}
else
{
return
FmhaDropout
{};
};
}();
auto
randval_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
randval_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
RandValOutputDataType
*
rand_val_ptr
=
reinterpret_cast
<
RandValOutputDataType
*>
(
kargs
.
rand_val_ptr
)
+
...
...
@@ -1103,14 +1014,11 @@ struct FmhaBwdDQDKDVKernel
}();
auto
[
dk_acc_tile
,
dv_acc_tile
]
=
FmhaPipeline
{}(
q_dram_window
,
qt_dram_window
,
k_dram_window
,
kt_dram_window
,
v_dram_window
,
bias_dram_window
,
randval_dram_window
,
do_dram_window
,
dot_dram_window
,
lse_dram_window
,
d_dram_window
,
dq_dram_window
,
...
...
@@ -1118,9 +1026,7 @@ struct FmhaBwdDQDKDVKernel
mask
,
position_encoding
,
kargs
.
raw_scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
kargs
.
scale
,
#endif
rp_undrop
,
scale_rp_undrop
,
smem_ptr
,
...
...
@@ -1169,10 +1075,9 @@ struct FmhaBwdDQDKDVKernel
}
};
template
<
typename
TilePartitioner_
,
typename
FmhaBwdOGradDotO_
>
template
<
typename
FmhaBwdOGradDotO_
>
struct
FmhaBwdOGradDotOKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaBwdOGradDotO
=
ck_tile
::
remove_cvref_t
<
FmhaBwdOGradDotO_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaBwdOGradDotO
::
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kBlockPerCu
=
FmhaBwdOGradDotO
::
kBlockPerCu
;
...
...
@@ -1234,13 +1139,13 @@ struct FmhaBwdOGradDotOKernel
ck_tile
::
index_t
nhead_stride_do
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
nhead_stride_d
;
ck_tile
::
index_t
batch_stride_d
;
};
struct
FmhaBwdOGradDotOBatchModeKargs
:
FmhaBwdOGradDotOCommonKargs
{
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
batch_stride_d
;
};
struct
FmhaBwdOGradDotOGroupModeKargs
:
FmhaBwdOGradDotOCommonKargs
...
...
@@ -1278,10 +1183,10 @@ struct FmhaBwdOGradDotOKernel
stride_o
,
nhead_stride_do
,
nhead_stride_o
,
nhead_stride_d
,
batch_stride_d
},
nhead_stride_d
},
batch_stride_do
,
batch_stride_o
};
batch_stride_o
,
batch_stride_d
};
return
kargs
;
}
...
...
@@ -1298,8 +1203,7 @@ struct FmhaBwdOGradDotOKernel
ck_tile
::
index_t
stride_o
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
nhead_stride_d
,
ck_tile
::
index_t
batch_stride_d
)
ck_tile
::
index_t
nhead_stride_d
)
{
Kargs
kargs
{{
o_ptr
,
do_ptr
,
...
...
@@ -1311,8 +1215,7 @@ struct FmhaBwdOGradDotOKernel
stride_o
,
nhead_stride_do
,
nhead_stride_o
,
nhead_stride_d
,
batch_stride_d
},
nhead_stride_d
},
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
)};
return
kargs
;
...
...
@@ -1321,7 +1224,16 @@ struct FmhaBwdOGradDotOKernel
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
)
{
return
TilePartitioner
::
GridSize
(
batch_size_
,
nhead_
,
seqlen_q_
);
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
()
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -1331,7 +1243,7 @@ struct FmhaBwdOGradDotOKernel
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// divide problem
const
auto
[
i_tile_m
,
i_nhead
,
i_batch
]
=
Tile
Partitioner
{}(
kargs
.
seqlen_q
);
const
auto
[
i_tile_m
,
i_nhead
,
i_batch
]
=
Get
Tile
Index
(
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
kM0
);
...
...
@@ -1346,7 +1258,7 @@ struct FmhaBwdOGradDotOKernel
batch_offset_o
=
query_start
*
kargs
.
stride_o
;
batch_offset_do
=
query_start
*
kargs
.
stride_do
;
batch_offset_d
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_d
;
batch_offset_d
=
query_start
;
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
...
...
@@ -1418,4 +1330,315 @@ struct FmhaBwdOGradDotOKernel
}
};
template
<
typename
FmhaBwdConvertQGrad_
>
struct
FmhaBwdConvertQGradKernel
{
using
FmhaBwdConvertQGrad
=
ck_tile
::
remove_cvref_t
<
FmhaBwdConvertQGrad_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaBwdConvertQGrad
::
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kBlockPerCu
=
FmhaBwdConvertQGrad
::
kBlockPerCu
;
static
constexpr
ck_tile
::
index_t
kM0
=
FmhaBwdConvertQGrad
::
kM0
;
static
constexpr
ck_tile
::
index_t
kN0
=
FmhaBwdConvertQGrad
::
kN0
;
static
constexpr
ck_tile
::
index_t
kQKHeaddim
=
FmhaBwdConvertQGrad
::
kQKHeaddim
;
using
AccDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaBwdConvertQGrad
::
AccDataType
>
;
using
QGradDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaBwdConvertQGrad
::
QGradDataType
>
;
static
constexpr
bool
kIsGroupMode
=
FmhaBwdConvertQGrad
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
FmhaBwdConvertQGrad
::
kPadSeqLenQ
;
static
constexpr
bool
kPadHeadDimQ
=
FmhaBwdConvertQGrad
::
kPadHeadDimQ
;
static
constexpr
bool
kIsDeterministic
=
FmhaBwdConvertQGrad
::
kIsDeterministic
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
ck_tile
::
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
// clang-format on
CK_TILE_HOST
static
std
::
string
GetName
()
{
// sync with generate.py
// clang-format off
#define _SS_ std::string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
std
::
string
n
;
if
(
kPadSeqLenQ
)
n
+=
"s"
;
if
(
kPadHeadDimQ
)
n
+=
"d"
;
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
_SS_
(
"fmha_bwd_convert_dq_d"
)
+
_TS_
(
kQKHeaddim
)
+
"_"
+
_SS_
(
t2s
<
QGradDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
(
kIsDeterministic
?
"_deterministic"
:
""
)
+
"_"
+
(
"o"
+
_TS_
(
kBlockPerCu
))
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
);
#undef _SS_
#undef _TS_
// clang-format on
}
// to avoid duplicated base class prblem, introduce an template arg
template
<
ck_tile
::
index_t
I
>
struct
FmhaBwdConvertQGradEmptyKargs
{
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct
FmhaBwdConvertQGradCommonKargs
{
const
void
*
dq_acc_ptr
;
void
*
dq_ptr
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_k
;
ck_tile
::
index_t
hdim_q
;
ck_tile
::
index_t
stride_dq
;
ck_tile
::
index_t
stride_dq_acc
;
ck_tile
::
index_t
nhead_stride_dq
;
ck_tile
::
index_t
nhead_stride_dq_acc
;
};
struct
FmhaBwdConvertQGradDeterministicKargs
{
ck_tile
::
index_t
split_stride_dq_acc
=
0
;
};
struct
FmhaBwdConvertQGradBatchModeKargs
:
FmhaBwdConvertQGradCommonKargs
,
std
::
conditional_t
<
kIsDeterministic
,
FmhaBwdConvertQGradDeterministicKargs
,
FmhaBwdConvertQGradEmptyKargs
<
0
>>
{
ck_tile
::
index_t
batch_stride_dq
;
ck_tile
::
index_t
batch_stride_dq_acc
;
};
struct
FmhaBwdConvertQGradGroupModeKargs
:
FmhaBwdConvertQGradCommonKargs
,
std
::
conditional_t
<
kIsDeterministic
,
FmhaBwdConvertQGradDeterministicKargs
,
FmhaBwdConvertQGradEmptyKargs
<
0
>>
{
const
int32_t
*
seqstart_q_ptr
;
const
int32_t
*
seqstart_k_ptr
;
};
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
FmhaBwdConvertQGradGroupModeKargs
,
FmhaBwdConvertQGradBatchModeKargs
>
;
template
<
bool
Cond
=
!
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
dq_acc_ptr
,
void
*
dq_ptr
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_k
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
stride_dq
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dq
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
batch_stride_dq
,
ck_tile
::
index_t
batch_stride_dq_acc
,
ck_tile
::
index_t
split_stride_dq_acc
)
{
Kargs
kargs
{{
dq_acc_ptr
,
dq_ptr
,
seqlen_q
,
seqlen_k
,
hdim_q
,
stride_dq
,
stride_dq_acc
,
nhead_stride_dq
,
nhead_stride_dq_acc
},
{},
batch_stride_dq
,
batch_stride_dq_acc
};
if
constexpr
(
kIsDeterministic
)
{
kargs
.
split_stride_dq_acc
=
split_stride_dq_acc
;
}
return
kargs
;
}
template
<
bool
Cond
=
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
dq_acc_ptr
,
void
*
dq_ptr
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_k_ptr
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
stride_dq
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dq
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
split_stride_dq_acc
)
{
Kargs
kargs
{{
dq_acc_ptr
,
dq_ptr
,
-
1
,
// seqlen will be updated by another pointer
-
1
,
//
hdim_q
,
stride_dq
,
stride_dq_acc
,
nhead_stride_dq
,
nhead_stride_dq_acc
},
{},
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
)};
if
constexpr
(
kIsDeterministic
)
{
kargs
.
split_stride_dq_acc
=
split_stride_dq_acc
;
}
return
kargs
;
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
)
{
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
()
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
0
;
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// divide problem
const
auto
[
i_tile_m
,
i_nhead
,
i_batch
]
=
GetTileIndex
();
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
kM0
);
long_index_t
batch_offset_dq
=
0
;
long_index_t
batch_offset_dq_acc
=
0
;
if
constexpr
(
kIsGroupMode
)
{
// get starting offset for each batch
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
batch_offset_dq
=
query_start
*
kargs
.
stride_dq
;
batch_offset_dq_acc
=
query_start
*
kargs
.
stride_dq_acc
;
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
kargs
.
seqlen_q
=
adjusted_seqstart_q_ptr
[
1
]
-
adjusted_seqstart_q_ptr
[
0
];
if
constexpr
(
kIsDeterministic
)
{
const
auto
adjusted_seqstart_k_ptr
=
kargs
.
seqstart_k_ptr
+
i_batch
;
kargs
.
seqlen_k
=
adjusted_seqstart_k_ptr
[
1
]
-
adjusted_seqstart_k_ptr
[
0
];
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if
(
kargs
.
seqlen_q
<=
i_m0
)
{
return
;
}
}
else
{
batch_offset_dq
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dq
;
batch_offset_dq_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dq_acc
;
}
// for simplicity, batch stride we just modify the pointer
QGradDataType
*
dq_ptr
=
reinterpret_cast
<
QGradDataType
*>
(
kargs
.
dq_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_dq
+
batch_offset_dq
;
// dQAcc/dQ DRAM and DRAM window
const
auto
dq_acc_dram
=
[
&
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
kIsDeterministic
)
{
const
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
(
kargs
.
nhead_stride_dq_acc
)
+
batch_offset_dq_acc
;
const
index_t
nsplits
=
ck_tile
::
integer_divide_ceil
(
kargs
.
seqlen_k
,
kN0
);
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_acc_ptr
,
make_tuple
(
nsplits
,
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
split_stride_dq_acc
,
kargs
.
stride_dq_acc
,
1
),
number
<
FmhaBwdConvertQGrad
::
kAlignmentQGradAcc
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_acc_dram_naive
,
make_tuple
(
number
<
1
>
{},
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
sequence
<
false
,
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
else
{
const
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
(
kargs
.
nhead_stride_dq_acc
)
+
batch_offset_dq_acc
;
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_dq_acc
,
1
),
number
<
FmhaBwdConvertQGrad
::
kAlignmentQGradAcc
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_acc_dram_naive
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
}();
auto
dq_dram
=
[
&
]()
{
auto
dq_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_dq
,
1
),
number
<
FmhaBwdConvertQGrad
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_dram_naive
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
auto
dq_acc_dram_window
=
[
&
]()
{
if
constexpr
(
kIsDeterministic
)
{
return
make_tile_window
(
dq_acc_dram
,
make_tuple
(
number
<
1
>
{},
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
{
0
,
i_m0
,
0
});
}
else
{
return
make_tile_window
(
dq_acc_dram
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
{
i_m0
,
0
});
}
}();
auto
dq_dram_window
=
make_tile_window
(
dq_dram
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
{
i_m0
,
0
});
if
constexpr
(
kIsDeterministic
)
{
const
index_t
nsplits
=
ck_tile
::
integer_divide_ceil
(
kargs
.
seqlen_k
,
kN0
);
FmhaBwdConvertQGrad
{}(
dq_acc_dram_window
,
dq_dram_window
,
nsplits
);
}
else
{
FmhaBwdConvertQGrad
{}(
dq_acc_dram_window
,
dq_dram_window
);
}
}
};
}
// namespace ck_tile
Prev
1
2
3
4
5
6
7
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment