Commit 41631747 authored by Ye Wang's avatar Ye Wang
Browse files

reproducer for dbias in shape 1hss

parent 0c9012fb
......@@ -501,7 +501,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
continue
if receipt == 3:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'alibi']
cond &= bias in ['no', 'bias', 'alibi']
cond &= dpad == dvpad
cond &= deterministic == "f"
if not cond:
......@@ -801,4 +801,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im
_, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n")
\ No newline at end of file
f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n")
......@@ -295,7 +295,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v));
ck_tile::HostTensor<BiasDataType> bias_host(
bias.type == bias_enum::elementwise_bias
? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k)
? get_lengths(i_perm, 1, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<AccDataType> alibi_slope_host(
bias.type == bias_enum::alibi
......@@ -321,7 +321,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
ck_tile::HostTensor<BiasGradDataType> dbias_host(
use_dbias
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
? get_lengths(i_perm, 1, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<AccDataType> dq_acc_host(
i_perm
......@@ -448,7 +448,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_v = (i_perm ? hdim_v : nhead_k * hdim_v);
const ck_tile::index_t stride_bias = (max_seqlen_k);
const ck_tile::index_t stride_bias = i_perm? shape_seqlen_k:nhead*shape_seqlen_k;
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
const ck_tile::index_t stride_randval = (max_seqlen_k);
const ck_tile::index_t stride_do = (o_perm ? hdim_v : nhead * hdim_v);
......@@ -459,7 +459,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_v = (i_perm ? shape_seqlen_k * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_bias = 0;
const ck_tile::index_t nhead_stride_bias = i_perm? shape_seqlen_q * shape_seqlen_k: shape_seqlen_k;
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
......@@ -477,7 +477,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_lsed = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_dbias = 0;
const ck_tile::index_t split_stride_dq_acc =
(shape_batch * nhead * shape_seqlen_q * hdim_q);
......@@ -657,12 +657,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(bias.type == bias_enum::elementwise_bias)
{
// elementwise bias
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
ck_tile::HostTensor<BiasDataType> bias_host_ref({nhead, real_seqlen_q, real_seqlen_k});
// clang-format off
if(i_perm)
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); });
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[0], i[1] + query_offset, i[2]); });
else
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); });
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, i[0], i[2]); });
// clang-format on
// broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q,
......@@ -813,6 +813,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
dv_buf.FromDevice(dv_host.data());
dbias_buf.FromDevice(dbias_host.data());
ck_tile::HostTensor<BiasGradDataType> dbias_host_result(
{nhead, max_seqlen_q, max_seqlen_k}); // dbias_g_m_n
if(use_dbias){
if(i_perm) dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(0, idx[0], idx[1], idx[2]); });
else dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(0, idx[1], idx[0], idx[2]); });
}
ck_tile::HostTensor<BiasGradDataType> dbias_host_ref(
{nhead, max_seqlen_q, max_seqlen_k}); // dbias_g_m_n
dbias_host_ref.SetZero();
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
......@@ -830,8 +840,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
{nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n low precision
ck_tile::HostTensor<AccDataType> dp_hp_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // dp_g_m_n high precision
ck_tile::HostTensor<BiasGradDataType> dbias_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n
ck_tile::HostTensor<QGradDataType> dq_host_ref({nhead, real_seqlen_q, hdim_q}); // dq_g_m_k
ck_tile::HostTensor<KGradDataType> dk_host_ref({nhead, real_seqlen_k, hdim_q}); // dk_g_n_k
ck_tile::HostTensor<VGradDataType> dv_host_ref({nhead, real_seqlen_k, hdim_v}); // dv_g_n_o
......@@ -870,7 +878,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(use_dbias)
{
ds_hp_host_ref.ForEach([&](auto& self, auto idx) {
dbias_host_ref(idx) = ck_tile::type_convert<BiasGradDataType>(self(idx));
dbias_host_ref(idx) += ck_tile::type_convert<BiasGradDataType>(self(idx));
});
}
......@@ -912,8 +920,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
{nhead, real_seqlen_k, hdim_q}); // dk_g_n_k
ck_tile::HostTensor<VGradDataType> dv_host_result(
{nhead, real_seqlen_k, hdim_v}); // dv_g_n_o
ck_tile::HostTensor<BiasGradDataType> dbias_host_result(
{nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n
// clang-format off
// permute
......@@ -926,11 +932,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(i_perm) dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[0], idx[1] + key_offset, idx[2]); });
else dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[1] + key_offset, idx[0], idx[2]); });
if(use_dbias)
{
if(i_perm) dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[0], idx[1] + query_offset, idx[2]); });
else dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[1] + query_offset, idx[0], idx[2]); });
}
// clang-format on
auto [rtol, atol] = get_elimit<DataType>(hdim_q, hdim_v);
......@@ -950,17 +951,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
rtol,
atol);
bool dbias_cur_pass = true;
if(use_dbias)
{
dbias_cur_pass = ck_tile::check_err(dbias_host_result,
dbias_host_ref,
std::string("Error: BiasGrad Incorrect results!"),
rtol,
atol);
}
pass &= (dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass);
if(!(dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass))
pass &= (dq_cur_pass & dk_cur_pass & dv_cur_pass);
if(!(dq_cur_pass & dk_cur_pass & dv_cur_pass))
{
std::cerr << "mismatch found at batch: " << wb << std::endl
<< "\tseqlen_q: " << real_seqlen_q << std::endl
......@@ -971,6 +963,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
break;
}
}
bool dbias_cur_pass = true;
auto [rtol, atol] = get_elimit<DataType>(hdim_q, hdim_v);
if(use_dbias)
{
dbias_cur_pass = ck_tile::check_err(dbias_host_result,
dbias_host_ref,
std::string("Error: BiasGrad Incorrect results!"),
rtol,
atol);
}
pass &= dbias_cur_pass;
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment