Commit 41631747 authored by Ye Wang's avatar Ye Wang
Browse files

reproducer for dbias in shape 1hss

parent 0c9012fb
...@@ -501,7 +501,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -501,7 +501,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
continue continue
if receipt == 3: if receipt == 3:
cond = dtype in ['fp16', 'bf16'] cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'alibi'] cond &= bias in ['no', 'bias', 'alibi']
cond &= dpad == dvpad cond &= dpad == dvpad
cond &= deterministic == "f" cond &= deterministic == "f"
if not cond: if not cond:
......
...@@ -295,7 +295,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -295,7 +295,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)); get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v));
ck_tile::HostTensor<BiasDataType> bias_host( ck_tile::HostTensor<BiasDataType> bias_host(
bias.type == bias_enum::elementwise_bias bias.type == bias_enum::elementwise_bias
? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k) ? get_lengths(i_perm, 1, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */); : std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<AccDataType> alibi_slope_host( ck_tile::HostTensor<AccDataType> alibi_slope_host(
bias.type == bias_enum::alibi bias.type == bias_enum::alibi
...@@ -321,7 +321,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -321,7 +321,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
ck_tile::HostTensor<BiasGradDataType> dbias_host( ck_tile::HostTensor<BiasGradDataType> dbias_host(
use_dbias use_dbias
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) ? get_lengths(i_perm, 1, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */); : std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<AccDataType> dq_acc_host( ck_tile::HostTensor<AccDataType> dq_acc_host(
i_perm i_perm
...@@ -448,7 +448,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -448,7 +448,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_v = (i_perm ? hdim_v : nhead_k * hdim_v); const ck_tile::index_t stride_v = (i_perm ? hdim_v : nhead_k * hdim_v);
const ck_tile::index_t stride_bias = (max_seqlen_k); const ck_tile::index_t stride_bias = i_perm? shape_seqlen_k:nhead*shape_seqlen_k;
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
const ck_tile::index_t stride_randval = (max_seqlen_k); const ck_tile::index_t stride_randval = (max_seqlen_k);
const ck_tile::index_t stride_do = (o_perm ? hdim_v : nhead * hdim_v); const ck_tile::index_t stride_do = (o_perm ? hdim_v : nhead * hdim_v);
...@@ -459,7 +459,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -459,7 +459,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_v = (i_perm ? shape_seqlen_k * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_v = (i_perm ? shape_seqlen_k * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_bias = 0; const ck_tile::index_t nhead_stride_bias = i_perm? shape_seqlen_q * shape_seqlen_k: shape_seqlen_k;
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
...@@ -477,7 +477,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -477,7 +477,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_lsed = (nhead * shape_seqlen_q); const ck_tile::index_t batch_stride_lsed = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q); const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v); const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_dbias = 0;
const ck_tile::index_t split_stride_dq_acc = const ck_tile::index_t split_stride_dq_acc =
(shape_batch * nhead * shape_seqlen_q * hdim_q); (shape_batch * nhead * shape_seqlen_q * hdim_q);
...@@ -657,12 +657,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -657,12 +657,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(bias.type == bias_enum::elementwise_bias) if(bias.type == bias_enum::elementwise_bias)
{ {
// elementwise bias // elementwise bias
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k}); ck_tile::HostTensor<BiasDataType> bias_host_ref({nhead, real_seqlen_q, real_seqlen_k});
// clang-format off // clang-format off
if(i_perm) if(i_perm)
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); }); bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[0], i[1] + query_offset, i[2]); });
else else
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); }); bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, i[0], i[2]); });
// clang-format on // clang-format on
// broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q,
...@@ -813,6 +813,16 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -813,6 +813,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
dv_buf.FromDevice(dv_host.data()); dv_buf.FromDevice(dv_host.data());
dbias_buf.FromDevice(dbias_host.data()); dbias_buf.FromDevice(dbias_host.data());
ck_tile::HostTensor<BiasGradDataType> dbias_host_result(
{nhead, max_seqlen_q, max_seqlen_k}); // dbias_g_m_n
if(use_dbias){
if(i_perm) dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(0, idx[0], idx[1], idx[2]); });
else dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(0, idx[1], idx[0], idx[2]); });
}
ck_tile::HostTensor<BiasGradDataType> dbias_host_ref(
{nhead, max_seqlen_q, max_seqlen_k}); // dbias_g_m_n
dbias_host_ref.SetZero();
for(ck_tile::index_t wb = 0; wb < batch; ++wb) for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{ {
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
...@@ -830,8 +840,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -830,8 +840,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
{nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n low precision {nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n low precision
ck_tile::HostTensor<AccDataType> dp_hp_host_ref( ck_tile::HostTensor<AccDataType> dp_hp_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // dp_g_m_n high precision {nhead, real_seqlen_q, real_seqlen_k}); // dp_g_m_n high precision
ck_tile::HostTensor<BiasGradDataType> dbias_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n
ck_tile::HostTensor<QGradDataType> dq_host_ref({nhead, real_seqlen_q, hdim_q}); // dq_g_m_k ck_tile::HostTensor<QGradDataType> dq_host_ref({nhead, real_seqlen_q, hdim_q}); // dq_g_m_k
ck_tile::HostTensor<KGradDataType> dk_host_ref({nhead, real_seqlen_k, hdim_q}); // dk_g_n_k ck_tile::HostTensor<KGradDataType> dk_host_ref({nhead, real_seqlen_k, hdim_q}); // dk_g_n_k
ck_tile::HostTensor<VGradDataType> dv_host_ref({nhead, real_seqlen_k, hdim_v}); // dv_g_n_o ck_tile::HostTensor<VGradDataType> dv_host_ref({nhead, real_seqlen_k, hdim_v}); // dv_g_n_o
...@@ -870,7 +878,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -870,7 +878,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(use_dbias) if(use_dbias)
{ {
ds_hp_host_ref.ForEach([&](auto& self, auto idx) { ds_hp_host_ref.ForEach([&](auto& self, auto idx) {
dbias_host_ref(idx) = ck_tile::type_convert<BiasGradDataType>(self(idx)); dbias_host_ref(idx) += ck_tile::type_convert<BiasGradDataType>(self(idx));
}); });
} }
...@@ -912,8 +920,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -912,8 +920,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
{nhead, real_seqlen_k, hdim_q}); // dk_g_n_k {nhead, real_seqlen_k, hdim_q}); // dk_g_n_k
ck_tile::HostTensor<VGradDataType> dv_host_result( ck_tile::HostTensor<VGradDataType> dv_host_result(
{nhead, real_seqlen_k, hdim_v}); // dv_g_n_o {nhead, real_seqlen_k, hdim_v}); // dv_g_n_o
ck_tile::HostTensor<BiasGradDataType> dbias_host_result(
{nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n
// clang-format off // clang-format off
// permute // permute
...@@ -926,11 +932,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -926,11 +932,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(i_perm) dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[0], idx[1] + key_offset, idx[2]); }); if(i_perm) dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[0], idx[1] + key_offset, idx[2]); });
else dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[1] + key_offset, idx[0], idx[2]); }); else dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[1] + key_offset, idx[0], idx[2]); });
if(use_dbias)
{
if(i_perm) dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[0], idx[1] + query_offset, idx[2]); });
else dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[1] + query_offset, idx[0], idx[2]); });
}
// clang-format on // clang-format on
auto [rtol, atol] = get_elimit<DataType>(hdim_q, hdim_v); auto [rtol, atol] = get_elimit<DataType>(hdim_q, hdim_v);
...@@ -950,17 +951,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -950,17 +951,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
rtol, rtol,
atol); atol);
bool dbias_cur_pass = true; pass &= (dq_cur_pass & dk_cur_pass & dv_cur_pass);
if(use_dbias) if(!(dq_cur_pass & dk_cur_pass & dv_cur_pass))
{
dbias_cur_pass = ck_tile::check_err(dbias_host_result,
dbias_host_ref,
std::string("Error: BiasGrad Incorrect results!"),
rtol,
atol);
}
pass &= (dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass);
if(!(dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass))
{ {
std::cerr << "mismatch found at batch: " << wb << std::endl std::cerr << "mismatch found at batch: " << wb << std::endl
<< "\tseqlen_q: " << real_seqlen_q << std::endl << "\tseqlen_q: " << real_seqlen_q << std::endl
...@@ -971,6 +963,18 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -971,6 +963,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
break; break;
} }
} }
bool dbias_cur_pass = true;
auto [rtol, atol] = get_elimit<DataType>(hdim_q, hdim_v);
if(use_dbias)
{
dbias_cur_pass = ck_tile::check_err(dbias_host_result,
dbias_host_ref,
std::string("Error: BiasGrad Incorrect results!"),
rtol,
atol);
}
pass &= dbias_cur_pass;
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment