Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b3100b6f
"docs/source/en/optimization/open_vino.md" did not exist on "a5d2ee9d474e35c874fcc2a3b1085012202c6b47"
Commit
b3100b6f
authored
Jul 20, 2024
by
danyao12
Browse files
remove FmhaBwdTilePartitioner
parent
9d78a6c5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
40 additions
and
72 deletions
+40
-72
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+3
-6
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+0
-1
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
+37
-15
include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
...ude/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
+0
-50
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
b3100b6f
...
...
@@ -104,8 +104,7 @@ using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
false>>;
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
ck_tile::FmhaBwdDQDKDVKernel<ck_tile::FmhaBwdKTilePartitioner<{F_bn0}>,
fmha_bwd_pipeline_{F_idx},
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_{F_idx},
fmha_bwd_dk_epilogue_{F_idx},
fmha_bwd_dv_epilogue_{F_idx}>;
...
...
@@ -517,8 +516,7 @@ using fmha_bwd_dot_do_o_{F_idx} =
typename ck_tile::BlockFmhaBwdOGradDotO<fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;
using fmha_bwd_dot_do_o_kernel_{F_idx} =
ck_tile::FmhaBwdOGradDotOKernel<ck_tile::FmhaBwdQTilePartitioner</* BlockSize = */ 64>,
fmha_bwd_dot_do_o_{F_idx}>;
ck_tile::FmhaBwdOGradDotOKernel<fmha_bwd_dot_do_o_{F_idx}>;
using dot_do_o_trait_{F_idx} =
fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
...
...
@@ -641,8 +639,7 @@ using fmha_bwd_convert_dq_{F_idx} =
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_{F_idx}>;
using fmha_bwd_convert_dq_kernel_{F_idx} =
ck_tile::FmhaBwdConvertQGradKernel<ck_tile::FmhaBwdQTilePartitioner<{F_bm0}>,
fmha_bwd_convert_dq_{F_idx}>;
ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_{F_idx}>;
using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
{F_dtype},
...
...
include/ck_tile/ops/fmha.hpp
View file @
b3100b6f
...
...
@@ -8,7 +8,6 @@
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
...
...
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
View file @
b3100b6f
...
...
@@ -23,13 +23,9 @@
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
KGradEpiloguePipeline_
,
typename
VGradEpiloguePipeline_
>
template
<
typename
FmhaPipeline_
,
typename
KGradEpiloguePipeline_
,
typename
VGradEpiloguePipeline_
>
struct
FmhaBwdDQDKDVKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
ck_tile
::
remove_cvref_t
<
FmhaPipeline_
>
;
using
KGradEpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
KGradEpiloguePipeline_
>
;
using
VGradEpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
VGradEpiloguePipeline_
>
;
...
...
@@ -536,7 +532,17 @@ struct FmhaBwdDQDKDVKernel
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_k_
)
{
return
TilePartitioner
::
GridSize
(
batch_size_
,
nhead_
,
seqlen_k_
);
return
dim3
(
batch_size_
,
nhead_
,
ck_tile
::
integer_divide_ceil
(
seqlen_k_
,
FmhaPipeline
::
kN0
));
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
()
{
const
index_t
i_block
=
blockIdx
.
z
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
x
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -554,7 +560,7 @@ struct FmhaBwdDQDKDVKernel
__shared__
char
smem_ptr
[
GetSmemSize
()];
// divide problem
const
auto
[
i_tile_n
,
i_nhead
,
i_batch
]
=
Tile
Partitioner
{}(
kargs
.
seqlen_k
);
const
auto
[
i_tile_n
,
i_nhead
,
i_batch
]
=
Get
Tile
Index
(
);
const
index_t
i_n0
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN0
);
...
...
@@ -1037,10 +1043,9 @@ struct FmhaBwdDQDKDVKernel
}
};
template
<
typename
TilePartitioner_
,
typename
FmhaBwdOGradDotO_
>
template
<
typename
FmhaBwdOGradDotO_
>
struct
FmhaBwdOGradDotOKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaBwdOGradDotO
=
ck_tile
::
remove_cvref_t
<
FmhaBwdOGradDotO_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaBwdOGradDotO
::
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kBlockPerCu
=
FmhaBwdOGradDotO
::
kBlockPerCu
;
...
...
@@ -1189,7 +1194,16 @@ struct FmhaBwdOGradDotOKernel
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
)
{
return
TilePartitioner
::
GridSize
(
batch_size_
,
nhead_
,
seqlen_q_
);
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
()
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -1199,7 +1213,7 @@ struct FmhaBwdOGradDotOKernel
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// divide problem
const
auto
[
i_tile_m
,
i_nhead
,
i_batch
]
=
Tile
Partitioner
{}(
kargs
.
seqlen_q
);
const
auto
[
i_tile_m
,
i_nhead
,
i_batch
]
=
Get
Tile
Index
(
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
kM0
);
...
...
@@ -1286,10 +1300,9 @@ struct FmhaBwdOGradDotOKernel
}
};
template
<
typename
TilePartitioner_
,
typename
FmhaBwdConvertQGrad_
>
template
<
typename
FmhaBwdConvertQGrad_
>
struct
FmhaBwdConvertQGradKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaBwdConvertQGrad
=
ck_tile
::
remove_cvref_t
<
FmhaBwdConvertQGrad_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaBwdConvertQGrad
::
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kBlockPerCu
=
FmhaBwdConvertQGrad
::
kBlockPerCu
;
...
...
@@ -1439,7 +1452,16 @@ struct FmhaBwdConvertQGradKernel
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
)
{
return
TilePartitioner
::
GridSize
(
batch_size_
,
nhead_
,
seqlen_q_
);
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
()
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -1449,7 +1471,7 @@ struct FmhaBwdConvertQGradKernel
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// divide problem
const
auto
[
i_tile_m
,
i_nhead
,
i_batch
]
=
Tile
Partitioner
{}(
kargs
.
seqlen_q
);
const
auto
[
i_tile_m
,
i_nhead
,
i_batch
]
=
Get
Tile
Index
(
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
kM0
);
...
...
include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
deleted
100644 → 0
View file @
9d78a6c5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
ck_tile
::
index_t
kN0
>
struct
FmhaBwdKTilePartitioner
{
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_k_
)
{
// TODO: this may need tuning
return
dim3
(
batch_size_
,
nhead_
,
ck_tile
::
integer_divide_ceil
(
seqlen_k_
,
kN0
));
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_k*/
)
{
const
index_t
i_block
=
blockIdx
.
z
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
x
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
};
template
<
ck_tile
::
index_t
kM0
>
struct
FmhaBwdQTilePartitioner
{
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
)
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment