Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
70514fd8
Commit
70514fd8
authored
Aug 14, 2024
by
danyao12
Browse files
bwd rtn
parent
17c97f58
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
120 additions
and
22 deletions
+120
-22
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+3
-2
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
+1
-1
include/ck_tile/core/tensor/tile_elementwise.hpp
include/ck_tile/core/tensor/tile_elementwise.hpp
+26
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
...e/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
+6
-1
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
...a/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
+28
-6
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
...eline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
+56
-12
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
70514fd8
...
@@ -500,8 +500,9 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -500,8 +500,9 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if
not
cond
:
if
not
cond
:
continue
continue
if
receipt
==
3
:
if
receipt
==
3
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
=
dtype
in
[
'bf16'
]
cond
&=
bias
in
[
'no'
,
'alibi'
]
cond
&=
bias
in
[
'no'
]
cond
&=
dropout
in
[
'no'
]
cond
&=
dpad
==
dvpad
cond
&=
dpad
==
dvpad
cond
&=
deterministic
==
"f"
cond
&=
deterministic
==
"f"
if
not
cond
:
if
not
cond
:
...
...
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
View file @
70514fd8
...
@@ -9,7 +9,7 @@ export CK_REPEAT=1
...
@@ -9,7 +9,7 @@ export CK_REPEAT=1
COMMON_ARGS
=
'-v=1'
COMMON_ARGS
=
'-v=1'
set
-x
set
-x
for
prec
in
"fp16"
"bf16"
;
do
for
prec
in
"bf16"
;
do
for
perm
in
0 1
;
do
for
perm
in
0 1
;
do
for
hdim
in
32 64 128 256
;
do
for
hdim
in
32 64 128 256
;
do
for
mode
in
0 1
;
do
for
mode
in
0 1
;
do
...
...
include/ck_tile/core/tensor/tile_elementwise.hpp
View file @
70514fd8
...
@@ -227,6 +227,32 @@ CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor& in_dstr_tensors)
...
@@ -227,6 +227,32 @@ CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor& in_dstr_tensors)
#endif
#endif
}
}
template
<
typename
OutDataType
,
typename
InTensor
>
CK_TILE_DEVICE
auto
cast_tile_rtn_bf16_fp32
(
const
InTensor
&
in_dstr_tensors
)
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)
// This API is designed to use the _pk_ serious of function
constexpr
auto
in_tile_dstr
=
InTensor
::
get_tile_distribution
();
constexpr
index_t
thread_buffer_size
=
InTensor
::
get_thread_buffer_size
();
auto
out_dstr_tensor
=
make_static_distributed_tensor
<
OutDataType
>
(
in_tile_dstr
);
for
(
index_t
i
=
0
;
i
<
thread_buffer_size
;
i
++
)
{
out_dstr_tensor
.
get_thread_buffer
().
at
(
i
)
=
float_to_bf16_raw
<
static_cast
<
bf16_rounding_mode
>
(
0
)
>
(
in_dstr_tensors
.
get_thread_buffer
()[
i
]);
}
return
out_dstr_tensor
;
#else
// fallback
return
tile_elementwise_in
(
type_convert
<
OutDataType
,
typename
InTensor
::
DataType
>
,
in_dstr_tensors
);
#endif
}
#if CK_TILE_USE_SUBDWORD_TILE_CAST
#if CK_TILE_USE_SUBDWORD_TILE_CAST
// this function assume either src or dst (or both) date type is under 1 dword
// this function assume either src or dst (or both) date type is under 1 dword
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
View file @
70514fd8
...
@@ -125,6 +125,11 @@ struct BlockFmhaBwdConvertQGrad
...
@@ -125,6 +125,11 @@ struct BlockFmhaBwdConvertQGrad
sweep_tile_span
(
dq_acc_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
2
>
{}],
[
&
](
auto
idx2
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
2
>
{}],
[
&
](
auto
idx2
)
{
constexpr
auto
n_i_j_idx
=
make_tuple
(
idx0
,
idx1
,
idx2
);
constexpr
auto
n_i_j_idx
=
make_tuple
(
idx0
,
idx1
,
idx2
);
if
constexpr
(
std
::
is_same_v
<
QGradDataType
,
bf16_t
>
)
dq_converted
(
n_i_j_idx
)
=
float_to_bf16_raw
<
static_cast
<
bf16_rounding_mode
>
(
0
)
>
(
dq_acc
[
n_i_j_idx
]);
else
dq_converted
(
n_i_j_idx
)
=
type_convert
<
QGradDataType
>
(
dq_acc
[
n_i_j_idx
]);
dq_converted
(
n_i_j_idx
)
=
type_convert
<
QGradDataType
>
(
dq_acc
[
n_i_j_idx
]);
});
});
});
});
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
View file @
70514fd8
...
@@ -611,18 +611,34 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -611,18 +611,34 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
dropout
.
template
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>(
dropout
.
template
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>(
seqlen_q_step
,
k_origin
.
at
(
number
<
0
>
{}),
pt
,
randval_dram_window
);
seqlen_q_step
,
k_origin
.
at
(
number
<
0
>
{}),
pt
,
randval_dram_window
);
}
}
const
auto
pt_gemm
=
[
&
]()
{
// const auto pt_gemm = [&]() {
// if constexpr(FmhaDropout::IsDropout)
// {
// return tile_elementwise_in(
// [](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f);
// }, pt);
// }
// else
// {
// return cast_tile<GemmDataType>(pt);
// }
// }();
const
auto
pt_dropped
=
[
&
]()
{
if
constexpr
(
FmhaDropout
::
IsDropout
)
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
{
return
tile_elementwise_in
(
return
tile_elementwise_in
([](
const
auto
&
x
)
{
return
x
>
0.
f
?
x
:
0.
f
;
},
pt
);
[](
const
auto
&
x
)
{
return
type_convert
<
GemmDataType
>
(
x
>
0.
f
?
x
:
0.
f
);
},
pt
);
}
}
else
else
{
{
return
cast_tile
<
GemmDataType
>
(
pt
)
;
return
pt
;
}
}
}();
}();
const
auto
pt_gemm
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
GemmDataType
,
bf16_t
>
)
return
impl
::
cast_tile_rtn_bf16_fp32
<
GemmDataType
>
(
pt_dropped
);
else
return
cast_tile
<
GemmDataType
>
(
pt_dropped
);
}();
// STAGE 3, P^T@OGrad^T Gemm1
// STAGE 3, P^T@OGrad^T Gemm1
auto
do_block_tile
=
load_tile
(
do_dram_window
);
auto
do_block_tile
=
load_tile
(
do_dram_window
);
...
@@ -702,7 +718,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -702,7 +718,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
auto
qt_reg_tensor
=
load_tile
(
qt_lds_read_window
);
auto
qt_reg_tensor
=
load_tile
(
qt_lds_read_window
);
block_sync_lds
();
block_sync_lds
();
const
auto
dst_gemm
=
cast_tile
<
GemmDataType
>
(
dst
);
// const auto dst_gemm = cast_tile<GemmDataType>(dst);
const
auto
dst_gemm
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
GemmDataType
,
bf16_t
>
)
return
impl
::
cast_tile_rtn_bf16_fp32
<
GemmDataType
>
(
dst
);
else
return
cast_tile
<
GemmDataType
>
(
dst
);
}();
Policy
::
template
SGradTFromGemm2CToGemm3A
<
Problem
,
Policy
::
template
SGradTFromGemm2CToGemm3A
<
Problem
,
decltype
(
dst_reg_tensor
),
decltype
(
dst_reg_tensor
),
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
View file @
70514fd8
...
@@ -647,18 +647,34 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -647,18 +647,34 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
dropout
.
template
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>(
dropout
.
template
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>(
seqlen_q_step
,
k_origin
.
at
(
number
<
0
>
{}),
pt
,
randval_dram_window
);
seqlen_q_step
,
k_origin
.
at
(
number
<
0
>
{}),
pt
,
randval_dram_window
);
}
}
const
auto
pt_gemm
=
[
&
]()
{
// const auto pt_gemm = [&]() {
// if constexpr(FmhaDropout::IsDropout)
// {
// return tile_elementwise_in(
// [](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f);
// }, pt);
// }
// else
// {
// return cast_tile<GemmDataType>(pt);
// }
// }();
const
auto
pt_dropped
=
[
&
]()
{
if
constexpr
(
FmhaDropout
::
IsDropout
)
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
{
return
tile_elementwise_in
(
return
tile_elementwise_in
([](
const
auto
&
x
)
{
return
x
>
0.
f
?
x
:
0.
f
;
},
pt
);
[](
const
auto
&
x
)
{
return
type_convert
<
GemmDataType
>
(
x
>
0.
f
?
x
:
0.
f
);
},
pt
);
}
}
else
else
{
{
return
cast_tile
<
GemmDataType
>
(
pt
)
;
return
pt
;
}
}
}();
}();
const
auto
pt_gemm
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
GemmDataType
,
bf16_t
>
)
return
impl
::
cast_tile_rtn_bf16_fp32
<
GemmDataType
>
(
pt_dropped
);
else
return
cast_tile
<
GemmDataType
>
(
pt_dropped
);
}();
// STAGE 3, P^T@OGrad^T Gemm1
// STAGE 3, P^T@OGrad^T Gemm1
Policy
::
template
PTFromGemm0CToGemm1A
<
Problem
,
Policy
::
template
PTFromGemm0CToGemm1A
<
Problem
,
...
@@ -733,7 +749,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -733,7 +749,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
}
}
// STAGE 6, SGrad^T@Q^T Gemm3
// STAGE 6, SGrad^T@Q^T Gemm3
const
auto
dst_gemm
=
cast_tile
<
GemmDataType
>
(
dst
);
// const auto dst_gemm = cast_tile<GemmDataType>(dst);
const
auto
dst_gemm
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
GemmDataType
,
bf16_t
>
)
return
impl
::
cast_tile_rtn_bf16_fp32
<
GemmDataType
>
(
dst
);
else
return
cast_tile
<
GemmDataType
>
(
dst
);
}();
Policy
::
template
SGradTFromGemm2CToGemm3A
<
Problem
,
Policy
::
template
SGradTFromGemm2CToGemm3A
<
Problem
,
decltype
(
dst_reg_tensor
),
decltype
(
dst_reg_tensor
),
...
@@ -900,18 +922,34 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -900,18 +922,34 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
}
}
// STAGE 3, P^T@OGrad^T Gemm1
// STAGE 3, P^T@OGrad^T Gemm1
const
auto
pt_gemm
=
[
&
]()
{
// const auto pt_gemm = [&]() {
// if constexpr(FmhaDropout::IsDropout)
// {
// return tile_elementwise_in(
// [](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
// pt);
// }
// else
// {
// return cast_tile<GemmDataType>(pt);
// }
// }();
const
auto
pt_dropped
=
[
&
]()
{
if
constexpr
(
FmhaDropout
::
IsDropout
)
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
{
return
tile_elementwise_in
(
return
tile_elementwise_in
([](
const
auto
&
x
)
{
return
x
>
0.
f
?
x
:
0.
f
;
},
pt
);
[](
const
auto
&
x
)
{
return
type_convert
<
GemmDataType
>
(
x
>
0.
f
?
x
:
0.
f
);
},
pt
);
}
}
else
else
{
{
return
cast_tile
<
GemmDataType
>
(
pt
)
;
return
pt
;
}
}
}();
}();
const
auto
pt_gemm
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
GemmDataType
,
bf16_t
>
)
return
impl
::
cast_tile_rtn_bf16_fp32
<
GemmDataType
>
(
pt_dropped
);
else
return
cast_tile
<
GemmDataType
>
(
pt_dropped
);
}();
Policy
::
template
PTFromGemm0CToGemm1A
<
Problem
,
decltype
(
pt_reg_tensor
),
decltype
(
pt_gemm
)>(
Policy
::
template
PTFromGemm0CToGemm1A
<
Problem
,
decltype
(
pt_reg_tensor
),
decltype
(
pt_gemm
)>(
pt_reg_tensor
,
pt_gemm
);
pt_reg_tensor
,
pt_gemm
);
...
@@ -969,7 +1007,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -969,7 +1007,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
}
}
// STAGE 6, SGrad^T@Q^T Gemm3
// STAGE 6, SGrad^T@Q^T Gemm3
const
auto
dst_gemm
=
cast_tile
<
GemmDataType
>
(
dst
);
// const auto dst_gemm = cast_tile<GemmDataType>(dst);
const
auto
dst_gemm
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
GemmDataType
,
bf16_t
>
)
return
impl
::
cast_tile_rtn_bf16_fp32
<
GemmDataType
>
(
dst
);
else
return
cast_tile
<
GemmDataType
>
(
dst
);
}();
Policy
::
template
SGradTFromGemm2CToGemm3A
<
Problem
,
Policy
::
template
SGradTFromGemm2CToGemm3A
<
Problem
,
decltype
(
dst_reg_tensor
),
decltype
(
dst_reg_tensor
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment