Unverified Commit 1da802bd authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Fix FA bwd alibi+causal NaN errors (#1352)

* fix bwd alibi nan error

* fix datatype

---------

Co-authored-by: danyao12 <danyao12>
parent 0162a5f6
......@@ -372,7 +372,7 @@ struct SimplifiedGenericAttentionMask
// index_t x_end = min(i_y + x, x_total);
bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad
bool bottom_left_edge = i_y_end > (i_x + y);
bool bottom_left_edge = i_y_end > min(i_x + y, y_total); // consider bottom pad
// bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now
return top_right_edge || bottom_left_edge;
......
......@@ -501,9 +501,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTLdsBlockDescriptor()
{
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QGradDataType);
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(OGradDataType);
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kKPerBlock = [&]() {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment