Unverified Commit 5f0d07b3 authored by Xu Zhao's avatar Xu Zhao Committed by GitHub
Browse files

Make BigBird model compatiable to fp16 dtype. (#16034)

* Make BigBird model compatiable to fp16 dtype.

* Use tree_map instead of map

* Reformat the code

* Fix import order

* Convert masks to the correct dtype

* Fix format issue

* Address comments.
parent 1cf28da6
...@@ -1378,6 +1378,13 @@ class BigBirdAttention(nn.Module): ...@@ -1378,6 +1378,13 @@ class BigBirdAttention(nn.Module):
from_blocked_mask=None, from_blocked_mask=None,
to_blocked_mask=None, to_blocked_mask=None,
): ):
# fp16 compatibility
if band_mask is not None:
band_mask = band_mask.to(hidden_states.dtype)
if from_mask is not None:
from_mask = from_mask.to(hidden_states.dtype)
if to_mask is not None:
to_mask = to_mask.to(hidden_states.dtype)
if self.attention_type == "original_full": if self.attention_type == "original_full":
self_outputs = self.self( self_outputs = self.self(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment