Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
fb26ec5d
Commit
fb26ec5d
authored
Jul 16, 2024
by
danyao12
Browse files
hd256 bias support
parent
237c93c8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
122 additions
and
166 deletions
+122
-166
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
...a/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
+0
-2
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+122
-164
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
View file @
fb26ec5d
...
...
@@ -824,7 +824,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
},
st_acc
,
biast_tile
);
move_tile_window
(
bias_dram_window
,
{
kM0
,
0
});
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
...
...
@@ -963,7 +962,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
shuffle_tile
(
dbiast_shuffle_tmp
,
dbiast_tile
);
store_tile
(
dbias_dram_window
,
dbiast_shuffle_tmp
);
move_tile_window
(
dbias_dram_window
,
{
kM0
,
0
});
}
// STAGE 6, SGrad^T@Q^T Gemm3
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
fb26ec5d
...
...
@@ -190,11 +190,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
QDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
QDataType
);
constexpr
index_t
kT
otal
P
ixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
t
otal
_p
ixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
kVecLoad
=
((
kT
otal
P
ixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
constexpr
index_t
kVecLoad
=
((
t
otal
_p
ixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
?
kMaxVecLoad
:
(
kT
otal
P
ixels
/
kMinVecLoad
);
:
(
t
otal
_p
ixels
/
kMinVecLoad
);
return
kVecLoad
;
}
...
...
@@ -209,11 +209,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
KDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
KDataType
);
constexpr
index_t
kT
otal
P
ixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
t
otal
_p
ixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
kVecLoad
=
((
kT
otal
P
ixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
constexpr
index_t
kVecLoad
=
((
t
otal
_p
ixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
?
kMaxVecLoad
:
(
kT
otal
P
ixels
/
kMinVecLoad
);
:
(
t
otal
_p
ixels
/
kMinVecLoad
);
return
kVecLoad
;
}
...
...
@@ -226,9 +226,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
VDataType
);
constexpr
index_t
kT
otal
P
ixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
t
otal
_p
ixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
return
kT
otal
P
ixels
>
kMaxVecLoad
?
kMaxVecLoad
:
kT
otal
P
ixels
;
return
t
otal
_p
ixels
>
kMaxVecLoad
?
kMaxVecLoad
:
t
otal
_p
ixels
;
}
template
<
typename
Problem
>
...
...
@@ -248,11 +248,30 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
OGradDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
OGradDataType
);
constexpr
index_t
kT
otal
P
ixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
t
otal
_p
ixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
kVecLoad
=
((
kT
otal
P
ixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
constexpr
index_t
kVecLoad
=
((
t
otal
_p
ixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
?
kMaxVecLoad
:
(
kTotalPixels
/
kMinVecLoad
);
:
(
total_pixels
/
kMinVecLoad
);
return
kVecLoad
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentBias
()
{
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
BiasDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
BiasDataType
);
constexpr
index_t
total_pixels
=
kMPerBlock
*
kNPerBlock
/
kBlockSize
;
constexpr
index_t
kVecLoad
=
((
total_pixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
?
kMaxVecLoad
:
(
total_pixels
/
kMinVecLoad
);
return
kVecLoad
;
}
...
...
@@ -335,25 +354,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kTotalPixels
=
kMPerBlock
*
kNPerBlock
/
kBlockSize
;
// TODO: not correct!
if
constexpr
(
kTotalPixels
>
32
)
return
8
;
else
return
4
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentBias
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kTotalPixels
=
kMPerBlock
*
kNPerBlock
/
kBlockSize
;
constexpr
index_t
total_pixels
=
kMPerBlock
*
kNPerBlock
/
kBlockSize
;
return
kT
otal
P
ixels
/
Get
Transposed
AlignmentBias
<
Problem
>
();
return
t
otal
_p
ixels
/
GetAlignmentBias
<
Problem
>
();
}
template
<
typename
Problem
>
...
...
@@ -489,6 +492,29 @@ struct BlockFmhaBwdPipelineDefaultPolicy
sequence
<
2
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBiasTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
N1
=
GetAlignmentBias
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
M1
=
get_warp_size
()
/
N0
;
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M2
=
kMPerBlock
/
(
M1
*
M0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
N0
,
N1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
1
>>
{});
}
template
<
typename
DataType
,
index_t
MPerBlock
,
index_t
KPerBlock
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakePreXDramTileDistribution
()
{
...
...
@@ -613,9 +639,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackBias
()
{
// TODO: this is for 3d layout
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
return
16
/
sizeof
(
BiasDataType
);
return
GetAlignmentBias
<
Problem
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackBiasT
()
{
return
GetTransposedAlignmentBias
<
Problem
>
();
}
template
<
typename
Problem
>
...
...
@@ -1520,42 +1550,46 @@ struct BlockFmhaBwdPipelineDefaultPolicy
return
ds_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBiasTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
N1
=
GetAlignmentBias
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
M2
=
GetTransposedAlignmentBias
<
Problem
>
();
constexpr
index_t
M1
=
get_warp_size
()
/
N0
;
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
N0
,
N1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
2
,
1
>
,
sequence
<
1
,
2
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBiasTLdsBlockDescriptor
()
{
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
BiasDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackBias
<
Problem
>
();
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
// Hold full block data
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
static_assert
(
PixelsPerRow
%
kKPack
==
0
);
constexpr
index_t
NPerRow
=
PixelsPerRow
/
kKPack
;
static_assert
(
kNPerBlock
%
NPerRow
==
0
);
static_assert
(
kMPerBlock
%
kKPack
==
0
);
constexpr
auto
biast_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kMPerBlock
/
kKPack
>
{},
number
<
kNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{},
number
<
kKPack
>
{}),
make_tuple
(
number
<
(
kNPerBlock
/
NPerRow
)
*
(
PixelsPerRow
+
kKPack
)
>
{},
number
<
PixelsPerRow
+
kKPack
>
{},
number
<
kKPack
>
{},
number
<
1
>
{}),
number
<
kKPack
>
{},
number
<
1
>
{});
constexpr
index_t
kKPack
=
GetSmemKPackBias
<
Problem
>
();
constexpr
index_t
kKPackT
=
GetSmemKPackBiasT
<
Problem
>
();
constexpr
auto
biast_lds_block_desc
=
transform_tensor_descriptor
(
biast_lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
kNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{})),
make_merge_transform
(
make_tuple
(
number
<
kMPerBlock
/
kKPack
>
{},
number
<
kKPack
>
{}))),
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
MakeXTLdsBlockDescriptor
<
Problem
,
kNPerBlock
,
kMPerBlock
,
kKPack
,
kKPackT
>
();
}
return
biast_lds_block_desc
;
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBiasTTileDistribution
()
{
using
c_block_tensor_type
=
decltype
(
BlockGemm
{}.
MakeCBlockTile
());
return
c_block_tensor_type
::
get_tile_distribution
();
}
template
<
typename
Problem
>
...
...
@@ -1681,20 +1715,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
smem_size_stage1
=
smem_size_qt
+
smem_size_q
+
+
smem_size_dot
+
smem_size_do
+
smem_size_lse
+
smem_size_d
+
max
(
smem_size_bias
,
smem_size_ds
);
constexpr
index_t
smem_size_stage2
=
smem_size_qt
+
smem_size_bias
;
constexpr
index_t
smem_size_stage3
=
smem_size_qt
;
constexpr
index_t
smem_size_stage4
=
smem_size_qt
+
smem_size_do
+
smem_size_d
;
constexpr
index_t
smem_size_stage5
=
smem_size_qt
;
constexpr
index_t
smem_size_stage6
=
smem_size_qt
+
smem_size_ds
;
return
max
(
smem_size_stage0_0
,
smem_size_stage0_1
,
smem_size_stage1
,
smem_size_stage2
,
smem_size_stage3
,
smem_size_stage4
,
smem_size_stage5
,
smem_size_stage6
);
return
max
(
smem_size_stage0_0
,
smem_size_stage0_1
,
smem_size_stage1
);
}
template
<
typename
Problem_
>
...
...
@@ -1718,25 +1740,24 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
MFMA_INST
=
Gemm0MFMA
;
// Evenly distributed to relieve SQ->TA FIFO pressure
constexpr
index_t
VMEM_READ__MFMA_Rate
=
MFMA_INST
/
VMEM_READ_INST
;
constexpr
index_t
MFMA_Remainder
=
MFMA_INST
-
VMEM_READ__MFMA_Rate
*
VMEM_READ__MFMA_Rate
;
constexpr
index_t
MFMA_PER_VMEM_READ
=
MFMA_INST
/
VMEM_READ_INST
;
constexpr
index_t
MFMA_Remainder
=
MFMA_INST
-
MFMA_PER_VMEM_READ
*
VMEM_READ_INST
;
// To hide instruction issue latency
constexpr
index_t
MFMA__
LDS_READ_
Rate
=
LDS_READ_INST
/
MFMA_INST
;
constexpr
index_t
LDS_READ_
PER_MFMA
=
LDS_READ_INST
/
MFMA_INST
;
static_for
<
0
,
VMEM_READ_INST
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
static_for
<
0
,
VMEM_READ
__MFMA_Rate
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
MFMA_PER_
VMEM_READ
,
1
>
{}([
&
](
auto
j
)
{
ignore
=
j
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
MFMA__
LDS_READ_
Rate
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_
PER_MFMA
,
0
);
// DS read
});
});
static_for
<
0
,
MFMA_Remainder
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
MFMA__
LDS_READ_
Rate
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_
PER_MFMA
,
0
);
// DS read
});
}
...
...
@@ -1749,12 +1770,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
MFMA_INST
=
Gemm1MFMA
;
// To hide instruction issue latency
constexpr
index_t
MFMA__
LDS_READ_
Rate
=
LDS_READ_INST
/
MFMA_INST
;
constexpr
index_t
LDS_READ_
PER_MFMA
=
LDS_READ_INST
/
MFMA_INST
;
static_for
<
0
,
MFMA_INST
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
MFMA__
LDS_READ_
Rate
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_
PER_MFMA
,
0
);
// DS read
});
}
...
...
@@ -1768,12 +1789,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
MFMA_INST
=
Gemm2MFMA
;
// To hide instruction issue latency
constexpr
index_t
MFMA__
LDS_WRITE_
Rate
=
LDS_WRITE_INST
/
MFMA_INST
;
constexpr
index_t
LDS_WRITE_
PER_MFMA
=
LDS_WRITE_INST
/
MFMA_INST
;
static_for
<
0
,
MFMA_INST
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x200
,
MFMA__
LDS_WRITE_
Rate
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x200
,
LDS_WRITE_
PER_MFMA
,
0
);
// DS write
});
}
...
...
@@ -1787,11 +1808,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
MFMA_INST
=
Gemm3MFMA
;
// To hide instruction issue latency
constexpr
index_t
MFMA__
LDS_WRITE_
Rate
=
constexpr
index_t
LDS_WRITE_
PER_MFMA
=
LDS_WRITE_INST
/
MFMA_INST
>=
1
?
LDS_WRITE_INST
/
MFMA_INST
:
1
;
constexpr
index_t
MFMA_INST_LDS_WRITE
=
LDS_WRITE_INST
/
MFMA__
LDS_WRITE_
Rate
;
constexpr
index_t
MFMA_INST_LDS_WRITE
=
LDS_WRITE_INST
/
LDS_WRITE_
PER_MFMA
;
constexpr
index_t
MFMA__
LDS_READ_
Rate
=
constexpr
index_t
LDS_READ_
PER_MFMA
=
(
MFMA_INST
-
MFMA_INST_LDS_WRITE
)
>
0
?
LDS_READ_INST
/
(
MFMA_INST
-
MFMA_INST_LDS_WRITE
)
>
0
?
LDS_READ_INST
/
(
MFMA_INST
-
MFMA_INST_LDS_WRITE
)
...
...
@@ -1801,13 +1822,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static_for
<
0
,
MFMA_INST_LDS_WRITE
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x200
,
MFMA__
LDS_WRITE_
Rate
,
0
);
// DS Write
__builtin_amdgcn_sched_group_barrier
(
0x200
,
LDS_WRITE_
PER_MFMA
,
0
);
// DS Write
});
static_for
<
0
,
MFMA_INST
-
MFMA_INST_LDS_WRITE
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
MFMA__
LDS_READ_
Rate
,
0
);
// DS Read
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_
PER_MFMA
,
0
);
// DS Read
});
}
...
...
@@ -1820,13 +1841,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
MFMA_INST
=
Gemm4MFMA
;
// To hide instruction issue latency
constexpr
index_t
MFMA__
LDS_READ_
Rate
=
constexpr
index_t
LDS_READ_
PER_MFMA
=
LDS_READ_INST
/
MFMA_INST
>
0
?
LDS_READ_INST
/
MFMA_INST
:
1
;
static_for
<
0
,
MFMA_INST
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
MFMA__
LDS_READ_
Rate
,
0
);
// DS Read
__builtin_amdgcn_sched_group_barrier
(
0x100
,
LDS_READ_
PER_MFMA
,
0
);
// DS Read
});
}
...
...
@@ -1902,69 +1923,6 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static
constexpr
index_t
D_LDS_WRITE
=
1
;
static
constexpr
index_t
SGradT_LDS_WRITE
=
kM0
*
kN0
/
kBlockSize
;
};
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBiasTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
N1
=
GetTransposedAlignmentBias
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
// P
constexpr
index_t
total_pixels
=
kMPerBlock
*
kNPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
M3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackBias
<
Problem
>
();
static_assert
(
kKPack
%
M3
==
0
);
constexpr
index_t
M2
=
kKPack
/
M3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
M1
=
get_warp_size
()
/
(
M2
*
N0
);
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
static_assert
(
kMPerBlock
==
M0
*
M1
*
M2
*
M3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
,
M3
>
,
sequence
<
N0
,
N1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
3
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBiasTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
N1
=
GetTransposedAlignmentBias
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
kMPerBlock
*
kNPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
// TODO: this is not always true?
constexpr
index_t
M3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemKPackBias
<
Problem
>
();
static_assert
(
kKPack
%
M3
==
0
);
constexpr
index_t
M2
=
kKPack
/
M3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
M1
=
get_warp_size
()
/
(
M2
*
N0
);
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
M0
,
M1
,
M2
,
M3
>
,
sequence
<
N0
,
N1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
1
,
3
>>
{});
}
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBiasTTileDistribution
()
{
using
c_block_tensor_type
=
decltype
(
BlockGemm
{}.
MakeCBlockTile
());
return
c_block_tensor_type
::
get_tile_distribution
();
}
};
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment