Commit 3efb8621 authored by danyao12's avatar danyao12
Browse files

tmp save

parent d4139c8b
...@@ -163,6 +163,146 @@ std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>() ...@@ -163,6 +163,146 @@ std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp" FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp"
FMHA_BWD_API=""" FMHA_BWD_API="""
#include <iostream> #include <iostream>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include "hsaco/fmha_hsaco.h"
#define HSA_KERNEL "kernel_func"
#define HIP_CALL(call) \\
do \\
{{ \\
hipError_t err = call; \\
if(err != hipSuccess) \\
{{ \\
printf("[hiperror](%d) fail to call %s", static_cast<int>(err), #call); \\
exit(0); \\
}} \\
}} while(0)
// extern declare the function since hip/hip_ext.h header is broken
extern hipError_t hipExtModuleLaunchKernel(hipFunction_t, // NOLINT
uint32_t,
uint32_t,
uint32_t,
uint32_t,
uint32_t,
uint32_t,
size_t,
hipStream_t,
void**,
void**,
hipEvent_t = nullptr,
hipEvent_t = nullptr,
uint32_t = 0);
struct p3
{{
unsigned int _p0;
unsigned int _p1;
unsigned int _p2;
}};
struct p2
{{
unsigned int _p0;
unsigned int _p1;
}};
struct __attribute__((packed)) fmha_bwd_asm_args
{{
void* ptr_dq;
p2 _p0;
void* ptr_dk;
p2 _p1;
void* ptr_dv;
p2 _p2;
const void* ptr_q;
p2 _p3;
const void* ptr_k;
p2 _p4;
const void* ptr_v;
p2 _p5;
const void* ptr_do;
p2 _p6;
const void* ptr_lse;
p2 _p7;
const void* ptr_d;
p2 _p8;
float scalar;
p3 _p9;
float log2e;
p3 _p10;
unsigned int seq_len;
p3 _p11;
unsigned int Ts;
p3 _p12;
unsigned int Hs;
p3 _p13;
unsigned int BAs;
p3 _p14;
}};
struct fmha_bwd_ext_traits
{{
int b;
int h;
int s;
int d;
int atm_f32;
int mask;
int ts_qo;
int ts_kv;
}};
std::string hip_error(int error) {{ return hipGetErrorString(static_cast<hipError_t>(error)); }}
class fmha_bwd_ext_kernel
{{
public:
fmha_bwd_ext_kernel(const std::string& name, unsigned char buffer[])
{{
// HIP_CALL(hipModuleLoadData(&module, buffer));
auto status = hipModuleLoadData(&module, buffer);
if(status != hipSuccess)
throw std::runtime_error("Failed to load module: " + hip_error(status));
HIP_CALL(hipModuleGetFunction(&kernel_func, module, name.c_str()));
}}
void
launch_kernel(fmha_bwd_ext_traits fmha_ext_traits, fmha_bwd_asm_args args, const ck_tile::stream_config& s) const
{{
size_t arg_size = sizeof(args);
void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
&args,
HIP_LAUNCH_PARAM_BUFFER_SIZE,
&arg_size,
HIP_LAUNCH_PARAM_END}};
int bdx = 256;
int gdx = fmha_ext_traits.s / fmha_ext_traits.ts_kv;
int gdy = fmha_ext_traits.h;
int gdz = fmha_ext_traits.b;
if(fmha_ext_traits.mask > 0)
{{
int num_tg = fmha_ext_traits.s / fmha_ext_traits.ts_kv;
gdx = (num_tg % 2) ? (num_tg / 2 + 1) : (num_tg / 2);
}}
HIP_CALL(hipModuleLaunchKernel(kernel_func,
gdx,
gdy,
gdz,
bdx,
1,
1,
0,
s.stream_id_,
NULL,
reinterpret_cast<void**>(&config)));
}}
private:
hipModule_t module;
hipFunction_t kernel_func;
}};
template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_dq_trait_> template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_dq_trait_>
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
...@@ -176,8 +316,83 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) ...@@ -176,8 +316,83 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
); );
}} }}
template <typename dot_do_o_trait_, typename convert_dq_trait_>
float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_ext_asm[], const std::string& bwd_ext_name)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_ext_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
fmha_bwd_asm_args args;
args.ptr_dq = a.dq_acc_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dv = a.dv_ptr;
args.ptr_q = a.q_ptr;
args.ptr_k = a.k_ptr;
args.ptr_v = a.v_ptr;
args.ptr_do = a.do_ptr;
args.ptr_lse = a.lse_ptr;
args.ptr_d = a.d_ptr;
args.scalar = a.scale;
args.log2e = ck_tile::log2e_v<float>;
args.seq_len = a.seqlen_q;
args.Ts = 128 * a.hdim_q * 2;
args.Hs = a.seqlen_q * a.hdim_q * 2;
args.BAs = a.nhead_q * a.seqlen_q * a.hdim_q * 2;
auto traits = fmha_bwd_ext_traits{{a.batch,
a.nhead_q,
a.seqlen_q,
a.hdim_q,
1,
a.mask_type,
32,
128}};
HIP_CALL(hipSetDevice(0));
fmha_bwd_ext_kernel impl(HSA_KERNEL, bwd_ext_asm);
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
);
}}
float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
float r = -1; float r = -1;
if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) &&
(a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 128 == 0) && (a.hdim_q == 128) && (a.hdim_v == 128) && (t.is_deterministic == false)) {{
if(t.data_type.compare("fp16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_fp16_a32";
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_a32, bwd_ext_name);
return r;
}}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_fp16_causal_a32";
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_causal_a32, bwd_ext_name);
return r;
}}
}}
else if(t.data_type.compare("bf16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_bf16_a32";
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32, bwd_ext_name);
return r;
}}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_bf16_causal_a32";
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32, bwd_ext_name);
return r;
}}
}}
}}
{F_dispatch} {F_dispatch}
return r; return r;
}} }}
...@@ -451,14 +666,14 @@ class FmhaBwdDQDKDVKernel: ...@@ -451,14 +666,14 @@ class FmhaBwdDQDKDVKernel:
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16': if dtype == 'fp16' or dtype == 'bf16':
return { return {
'32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), # '32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"], # "kr_ktr_vr_iglp", "kr_ktr_vr"],
'64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), # '64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"], # "kr_ktr_vr_iglp", "kr_ktr_vr"],
'128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), '128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"], "kr_ktr_vr_iglp", "kr_ktr_vr"],
'256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), # '256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"] # "kr_ktr_vr_iglp", "kr_ktr_vr"]
} }
else: else:
return None return None
...@@ -501,7 +716,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -501,7 +716,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
continue continue
if receipt == 3: if receipt == 3:
cond = dtype in ['fp16', 'bf16'] cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'alibi'] cond &= bias in ['no']
cond &= dpad == dvpad cond &= dpad == dvpad
cond &= deterministic == "f" cond &= deterministic == "f"
if not cond: if not cond:
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_bwd.hpp" #include "fmha_bwd.hpp"
#include "fmha_bwd_ext.hpp"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "mask.hpp" #include "mask.hpp"
#include "utils.hpp" #include "utils.hpp"
...@@ -135,7 +134,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -135,7 +134,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
hdim_v = hdim_q; hdim_v = hdim_q;
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim
bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead* hdim bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim
float scale = arg_parser.get_float("scale"); float scale = arg_parser.get_float("scale");
if(scale == .0f) if(scale == .0f)
...@@ -211,7 +210,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -211,7 +210,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using BiasGradDataType = typename TypeConfig::BiasGradDataType; using BiasGradDataType = typename TypeConfig::BiasGradDataType;
// accumulation numbers for performance evaluation // accumulation numbers for performance evaluation
// std::size_t flop = 0, num_byte = 0; std::size_t flop = 0, num_byte = 0;
auto max_seqlen_q = auto max_seqlen_q =
std::numeric_limits<int32_t>::min(); // we will use max seqlen to decide grid size std::numeric_limits<int32_t>::min(); // we will use max seqlen to decide grid size
auto max_seqlen_k = auto max_seqlen_k =
...@@ -232,20 +231,20 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -232,20 +231,20 @@ bool run(const ck_tile::ArgParser& arg_parser)
max_seqlen_k = real_seqlen_k; max_seqlen_k = real_seqlen_k;
} }
// flop += nhead * (static_cast<std::size_t>(3) * static_cast<std::size_t>(2) * flop += nhead * (static_cast<std::size_t>(3) * static_cast<std::size_t>(2) *
// real_seqlen_q * real_seqlen_k * hdim_q + // Q@K/dS^T@Q^T/dS@K^T real_seqlen_q * real_seqlen_k * hdim_q + // Q@K/dS^T@Q^T/dS@K^T
// static_cast<std::size_t>(2) * static_cast<std::size_t>(2) * static_cast<std::size_t>(2) * static_cast<std::size_t>(2) *
// real_seqlen_q * real_seqlen_k * hdim_v); // dO@V/P^T@dO^T real_seqlen_q * real_seqlen_k * hdim_v); // dO@V/P^T@dO^T
// num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q +
// sizeof(KDataType) * real_seqlen_k * hdim_q + sizeof(KDataType) * real_seqlen_k * hdim_q +
// sizeof(VDataType) * real_seqlen_k * hdim_v + sizeof(VDataType) * real_seqlen_k * hdim_v +
// sizeof(ODataType) * real_seqlen_q * hdim_v + sizeof(ODataType) * real_seqlen_q * hdim_v +
// sizeof(OGradDataType) * real_seqlen_q * hdim_v + sizeof(OGradDataType) * real_seqlen_q * hdim_v +
// sizeof(QGradDataType) * real_seqlen_q * hdim_q + sizeof(QGradDataType) * real_seqlen_q * hdim_q +
// sizeof(KGradDataType) * real_seqlen_k * hdim_q + sizeof(KGradDataType) * real_seqlen_k * hdim_q +
// sizeof(VGradDataType) * real_seqlen_k * hdim_v + sizeof(VGradDataType) * real_seqlen_k * hdim_v +
// sizeof(LSEDataType) * real_seqlen_q); sizeof(LSEDataType) * real_seqlen_q);
} }
} }
...@@ -460,8 +459,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -460,8 +459,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t split_stride_dq_acc = const ck_tile::index_t split_stride_dq_acc =
(shape_batch * nhead * shape_seqlen_q * hdim_q); (shape_batch * nhead * shape_seqlen_q * hdim_q);
return fmha_bwd_args{ return fmha_bwd_args{q_buf.GetDeviceBuffer(),
q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(), v_buf.GetDeviceBuffer(),
bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
...@@ -469,7 +467,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -469,7 +467,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
o_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(), lse_buf.GetDeviceBuffer(),
do_buf.GetDeviceBuffer(), do_buf.GetDeviceBuffer(),
d_buf.GetDeviceBuffer(), // 需要使用dot_do_o kernel 生成d_buf(对应niels 的odo buffer) d_buf.GetDeviceBuffer(),
randval_buf.GetDeviceBuffer(), randval_buf.GetDeviceBuffer(),
dq_buf.GetDeviceBuffer(), dq_buf.GetDeviceBuffer(),
dk_buf.GetDeviceBuffer(), dk_buf.GetDeviceBuffer(),
...@@ -492,7 +490,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -492,7 +490,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
stride_q, stride_q,
stride_k, stride_k,
stride_v, stride_v,
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias, bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead)
: stride_bias,
stride_o, stride_o,
stride_randval, stride_randval,
stride_do, stride_do,
...@@ -536,92 +535,20 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -536,92 +535,20 @@ bool run(const ck_tile::ArgParser& arg_parser)
{drop_seed, drop_offset}}; {drop_seed, drop_offset}};
}(); }();
float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config, 0); float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config);
if(ave_time < 0) if(ave_time < 0)
{ {
std::cout << ", not supported yet" << std::flush << std::endl; std::cout << ", not supported yet" << std::flush << std::endl;
return false; return false;
} }
int atm_f32 = 1; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
int skip_dq_rd = 1;
int mask_asm = 1;
int mask_kb_asm = 0;
int ts_qo = 32; float gb_per_sec = num_byte / 1.E6 / ave_time;
int ts_kv = 128;
int dump_result = 0; std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, "
auto fmha_ext_traits = fmha_bwd_ext_traits{batch, << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec
nhead, << " GB/s" << std::flush;
seqlen_q,
hdim_q,
atm_f32,
skip_dq_rd,
mask_asm,
mask_kb_asm,
ts_qo,
ts_kv,
dump_result};
int stride_tg = ts_kv * hdim_q * 2;
int stride_head = seqlen_q * hdim_q * 2;
int stride_batch = nhead * seqlen_q * hdim_q * 2;
float k_log2e = log2f(expf(1));
float k_scalar = sqrt(hdim_q);
k_scalar = static_cast<float>(1.0 / static_cast<double>(k_scalar));
#ifdef ASM_PRINT
// debug pointer
float *host_print, *print;
host_print = (float*)malloc(bdx * 8);
HIP_CALL(hipMalloc(&print, bdx * 8));
#endif
fmha_bwd_asm_args args;
args.ptr_dq = dq_acc_buf.GetDeviceBuffer(); // dev_dq;
// args.ptr_dq = dq_buf.GetDeviceBuffer(); // dev_dq;
args.ptr_dk = dk_buf.GetDeviceBuffer(); // dev_dk;
args.ptr_dv = dv_buf.GetDeviceBuffer(); // dev_dv;
args.ptr_q = q_buf.GetDeviceBuffer(); // dev_q;
args.ptr_k = k_buf.GetDeviceBuffer(); // dev_k;
args.ptr_v = v_buf.GetDeviceBuffer(); // dev_v;
args.ptr_do = do_buf.GetDeviceBuffer(); // dev_do;
args.ptr_lse = lse_buf.GetDeviceBuffer(); // dev_lse;
args.ptr_odo = d_buf.GetDeviceBuffer(); // dev_odo;
args.scalar = k_scalar;
args.log2e = k_log2e;
args.seq_len = seqlen_q;
args.Ts = stride_tg;
args.Hs = stride_head;
args.BAs = stride_batch;
#ifdef ASM_PRINT
args.print = (void*)print;
#endif
hipStream_t stream_ext = nullptr;
fmha_bwd_ext(fmha_ext_traits, args, stream_ext);
fmha_bwd(fmha_traits, fmha_args, stream_config, 1);
// if((atm_f32 == 1) || ((!skip_dq_rd) && (atm_f32 == 2)))
// HIP_CALL(
// hipMemcpy(host_fp32_dq, dev_dq, sz_mx_dq * sizeof(float), hipMemcpyDeviceToHost));
// else
// HIP_CALL(
// hipMemcpy(host_fp16_dq, dev_dq, sz_mx * sizeof(float) / 2, hipMemcpyDeviceToHost));
// ;
// HIP_CALL(hipMemcpy(host_fp16_dk, dev_dk, sz_mx * sizeof(float) / 2, hipMemcpyDeviceToHost));
// HIP_CALL(hipMemcpy(host_fp16_dv, dev_dv, sz_mx * sizeof(float) / 2, hipMemcpyDeviceToHost));
// float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
// float gb_per_sec = num_byte / 1.E6 / ave_time;
// std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, "
// << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) <<
// gb_per_sec
// << " GB/s" << std::flush;
if(!do_validation) if(!do_validation)
{ {
...@@ -845,12 +772,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -845,12 +772,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::stream_config stream_config_v{ ck_tile::stream_config stream_config_v{
nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")}; nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")};
// just fot odo buffer fmha_bwd(fmha_traits, fmha_args, stream_config_v);
fmha_bwd(fmha_traits, fmha_args, stream_config_v, 0);
fmha_bwd_ext(fmha_ext_traits, args, stream_ext);
fmha_bwd(fmha_traits, fmha_args, stream_config_v, 1);
dq_buf.FromDevice(dq_host.data()); dq_buf.FromDevice(dq_host.data());
dk_buf.FromDevice(dk_host.data()); dk_buf.FromDevice(dk_host.data());
...@@ -1017,6 +939,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1017,6 +939,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
return pass; return pass;
} }
...@@ -1031,10 +954,10 @@ int main(int argc, char* argv[]) ...@@ -1031,10 +954,10 @@ int main(int argc, char* argv[])
{ {
return run<ck_tile::half_t>(arg_parser) ? 0 : -2; return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
} }
// else if(data_type == "bf16") else if(data_type == "bf16")
// { {
// return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2; return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
// } }
return -3; return -3;
} }
...@@ -440,4 +440,4 @@ struct fmha_bwd_traits ...@@ -440,4 +440,4 @@ struct fmha_bwd_traits
bool is_deterministic; bool is_deterministic;
// TODO: padding check is inside this api // TODO: padding check is inside this api
}; };
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&, int flag); float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment