Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
54d2e0a7
Commit
54d2e0a7
authored
Oct 14, 2024
by
Po Yen, Chen
Browse files
Enlarge V tile size
parent
21d1fe01
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
9 deletions
+12
-9
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_2wave_default_policy.hpp
...ine/block_fmha_pipeline_qr_ks_vs_2wave_default_policy.hpp
+12
-9
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_2wave_default_policy.hpp
View file @
54d2e0a7
...
@@ -675,9 +675,10 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
...
@@ -675,9 +675,10 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
using
VLayout
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
::
VLayout
>
;
using
VLayout
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
::
VLayout
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
// constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; // [POYENC] old tile size
// constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // [POYENC] old tile size
constexpr
index_t
kNPerBlock
=
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
// [POYENC] updated tile size
Problem
::
BlockFmhaShape
::
kK0BlockLength
;
// [POYENC] updated tile size
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
...
@@ -776,9 +777,10 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
...
@@ -776,9 +777,10 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
using
VLayout
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
::
VLayout
>
;
using
VLayout
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
::
VLayout
>
;
static_assert
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
static_assert
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
// constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; // [POYENC] old tile size
// constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // [POYENC] old tile size
constexpr
index_t
kNPerBlock
=
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
// [POYENC] updated tile size
Problem
::
BlockFmhaShape
::
kK0BlockLength
;
// [POYENC] updated tile size
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
N1
=
GetAlignmentV
<
Problem
>
();
constexpr
index_t
N1
=
GetAlignmentV
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
...
@@ -901,9 +903,10 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
...
@@ -901,9 +903,10 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
VDataType
);
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
VDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackV
<
Problem
>
();
constexpr
index_t
kKPack
=
GetSmemKPackV
<
Problem
>
();
static_assert
(
PixelsPerRow
%
kKPack
==
0
);
static_assert
(
PixelsPerRow
%
kKPack
==
0
);
constexpr
index_t
NPerRow
=
PixelsPerRow
/
kKPack
;
constexpr
index_t
NPerRow
=
PixelsPerRow
/
kKPack
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kNPerBlock
=
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
// [POYENC] updated tile size
Problem
::
BlockFmhaShape
::
kK0BlockLength
;
// [POYENC] updated tile size
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
static_assert
(
kNPerBlock
%
NPerRow
==
0
);
static_assert
(
kNPerBlock
%
NPerRow
==
0
);
static_assert
(
kKPerBlock
%
kKPack
==
0
);
static_assert
(
kKPerBlock
%
kKPack
==
0
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment