Commit d4139c8b authored by Fang.Che's avatar Fang.Che
Browse files

add fmha asm api: fmha_bwd_ext

parent 933ac7c7
...@@ -58,7 +58,8 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd") ...@@ -58,7 +58,8 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
# not using add_example_executable() to add this target, since we don't want this to have # not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check" # to be included in "make all/install/check"
message("adding example ${EXAMPLE_FMHA_BWD}") message("adding example ${EXAMPLE_FMHA_BWD}")
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL fmha_bwd.cpp)
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL hsaco/bwd_arg.cpp hsaco/bwd_bf16_a32.cpp hsaco/bwd_bf16_causal_a32.cpp hsaco/bwd_bf16_nocoex_a32.cpp hsaco/bwd_bf16_nocoex_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_nocoex_a32.cpp hsaco/bwd_fp16_nocoex_causal_a32.cpp fmha_bwd_ext.cpp fmha_bwd.cpp)
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS}) target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS})
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_bwd.hpp" #include "fmha_bwd.hpp"
#include "fmha_bwd_ext.hpp"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "mask.hpp" #include "mask.hpp"
#include "utils.hpp" #include "utils.hpp"
...@@ -134,7 +135,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -134,7 +135,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
hdim_v = hdim_q; hdim_v = hdim_q;
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim
bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead* hdim
float scale = arg_parser.get_float("scale"); float scale = arg_parser.get_float("scale");
if(scale == .0f) if(scale == .0f)
...@@ -210,7 +211,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -210,7 +211,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using BiasGradDataType = typename TypeConfig::BiasGradDataType; using BiasGradDataType = typename TypeConfig::BiasGradDataType;
// accumulation numbers for performance evaluation // accumulation numbers for performance evaluation
std::size_t flop = 0, num_byte = 0; // std::size_t flop = 0, num_byte = 0;
auto max_seqlen_q = auto max_seqlen_q =
std::numeric_limits<int32_t>::min(); // we will use max seqlen to decide grid size std::numeric_limits<int32_t>::min(); // we will use max seqlen to decide grid size
auto max_seqlen_k = auto max_seqlen_k =
...@@ -231,20 +232,20 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -231,20 +232,20 @@ bool run(const ck_tile::ArgParser& arg_parser)
max_seqlen_k = real_seqlen_k; max_seqlen_k = real_seqlen_k;
} }
flop += nhead * (static_cast<std::size_t>(3) * static_cast<std::size_t>(2) * // flop += nhead * (static_cast<std::size_t>(3) * static_cast<std::size_t>(2) *
real_seqlen_q * real_seqlen_k * hdim_q + // Q@K/dS^T@Q^T/dS@K^T // real_seqlen_q * real_seqlen_k * hdim_q + // Q@K/dS^T@Q^T/dS@K^T
static_cast<std::size_t>(2) * static_cast<std::size_t>(2) * // static_cast<std::size_t>(2) * static_cast<std::size_t>(2) *
real_seqlen_q * real_seqlen_k * hdim_v); // dO@V/P^T@dO^T // real_seqlen_q * real_seqlen_k * hdim_v); // dO@V/P^T@dO^T
num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + // num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q +
sizeof(KDataType) * real_seqlen_k * hdim_q + // sizeof(KDataType) * real_seqlen_k * hdim_q +
sizeof(VDataType) * real_seqlen_k * hdim_v + // sizeof(VDataType) * real_seqlen_k * hdim_v +
sizeof(ODataType) * real_seqlen_q * hdim_v + // sizeof(ODataType) * real_seqlen_q * hdim_v +
sizeof(OGradDataType) * real_seqlen_q * hdim_v + // sizeof(OGradDataType) * real_seqlen_q * hdim_v +
sizeof(QGradDataType) * real_seqlen_q * hdim_q + // sizeof(QGradDataType) * real_seqlen_q * hdim_q +
sizeof(KGradDataType) * real_seqlen_k * hdim_q + // sizeof(KGradDataType) * real_seqlen_k * hdim_q +
sizeof(VGradDataType) * real_seqlen_k * hdim_v + // sizeof(VGradDataType) * real_seqlen_k * hdim_v +
sizeof(LSEDataType) * real_seqlen_q); // sizeof(LSEDataType) * real_seqlen_q);
} }
} }
...@@ -459,96 +460,168 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -459,96 +460,168 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t split_stride_dq_acc = const ck_tile::index_t split_stride_dq_acc =
(shape_batch * nhead * shape_seqlen_q * hdim_q); (shape_batch * nhead * shape_seqlen_q * hdim_q);
return fmha_bwd_args{q_buf.GetDeviceBuffer(), return fmha_bwd_args{
k_buf.GetDeviceBuffer(), q_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(),
bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() v_buf.GetDeviceBuffer(),
: bias_buf.GetDeviceBuffer(), bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
o_buf.GetDeviceBuffer(), : bias_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(),
do_buf.GetDeviceBuffer(), lse_buf.GetDeviceBuffer(),
d_buf.GetDeviceBuffer(), do_buf.GetDeviceBuffer(),
randval_buf.GetDeviceBuffer(), d_buf.GetDeviceBuffer(), // 需要使用dot_do_o kernel 生成d_buf(对应niels 的odo buffer)
dq_buf.GetDeviceBuffer(), randval_buf.GetDeviceBuffer(),
dk_buf.GetDeviceBuffer(), dq_buf.GetDeviceBuffer(),
dv_buf.GetDeviceBuffer(), dk_buf.GetDeviceBuffer(),
dbias_buf.GetDeviceBuffer(), dv_buf.GetDeviceBuffer(),
dq_acc_buf.GetDeviceBuffer(), dbias_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(), dq_acc_buf.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(),
nullptr, seqstart_k.GetDeviceBuffer(),
shape_seqlen_q, nullptr,
shape_seqlen_k, shape_seqlen_q,
batch, shape_seqlen_k,
max_seqlen_q, batch,
max_seqlen_k, max_seqlen_q,
hdim_q, max_seqlen_k,
hdim_v, hdim_q,
nhead, hdim_v,
nhead_k, nhead,
scale, nhead_k,
stride_q, scale,
stride_k, stride_q,
stride_v, stride_k,
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) stride_v,
: stride_bias, bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias,
stride_o, stride_o,
stride_randval, stride_randval,
stride_do, stride_do,
stride_q, // stride_dq_acc stride_q, // stride_dq_acc
stride_q, // stride_dq stride_q, // stride_dq
stride_dk, stride_dk,
stride_dv, stride_dv,
stride_dbias, stride_dbias,
nhead_stride_q, nhead_stride_q,
nhead_stride_k, nhead_stride_k,
nhead_stride_v, nhead_stride_v,
nhead_stride_bias, nhead_stride_bias,
nhead_stride_o, nhead_stride_o,
nhead_stride_randval, nhead_stride_randval,
nhead_stride_do, nhead_stride_do,
nhead_stride_lsed, nhead_stride_lsed,
nhead_stride_q, // nhead_stride_dq_acc nhead_stride_q, // nhead_stride_dq_acc
nhead_stride_q, // nhead_stride_dq nhead_stride_q, // nhead_stride_dq
nhead_stride_k, // nhead_stride_dk nhead_stride_k, // nhead_stride_dk
nhead_stride_v, // nhead_stride_dv nhead_stride_v, // nhead_stride_dv
nhead_stride_dbias, nhead_stride_dbias,
batch_stride_q, batch_stride_q,
batch_stride_k, batch_stride_k,
batch_stride_v, batch_stride_v,
batch_stride_bias, batch_stride_bias,
batch_stride_o, batch_stride_o,
batch_stride_randval, batch_stride_randval,
batch_stride_do, batch_stride_do,
batch_stride_lsed, batch_stride_lsed,
batch_stride_q, // batch_stride_dq_acc batch_stride_q, // batch_stride_dq_acc
batch_stride_q, // batch_stride_dq batch_stride_q, // batch_stride_dq
batch_stride_dk, batch_stride_dk,
batch_stride_dv, batch_stride_dv,
batch_stride_dbias, batch_stride_dbias,
split_stride_dq_acc, split_stride_dq_acc,
mask.left, mask.left,
mask.right, mask.right,
static_cast<ck_tile::index_t>(mask.type), static_cast<ck_tile::index_t>(mask.type),
p_drop, p_drop,
p_undrop, p_undrop,
{drop_seed, drop_offset}}; {drop_seed, drop_offset}};
}(); }();
float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config); float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config, 0);
if(ave_time < 0) if(ave_time < 0)
{ {
std::cout << ", not supported yet" << std::flush << std::endl; std::cout << ", not supported yet" << std::flush << std::endl;
return false; return false;
} }
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; int atm_f32 = 1;
int skip_dq_rd = 1;
float gb_per_sec = num_byte / 1.E6 / ave_time; int mask_asm = 1;
int mask_kb_asm = 0;
std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, "
<< std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec int ts_qo = 32;
<< " GB/s" << std::flush; int ts_kv = 128;
int dump_result = 0;
auto fmha_ext_traits = fmha_bwd_ext_traits{batch,
nhead,
seqlen_q,
hdim_q,
atm_f32,
skip_dq_rd,
mask_asm,
mask_kb_asm,
ts_qo,
ts_kv,
dump_result};
int stride_tg = ts_kv * hdim_q * 2;
int stride_head = seqlen_q * hdim_q * 2;
int stride_batch = nhead * seqlen_q * hdim_q * 2;
float k_log2e = log2f(expf(1));
float k_scalar = sqrt(hdim_q);
k_scalar = static_cast<float>(1.0 / static_cast<double>(k_scalar));
#ifdef ASM_PRINT
// debug pointer
float *host_print, *print;
host_print = (float*)malloc(bdx * 8);
HIP_CALL(hipMalloc(&print, bdx * 8));
#endif
fmha_bwd_asm_args args;
args.ptr_dq = dq_acc_buf.GetDeviceBuffer(); // dev_dq;
// args.ptr_dq = dq_buf.GetDeviceBuffer(); // dev_dq;
args.ptr_dk = dk_buf.GetDeviceBuffer(); // dev_dk;
args.ptr_dv = dv_buf.GetDeviceBuffer(); // dev_dv;
args.ptr_q = q_buf.GetDeviceBuffer(); // dev_q;
args.ptr_k = k_buf.GetDeviceBuffer(); // dev_k;
args.ptr_v = v_buf.GetDeviceBuffer(); // dev_v;
args.ptr_do = do_buf.GetDeviceBuffer(); // dev_do;
args.ptr_lse = lse_buf.GetDeviceBuffer(); // dev_lse;
args.ptr_odo = d_buf.GetDeviceBuffer(); // dev_odo;
args.scalar = k_scalar;
args.log2e = k_log2e;
args.seq_len = seqlen_q;
args.Ts = stride_tg;
args.Hs = stride_head;
args.BAs = stride_batch;
#ifdef ASM_PRINT
args.print = (void*)print;
#endif
hipStream_t stream_ext = nullptr;
fmha_bwd_ext(fmha_ext_traits, args, stream_ext);
fmha_bwd(fmha_traits, fmha_args, stream_config, 1);
// if((atm_f32 == 1) || ((!skip_dq_rd) && (atm_f32 == 2)))
// HIP_CALL(
// hipMemcpy(host_fp32_dq, dev_dq, sz_mx_dq * sizeof(float), hipMemcpyDeviceToHost));
// else
// HIP_CALL(
// hipMemcpy(host_fp16_dq, dev_dq, sz_mx * sizeof(float) / 2, hipMemcpyDeviceToHost));
// ;
// HIP_CALL(hipMemcpy(host_fp16_dk, dev_dk, sz_mx * sizeof(float) / 2, hipMemcpyDeviceToHost));
// HIP_CALL(hipMemcpy(host_fp16_dv, dev_dv, sz_mx * sizeof(float) / 2, hipMemcpyDeviceToHost));
// float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
// float gb_per_sec = num_byte / 1.E6 / ave_time;
// std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, "
// << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) <<
// gb_per_sec
// << " GB/s" << std::flush;
if(!do_validation) if(!do_validation)
{ {
...@@ -772,7 +845,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -772,7 +845,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::stream_config stream_config_v{ ck_tile::stream_config stream_config_v{
nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")}; nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")};
fmha_bwd(fmha_traits, fmha_args, stream_config_v); // just fot odo buffer
fmha_bwd(fmha_traits, fmha_args, stream_config_v, 0);
fmha_bwd_ext(fmha_ext_traits, args, stream_ext);
fmha_bwd(fmha_traits, fmha_args, stream_config_v, 1);
dq_buf.FromDevice(dq_host.data()); dq_buf.FromDevice(dq_host.data());
dk_buf.FromDevice(dk_host.data()); dk_buf.FromDevice(dk_host.data());
...@@ -939,7 +1017,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -939,7 +1017,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
return pass; return pass;
} }
...@@ -954,10 +1031,10 @@ int main(int argc, char* argv[]) ...@@ -954,10 +1031,10 @@ int main(int argc, char* argv[])
{ {
return run<ck_tile::half_t>(arg_parser) ? 0 : -2; return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
} }
else if(data_type == "bf16") // else if(data_type == "bf16")
{ // {
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2; // return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
} // }
return -3; return -3;
} }
...@@ -440,4 +440,4 @@ struct fmha_bwd_traits ...@@ -440,4 +440,4 @@ struct fmha_bwd_traits
bool is_deterministic; bool is_deterministic;
// TODO: padding check is inside this api // TODO: padding check is inside this api
}; };
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&); float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&, int flag);
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <stdio.h>
#include <vector>
#include <functional>
#include <iostream>
#include <string>
#include "fmha_bwd_ext.hpp"
#include "hsaco/fmha_hsaco.h"
int fmha_bwd_ext(fmha_bwd_ext_traits fmha_ext_traits, fmha_bwd_asm_args args, hipStream_t stream)
{
hipEvent_t evt_00, evt_11;
HIP_CALL(hipSetDevice(0));
// fmha_bwd_ext_kernel impl(HSACO, HSA_KERNEL);
fmha_bwd_ext_kernel impl(HSA_KERNEL, bwd_fp16_a32);
int b = fmha_ext_traits.b;
int h = fmha_ext_traits.h;
int s = fmha_ext_traits.s;
int d = fmha_ext_traits.d;
int atm_f32 = fmha_ext_traits.atm_f32;
int skip_dq_rd = fmha_ext_traits.skip_dq_rd;
// int ts_qo = fmha_ext_traits.ts_qo;
// int ts_kv = fmha_ext_traits.ts_kv;
int dump_result = fmha_ext_traits.dump_result;
std::cout << "b:" << b << std::endl;
std::cout << "h:" << h << std::endl;
std::cout << "s:" << s << std::endl;
std::cout << "d:" << d << std::endl;
std::cout << "dump_result:" << dump_result << std::endl;
std::cout << "atm_f32:" << atm_f32 << std::endl;
std::cout << "skip_dq_rd:" << skip_dq_rd << std::endl;
size_t arg_size = sizeof(args);
printf("argsize: %zu\n", arg_size);
#ifdef ASM_PRINT
int max_i = 256;
HIP_CALL(hipMemcpy(host_print, print, 8 * max_i, hipMemcpyDeviceToHost));
for(int i = 0; i < max_i; i++)
{
if(((uint32_t*)host_print)[2 * i + 1] != 0x5c005c00)
printf("Thread%d, PrintVal:0x%x\n",
((int*)host_print)[2 * i],
((uint32_t*)host_print)[2 * i + 1]);
// std::cout<<"Thread"<<((int*) host_print)[2*i]<<",
// PrintVal1:"<<(((float16*)host_print)[4*i+2])<<
//", PrintVal2:"<<( ( (float16*)host_print )[4*i+3] )<<std::endl;
}
#endif
HIP_CALL(hipEventCreate(&evt_00));
HIP_CALL(hipEventCreate(&evt_11));
HIP_CALL(hipDeviceSynchronize());
HIP_CALL(hipEventRecord(evt_00, NULL));
impl.launch_kernel(fmha_ext_traits, args, stream);
std::cout << "we are done" << std::endl;
float elapsed_ms;
HIP_CALL(hipEventRecord(evt_11, NULL));
HIP_CALL(hipEventSynchronize(evt_11));
HIP_CALL(hipDeviceSynchronize());
HIP_CALL(hipEventElapsedTime(&elapsed_ms, evt_00, evt_11));
HIP_CALL(hipEventDestroy(evt_00));
HIP_CALL(hipEventDestroy(evt_11));
// float time_per_loop = elapsed_ms / total_loop;
// float gflops = static_cast<float>(2.0) * 5 * b * h * d * s * s / time_per_loop / (1e6);
// printf("b:%d,h:%d,s:%d,d:%d, time: %.3f, gflops:%.3f\n", b, h, s, d, time_per_loop, gflops);
printf("b:%d,h:%d,s:%d,d:%d\n", b, h, s, d);
#ifdef ASM_PRINT
free(host_print);
HIP_CALL(hipFree(print));
#endif
// printf("CU:%d, TIPS:%.3f(2x:%.3f, 4x:%.3f), cost:%fms per loop\n", num_cu, tips, 2*tips,
// 4*tips, time_per_loop);
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#define HSACO "kernel.co"
#define HSA_KERNEL "kernel_func"
struct p3
{
unsigned int _p0;
unsigned int _p1;
unsigned int _p2;
};
struct p2
{
unsigned int _p0;
unsigned int _p1;
};
struct __attribute__((packed)) fmha_bwd_asm_args
{
void* ptr_dq;
p2 _p0;
void* ptr_dk;
p2 _p1;
void* ptr_dv;
p2 _p2;
void* ptr_q;
p2 _p3;
void* ptr_k;
p2 _p4;
void* ptr_v;
p2 _p5;
void* ptr_do;
p2 _p6;
void* ptr_lse;
p2 _p7;
void* ptr_odo;
p2 _p8;
float scalar;
p3 _p9;
float log2e;
p3 _p10;
unsigned int seq_len;
p3 _p11;
unsigned int Ts;
p3 _p12;
unsigned int Hs;
p3 _p13;
unsigned int BAs;
p3 _p14;
#ifdef ASM_PRINT
void* print;
#endif
};
#define HIP_CALL(call) \
do \
{ \
hipError_t err = call; \
if(err != hipSuccess) \
{ \
printf("[hiperror](%d) fail to call %s", static_cast<int>(err), #call); \
exit(0); \
} \
} while(0)
struct fmha_bwd_ext_traits
{
int b;
int h;
int s;
int d;
int atm_f32;
int skip_dq_rd;
int mask;
int mask_kb;
int ts_qo;
int ts_kv;
int dump_result;
};
class fmha_bwd_ext_kernel
{
public:
// fmha_bwd_ext_kernel(const char* image, const std::string& name)
// {
// HIP_CALL(hipModuleLoad(&module, image));
// HIP_CALL(hipModuleGetFunction(&kernel_func, module, name.c_str()));
// }
fmha_bwd_ext_kernel(const std::string& name, unsigned char buffer[])
{
HIP_CALL(hipModuleLoadData(&module, buffer));
HIP_CALL(hipModuleGetFunction(&kernel_func, module, name.c_str()));
}
void
launch_kernel(fmha_bwd_ext_traits fmha_ext_traits, fmha_bwd_asm_args args, hipStream_t stream)
{
size_t arg_size = sizeof(args);
void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER,
&args,
HIP_LAUNCH_PARAM_BUFFER_SIZE,
&arg_size,
HIP_LAUNCH_PARAM_END};
int bdx = 256;
int gdx = fmha_ext_traits.s / fmha_ext_traits.ts_kv;
int gdy = fmha_ext_traits.h;
int gdz = fmha_ext_traits.b;
if((fmha_ext_traits.mask == 1) && (fmha_ext_traits.mask_kb == 1))
{
int num_tg = fmha_ext_traits.s / fmha_ext_traits.ts_kv;
gdx = (num_tg % 2) ? (num_tg / 2 + 1) : (num_tg / 2);
}
HIP_CALL(hipModuleLaunchKernel(kernel_func,
gdx,
gdy,
gdz,
bdx,
1,
1,
0,
stream,
NULL,
reinterpret_cast<void**>(&config)));
}
private:
hipModule_t module;
hipFunction_t kernel_func;
};
int fmha_bwd_ext(fmha_bwd_ext_traits fmha_ext_traits, fmha_bwd_asm_args args, hipStream_t stream);
// int fmha_bwd_ext(fmha_bwd_ext_traits fmha_ext_traits);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment