Commit 469509ba authored by Ye Wang's avatar Ye Wang
Browse files

debug 11HS dbias for TE

parent 8ef8a994
......@@ -501,7 +501,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
continue
if receipt == 3:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'alibi']
cond &= bias in ['no', 'bias', 'alibi']
cond &= dpad == dvpad
cond &= deterministic == "f"
if not cond:
......@@ -801,4 +801,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im
_, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n")
\ No newline at end of file
f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n")
......@@ -321,7 +321,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
ck_tile::HostTensor<BiasGradDataType> dbias_host(
use_dbias
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<AccDataType> dq_acc_host(
i_perm
......@@ -454,7 +454,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t stride_do = (o_perm ? hdim_v : nhead * hdim_v);
const ck_tile::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v);
const ck_tile::index_t stride_dbias = (i_perm ? max_seqlen_k : nhead * max_seqlen_k);
const ck_tile::index_t stride_dbias = (max_seqlen_k);
// setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
......@@ -464,8 +464,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q;
const ck_tile::index_t nhead_stride_dbias =
(i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k);
const ck_tile::index_t nhead_stride_dbias = 0;
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
......@@ -477,7 +476,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_lsed = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_dbias = 0;
const ck_tile::index_t split_stride_dq_acc =
(shape_batch * nhead * shape_seqlen_q * hdim_q);
......@@ -813,6 +812,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
dv_buf.FromDevice(dv_host.data());
dbias_buf.FromDevice(dbias_host.data());
ck_tile::HostTensor<BiasGradDataType> dbias_host_result(
{1, 1, max_seqlen_q, max_seqlen_k}); // dbias_g_m_n
if(use_dbias){
if(i_perm) dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(idx[0], idx[1], idx[2], idx[3]); });
else dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(idx[0], idx[2], idx[1], idx[3]); });
}
ck_tile::HostTensor<BiasGradDataType> dbias_host_ref(
{1, 1, max_seqlen_q, max_seqlen_k}); // dbias_g_m_n
dbias_host_ref.SetZero();
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
......@@ -830,8 +840,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
{nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n low precision
ck_tile::HostTensor<AccDataType> dp_hp_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // dp_g_m_n high precision
ck_tile::HostTensor<BiasGradDataType> dbias_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n
ck_tile::HostTensor<QGradDataType> dq_host_ref({nhead, real_seqlen_q, hdim_q}); // dq_g_m_k
ck_tile::HostTensor<KGradDataType> dk_host_ref({nhead, real_seqlen_k, hdim_q}); // dk_g_n_k
ck_tile::HostTensor<VGradDataType> dv_host_ref({nhead, real_seqlen_k, hdim_v}); // dv_g_n_o
......@@ -870,7 +878,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(use_dbias)
{
ds_hp_host_ref.ForEach([&](auto& self, auto idx) {
dbias_host_ref(idx) = ck_tile::type_convert<BiasGradDataType>(self(idx));
dbias_host_ref(0, 0, idx[1], idx[2]) += ck_tile::type_convert<BiasGradDataType>(self(idx));
});
}
......@@ -912,8 +920,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
{nhead, real_seqlen_k, hdim_q}); // dk_g_n_k
ck_tile::HostTensor<VGradDataType> dv_host_result(
{nhead, real_seqlen_k, hdim_v}); // dv_g_n_o
ck_tile::HostTensor<BiasGradDataType> dbias_host_result(
{nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n
// clang-format off
// permute
......@@ -926,11 +932,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(i_perm) dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[0], idx[1] + key_offset, idx[2]); });
else dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[1] + key_offset, idx[0], idx[2]); });
if(use_dbias)
{
if(i_perm) dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[0], idx[1] + query_offset, idx[2]); });
else dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[1] + query_offset, idx[0], idx[2]); });
}
// clang-format on
auto [rtol, atol] = get_elimit<DataType>(hdim_q, hdim_v);
......@@ -950,17 +951,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
rtol,
atol);
bool dbias_cur_pass = true;
if(use_dbias)
{
dbias_cur_pass = ck_tile::check_err(dbias_host_result,
dbias_host_ref,
std::string("Error: BiasGrad Incorrect results!"),
rtol,
atol);
}
pass &= (dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass);
if(!(dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass))
pass &= (dq_cur_pass & dk_cur_pass & dv_cur_pass);
if(!(dq_cur_pass & dk_cur_pass & dv_cur_pass))
{
std::cerr << "mismatch found at batch: " << wb << std::endl
<< "\tseqlen_q: " << real_seqlen_q << std::endl
......@@ -972,6 +964,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
bool dbias_cur_pass = true;
auto [rtol, atol] = get_elimit<DataType>(hdim_q, hdim_v);
if(use_dbias){
dbias_cur_pass = ck_tile::check_err(dbias_host_result,
dbias_host_ref,
std::string("Error: BiasGrad Incorrect results!"),
rtol,
atol);
}
pass &= dbias_cur_pass;
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
return pass;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment