Unverified Commit 1da802bd authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Fix FA bwd alibi+causal NaN errors (#1352)

* fix bwd alibi nan error

* fix datatype

---------

Co-authored-by: danyao12 <danyao12>
parent 0162a5f6
...@@ -372,7 +372,7 @@ struct SimplifiedGenericAttentionMask ...@@ -372,7 +372,7 @@ struct SimplifiedGenericAttentionMask
// index_t x_end = min(i_y + x, x_total); // index_t x_end = min(i_y + x, x_total);
bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad
bool bottom_left_edge = i_y_end > (i_x + y); bool bottom_left_edge = i_y_end > min(i_x + y, y_total); // consider bottom pad
// bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now // bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now
return top_right_edge || bottom_left_edge; return top_right_edge || bottom_left_edge;
......
...@@ -501,9 +501,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -501,9 +501,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTLdsBlockDescriptor()
{ {
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>; using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QGradDataType); constexpr index_t PixelsPerRow = Banks * 4 / sizeof(OGradDataType);
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>(); constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kKPerBlock = [&]() { constexpr index_t kKPerBlock = [&]() {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment