Unverified Commit da27c4b3 authored by Brad Jascob's avatar Brad Jascob Committed by GitHub
Browse files

Update modeling_longt5.py (#17777)

On line 180, `torch.tensor(-1.0, xxx)` gives the error "TypeError: 'float' object cannot be interpreted as an integer" 
This is because the dtype here is `int64`.  For `dtype=int64`, this needs to simply be `-1`.  
This impacts the long-t5-tglogbal-x model.  It does not impact the long-t5-local-x version which does not appear to call this line.
parent d3cb2888
......@@ -177,7 +177,7 @@ def _make_global_fixed_block_ids(
fixed_block_mask = torch.cumsum(fixed_block_mask, axis=1) - fixed_block_mask
mask = torch.where(attention_mask != 0.0, 1.0, -1000.0).type(attention_mask.dtype)
global_block_ids = torch.floor(mask + fixed_block_mask - 1.0).type(attention_mask.dtype)
_global_block_ids_lower_bound = torch.tensor(-1.0, dtype=global_block_ids.dtype, device=global_block_ids.device)
_global_block_ids_lower_bound = torch.tensor(-1, dtype=global_block_ids.dtype, device=global_block_ids.device)
global_block_ids = torch.where(
global_block_ids > _global_block_ids_lower_bound, global_block_ids, _global_block_ids_lower_bound
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment