Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d09572e8
Unverified
Commit
d09572e8
authored
Sep 11, 2024
by
Dan Yao
Committed by
GitHub
Sep 10, 2024
Browse files
[CK_TILE] FA bwd repair (#1502)
* fix fa bwd * revert kernelBlockSize in gemm_kernel.hpp
parent
cf08df6b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
28 additions
and
28 deletions
+28
-28
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+15
-15
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+1
-1
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp
...gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp
+1
-1
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...lock_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+6
-6
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp
...gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp
+1
-1
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
...ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
+4
-4
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
d09572e8
...
...
@@ -29,9 +29,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kM0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK0
>
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
...
...
@@ -62,9 +62,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK1
>
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>>
;
...
...
@@ -94,9 +94,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kM0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK2
>
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK2
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
>>
;
...
...
@@ -127,9 +127,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK3
>
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK3
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
>>
;
...
...
@@ -159,9 +159,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kM0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK4
>
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK4
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
>>
;
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
d09572e8
...
...
@@ -25,7 +25,7 @@ struct GemmKernel
using
LayoutA
=
remove_cvref_t
<
LayoutA_
>
;
using
LayoutB
=
remove_cvref_t
<
LayoutB_
>
;
using
LayoutC
=
remove_cvref_t
<
LayoutC_
>
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
Kernel
BlockSize
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
k
BlockSize
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
...
...
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
d09572e8
...
...
@@ -19,7 +19,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
Kernel
BlockSize
=
Problem
::
Kernel
BlockSize
;
static
constexpr
index_t
k
BlockSize
=
Problem
::
k
BlockSize
;
static
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
...
...
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
d09572e8
...
...
@@ -195,7 +195,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
Kernel
BlockSize
=
Problem
::
Kernel
BlockSize
;
constexpr
index_t
k
BlockSize
=
Problem
::
k
BlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
...
...
@@ -204,7 +204,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
#if 1 // coalesce reading for each blocks
constexpr
index_t
M1
=
Kernel
BlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
k
BlockSize
/
get_warp_size
();
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
...
...
@@ -217,7 +217,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
#else // coalesce reading for each warps
constexpr
index_t
M0
=
Kernel
BlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
k
BlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
kMPerBlock
/
(
M2
*
M0
);
return
make_static_tile_distribution
(
...
...
@@ -235,7 +235,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
constexpr
index_t
Kernel
BlockSize
=
Problem
::
Kernel
BlockSize
;
constexpr
index_t
k
BlockSize
=
Problem
::
k
BlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
...
...
@@ -244,7 +244,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
#if 1 // coalesce reading for each blocks
constexpr
index_t
N1
=
Kernel
BlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
k
BlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
...
...
@@ -257,7 +257,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
#else // coalesce reading for each warps
constexpr
index_t
N0
=
Kernel
BlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
k
BlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
kNPerBlock
/
(
N2
*
N0
);
return
make_static_tile_distribution
(
...
...
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp
View file @
d09572e8
...
...
@@ -19,7 +19,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV2
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
Kernel
BlockSize
=
Problem
::
Kernel
BlockSize
;
static
constexpr
index_t
k
BlockSize
=
Problem
::
k
BlockSize
;
static
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
...
...
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
View file @
d09572e8
...
...
@@ -23,10 +23,10 @@ struct BlockGemmPipelineProblem
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
index_t
Kernel
BlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadC
=
kPadC_
;
static
constexpr
index_t
k
BlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadC
=
kPadC_
;
static
constexpr
index_t
AlignmentA
=
kPadA
?
VectorLoadSize
/
sizeof
(
ADataType
)
:
1
;
static
constexpr
index_t
AlignmentB
=
kPadB
?
VectorLoadSize
/
sizeof
(
BDataType
)
:
1
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment