Commit 0f091062 authored by thomwolf's avatar thomwolf
Browse files

Merge branch 'glue-example' into tf2

parents c4acc3a8 e4022d96
...@@ -321,9 +321,17 @@ class RelPartialLearnableMultiHeadAttn(nn.Module): ...@@ -321,9 +321,17 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
if attn_mask is not None and torch.sum(attn_mask).item(): if attn_mask is not None and torch.sum(attn_mask).item():
attn_mask = (attn_mask == 1) # Switch to bool attn_mask = (attn_mask == 1) # Switch to bool
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
if next(self.parameters()).dtype == torch.float16:
attn_score = attn_score.float().masked_fill(
attn_mask[None,:,:,None], -65000).type_as(attn_score)
else:
attn_score = attn_score.float().masked_fill( attn_score = attn_score.float().masked_fill(
attn_mask[None,:,:,None], -1e30).type_as(attn_score) attn_mask[None,:,:,None], -1e30).type_as(attn_score)
elif attn_mask.dim() == 3: elif attn_mask.dim() == 3:
if next(self.parameters()).dtype == torch.float16:
attn_score = attn_score.float().masked_fill(
attn_mask[:,:,:,None], -65000).type_as(attn_score)
else:
attn_score = attn_score.float().masked_fill( attn_score = attn_score.float().masked_fill(
attn_mask[:,:,:,None], -1e30).type_as(attn_score) attn_mask[:,:,:,None], -1e30).type_as(attn_score)
...@@ -547,7 +555,7 @@ TRANSFO_XL_INPUTS_DOCSTRING = r""" ...@@ -547,7 +555,7 @@ TRANSFO_XL_INPUTS_DOCSTRING = r"""
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
""" """
@add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.", @add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
TRANSFO_XL_START_DOCSTRING, TRANSFO_XL_INPUTS_DOCSTRING) TRANSFO_XL_START_DOCSTRING, TRANSFO_XL_INPUTS_DOCSTRING)
class TransfoXLModel(TransfoXLPreTrainedModel): class TransfoXLModel(TransfoXLPreTrainedModel):
r""" r"""
......
...@@ -457,6 +457,9 @@ class PoolerStartLogits(nn.Module): ...@@ -457,6 +457,9 @@ class PoolerStartLogits(nn.Module):
x = self.dense(hidden_states).squeeze(-1) x = self.dense(hidden_states).squeeze(-1)
if p_mask is not None: if p_mask is not None:
if next(self.parameters()).dtype == torch.float16:
x = x * (1 - p_mask) - 65500 * p_mask
else:
x = x * (1 - p_mask) - 1e30 * p_mask x = x * (1 - p_mask) - 1e30 * p_mask
return x return x
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment