Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
dcc3593f
Commit
dcc3593f
authored
Jul 25, 2024
by
danyao12
Browse files
fix hd32 error and boost performance
parent
b2510c05
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
123 additions
and
6 deletions
+123
-6
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+1
-1
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
...a/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
+12
-5
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+110
-0
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
dcc3593f
...
@@ -448,7 +448,7 @@ class FmhaBwdDQDKDVKernel:
...
@@ -448,7 +448,7 @@ class FmhaBwdDQDKDVKernel:
def
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
def
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
return
{
return
{
# '32' : [FmhaBwdDQDKDVTileSize(
64, 64
, 32,
64
, 32,
64
, 64, 32, 32, 1,
2
, 1,
2
, 1, 1, 2, 1, 1,
32
, 32, 16,
32, 32
, 16, 1),
# '32' : [FmhaBwdDQDKDVTileSize(
32, 128
, 32,
32
, 32,
32
, 64, 32, 32, 1,
4
, 1,
4
, 1, 1, 2,
2,
1, 1
6
,
16
, 32, 16,
16
, 16, 1),
# "kr_ktr_vr"],
# "kr_ktr_vr"],
'64'
:
[
FmhaBwdDQDKDVTileSize
(
64
,
128
,
64
,
64
,
64
,
64
,
64
,
64
,
64
,
1
,
4
,
1
,
4
,
1
,
1
,
2
,
2
,
1
,
32
,
32
,
16
,
32
,
32
,
16
,
1
),
'64'
:
[
FmhaBwdDQDKDVTileSize
(
64
,
128
,
64
,
64
,
64
,
64
,
64
,
64
,
64
,
1
,
4
,
1
,
4
,
1
,
1
,
2
,
2
,
1
,
32
,
32
,
16
,
32
,
32
,
16
,
1
),
"kr_ktr_vr"
],
"kr_ktr_vr"
],
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
View file @
dcc3593f
...
@@ -660,7 +660,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -660,7 +660,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
}();
}();
// STAGE 3, P^T@OGrad^T Gemm1
// STAGE 3, P^T@OGrad^T Gemm1
pt_reg_tensor
.
get_thread_buffer
()
=
pt_gemm
.
get_thread_buffer
();
Policy
::
template
PTFromGemm0CToGemm1A
<
Problem
,
decltype
(
pt_reg_tensor
),
decltype
(
pt_gemm
)>(
pt_reg_tensor
,
pt_gemm
);
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
auto
qt_reg_tensor
=
load_tile
(
qt_lds_read_window
);
auto
qt_reg_tensor
=
load_tile
(
qt_lds_read_window
);
...
@@ -732,7 +734,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -732,7 +734,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
// STAGE 6, SGrad^T@Q^T Gemm3
// STAGE 6, SGrad^T@Q^T Gemm3
const
auto
dst_gemm
=
cast_tile
<
GemmDataType
>
(
dst
);
const
auto
dst_gemm
=
cast_tile
<
GemmDataType
>
(
dst
);
dst_reg_tensor
.
get_thread_buffer
()
=
dst_gemm
.
get_thread_buffer
();
Policy
::
template
SGradTFromGemm2CToGemm3A
<
Problem
,
decltype
(
dst_reg_tensor
),
decltype
(
dst_gemm
)>(
dst_reg_tensor
,
dst_gemm
);
gemm_3
(
dk_acc
,
dst_reg_tensor
,
qt_reg_tensor
);
gemm_3
(
dk_acc
,
dst_reg_tensor
,
qt_reg_tensor
);
...
@@ -908,7 +912,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -908,7 +912,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
}
}
}();
}();
pt_reg_tensor
.
get_thread_buffer
()
=
pt_gemm
.
get_thread_buffer
();
Policy
::
template
PTFromGemm0CToGemm1A
<
Problem
,
decltype
(
pt_reg_tensor
),
decltype
(
pt_gemm
)>(
pt_reg_tensor
,
pt_gemm
);
auto
dot_reg_tensor
=
load_tile
(
dot_lds_read_window
);
auto
dot_reg_tensor
=
load_tile
(
dot_lds_read_window
);
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
...
@@ -965,7 +970,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -965,7 +970,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
// STAGE 6, SGrad^T@Q^T Gemm3
// STAGE 6, SGrad^T@Q^T Gemm3
const
auto
dst_gemm
=
cast_tile
<
GemmDataType
>
(
dst
);
const
auto
dst_gemm
=
cast_tile
<
GemmDataType
>
(
dst
);
dst_reg_tensor
.
get_thread_buffer
()
=
dst_gemm
.
get_thread_buffer
();
Policy
::
template
SGradTFromGemm2CToGemm3A
<
Problem
,
decltype
(
dst_reg_tensor
),
decltype
(
dst_gemm
)>(
dst_reg_tensor
,
dst_gemm
);
gemm_3
(
dk_acc
,
dst_reg_tensor
,
qt_reg_tensor
);
gemm_3
(
dk_acc
,
dst_reg_tensor
,
qt_reg_tensor
);
store_tile
(
ds_lds_window
,
dst_gemm
);
store_tile
(
ds_lds_window
,
dst_gemm
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
dcc3593f
...
@@ -1508,6 +1508,116 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1508,6 +1508,116 @@ struct BlockFmhaBwdPipelineDefaultPolicy
return
ds_block_dstr
;
return
ds_block_dstr
;
}
}
template
<
typename
Problem
,
typename
PTOutTensor
,
typename
PTInTensor
>
CK_TILE_DEVICE
static
constexpr
void
PTFromGemm0CToGemm1A
(
PTOutTensor
&
pt_out
,
const
PTInTensor
&
pt_in
)
{
if
constexpr
(
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
0
>
{})
==
16
)
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
2
>
{}),
true
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
using
AWarpDstr
=
typename
WarpGemm
::
AWarpDstr
;
using
CWarpDstr
=
typename
WarpGemm
::
CWarpDstr
;
auto
pt_warp_tensor
=
make_static_distributed_tensor
<
typename
Problem
::
GemmDataType
>
(
CWarpDstr
{});
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
pt_warp_tensor
.
get_thread_buffer
()
=
pt_in
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
kIter
,
mIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
pt_out
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
kIter
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
),
pt_warp_tensor
.
get_thread_buffer
());
});
});
}
else
{
pt_out
.
get_thread_buffer
()
=
pt_in
.
get_thread_buffer
();
}
}
template
<
typename
Problem
,
typename
SGradTOutTensor
,
typename
SGradTInTensor
>
CK_TILE_DEVICE
static
constexpr
void
SGradTFromGemm2CToGemm3A
(
SGradTOutTensor
&
dst_out
,
const
SGradTInTensor
&
dst_in
)
{
if
constexpr
(
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
0
>
{})
==
16
)
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
2
>
{}),
true
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK3
;
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
using
AWarpDstr
=
typename
WarpGemm
::
AWarpDstr
;
using
CWarpDstr
=
typename
WarpGemm
::
CWarpDstr
;
auto
dst_warp_tensor
=
make_static_distributed_tensor
<
typename
Problem
::
GemmDataType
>
(
CWarpDstr
{});
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
dst_warp_tensor
.
get_thread_buffer
()
=
dst_in
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
kIter
,
mIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
dst_out
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
kIter
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
),
dst_warp_tensor
.
get_thread_buffer
());
});
});
}
else
{
dst_out
.
get_thread_buffer
()
=
dst_in
.
get_thread_buffer
();
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBiasTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBiasTileDistribution
()
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment