Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d99d4d56
Commit
d99d4d56
authored
Feb 18, 2025
by
aska-0096
Browse files
remove xor usage in q, do and ds
parent
545eec16
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
13 deletions
+11
-13
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+11
-13
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
d99d4d56
...
...
@@ -1011,9 +1011,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
KPack
=
GetSmemKPackQ
<
Problem
>
();
constexpr
index_t
k
KPack
=
GetSmemKPackQ
<
Problem
>
();
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
KPack
,
false
>
();
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
k
KPack
,
false
>
();
}
template
<
typename
Problem
>
...
...
@@ -1077,7 +1077,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPack
=
Get
SmemKPack
Q
<
Problem
>
();
constexpr
index_t
kKPack
=
Get
Alignment
Q
<
Problem
>
();
constexpr
index_t
kKPackT
=
GetSmemKPackQT
<
Problem
>
();
return
MakeXTLdsBlockDescriptor
<
Problem
,
kNPerBlock
,
kKPerBlock
,
kKPack
,
kKPackT
>
();
...
...
@@ -1218,9 +1218,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
KPack
=
GetSmemKPackOGrad
<
Problem
>
();
constexpr
index_t
k
KPack
=
GetSmemKPackOGrad
<
Problem
>
();
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
KPack
,
false
>
();
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
k
KPack
,
false
>
();
}
template
<
typename
Problem
>
...
...
@@ -1284,7 +1284,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPack
=
Get
SmemKPack
OGrad
<
Problem
>
();
constexpr
index_t
kKPack
=
Get
Alignment
OGrad
<
Problem
>
();
constexpr
index_t
kKPackT
=
GetSmemKPackOGradT
<
Problem
>
();
return
MakeXTLdsBlockDescriptor
<
Problem
,
kNPerBlock
,
kKPerBlock
,
kKPack
,
kKPackT
>
();
...
...
@@ -1377,7 +1377,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPack
=
GetSmemKPackSGrad
<
Problem
>
();
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
kKPack
>
();
return
MakeXLdsBlockDescriptor
<
kMPerBlock
,
kKPerBlock
,
kKPack
,
false
>
();
}
template
<
typename
Problem
>
...
...
@@ -1924,20 +1924,18 @@ struct BlockFmhaBwdPipelineDefaultPolicy
kM0
*
kQKHeaddim
/
get_warp_size
()
/
GetTransposedAlignmentQ
<
Problem
>
();
// 16 * 32 / 64 / 8 = 1
static
constexpr
index_t
SGradT_LDS_READ_P1
=
// kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
kM0
*
kK4
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
2
;
kM0
*
kK4
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
GetSmemKPackSGrad
<
Problem
>
();
// 16 * 128 / 64 / 8 = 4
static
constexpr
index_t
Q_LDS_READ
=
kM0
*
kK0
/
(
get_warp_size
()
*
Gemm0MWarp
)
/
Get
Alignment
Q
<
Problem
>
();
kM0
*
kK0
/
(
get_warp_size
()
*
Gemm0MWarp
)
/
Get
SmemKPack
Q
<
Problem
>
();
// 1
static
constexpr
index_t
LSE_LDS_READ
=
WarpGemmM
==
16
?
kM0
/
(
4
*
4
)
:
kM0
/
(
2
*
4
);
// 16 * 96 / 64 / 8 = 3
static
constexpr
index_t
SGradT_LDS_READ_P2
=
// kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
kM0
*
(
kN0
-
kK4
)
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
2
;
kM0
*
(
kN0
-
kK4
)
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
GetSmemKPackSGrad
<
Problem
>
();
// 16 * 128 / 64 / 8 = 4
static
constexpr
index_t
OGrad_LDS_READ
=
kM0
*
kK2
/
(
get_warp_size
()
*
Gemm2MWarp
)
/
Get
Alignment
OGrad
<
Problem
>
();
kM0
*
kK2
/
(
get_warp_size
()
*
Gemm2MWarp
)
/
Get
SmemKPack
OGrad
<
Problem
>
();
// 1
static
constexpr
index_t
D_LDS_READ
=
WarpGemmM
==
16
?
kM0
/
(
4
*
4
)
:
kM0
/
(
2
*
4
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment