Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d4139c8b
Commit
d4139c8b
authored
Sep 05, 2024
by
Fang.Che
Browse files
add fmha asm api: fmha_bwd_ext
parent
933ac7c7
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
411 additions
and
106 deletions
+411
-106
example/ck_tile/01_fmha/CMakeLists.txt
example/ck_tile/01_fmha/CMakeLists.txt
+2
-1
example/ck_tile/01_fmha/fmha_bwd.cpp
example/ck_tile/01_fmha/fmha_bwd.cpp
+181
-104
example/ck_tile/01_fmha/fmha_bwd.hpp
example/ck_tile/01_fmha/fmha_bwd.hpp
+1
-1
example/ck_tile/01_fmha/fmha_bwd_ext.cpp
example/ck_tile/01_fmha/fmha_bwd_ext.cpp
+85
-0
example/ck_tile/01_fmha/fmha_bwd_ext.hpp
example/ck_tile/01_fmha/fmha_bwd_ext.hpp
+142
-0
No files found.
example/ck_tile/01_fmha/CMakeLists.txt
View file @
d4139c8b
...
@@ -58,7 +58,8 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
...
@@ -58,7 +58,8 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
# not using add_example_executable() to add this target, since we don't want this to have
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
# to be included in "make all/install/check"
message
(
"adding example
${
EXAMPLE_FMHA_BWD
}
"
)
message
(
"adding example
${
EXAMPLE_FMHA_BWD
}
"
)
add_executable
(
${
EXAMPLE_FMHA_BWD
}
EXCLUDE_FROM_ALL fmha_bwd.cpp
)
add_executable
(
${
EXAMPLE_FMHA_BWD
}
EXCLUDE_FROM_ALL hsaco/bwd_arg.cpp hsaco/bwd_bf16_a32.cpp hsaco/bwd_bf16_causal_a32.cpp hsaco/bwd_bf16_nocoex_a32.cpp hsaco/bwd_bf16_nocoex_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_nocoex_a32.cpp hsaco/bwd_fp16_nocoex_causal_a32.cpp fmha_bwd_ext.cpp fmha_bwd.cpp
)
target_include_directories
(
${
EXAMPLE_FMHA_BWD
}
PRIVATE
${
CMAKE_CURRENT_LIST_DIR
}
)
target_include_directories
(
${
EXAMPLE_FMHA_BWD
}
PRIVATE
${
CMAKE_CURRENT_LIST_DIR
}
)
target_sources
(
${
EXAMPLE_FMHA_BWD
}
PRIVATE
${
FMHA_BWD_GEN_BLOBS
}
)
target_sources
(
${
EXAMPLE_FMHA_BWD
}
PRIVATE
${
FMHA_BWD_GEN_BLOBS
}
)
...
...
example/ck_tile/01_fmha/fmha_bwd.cpp
View file @
d4139c8b
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_bwd.hpp"
#include "fmha_bwd.hpp"
#include "fmha_bwd_ext.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host.hpp"
#include "mask.hpp"
#include "mask.hpp"
#include "utils.hpp"
#include "utils.hpp"
...
@@ -134,7 +135,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -134,7 +135,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
hdim_v
=
hdim_q
;
hdim_v
=
hdim_q
;
bool
i_perm
=
arg_parser
.
get_bool
(
"iperm"
);
// if true, will be batch * nhead * seqlen * hdim
bool
i_perm
=
arg_parser
.
get_bool
(
"iperm"
);
// if true, will be batch * nhead * seqlen * hdim
bool
o_perm
=
arg_parser
.
get_bool
(
"operm"
);
// if false, will be batch * seqlen * nhead
* hdim
bool
o_perm
=
arg_parser
.
get_bool
(
"operm"
);
// if false, will be batch * seqlen * nhead* hdim
float
scale
=
arg_parser
.
get_float
(
"scale"
);
float
scale
=
arg_parser
.
get_float
(
"scale"
);
if
(
scale
==
.0
f
)
if
(
scale
==
.0
f
)
...
@@ -210,7 +211,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -210,7 +211,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using
BiasGradDataType
=
typename
TypeConfig
::
BiasGradDataType
;
using
BiasGradDataType
=
typename
TypeConfig
::
BiasGradDataType
;
// accumulation numbers for performance evaluation
// accumulation numbers for performance evaluation
std
::
size_t
flop
=
0
,
num_byte
=
0
;
//
std::size_t flop = 0, num_byte = 0;
auto
max_seqlen_q
=
auto
max_seqlen_q
=
std
::
numeric_limits
<
int32_t
>::
min
();
// we will use max seqlen to decide grid size
std
::
numeric_limits
<
int32_t
>::
min
();
// we will use max seqlen to decide grid size
auto
max_seqlen_k
=
auto
max_seqlen_k
=
...
@@ -231,20 +232,20 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -231,20 +232,20 @@ bool run(const ck_tile::ArgParser& arg_parser)
max_seqlen_k
=
real_seqlen_k
;
max_seqlen_k
=
real_seqlen_k
;
}
}
flop
+=
nhead
*
(
static_cast
<
std
::
size_t
>
(
3
)
*
static_cast
<
std
::
size_t
>
(
2
)
*
//
flop += nhead * (static_cast<std::size_t>(3) * static_cast<std::size_t>(2) *
real_seqlen_q
*
real_seqlen_k
*
hdim_q
+
// Q@K/dS^T@Q^T/dS@K^T
//
real_seqlen_q * real_seqlen_k * hdim_q + // Q@K/dS^T@Q^T/dS@K^T
static_cast
<
std
::
size_t
>
(
2
)
*
static_cast
<
std
::
size_t
>
(
2
)
*
//
static_cast<std::size_t>(2) * static_cast<std::size_t>(2) *
real_seqlen_q
*
real_seqlen_k
*
hdim_v
);
// dO@V/P^T@dO^T
//
real_seqlen_q * real_seqlen_k * hdim_v); // dO@V/P^T@dO^T
num_byte
+=
nhead
*
(
sizeof
(
QDataType
)
*
real_seqlen_q
*
hdim_q
+
//
num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q +
sizeof
(
KDataType
)
*
real_seqlen_k
*
hdim_q
+
//
sizeof(KDataType) * real_seqlen_k * hdim_q +
sizeof
(
VDataType
)
*
real_seqlen_k
*
hdim_v
+
//
sizeof(VDataType) * real_seqlen_k * hdim_v +
sizeof
(
ODataType
)
*
real_seqlen_q
*
hdim_v
+
//
sizeof(ODataType) * real_seqlen_q * hdim_v +
sizeof
(
OGradDataType
)
*
real_seqlen_q
*
hdim_v
+
//
sizeof(OGradDataType) * real_seqlen_q * hdim_v +
sizeof
(
QGradDataType
)
*
real_seqlen_q
*
hdim_q
+
//
sizeof(QGradDataType) * real_seqlen_q * hdim_q +
sizeof
(
KGradDataType
)
*
real_seqlen_k
*
hdim_q
+
//
sizeof(KGradDataType) * real_seqlen_k * hdim_q +
sizeof
(
VGradDataType
)
*
real_seqlen_k
*
hdim_v
+
//
sizeof(VGradDataType) * real_seqlen_k * hdim_v +
sizeof
(
LSEDataType
)
*
real_seqlen_q
);
//
sizeof(LSEDataType) * real_seqlen_q);
}
}
}
}
...
@@ -459,7 +460,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -459,7 +460,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
split_stride_dq_acc
=
const
ck_tile
::
index_t
split_stride_dq_acc
=
(
shape_batch
*
nhead
*
shape_seqlen_q
*
hdim_q
);
(
shape_batch
*
nhead
*
shape_seqlen_q
*
hdim_q
);
return
fmha_bwd_args
{
q_buf
.
GetDeviceBuffer
(),
return
fmha_bwd_args
{
q_buf
.
GetDeviceBuffer
(),
k_buf
.
GetDeviceBuffer
(),
k_buf
.
GetDeviceBuffer
(),
v_buf
.
GetDeviceBuffer
(),
v_buf
.
GetDeviceBuffer
(),
bias
.
type
==
bias_enum
::
alibi
?
alibi_slope_buf
.
GetDeviceBuffer
()
bias
.
type
==
bias_enum
::
alibi
?
alibi_slope_buf
.
GetDeviceBuffer
()
...
@@ -467,7 +469,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -467,7 +469,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
o_buf
.
GetDeviceBuffer
(),
o_buf
.
GetDeviceBuffer
(),
lse_buf
.
GetDeviceBuffer
(),
lse_buf
.
GetDeviceBuffer
(),
do_buf
.
GetDeviceBuffer
(),
do_buf
.
GetDeviceBuffer
(),
d_buf
.
GetDeviceBuffer
(),
d_buf
.
GetDeviceBuffer
(),
// 需要使用dot_do_o kernel 生成d_buf(对应niels 的odo buffer)
randval_buf
.
GetDeviceBuffer
(),
randval_buf
.
GetDeviceBuffer
(),
dq_buf
.
GetDeviceBuffer
(),
dq_buf
.
GetDeviceBuffer
(),
dk_buf
.
GetDeviceBuffer
(),
dk_buf
.
GetDeviceBuffer
(),
...
@@ -490,8 +492,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -490,8 +492,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
stride_q
,
stride_q
,
stride_k
,
stride_k
,
stride_v
,
stride_v
,
bias
.
type
==
bias_enum
::
alibi
?
(
bias
.
rank_info
==
0
?
0
:
nhead
)
bias
.
type
==
bias_enum
::
alibi
?
(
bias
.
rank_info
==
0
?
0
:
nhead
)
:
stride_bias
,
:
stride_bias
,
stride_o
,
stride_o
,
stride_randval
,
stride_randval
,
stride_do
,
stride_do
,
...
@@ -535,20 +536,92 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -535,20 +536,92 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
drop_seed
,
drop_offset
}};
{
drop_seed
,
drop_offset
}};
}();
}();
float
ave_time
=
fmha_bwd
(
fmha_traits
,
fmha_args
,
stream_config
);
float
ave_time
=
fmha_bwd
(
fmha_traits
,
fmha_args
,
stream_config
,
0
);
if
(
ave_time
<
0
)
if
(
ave_time
<
0
)
{
{
std
::
cout
<<
", not supported yet"
<<
std
::
flush
<<
std
::
endl
;
std
::
cout
<<
", not supported yet"
<<
std
::
flush
<<
std
::
endl
;
return
false
;
return
false
;
}
}
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
int
atm_f32
=
1
;
int
skip_dq_rd
=
1
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
int
mask_asm
=
1
;
int
mask_kb_asm
=
0
;
std
::
cout
<<
std
::
fixed
<<
", "
<<
std
::
setprecision
(
3
)
<<
ave_time
<<
" ms, "
int
ts_qo
=
32
;
<<
std
::
setprecision
(
2
)
<<
tflops
<<
" TFlops, "
<<
std
::
setprecision
(
2
)
<<
gb_per_sec
int
ts_kv
=
128
;
<<
" GB/s"
<<
std
::
flush
;
int
dump_result
=
0
;
auto
fmha_ext_traits
=
fmha_bwd_ext_traits
{
batch
,
nhead
,
seqlen_q
,
hdim_q
,
atm_f32
,
skip_dq_rd
,
mask_asm
,
mask_kb_asm
,
ts_qo
,
ts_kv
,
dump_result
};
int
stride_tg
=
ts_kv
*
hdim_q
*
2
;
int
stride_head
=
seqlen_q
*
hdim_q
*
2
;
int
stride_batch
=
nhead
*
seqlen_q
*
hdim_q
*
2
;
float
k_log2e
=
log2f
(
expf
(
1
));
float
k_scalar
=
sqrt
(
hdim_q
);
k_scalar
=
static_cast
<
float
>
(
1.0
/
static_cast
<
double
>
(
k_scalar
));
#ifdef ASM_PRINT
// debug pointer
float
*
host_print
,
*
print
;
host_print
=
(
float
*
)
malloc
(
bdx
*
8
);
HIP_CALL
(
hipMalloc
(
&
print
,
bdx
*
8
));
#endif
fmha_bwd_asm_args
args
;
args
.
ptr_dq
=
dq_acc_buf
.
GetDeviceBuffer
();
// dev_dq;
// args.ptr_dq = dq_buf.GetDeviceBuffer(); // dev_dq;
args
.
ptr_dk
=
dk_buf
.
GetDeviceBuffer
();
// dev_dk;
args
.
ptr_dv
=
dv_buf
.
GetDeviceBuffer
();
// dev_dv;
args
.
ptr_q
=
q_buf
.
GetDeviceBuffer
();
// dev_q;
args
.
ptr_k
=
k_buf
.
GetDeviceBuffer
();
// dev_k;
args
.
ptr_v
=
v_buf
.
GetDeviceBuffer
();
// dev_v;
args
.
ptr_do
=
do_buf
.
GetDeviceBuffer
();
// dev_do;
args
.
ptr_lse
=
lse_buf
.
GetDeviceBuffer
();
// dev_lse;
args
.
ptr_odo
=
d_buf
.
GetDeviceBuffer
();
// dev_odo;
args
.
scalar
=
k_scalar
;
args
.
log2e
=
k_log2e
;
args
.
seq_len
=
seqlen_q
;
args
.
Ts
=
stride_tg
;
args
.
Hs
=
stride_head
;
args
.
BAs
=
stride_batch
;
#ifdef ASM_PRINT
args
.
print
=
(
void
*
)
print
;
#endif
hipStream_t
stream_ext
=
nullptr
;
fmha_bwd_ext
(
fmha_ext_traits
,
args
,
stream_ext
);
fmha_bwd
(
fmha_traits
,
fmha_args
,
stream_config
,
1
);
// if((atm_f32 == 1) || ((!skip_dq_rd) && (atm_f32 == 2)))
// HIP_CALL(
// hipMemcpy(host_fp32_dq, dev_dq, sz_mx_dq * sizeof(float), hipMemcpyDeviceToHost));
// else
// HIP_CALL(
// hipMemcpy(host_fp16_dq, dev_dq, sz_mx * sizeof(float) / 2, hipMemcpyDeviceToHost));
// ;
// HIP_CALL(hipMemcpy(host_fp16_dk, dev_dk, sz_mx * sizeof(float) / 2, hipMemcpyDeviceToHost));
// HIP_CALL(hipMemcpy(host_fp16_dv, dev_dv, sz_mx * sizeof(float) / 2, hipMemcpyDeviceToHost));
// float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
// float gb_per_sec = num_byte / 1.E6 / ave_time;
// std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, "
// << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) <<
// gb_per_sec
// << " GB/s" << std::flush;
if
(
!
do_validation
)
if
(
!
do_validation
)
{
{
...
@@ -772,7 +845,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -772,7 +845,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
stream_config
stream_config_v
{
ck_tile
::
stream_config
stream_config_v
{
nullptr
,
true
,
0
,
0
,
1
,
arg_parser
.
get_str
(
"timer"
)
==
std
::
string
(
"gpu"
)};
nullptr
,
true
,
0
,
0
,
1
,
arg_parser
.
get_str
(
"timer"
)
==
std
::
string
(
"gpu"
)};
fmha_bwd
(
fmha_traits
,
fmha_args
,
stream_config_v
);
// just fot odo buffer
fmha_bwd
(
fmha_traits
,
fmha_args
,
stream_config_v
,
0
);
fmha_bwd_ext
(
fmha_ext_traits
,
args
,
stream_ext
);
fmha_bwd
(
fmha_traits
,
fmha_args
,
stream_config_v
,
1
);
dq_buf
.
FromDevice
(
dq_host
.
data
());
dq_buf
.
FromDevice
(
dq_host
.
data
());
dk_buf
.
FromDevice
(
dk_host
.
data
());
dk_buf
.
FromDevice
(
dk_host
.
data
());
...
@@ -939,7 +1017,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -939,7 +1017,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
return
pass
;
return
pass
;
}
}
...
@@ -954,10 +1031,10 @@ int main(int argc, char* argv[])
...
@@ -954,10 +1031,10 @@ int main(int argc, char* argv[])
{
{
return
run
<
ck_tile
::
half_t
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
half_t
>
(
arg_parser
)
?
0
:
-
2
;
}
}
else
if
(
data_type
==
"bf16"
)
//
else if(data_type == "bf16")
{
//
{
return
run
<
ck_tile
::
bf16_t
>
(
arg_parser
)
?
0
:
-
2
;
//
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
}
//
}
return
-
3
;
return
-
3
;
}
}
example/ck_tile/01_fmha/fmha_bwd.hpp
View file @
d4139c8b
...
@@ -440,4 +440,4 @@ struct fmha_bwd_traits
...
@@ -440,4 +440,4 @@ struct fmha_bwd_traits
bool
is_deterministic
;
bool
is_deterministic
;
// TODO: padding check is inside this api
// TODO: padding check is inside this api
};
};
float
fmha_bwd
(
fmha_bwd_traits
,
fmha_bwd_args
,
const
ck_tile
::
stream_config
&
);
float
fmha_bwd
(
fmha_bwd_traits
,
fmha_bwd_args
,
const
ck_tile
::
stream_config
&
,
int
flag
);
example/ck_tile/01_fmha/fmha_bwd_ext.cpp
0 → 100644
View file @
d4139c8b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <stdio.h>
#include <vector>
#include <functional>
#include <iostream>
#include <string>
#include "fmha_bwd_ext.hpp"
#include "hsaco/fmha_hsaco.h"
int
fmha_bwd_ext
(
fmha_bwd_ext_traits
fmha_ext_traits
,
fmha_bwd_asm_args
args
,
hipStream_t
stream
)
{
hipEvent_t
evt_00
,
evt_11
;
HIP_CALL
(
hipSetDevice
(
0
));
// fmha_bwd_ext_kernel impl(HSACO, HSA_KERNEL);
fmha_bwd_ext_kernel
impl
(
HSA_KERNEL
,
bwd_fp16_a32
);
int
b
=
fmha_ext_traits
.
b
;
int
h
=
fmha_ext_traits
.
h
;
int
s
=
fmha_ext_traits
.
s
;
int
d
=
fmha_ext_traits
.
d
;
int
atm_f32
=
fmha_ext_traits
.
atm_f32
;
int
skip_dq_rd
=
fmha_ext_traits
.
skip_dq_rd
;
// int ts_qo = fmha_ext_traits.ts_qo;
// int ts_kv = fmha_ext_traits.ts_kv;
int
dump_result
=
fmha_ext_traits
.
dump_result
;
std
::
cout
<<
"b:"
<<
b
<<
std
::
endl
;
std
::
cout
<<
"h:"
<<
h
<<
std
::
endl
;
std
::
cout
<<
"s:"
<<
s
<<
std
::
endl
;
std
::
cout
<<
"d:"
<<
d
<<
std
::
endl
;
std
::
cout
<<
"dump_result:"
<<
dump_result
<<
std
::
endl
;
std
::
cout
<<
"atm_f32:"
<<
atm_f32
<<
std
::
endl
;
std
::
cout
<<
"skip_dq_rd:"
<<
skip_dq_rd
<<
std
::
endl
;
size_t
arg_size
=
sizeof
(
args
);
printf
(
"argsize: %zu
\n
"
,
arg_size
);
#ifdef ASM_PRINT
int
max_i
=
256
;
HIP_CALL
(
hipMemcpy
(
host_print
,
print
,
8
*
max_i
,
hipMemcpyDeviceToHost
));
for
(
int
i
=
0
;
i
<
max_i
;
i
++
)
{
if
(((
uint32_t
*
)
host_print
)[
2
*
i
+
1
]
!=
0x5c005c00
)
printf
(
"Thread%d, PrintVal:0x%x
\n
"
,
((
int
*
)
host_print
)[
2
*
i
],
((
uint32_t
*
)
host_print
)[
2
*
i
+
1
]);
// std::cout<<"Thread"<<((int*) host_print)[2*i]<<",
// PrintVal1:"<<(((float16*)host_print)[4*i+2])<<
//", PrintVal2:"<<( ( (float16*)host_print )[4*i+3] )<<std::endl;
}
#endif
HIP_CALL
(
hipEventCreate
(
&
evt_00
));
HIP_CALL
(
hipEventCreate
(
&
evt_11
));
HIP_CALL
(
hipDeviceSynchronize
());
HIP_CALL
(
hipEventRecord
(
evt_00
,
NULL
));
impl
.
launch_kernel
(
fmha_ext_traits
,
args
,
stream
);
std
::
cout
<<
"we are done"
<<
std
::
endl
;
float
elapsed_ms
;
HIP_CALL
(
hipEventRecord
(
evt_11
,
NULL
));
HIP_CALL
(
hipEventSynchronize
(
evt_11
));
HIP_CALL
(
hipDeviceSynchronize
());
HIP_CALL
(
hipEventElapsedTime
(
&
elapsed_ms
,
evt_00
,
evt_11
));
HIP_CALL
(
hipEventDestroy
(
evt_00
));
HIP_CALL
(
hipEventDestroy
(
evt_11
));
// float time_per_loop = elapsed_ms / total_loop;
// float gflops = static_cast<float>(2.0) * 5 * b * h * d * s * s / time_per_loop / (1e6);
// printf("b:%d,h:%d,s:%d,d:%d, time: %.3f, gflops:%.3f\n", b, h, s, d, time_per_loop, gflops);
printf
(
"b:%d,h:%d,s:%d,d:%d
\n
"
,
b
,
h
,
s
,
d
);
#ifdef ASM_PRINT
free
(
host_print
);
HIP_CALL
(
hipFree
(
print
));
#endif
// printf("CU:%d, TIPS:%.3f(2x:%.3f, 4x:%.3f), cost:%fms per loop\n", num_cu, tips, 2*tips,
// 4*tips, time_per_loop);
return
0
;
}
example/ck_tile/01_fmha/fmha_bwd_ext.hpp
0 → 100644
View file @
d4139c8b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#define HSACO "kernel.co"
#define HSA_KERNEL "kernel_func"
struct
p3
{
unsigned
int
_p0
;
unsigned
int
_p1
;
unsigned
int
_p2
;
};
struct
p2
{
unsigned
int
_p0
;
unsigned
int
_p1
;
};
struct
__attribute__
((
packed
))
fmha_bwd_asm_args
{
void
*
ptr_dq
;
p2
_p0
;
void
*
ptr_dk
;
p2
_p1
;
void
*
ptr_dv
;
p2
_p2
;
void
*
ptr_q
;
p2
_p3
;
void
*
ptr_k
;
p2
_p4
;
void
*
ptr_v
;
p2
_p5
;
void
*
ptr_do
;
p2
_p6
;
void
*
ptr_lse
;
p2
_p7
;
void
*
ptr_odo
;
p2
_p8
;
float
scalar
;
p3
_p9
;
float
log2e
;
p3
_p10
;
unsigned
int
seq_len
;
p3
_p11
;
unsigned
int
Ts
;
p3
_p12
;
unsigned
int
Hs
;
p3
_p13
;
unsigned
int
BAs
;
p3
_p14
;
#ifdef ASM_PRINT
void
*
print
;
#endif
};
#define HIP_CALL(call) \
do \
{ \
hipError_t err = call; \
if(err != hipSuccess) \
{ \
printf("[hiperror](%d) fail to call %s", static_cast<int>(err), #call); \
exit(0); \
} \
} while(0)
struct
fmha_bwd_ext_traits
{
int
b
;
int
h
;
int
s
;
int
d
;
int
atm_f32
;
int
skip_dq_rd
;
int
mask
;
int
mask_kb
;
int
ts_qo
;
int
ts_kv
;
int
dump_result
;
};
class
fmha_bwd_ext_kernel
{
public:
// fmha_bwd_ext_kernel(const char* image, const std::string& name)
// {
// HIP_CALL(hipModuleLoad(&module, image));
// HIP_CALL(hipModuleGetFunction(&kernel_func, module, name.c_str()));
// }
fmha_bwd_ext_kernel
(
const
std
::
string
&
name
,
unsigned
char
buffer
[])
{
HIP_CALL
(
hipModuleLoadData
(
&
module
,
buffer
));
HIP_CALL
(
hipModuleGetFunction
(
&
kernel_func
,
module
,
name
.
c_str
()));
}
void
launch_kernel
(
fmha_bwd_ext_traits
fmha_ext_traits
,
fmha_bwd_asm_args
args
,
hipStream_t
stream
)
{
size_t
arg_size
=
sizeof
(
args
);
void
*
config
[]
=
{
HIP_LAUNCH_PARAM_BUFFER_POINTER
,
&
args
,
HIP_LAUNCH_PARAM_BUFFER_SIZE
,
&
arg_size
,
HIP_LAUNCH_PARAM_END
};
int
bdx
=
256
;
int
gdx
=
fmha_ext_traits
.
s
/
fmha_ext_traits
.
ts_kv
;
int
gdy
=
fmha_ext_traits
.
h
;
int
gdz
=
fmha_ext_traits
.
b
;
if
((
fmha_ext_traits
.
mask
==
1
)
&&
(
fmha_ext_traits
.
mask_kb
==
1
))
{
int
num_tg
=
fmha_ext_traits
.
s
/
fmha_ext_traits
.
ts_kv
;
gdx
=
(
num_tg
%
2
)
?
(
num_tg
/
2
+
1
)
:
(
num_tg
/
2
);
}
HIP_CALL
(
hipModuleLaunchKernel
(
kernel_func
,
gdx
,
gdy
,
gdz
,
bdx
,
1
,
1
,
0
,
stream
,
NULL
,
reinterpret_cast
<
void
**>
(
&
config
)));
}
private:
hipModule_t
module
;
hipFunction_t
kernel_func
;
};
int
fmha_bwd_ext
(
fmha_bwd_ext_traits
fmha_ext_traits
,
fmha_bwd_asm_args
args
,
hipStream_t
stream
);
// int fmha_bwd_ext(fmha_bwd_ext_traits fmha_ext_traits);
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment