Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3efb8621
Commit
3efb8621
authored
Sep 18, 2024
by
danyao12
Browse files
tmp save
parent
d4139c8b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
327 additions
and
189 deletions
+327
-189
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+222
-7
example/ck_tile/01_fmha/fmha_bwd.cpp
example/ck_tile/01_fmha/fmha_bwd.cpp
+104
-181
example/ck_tile/01_fmha/fmha_bwd.hpp
example/ck_tile/01_fmha/fmha_bwd.hpp
+1
-1
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
3efb8621
...
...
@@ -163,6 +163,146 @@ std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
FMHA_BWD_API_FILENAME
=
"fmha_bwd_api.cpp"
FMHA_BWD_API
=
"""
#include <iostream>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include "hsaco/fmha_hsaco.h"
#define HSA_KERNEL "kernel_func"
#define HIP_CALL(call)
\\
do
\\
{{
\\
hipError_t err = call;
\\
if(err != hipSuccess)
\\
{{
\\
printf("[hiperror](%d) fail to call %s", static_cast<int>(err), #call);
\\
exit(0);
\\
}}
\\
}} while(0)
// extern declare the function since hip/hip_ext.h header is broken
extern hipError_t hipExtModuleLaunchKernel(hipFunction_t, // NOLINT
uint32_t,
uint32_t,
uint32_t,
uint32_t,
uint32_t,
uint32_t,
size_t,
hipStream_t,
void**,
void**,
hipEvent_t = nullptr,
hipEvent_t = nullptr,
uint32_t = 0);
struct p3
{{
unsigned int _p0;
unsigned int _p1;
unsigned int _p2;
}};
struct p2
{{
unsigned int _p0;
unsigned int _p1;
}};
struct __attribute__((packed)) fmha_bwd_asm_args
{{
void* ptr_dq;
p2 _p0;
void* ptr_dk;
p2 _p1;
void* ptr_dv;
p2 _p2;
const void* ptr_q;
p2 _p3;
const void* ptr_k;
p2 _p4;
const void* ptr_v;
p2 _p5;
const void* ptr_do;
p2 _p6;
const void* ptr_lse;
p2 _p7;
const void* ptr_d;
p2 _p8;
float scalar;
p3 _p9;
float log2e;
p3 _p10;
unsigned int seq_len;
p3 _p11;
unsigned int Ts;
p3 _p12;
unsigned int Hs;
p3 _p13;
unsigned int BAs;
p3 _p14;
}};
struct fmha_bwd_ext_traits
{{
int b;
int h;
int s;
int d;
int atm_f32;
int mask;
int ts_qo;
int ts_kv;
}};
std::string hip_error(int error) {{ return hipGetErrorString(static_cast<hipError_t>(error)); }}
class fmha_bwd_ext_kernel
{{
public:
fmha_bwd_ext_kernel(const std::string& name, unsigned char buffer[])
{{
// HIP_CALL(hipModuleLoadData(&module, buffer));
auto status = hipModuleLoadData(&module, buffer);
if(status != hipSuccess)
throw std::runtime_error("Failed to load module: " + hip_error(status));
HIP_CALL(hipModuleGetFunction(&kernel_func, module, name.c_str()));
}}
void
launch_kernel(fmha_bwd_ext_traits fmha_ext_traits, fmha_bwd_asm_args args, const ck_tile::stream_config& s) const
{{
size_t arg_size = sizeof(args);
void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
&args,
HIP_LAUNCH_PARAM_BUFFER_SIZE,
&arg_size,
HIP_LAUNCH_PARAM_END}};
int bdx = 256;
int gdx = fmha_ext_traits.s / fmha_ext_traits.ts_kv;
int gdy = fmha_ext_traits.h;
int gdz = fmha_ext_traits.b;
if(fmha_ext_traits.mask > 0)
{{
int num_tg = fmha_ext_traits.s / fmha_ext_traits.ts_kv;
gdx = (num_tg % 2) ? (num_tg / 2 + 1) : (num_tg / 2);
}}
HIP_CALL(hipModuleLaunchKernel(kernel_func,
gdx,
gdy,
gdz,
bdx,
1,
1,
0,
s.stream_id_,
NULL,
reinterpret_cast<void**>(&config)));
}}
private:
hipModule_t module;
hipFunction_t kernel_func;
}};
template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_dq_trait_>
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
...
...
@@ -176,8 +316,83 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
);
}}
template <typename dot_do_o_trait_, typename convert_dq_trait_>
float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_ext_asm[], const std::string& bwd_ext_name)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_ext_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
fmha_bwd_asm_args args;
args.ptr_dq = a.dq_acc_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dv = a.dv_ptr;
args.ptr_q = a.q_ptr;
args.ptr_k = a.k_ptr;
args.ptr_v = a.v_ptr;
args.ptr_do = a.do_ptr;
args.ptr_lse = a.lse_ptr;
args.ptr_d = a.d_ptr;
args.scalar = a.scale;
args.log2e = ck_tile::log2e_v<float>;
args.seq_len = a.seqlen_q;
args.Ts = 128 * a.hdim_q * 2;
args.Hs = a.seqlen_q * a.hdim_q * 2;
args.BAs = a.nhead_q * a.seqlen_q * a.hdim_q * 2;
auto traits = fmha_bwd_ext_traits{{a.batch,
a.nhead_q,
a.seqlen_q,
a.hdim_q,
1,
a.mask_type,
32,
128}};
HIP_CALL(hipSetDevice(0));
fmha_bwd_ext_kernel impl(HSA_KERNEL, bwd_ext_asm);
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
);
}}
float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
float r = -1;
if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) &&
(a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 128 == 0) && (a.hdim_q == 128) && (a.hdim_v == 128) && (t.is_deterministic == false)) {{
if(t.data_type.compare("fp16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_fp16_a32";
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_a32, bwd_ext_name);
return r;
}}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_fp16_causal_a32";
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_causal_a32, bwd_ext_name);
return r;
}}
}}
else if(t.data_type.compare("bf16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_bf16_a32";
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32, bwd_ext_name);
return r;
}}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_bf16_causal_a32";
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32, bwd_ext_name);
return r;
}}
}}
}}
{F_dispatch}
return r;
}}
...
...
@@ -451,14 +666,14 @@ class FmhaBwdDQDKDVKernel:
def
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
return
{
'32'
:
[
FmhaBwdDQDKDVTileSize
(
32
,
128
,
32
,
32
,
32
,
32
,
64
,
32
,
32
,
1
,
4
,
1
,
4
,
1
,
1
,
2
,
2
,
1
,
16
,
16
,
32
,
16
,
16
,
16
,
1
),
"kr_ktr_vr_iglp"
,
"kr_ktr_vr"
],
'64'
:
[
FmhaBwdDQDKDVTileSize
(
32
,
128
,
64
,
32
,
64
,
32
,
32
,
64
,
64
,
1
,
4
,
1
,
4
,
1
,
1
,
1
,
4
,
1
,
16
,
16
,
32
,
16
,
16
,
16
,
1
),
"kr_ktr_vr_iglp"
,
"kr_ktr_vr"
],
#
'32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
#
"kr_ktr_vr_iglp", "kr_ktr_vr"],
#
'64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
#
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'128'
:
[
FmhaBwdDQDKDVTileSize
(
16
,
128
,
128
,
16
,
128
,
16
,
32
,
128
,
128
,
1
,
4
,
1
,
4
,
1
,
1
,
1
,
4
,
1
,
16
,
16
,
32
,
16
,
16
,
16
,
1
),
"kr_ktr_vr_iglp"
,
"kr_ktr_vr"
],
'256'
:
[
FmhaBwdDQDKDVTileSize
(
16
,
64
,
256
,
16
,
256
,
16
,
32
,
256
,
256
,
1
,
4
,
1
,
4
,
1
,
1
,
1
,
4
,
1
,
16
,
16
,
32
,
16
,
16
,
16
,
1
),
"kr_ktr_vr_iglp"
,
"kr_ktr_vr"
]
#
'256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
#
"kr_ktr_vr_iglp", "kr_ktr_vr"]
}
else
:
return
None
...
...
@@ -501,7 +716,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
continue
if
receipt
==
3
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
bias
in
[
'no'
,
'alibi'
]
cond
&=
bias
in
[
'no'
]
cond
&=
dpad
==
dvpad
cond
&=
deterministic
==
"f"
if
not
cond
:
...
...
example/ck_tile/01_fmha/fmha_bwd.cpp
View file @
3efb8621
...
...
@@ -2,7 +2,6 @@
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_bwd.hpp"
#include "fmha_bwd_ext.hpp"
#include "ck_tile/host.hpp"
#include "mask.hpp"
#include "utils.hpp"
...
...
@@ -135,7 +134,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
hdim_v
=
hdim_q
;
bool
i_perm
=
arg_parser
.
get_bool
(
"iperm"
);
// if true, will be batch * nhead * seqlen * hdim
bool
o_perm
=
arg_parser
.
get_bool
(
"operm"
);
// if false, will be batch * seqlen * nhead* hdim
bool
o_perm
=
arg_parser
.
get_bool
(
"operm"
);
// if false, will be batch * seqlen * nhead
* hdim
float
scale
=
arg_parser
.
get_float
(
"scale"
);
if
(
scale
==
.0
f
)
...
...
@@ -211,7 +210,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using
BiasGradDataType
=
typename
TypeConfig
::
BiasGradDataType
;
// accumulation numbers for performance evaluation
//
std::size_t flop = 0, num_byte = 0;
std
::
size_t
flop
=
0
,
num_byte
=
0
;
auto
max_seqlen_q
=
std
::
numeric_limits
<
int32_t
>::
min
();
// we will use max seqlen to decide grid size
auto
max_seqlen_k
=
...
...
@@ -232,20 +231,20 @@ bool run(const ck_tile::ArgParser& arg_parser)
max_seqlen_k
=
real_seqlen_k
;
}
//
flop += nhead * (static_cast<std::size_t>(3) * static_cast<std::size_t>(2) *
//
real_seqlen_q * real_seqlen_k * hdim_q + // Q@K/dS^T@Q^T/dS@K^T
//
static_cast<std::size_t>(2) * static_cast<std::size_t>(2) *
//
real_seqlen_q * real_seqlen_k * hdim_v); // dO@V/P^T@dO^T
//
num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q +
//
sizeof(KDataType) * real_seqlen_k * hdim_q +
//
sizeof(VDataType) * real_seqlen_k * hdim_v +
//
sizeof(ODataType) * real_seqlen_q * hdim_v +
//
sizeof(OGradDataType) * real_seqlen_q * hdim_v +
//
sizeof(QGradDataType) * real_seqlen_q * hdim_q +
//
sizeof(KGradDataType) * real_seqlen_k * hdim_q +
//
sizeof(VGradDataType) * real_seqlen_k * hdim_v +
//
sizeof(LSEDataType) * real_seqlen_q);
flop
+=
nhead
*
(
static_cast
<
std
::
size_t
>
(
3
)
*
static_cast
<
std
::
size_t
>
(
2
)
*
real_seqlen_q
*
real_seqlen_k
*
hdim_q
+
// Q@K/dS^T@Q^T/dS@K^T
static_cast
<
std
::
size_t
>
(
2
)
*
static_cast
<
std
::
size_t
>
(
2
)
*
real_seqlen_q
*
real_seqlen_k
*
hdim_v
);
// dO@V/P^T@dO^T
num_byte
+=
nhead
*
(
sizeof
(
QDataType
)
*
real_seqlen_q
*
hdim_q
+
sizeof
(
KDataType
)
*
real_seqlen_k
*
hdim_q
+
sizeof
(
VDataType
)
*
real_seqlen_k
*
hdim_v
+
sizeof
(
ODataType
)
*
real_seqlen_q
*
hdim_v
+
sizeof
(
OGradDataType
)
*
real_seqlen_q
*
hdim_v
+
sizeof
(
QGradDataType
)
*
real_seqlen_q
*
hdim_q
+
sizeof
(
KGradDataType
)
*
real_seqlen_k
*
hdim_q
+
sizeof
(
VGradDataType
)
*
real_seqlen_k
*
hdim_v
+
sizeof
(
LSEDataType
)
*
real_seqlen_q
);
}
}
...
...
@@ -460,168 +459,96 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
split_stride_dq_acc
=
(
shape_batch
*
nhead
*
shape_seqlen_q
*
hdim_q
);
return
fmha_bwd_args
{
q
_buf
.
GetDeviceBuffer
(),
k
_buf
.
GetDeviceBuffer
(),
v
_buf
.
GetDeviceBuffer
()
,
bias
.
type
==
bias_enum
::
alibi
?
alibi_slope
_buf
.
GetDeviceBuffer
()
:
bias
_buf
.
GetDeviceBuffer
(),
o
_buf
.
GetDeviceBuffer
(),
lse
_buf
.
GetDeviceBuffer
(),
d
o
_buf
.
GetDeviceBuffer
(),
d_buf
.
GetDeviceBuffer
(),
// 需要使用dot_do_o kernel 生成d_buf(对应niels 的odo buffer)
randval
_buf
.
GetDeviceBuffer
(),
d
q
_buf
.
GetDeviceBuffer
(),
d
k
_buf
.
GetDeviceBuffer
(),
d
v
_buf
.
GetDeviceBuffer
(),
d
bias
_buf
.
GetDeviceBuffer
(),
dq_acc_buf
.
GetDeviceBuffer
(),
seqstart_
q
.
GetDeviceBuffer
(),
seqstart_k
.
GetDeviceBuffer
()
,
nullptr
,
shape_seqlen_
q
,
shape_seqlen_k
,
batch
,
max_seqlen_
q
,
max_seqlen_k
,
hdim_
q
,
hdim_v
,
nhead
,
nhead_k
,
s
cale
,
stride_
q
,
stride_
k
,
stride_v
,
bias
.
type
==
bias_enum
::
alibi
?
(
bias
.
rank_info
==
0
?
0
:
nhead
)
:
stride_bias
,
stride_o
,
stride_randval
,
stride_do
,
stride_q
,
// stride_dq_acc
stride_q
,
// stride_dq
stride_dk
,
stride_dv
,
stride_dbias
,
nhead_stride_q
,
nhead_stride_k
,
nhead_stride_v
,
nhead_stride_bias
,
nhead_stride_o
,
nhead_stride_randval
,
nhead_stride_do
,
nhead_stride_lsed
,
nhead_stride_q
,
// nhead_stride_dq_acc
nhead_stride_q
,
// nhead_stride_dq
nhead_stride_k
,
// nhead_stride_dk
nhead_stride_v
,
// nhead_stride_dv
nhead_stride_dbias
,
batch_stride_q
,
batch_stride_k
,
batch_stride_v
,
batch_stride_bias
,
batch_stride_o
,
batch_stride_randval
,
batch_stride_do
,
batch_stride_lsed
,
batch_stride_q
,
// batch_stride_dq_acc
batch_stride_q
,
// batch_stride_dq
batch_stride_dk
,
batch_stride_dv
,
batch_stride_dbias
,
split_stride_dq_acc
,
mask
.
left
,
mask
.
right
,
static_cast
<
ck_tile
::
index_t
>
(
mask
.
type
),
p_drop
,
p_undrop
,
{
drop_seed
,
drop_offset
}};
return
fmha_bwd_args
{
q_buf
.
GetDeviceBuffer
(),
k
_buf
.
GetDeviceBuffer
(),
v
_buf
.
GetDeviceBuffer
(),
bias
.
type
==
bias_enum
::
alibi
?
alibi_slope
_buf
.
GetDeviceBuffer
()
:
bias
_buf
.
GetDeviceBuffer
()
,
o
_buf
.
GetDeviceBuffer
(),
lse
_buf
.
GetDeviceBuffer
(),
do
_buf
.
GetDeviceBuffer
(),
d_buf
.
GetDeviceBuffer
(),
randval_buf
.
GetDeviceBuffer
(),
dq
_buf
.
GetDeviceBuffer
(),
d
k
_buf
.
GetDeviceBuffer
(),
d
v
_buf
.
GetDeviceBuffer
(),
d
bias
_buf
.
GetDeviceBuffer
(),
d
q_acc
_buf
.
GetDeviceBuffer
(),
seqstart_q
.
GetDeviceBuffer
(),
seqstart_
k
.
GetDeviceBuffer
(),
nullptr
,
shape_seqlen_q
,
shape_seqlen_
k
,
batch
,
max_seqlen_q
,
max_seqlen_
k
,
hdim_q
,
hdim_
v
,
nhead
,
nhead
_k
,
scale
,
s
tride_q
,
stride_
k
,
stride_
v
,
bias
.
type
==
bias_enum
::
alibi
?
(
bias
.
rank_info
==
0
?
0
:
nhead
)
:
stride_bias
,
stride_o
,
stride_randval
,
stride_do
,
stride_q
,
// stride_dq_acc
stride_q
,
// stride_dq
stride_dk
,
stride_dv
,
stride_dbias
,
nhead_stride_q
,
nhead_stride_k
,
nhead_stride_v
,
nhead_stride_bias
,
nhead_stride_o
,
nhead_stride_randval
,
nhead_stride_do
,
nhead_stride_lsed
,
nhead_stride_q
,
// nhead_stride_dq_acc
nhead_stride_q
,
// nhead_stride_dq
nhead_stride_k
,
// nhead_stride_dk
nhead_stride_v
,
// nhead_stride_dv
nhead_stride_dbias
,
batch_stride_q
,
batch_stride_k
,
batch_stride_v
,
batch_stride_bias
,
batch_stride_o
,
batch_stride_randval
,
batch_stride_do
,
batch_stride_lsed
,
batch_stride_q
,
// batch_stride_dq_acc
batch_stride_q
,
// batch_stride_dq
batch_stride_dk
,
batch_stride_dv
,
batch_stride_dbias
,
split_stride_dq_acc
,
mask
.
left
,
mask
.
right
,
static_cast
<
ck_tile
::
index_t
>
(
mask
.
type
),
p_drop
,
p_undrop
,
{
drop_seed
,
drop_offset
}};
}();
float
ave_time
=
fmha_bwd
(
fmha_traits
,
fmha_args
,
stream_config
,
0
);
float
ave_time
=
fmha_bwd
(
fmha_traits
,
fmha_args
,
stream_config
);
if
(
ave_time
<
0
)
{
std
::
cout
<<
", not supported yet"
<<
std
::
flush
<<
std
::
endl
;
return
false
;
}
int
atm_f32
=
1
;
int
skip_dq_rd
=
1
;
int
mask_asm
=
1
;
int
mask_kb_asm
=
0
;
int
ts_qo
=
32
;
int
ts_kv
=
128
;
int
dump_result
=
0
;
auto
fmha_ext_traits
=
fmha_bwd_ext_traits
{
batch
,
nhead
,
seqlen_q
,
hdim_q
,
atm_f32
,
skip_dq_rd
,
mask_asm
,
mask_kb_asm
,
ts_qo
,
ts_kv
,
dump_result
};
int
stride_tg
=
ts_kv
*
hdim_q
*
2
;
int
stride_head
=
seqlen_q
*
hdim_q
*
2
;
int
stride_batch
=
nhead
*
seqlen_q
*
hdim_q
*
2
;
float
k_log2e
=
log2f
(
expf
(
1
));
float
k_scalar
=
sqrt
(
hdim_q
);
k_scalar
=
static_cast
<
float
>
(
1.0
/
static_cast
<
double
>
(
k_scalar
));
#ifdef ASM_PRINT
// debug pointer
float
*
host_print
,
*
print
;
host_print
=
(
float
*
)
malloc
(
bdx
*
8
);
HIP_CALL
(
hipMalloc
(
&
print
,
bdx
*
8
));
#endif
fmha_bwd_asm_args
args
;
args
.
ptr_dq
=
dq_acc_buf
.
GetDeviceBuffer
();
// dev_dq;
// args.ptr_dq = dq_buf.GetDeviceBuffer(); // dev_dq;
args
.
ptr_dk
=
dk_buf
.
GetDeviceBuffer
();
// dev_dk;
args
.
ptr_dv
=
dv_buf
.
GetDeviceBuffer
();
// dev_dv;
args
.
ptr_q
=
q_buf
.
GetDeviceBuffer
();
// dev_q;
args
.
ptr_k
=
k_buf
.
GetDeviceBuffer
();
// dev_k;
args
.
ptr_v
=
v_buf
.
GetDeviceBuffer
();
// dev_v;
args
.
ptr_do
=
do_buf
.
GetDeviceBuffer
();
// dev_do;
args
.
ptr_lse
=
lse_buf
.
GetDeviceBuffer
();
// dev_lse;
args
.
ptr_odo
=
d_buf
.
GetDeviceBuffer
();
// dev_odo;
args
.
scalar
=
k_scalar
;
args
.
log2e
=
k_log2e
;
args
.
seq_len
=
seqlen_q
;
args
.
Ts
=
stride_tg
;
args
.
Hs
=
stride_head
;
args
.
BAs
=
stride_batch
;
#ifdef ASM_PRINT
args
.
print
=
(
void
*
)
print
;
#endif
hipStream_t
stream_ext
=
nullptr
;
fmha_bwd_ext
(
fmha_ext_traits
,
args
,
stream_ext
);
fmha_bwd
(
fmha_traits
,
fmha_args
,
stream_config
,
1
);
// if((atm_f32 == 1) || ((!skip_dq_rd) && (atm_f32 == 2)))
// HIP_CALL(
// hipMemcpy(host_fp32_dq, dev_dq, sz_mx_dq * sizeof(float), hipMemcpyDeviceToHost));
// else
// HIP_CALL(
// hipMemcpy(host_fp16_dq, dev_dq, sz_mx * sizeof(float) / 2, hipMemcpyDeviceToHost));
// ;
// HIP_CALL(hipMemcpy(host_fp16_dk, dev_dk, sz_mx * sizeof(float) / 2, hipMemcpyDeviceToHost));
// HIP_CALL(hipMemcpy(host_fp16_dv, dev_dv, sz_mx * sizeof(float) / 2, hipMemcpyDeviceToHost));
// float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
// float gb_per_sec = num_byte / 1.E6 / ave_time;
// std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, "
// << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) <<
// gb_per_sec
// << " GB/s" << std::flush;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
std
::
fixed
<<
", "
<<
std
::
setprecision
(
3
)
<<
ave_time
<<
" ms, "
<<
std
::
setprecision
(
2
)
<<
tflops
<<
" TFlops, "
<<
std
::
setprecision
(
2
)
<<
gb_per_sec
<<
" GB/s"
<<
std
::
flush
;
if
(
!
do_validation
)
{
...
...
@@ -845,12 +772,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
stream_config
stream_config_v
{
nullptr
,
true
,
0
,
0
,
1
,
arg_parser
.
get_str
(
"timer"
)
==
std
::
string
(
"gpu"
)};
// just fot odo buffer
fmha_bwd
(
fmha_traits
,
fmha_args
,
stream_config_v
,
0
);
fmha_bwd_ext
(
fmha_ext_traits
,
args
,
stream_ext
);
fmha_bwd
(
fmha_traits
,
fmha_args
,
stream_config_v
,
1
);
fmha_bwd
(
fmha_traits
,
fmha_args
,
stream_config_v
);
dq_buf
.
FromDevice
(
dq_host
.
data
());
dk_buf
.
FromDevice
(
dk_host
.
data
());
...
...
@@ -1017,6 +939,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
return
pass
;
}
...
...
@@ -1031,10 +954,10 @@ int main(int argc, char* argv[])
{
return
run
<
ck_tile
::
half_t
>
(
arg_parser
)
?
0
:
-
2
;
}
//
else if(data_type == "bf16")
//
{
//
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
//
}
else
if
(
data_type
==
"bf16"
)
{
return
run
<
ck_tile
::
bf16_t
>
(
arg_parser
)
?
0
:
-
2
;
}
return
-
3
;
}
example/ck_tile/01_fmha/fmha_bwd.hpp
View file @
3efb8621
...
...
@@ -440,4 +440,4 @@ struct fmha_bwd_traits
bool
is_deterministic
;
// TODO: padding check is inside this api
};
float
fmha_bwd
(
fmha_bwd_traits
,
fmha_bwd_args
,
const
ck_tile
::
stream_config
&
,
int
flag
);
float
fmha_bwd
(
fmha_bwd_traits
,
fmha_bwd_args
,
const
ck_tile
::
stream_config
&
);
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment