Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
237c93c8
Commit
237c93c8
authored
Jul 15, 2024
by
danyao12
Browse files
bias support
parent
ca4a9f00
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
18 deletions
+18
-18
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
...a/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
+9
-7
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+9
-11
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
View file @
237c93c8
...
@@ -384,7 +384,11 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -384,7 +384,11 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
BiasDataType
*
biast_lds_ptr
=
static_cast
<
BiasDataType
*>
(
static_cast
<
void
*>
(
BiasDataType
*
biast_lds_ptr
=
static_cast
<
BiasDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()));
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeLSE
<
Problem
>()
+
Policy
::
template
GetSmemSizeD
<
Problem
>()));
auto
biast_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
biast_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
biast_lds_ptr
,
Policy
::
template
MakeBiasTLdsBlockDescriptor
<
Problem
>());
biast_lds_ptr
,
Policy
::
template
MakeBiasTLdsBlockDescriptor
<
Problem
>());
...
@@ -555,9 +559,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -555,9 +559,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
block_sync_lds
();
auto
bias_shuffle_tmp
=
make_static_distributed_tensor
<
BiasDataType
>
(
auto
bias_shuffle_tmp
=
make_static_distributed_tensor
<
BiasDataType
>
(
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
shuffle_tile
(
bias_shuffle_tmp
,
bias_tile
);
shuffle_tile
(
bias_shuffle_tmp
,
bias_tile
);
...
@@ -571,6 +573,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -571,6 +573,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
st_acc
,
st_acc
,
biast_tile
);
biast_tile
);
move_tile_window
(
bias_dram_window
,
{
kM0
,
0
});
move_tile_window
(
bias_dram_window
,
{
kM0
,
0
});
__builtin_amdgcn_sched_barrier
(
0
);
}
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
{
...
@@ -725,6 +728,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -725,6 +728,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
shuffle_tile
(
dbiast_shuffle_tmp
,
dbiast_tile
);
shuffle_tile
(
dbiast_shuffle_tmp
,
dbiast_tile
);
store_tile
(
dbias_dram_window
,
dbiast_shuffle_tmp
);
store_tile
(
dbias_dram_window
,
dbiast_shuffle_tmp
);
move_tile_window
(
dbias_dram_window
,
{
kM0
,
0
});
move_tile_window
(
dbias_dram_window
,
{
kM0
,
0
});
__builtin_amdgcn_sched_barrier
(
0
);
}
}
// STAGE 6, SGrad^T@Q^T Gemm3
// STAGE 6, SGrad^T@Q^T Gemm3
...
@@ -807,9 +811,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -807,9 +811,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
block_sync_lds
();
auto
bias_shuffle_tmp
=
make_static_distributed_tensor
<
BiasDataType
>
(
auto
bias_shuffle_tmp
=
make_static_distributed_tensor
<
BiasDataType
>
(
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
shuffle_tile
(
bias_shuffle_tmp
,
bias_tile
);
shuffle_tile
(
bias_shuffle_tmp
,
bias_tile
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
237c93c8
...
@@ -331,21 +331,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -331,21 +331,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetTransposedAlignmentBias
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetTransposedAlignmentBias
()
{
{
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kTotalPixels
=
kMPerBlock
*
kNPerBlock
/
kBlockSize
;
constexpr
index_t
kTotalPixels
=
kMPerBlock
*
kNPerBlock
/
kBlockSize
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
BiasDataType
);
// TODO: not correct!
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
BiasDataType
);
if
constexpr
(
kTotalPixels
>
32
)
return
8
;
constexpr
index_t
kVecLoad
=
((
kTotalPixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
else
?
kMaxVecLoad
return
4
;
:
(
kTotalPixels
/
kMinVecLoad
);
return
kVecLoad
;
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -617,7 +613,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -617,7 +613,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackBias
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackBias
()
{
{
return
GetAlignmentBias
<
Problem
>
();
// TODO: this is for 3d layout
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
return
16
/
sizeof
(
BiasDataType
);
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -1682,7 +1680,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1682,7 +1680,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
smem_size_stage0_1
=
smem_size_v
;
constexpr
index_t
smem_size_stage0_1
=
smem_size_v
;
constexpr
index_t
smem_size_stage1
=
smem_size_qt
+
smem_size_q
+
+
smem_size_dot
+
constexpr
index_t
smem_size_stage1
=
smem_size_qt
+
smem_size_q
+
+
smem_size_dot
+
smem_size_do
+
smem_size_lse
+
smem_size_d
+
smem_size_do
+
smem_size_lse
+
smem_size_d
+
smem_size_ds
;
max
(
smem_size_bias
,
smem_size_ds
)
;
constexpr
index_t
smem_size_stage2
=
smem_size_qt
+
smem_size_bias
;
constexpr
index_t
smem_size_stage2
=
smem_size_qt
+
smem_size_bias
;
constexpr
index_t
smem_size_stage3
=
smem_size_qt
;
constexpr
index_t
smem_size_stage3
=
smem_size_qt
;
constexpr
index_t
smem_size_stage4
=
smem_size_qt
+
smem_size_do
+
smem_size_d
;
constexpr
index_t
smem_size_stage4
=
smem_size_qt
+
smem_size_do
+
smem_size_d
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment