Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
74f1516c
Commit
74f1516c
authored
Jul 10, 2024
by
danyao12
Browse files
tmp save
parent
497ccb87
Changes
43
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2047 additions
and
2551 deletions
+2047
-2551
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp
...k_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp
+0
-20
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp
.../fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp
+0
-821
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp
...block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp
+0
-20
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp
...ipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp
+0
-692
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp
...mha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp
+0
-20
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+1674
-938
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp
...k_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp
+1
-3
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp
...ile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp
+32
-3
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
...mha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
+12
-8
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp
...peline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp
+8
-7
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
...ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
+4
-1
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
+9
-7
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+5
-6
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
...le/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
+2
-1
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
+16
-0
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
+10
-4
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+3
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
...e/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
+202
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp
...gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp
+36
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp
...emm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp
+33
-0
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp
deleted
100644 → 0
View file @
497ccb87
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
// This pipeline is v located in regs, k & k^t located in lds.
using
BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy
=
BlockFmhaBwdPipelineDefaultPolicy
<
/* QLoadOnce_ = */
false
,
/* QTLoadOnce_ = */
false
,
/* KLoadOnce_ = */
true
,
/* KTLoadOnce_ = */
true
,
/* VLoadOnce_ = */
true
,
/* OGradLoadOnce_ = */
false
,
/* OGradTLoadOnce_ = */
false
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp
deleted
100644 → 0
View file @
497ccb87
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp
deleted
100644 → 0
View file @
497ccb87
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
// This pipeline is v located in regs, k located in lds.
using
BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy
=
BlockFmhaBwdPipelineDefaultPolicy
<
/* QLoadOnce_ = */
false
,
/* QTLoadOnce_ = */
false
,
/* KLoadOnce_ = */
true
,
/* KTLoadOnce_ = */
false
,
/* VLoadOnce_ = */
true
,
/* OGradLoadOnce_ = */
false
,
/* OGradTLoadOnce_ = */
false
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp
deleted
100644 → 0
View file @
497ccb87
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp
deleted
100644 → 0
View file @
497ccb87
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
// This pipeline is v located in regs, q & k & do located in lds.
using
BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy
=
BlockFmhaBwdPipelineDefaultPolicy
<
/* QLoadOnce_ = */
true
,
/* QTLoadOnce_ = */
false
,
/* KLoadOnce_ = */
true
,
/* KTLoadOnce_ = */
false
,
/* VLoadOnce_ = */
true
,
/* OGradLoadOnce_ = */
true
,
/* OGradTLoadOnce_ = */
false
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
74f1516c
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp
View file @
74f1516c
...
@@ -8,9 +8,7 @@ namespace ck_tile {
...
@@ -8,9 +8,7 @@ namespace ck_tile {
// This class is used for codegen pattern matching
// This class is used for codegen pattern matching
enum
class
BlockFmhaBwdPipelineEnum
enum
class
BlockFmhaBwdPipelineEnum
{
{
KSKTSVR
=
0
,
KRKTRVR
=
0
,
QSKSVROGradS
,
KSVR
,
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp
View file @
74f1516c
...
@@ -24,7 +24,9 @@ template <typename QDataType_,
...
@@ -24,7 +24,9 @@ template <typename QDataType_,
typename
BiasGradDataType_
,
typename
BiasGradDataType_
,
typename
BlockFmhaShape_
,
typename
BlockFmhaShape_
,
bool
kIsGroupMode_
,
bool
kIsGroupMode_
,
bool
kIsDeterministic_
,
typename
FmhaMask_
,
typename
FmhaMask_
,
typename
FmhaDropout_
,
typename
Traits_
>
typename
Traits_
>
struct
BlockFmhaBwdPipelineProblem
struct
BlockFmhaBwdPipelineProblem
{
{
...
@@ -45,10 +47,12 @@ struct BlockFmhaBwdPipelineProblem
...
@@ -45,10 +47,12 @@ struct BlockFmhaBwdPipelineProblem
using
BiasGradDataType
=
remove_cvref_t
<
BiasGradDataType_
>
;
using
BiasGradDataType
=
remove_cvref_t
<
BiasGradDataType_
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
BlockFmhaShape_
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
BlockFmhaShape_
>
;
using
FmhaMask
=
remove_cvref_t
<
FmhaMask_
>
;
using
FmhaMask
=
remove_cvref_t
<
FmhaMask_
>
;
using
FmhaDropout
=
remove_cvref_t
<
FmhaDropout_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
index_t
kBlockSize
=
BlockFmhaShape
::
NumWarps
*
get_warp_size
();
static
constexpr
index_t
kBlockSize
=
BlockFmhaShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
bool
kIsDeterministic
=
kIsDeterministic_
;
// attributes from traits
// attributes from traits
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
...
@@ -57,7 +61,6 @@ struct BlockFmhaBwdPipelineProblem
...
@@ -57,7 +61,6 @@ struct BlockFmhaBwdPipelineProblem
static
constexpr
bool
kPadHeadDimV
=
Traits
::
kPadHeadDimV
;
static
constexpr
bool
kPadHeadDimV
=
Traits
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Traits
::
BiasEnum
;
static
constexpr
auto
BiasEnum
=
Traits
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
Traits
::
kHasBiasGrad
;
static
constexpr
bool
kHasBiasGrad
=
Traits
::
kHasBiasGrad
;
static
constexpr
bool
kHasDropout
=
Traits
::
kHasDropout
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
};
...
@@ -88,4 +91,30 @@ struct BlockFmhaBwdOGradDotOPipelineProblem
...
@@ -88,4 +91,30 @@ struct BlockFmhaBwdOGradDotOPipelineProblem
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
};
template
<
typename
AccDataType_
,
typename
QGradDataType_
,
typename
Shape_
,
typename
Traits_
,
bool
kIsGroupMode_
,
bool
kIsDeterministic_
>
struct
BlockFmhaBwdConvertQGradPipelineProblem
{
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
QGradDataType
=
remove_cvref_t
<
QGradDataType_
>
;
using
Shape
=
remove_cvref_t
<
Shape_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
index_t
kBlockSize
=
Shape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
bool
kIsDeterministic
=
kIsDeterministic_
;
static_assert
(
0
<
kBlockSize
&&
kBlockSize
%
get_warp_size
()
==
0
,
"kBlockSize should be divisible by get_warp_size()"
);
// attributes from traits
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
static
constexpr
bool
kPadHeadDimQ
=
Traits
::
kPadHeadDimQ
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
View file @
74f1516c
...
@@ -28,6 +28,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -28,6 +28,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaDropout
=
remove_cvref_t
<
typename
Problem
::
FmhaDropout
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
...
@@ -50,7 +51,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -50,7 +51,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
true
;
// always store LSE (acc)
static
constexpr
bool
kStoreLSE
=
true
;
// always store LSE (acc)
static
constexpr
bool
kHasDropout
=
false
;
// ignore this flag
static
constexpr
bool
kHasUnevenSplits
=
Problem
::
kHasUnevenSplits
;
static
constexpr
bool
kHasUnevenSplits
=
Problem
::
kHasUnevenSplits
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
...
@@ -141,7 +141,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -141,7 +141,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
,
void
*
smem_ptr
,
Block
Dropout
&
dropout
)
const
Fmha
Dropout
dropout
)
const
{
{
static_assert
(
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
...
@@ -249,7 +249,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -249,7 +249,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
MakeRandvalDramWindow
<
decltype
(
gemm_0
)
>
(
auto
randval_dram_window
=
dropout
.
template
MakeRandvalDramWindow
<
decltype
(
gemm_0
)>(
randval_dram_block_window_tmp
,
seqlen_k_start
);
randval_dram_block_window_tmp
,
seqlen_k_start
);
auto
v_dram_window
=
auto
v_dram_window
=
...
@@ -501,10 +501,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -501,10 +501,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
});
});
});
});
if
constexpr
(
kHa
sDropout
)
if
constexpr
(
FmhaDropout
::
I
sDropout
)
{
{
dropout
.
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>
(
dropout
.
template
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>(
smem_ptr
,
seqlen_k_start
+
i_total_loops
*
kN0
,
p_compute
,
randval_dram_window
);
smem_ptr
,
q_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
+
i_total_loops
*
kN0
,
p_compute
,
randval_dram_window
);
}
}
block_sync_lds
();
block_sync_lds
();
...
@@ -637,7 +641,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -637,7 +641,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
,
void
*
smem_ptr
,
Block
Dropout
&
dropout
)
const
Fmha
Dropout
dropout
)
const
{
{
return
operator
()(
q_dram_block_window_tmp
,
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
identity
{},
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp
View file @
74f1516c
...
@@ -29,6 +29,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -29,6 +29,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaDropout
=
remove_cvref_t
<
typename
Problem
::
FmhaDropout
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
...
@@ -55,7 +56,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -55,7 +56,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
static
constexpr
bool
kPadHeadDimV
=
true
;
// support multiple of vector(like 8x)
static
constexpr
bool
kPadHeadDimV
=
true
;
// support multiple of vector(like 8x)
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
true
;
// always store LSE (acc)
static
constexpr
bool
kStoreLSE
=
true
;
// always store LSE (acc)
static
constexpr
bool
kHasDropout
=
false
;
// ignore this flag
static
constexpr
bool
kHasUnevenSplits
=
Problem
::
kHasUnevenSplits
;
static
constexpr
bool
kHasUnevenSplits
=
Problem
::
kHasUnevenSplits
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
...
@@ -153,7 +153,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -153,7 +153,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
,
void
*
smem_ptr
,
Block
Dropout
&
dropout
)
const
Fmha
Dropout
dropout
)
const
{
{
static_assert
(
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
...
@@ -301,7 +301,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -301,7 +301,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
MakeRandvalDramWindow
<
decltype
(
gemm_0
)
>
(
auto
randval_dram_window
=
dropout
.
template
MakeRandvalDramWindow
<
decltype
(
gemm_0
)>(
randval_dram_block_window_tmp
,
seqlen_k_start
);
randval_dram_block_window_tmp
,
seqlen_k_start
);
auto
v_dram_window
=
auto
v_dram_window
=
...
@@ -584,12 +584,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -584,12 +584,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
});
});
});
});
if
constexpr
(
kHa
sDropout
)
if
constexpr
(
FmhaDropout
::
I
sDropout
)
{
{
auto
randval_ptr
=
auto
randval_ptr
=
reinterpret_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeKV
<
Problem
>();
reinterpret_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeKV
<
Problem
>();
dropout
.
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>
(
dropout
.
template
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>(
randval_ptr
,
randval_ptr
,
q_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
+
i_total_loops
*
kN0
,
seqlen_k_start
+
i_total_loops
*
kN0
,
p_compute
,
p_compute
,
randval_dram_window
);
randval_dram_window
);
...
@@ -741,7 +742,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -741,7 +742,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
,
void
*
smem_ptr
,
Block
Dropout
&
dropout
)
const
Fmha
Dropout
dropout
)
const
{
{
return
operator
()(
q_dram_block_window_tmp
,
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
identity
{},
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
View file @
74f1516c
...
@@ -21,6 +21,7 @@ template <typename QDataType_,
...
@@ -21,6 +21,7 @@ template <typename QDataType_,
typename
BlockFmhaShape_
,
typename
BlockFmhaShape_
,
bool
kIsGroupMode_
,
bool
kIsGroupMode_
,
typename
FmhaMask_
,
typename
FmhaMask_
,
typename
FmhaDropout_
,
typename
Traits_
>
typename
Traits_
>
struct
BlockFmhaPipelineProblem
struct
BlockFmhaPipelineProblem
{
{
...
@@ -37,6 +38,7 @@ struct BlockFmhaPipelineProblem
...
@@ -37,6 +38,7 @@ struct BlockFmhaPipelineProblem
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
BlockFmhaShape_
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
BlockFmhaShape_
>
;
using
FmhaMask
=
remove_cvref_t
<
FmhaMask_
>
;
using
FmhaMask
=
remove_cvref_t
<
FmhaMask_
>
;
using
FmhaDropout
=
remove_cvref_t
<
FmhaDropout_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
index_t
kBlockSize
=
BlockFmhaShape
::
NumWarps
*
get_warp_size
();
static
constexpr
index_t
kBlockSize
=
BlockFmhaShape
::
NumWarps
*
get_warp_size
();
...
@@ -49,7 +51,6 @@ struct BlockFmhaPipelineProblem
...
@@ -49,7 +51,6 @@ struct BlockFmhaPipelineProblem
static
constexpr
bool
kPadHeadDimV
=
Traits
::
kPadHeadDimV
;
static
constexpr
bool
kPadHeadDimV
=
Traits
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Traits
::
BiasEnum
;
static
constexpr
auto
BiasEnum
=
Traits
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Traits
::
kStoreLSE
;
static
constexpr
bool
kStoreLSE
=
Traits
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
Traits
::
kHasDropout
;
static
constexpr
bool
kDoFp8StaticQuant
=
Traits
::
kDoFp8StaticQuant
;
static
constexpr
bool
kDoFp8StaticQuant
=
Traits
::
kDoFp8StaticQuant
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
};
...
@@ -68,6 +69,7 @@ template <typename QDataType,
...
@@ -68,6 +69,7 @@ template <typename QDataType,
typename
BlockFmhaShape
,
typename
BlockFmhaShape
,
bool
kIsGroupMode
,
bool
kIsGroupMode
,
typename
FmhaMask
,
typename
FmhaMask
,
typename
FmhaDropout
,
typename
Traits
>
typename
Traits
>
struct
BlockFmhaFwdSplitKVPipelineProblem
:
BlockFmhaPipelineProblem
<
QDataType
,
struct
BlockFmhaFwdSplitKVPipelineProblem
:
BlockFmhaPipelineProblem
<
QDataType
,
KDataType
,
KDataType
,
...
@@ -83,6 +85,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem<QDataType,
...
@@ -83,6 +85,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem<QDataType,
BlockFmhaShape
,
BlockFmhaShape
,
kIsGroupMode
,
kIsGroupMode
,
FmhaMask
,
FmhaMask
,
FmhaDropout
,
Traits
>
Traits
>
{
{
static
constexpr
bool
kHasUnevenSplits
=
kIsGroupMode
||
Traits
::
kHasUnevenSplits
;
static
constexpr
bool
kHasUnevenSplits
=
kIsGroupMode
||
Traits
::
kHasUnevenSplits
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
View file @
74f1516c
...
@@ -29,6 +29,7 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -29,6 +29,7 @@ struct BlockFmhaPipelineQRKSVS
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaDropout
=
remove_cvref_t
<
typename
Problem
::
FmhaDropout
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
...
@@ -51,7 +52,6 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -51,7 +52,6 @@ struct BlockFmhaPipelineQRKSVS
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
// ... together with tensor distribution. tensor dist should able to overwrite this
...
@@ -100,8 +100,6 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -100,8 +100,6 @@ struct BlockFmhaPipelineQRKSVS
static
constexpr
const
char
*
name
=
"qr"
;
static
constexpr
const
char
*
name
=
"qr"
;
using
DropoutType
=
std
::
conditional_t
<
kHasDropout
,
BlockDropout
,
NullBlockDropout
>
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
...
@@ -141,7 +139,7 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -141,7 +139,7 @@ struct BlockFmhaPipelineQRKSVS
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
,
void
*
smem_ptr
,
Dropout
Type
&
dropout
)
const
Fmha
Dropout
dropout
)
const
{
{
static_assert
(
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
...
@@ -486,10 +484,14 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -486,10 +484,14 @@ struct BlockFmhaPipelineQRKSVS
});
});
});
});
if
constexpr
(
kHa
sDropout
)
if
constexpr
(
FmhaDropout
::
I
sDropout
)
{
{
dropout
.
template
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>(
dropout
.
template
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>(
smem_ptr
,
seqlen_k_start
+
i_total_loops
*
kN0
,
p_compute
,
randval_dram_window
);
smem_ptr
,
q_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
+
i_total_loops
*
kN0
,
p_compute
,
randval_dram_window
);
}
}
block_sync_lds
();
block_sync_lds
();
...
@@ -620,7 +622,7 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -620,7 +622,7 @@ struct BlockFmhaPipelineQRKSVS
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
,
void
*
smem_ptr
,
Dropout
Type
&
dropout
)
const
Fmha
Dropout
dropout
)
const
{
{
return
operator
()(
q_dram_block_window_tmp
,
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
identity
{},
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
74f1516c
...
@@ -30,6 +30,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -30,6 +30,7 @@ struct BlockFmhaPipelineQRKSVSAsync
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaDropout
=
remove_cvref_t
<
typename
Problem
::
FmhaDropout
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
...
@@ -56,7 +57,6 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -56,7 +57,6 @@ struct BlockFmhaPipelineQRKSVSAsync
static
constexpr
bool
kPadHeadDimV
=
true
;
// support multiple of vector(like 8x)
static
constexpr
bool
kPadHeadDimV
=
true
;
// support multiple of vector(like 8x)
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
// ... together with tensor distribution. tensor dist should able to overwrite this
...
@@ -112,8 +112,6 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -112,8 +112,6 @@ struct BlockFmhaPipelineQRKSVSAsync
static
constexpr
const
char
*
name
=
"qr_async"
;
static
constexpr
const
char
*
name
=
"qr_async"
;
using
DropoutType
=
std
::
conditional_t
<
kHasDropout
,
BlockDropout
,
NullBlockDropout
>
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
...
@@ -153,7 +151,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -153,7 +151,7 @@ struct BlockFmhaPipelineQRKSVSAsync
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
,
void
*
smem_ptr
,
Dropout
Type
&
dropout
)
const
Fmha
Dropout
dropout
)
const
{
{
static_assert
(
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
...
@@ -569,12 +567,13 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -569,12 +567,13 @@ struct BlockFmhaPipelineQRKSVSAsync
});
});
});
});
if
constexpr
(
kHa
sDropout
)
if
constexpr
(
FmhaDropout
::
I
sDropout
)
{
{
auto
randval_ptr
=
auto
randval_ptr
=
reinterpret_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeKV
<
Problem
>();
reinterpret_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeKV
<
Problem
>();
dropout
.
template
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>(
dropout
.
template
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>(
randval_ptr
,
randval_ptr
,
q_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
+
i_total_loops
*
kN0
,
seqlen_k_start
+
i_total_loops
*
kN0
,
p_compute
,
p_compute
,
randval_dram_window
);
randval_dram_window
);
...
@@ -730,7 +729,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -730,7 +729,7 @@ struct BlockFmhaPipelineQRKSVSAsync
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
,
void
*
smem_ptr
,
Dropout
Type
&
dropout
)
const
Fmha
Dropout
dropout
)
const
{
{
return
operator
()(
q_dram_block_window_tmp
,
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
identity
{},
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
View file @
74f1516c
...
@@ -28,6 +28,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
...
@@ -28,6 +28,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaDropout
=
remove_cvref_t
<
typename
Problem
::
FmhaDropout
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
...
@@ -124,7 +125,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
...
@@ -124,7 +125,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
float
descale_qk
,
float
descale_qk
,
float
descale_sv
,
float
descale_sv
,
void
*
smem_ptr
,
void
*
smem_ptr
,
Block
Dropout
&
/*dropout*/
)
const
// not supported
Fmha
Dropout
&
/*dropout*/
)
const
// not supported
{
{
static_assert
(
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
...
...
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
View file @
74f1516c
...
@@ -92,4 +92,20 @@ struct TileFmhaBwdShape
...
@@ -92,4 +92,20 @@ struct TileFmhaBwdShape
// that need load V at once
// that need load V at once
};
};
template
<
typename
BlockTile_
,
// sequence<...
typename
BlockWarps_
,
typename
WarpTile_
>
struct
TileFmhaBwdConvertQGradShape
{
using
BlockTile
=
remove_cvref_t
<
BlockTile_
>
;
using
BlockWarps
=
remove_cvref_t
<
BlockWarps_
>
;
using
WarpTile
=
remove_cvref_t
<
WarpTile_
>
;
static
constexpr
index_t
NumWarps
=
reduce_on_sequence
(
BlockWarps
{},
multiplies
{},
number
<
1
>
{});
static
constexpr
index_t
kM0
=
BlockTile
::
at
(
number
<
0
>
{});
// tile size along q seqlen
static
constexpr
index_t
kN0
=
BlockTile
::
at
(
number
<
1
>
{});
// tile size along k seqlen
static
constexpr
index_t
kQKHeaddim
=
BlockTile
::
at
(
number
<
2
>
{});
// Q & K headdim
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
View file @
74f1516c
...
@@ -15,7 +15,6 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
...
@@ -15,7 +15,6 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
BlockAttentionBiasEnum
BiasEnum_
,
BlockAttentionBiasEnum
BiasEnum_
,
bool
kHasBiasGrad_
,
bool
kHasBiasGrad_
,
bool
kStoreLSE_
,
bool
kStoreLSE_
,
bool
kHasDropout_
,
bool
kDoFp8StaticQuant_
,
bool
kDoFp8StaticQuant_
,
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
struct
TileFmhaTraits
struct
TileFmhaTraits
...
@@ -27,7 +26,6 @@ struct TileFmhaTraits
...
@@ -27,7 +26,6 @@ struct TileFmhaTraits
static
constexpr
auto
BiasEnum
=
BiasEnum_
;
static
constexpr
auto
BiasEnum
=
BiasEnum_
;
static
constexpr
bool
kHasBiasGrad
=
kHasBiasGrad_
;
static
constexpr
bool
kHasBiasGrad
=
kHasBiasGrad_
;
static
constexpr
bool
kStoreLSE
=
kStoreLSE_
;
static
constexpr
bool
kStoreLSE
=
kStoreLSE_
;
static
constexpr
bool
kHasDropout
=
kHasDropout_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
};
...
@@ -39,7 +37,6 @@ template <bool kPadSeqLenQ /* padding for seqlen_q */,
...
@@ -39,7 +37,6 @@ template <bool kPadSeqLenQ /* padding for seqlen_q */,
BlockAttentionBiasEnum
BiasEnum
,
BlockAttentionBiasEnum
BiasEnum
,
bool
kHasBiasGrad
,
bool
kHasBiasGrad
,
bool
kStoreLSE
,
bool
kStoreLSE
,
bool
kHasDropout
,
bool
kDoFp8StaticQuant
,
bool
kDoFp8StaticQuant
,
bool
kHasUnevenSplits_
=
true
,
bool
kHasUnevenSplits_
=
true
,
index_t
kBlockPerCu
=
-
1
/* overwrite occupancy if not -1 */
>
index_t
kBlockPerCu
=
-
1
/* overwrite occupancy if not -1 */
>
...
@@ -50,7 +47,6 @@ struct TileFmhaFwdSplitKVTraits : TileFmhaTraits<kPadSeqLenQ,
...
@@ -50,7 +47,6 @@ struct TileFmhaFwdSplitKVTraits : TileFmhaTraits<kPadSeqLenQ,
BiasEnum
,
BiasEnum
,
kHasBiasGrad
,
kHasBiasGrad
,
kStoreLSE
,
kStoreLSE
,
kHasDropout
,
kDoFp8StaticQuant
,
kDoFp8StaticQuant
,
kBlockPerCu
>
kBlockPerCu
>
{
{
...
@@ -86,4 +82,14 @@ struct TileFmhaBwdOGradDotOTraits
...
@@ -86,4 +82,14 @@ struct TileFmhaBwdOGradDotOTraits
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
};
template
<
bool
kPadSeqLenQ_
/* padding for seqlen_q */
,
bool
kPadHeadDimQ_
/* paddding for hdim_q */
,
index_t
kBlockPerCu_
=
2
/* hint to occupancy */
>
struct
TileFmhaBwdConvertQGradTraits
{
static
constexpr
bool
kPadSeqLenQ
=
kPadSeqLenQ_
;
static
constexpr
bool
kPadHeadDimQ
=
kPadHeadDimQ_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/gemm.hpp
View file @
74f1516c
...
@@ -5,6 +5,9 @@
...
@@ -5,6 +5,9 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
0 → 100644
View file @
74f1516c
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp
0 → 100644
View file @
74f1516c
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp
0 → 100644
View file @
74f1516c
This diff is collapsed.
Click to expand it.
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment