Unverified Commit b62abe87 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #1249 from ziliwang/master

fixed: hard coding for max and min number will out of range in fp16, which will cause nan.
parents 11ac4b95 8bdee1cb
...@@ -231,7 +231,7 @@ class PositionwiseFF(nn.Module): ...@@ -231,7 +231,7 @@ class PositionwiseFF(nn.Module):
class MultiHeadAttn(nn.Module): class MultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
pre_lnorm=False, r_r_bias=None, r_w_bias=None, output_attentions=False): pre_lnorm=False, r_r_bias=None, r_w_bias=None, output_attentions=False):
super(MultiHeadAttn, self).__init__() super(MultiHeadAttn, self).__init__()
...@@ -451,11 +451,19 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn): ...@@ -451,11 +451,19 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
if attn_mask is not None and torch.sum(attn_mask).item(): if attn_mask is not None and torch.sum(attn_mask).item():
attn_mask = (attn_mask == 1) # Switch to bool attn_mask = (attn_mask == 1) # Switch to bool
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
attn_score = attn_score.float().masked_fill( if next(self.parameters()).dtype == torch.float16:
attn_mask[None,:,:,None], -1e30).type_as(attn_score) attn_score = attn_score.float().masked_fill(
attn_mask[None,:,:,None], -65000).type_as(attn_score)
else:
attn_score = attn_score.float().masked_fill(
attn_mask[None,:,:,None], -1e30).type_as(attn_score)
elif attn_mask.dim() == 3: elif attn_mask.dim() == 3:
attn_score = attn_score.float().masked_fill( if next(self.parameters()).dtype == torch.float16:
attn_mask[:,:,:,None], -1e30).type_as(attn_score) attn_score = attn_score.float().masked_fill(
attn_mask[:,:,:,None], -65000).type_as(attn_score)
else:
attn_score = attn_score.float().masked_fill(
attn_mask[:,:,:,None], -1e30).type_as(attn_score)
# [qlen x klen x bsz x n_head] # [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1) attn_prob = F.softmax(attn_score, dim=1)
...@@ -587,7 +595,7 @@ class DecoderLayer(nn.Module): ...@@ -587,7 +595,7 @@ class DecoderLayer(nn.Module):
super(DecoderLayer, self).__init__() super(DecoderLayer, self).__init__()
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None, head_mask=None): def forward(self, dec_inp, dec_attn_mask=None, mems=None, head_mask=None):
...@@ -607,7 +615,7 @@ class RelLearnableDecoderLayer(nn.Module): ...@@ -607,7 +615,7 @@ class RelLearnableDecoderLayer(nn.Module):
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs) **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None, head_mask=None): def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None, head_mask=None):
...@@ -628,7 +636,7 @@ class RelPartialLearnableDecoderLayer(nn.Module): ...@@ -628,7 +636,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs) d_head, dropout, **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None): def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None):
...@@ -645,7 +653,7 @@ class RelPartialLearnableDecoderLayer(nn.Module): ...@@ -645,7 +653,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
class AdaptiveEmbedding(nn.Module): class AdaptiveEmbedding(nn.Module):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
sample_softmax=False): sample_softmax=False):
super(AdaptiveEmbedding, self).__init__() super(AdaptiveEmbedding, self).__init__()
...@@ -683,7 +691,7 @@ class AdaptiveEmbedding(nn.Module): ...@@ -683,7 +691,7 @@ class AdaptiveEmbedding(nn.Module):
else: else:
param = next(self.parameters()) param = next(self.parameters())
inp_flat = inp.view(-1) inp_flat = inp.view(-1)
emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
dtype=param.dtype, device=param.device) dtype=param.dtype, device=param.device)
for i in range(len(self.cutoffs)): for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
...@@ -852,7 +860,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -852,7 +860,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self.n_head = config.n_head self.n_head = config.n_head
self.d_head = config.d_head self.d_head = config.d_head
self.word_emb = AdaptiveEmbedding(config.n_token, config.d_embed, config.d_model, config.cutoffs, self.word_emb = AdaptiveEmbedding(config.n_token, config.d_embed, config.d_model, config.cutoffs,
div_val=config.div_val) div_val=config.div_val)
self.drop = nn.Dropout(config.dropout) self.drop = nn.Dropout(config.dropout)
...@@ -1011,7 +1019,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -1011,7 +1019,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
hids = [] hids = []
attentions = [] attentions = []
if self.attn_type == 0: # default if self.attn_type == 0: # default
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device, pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype) dtype=word_emb.dtype)
if self.clamp_len > 0: if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len) pos_seq.clamp_(max=self.clamp_len)
...@@ -1165,7 +1173,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -1165,7 +1173,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
self.sampler = LogUniformSampler(config.n_token, config.sample_softmax) self.sampler = LogUniformSampler(config.n_token, config.sample_softmax)
# use adaptive softmax (including standard softmax) # use adaptive softmax (including standard softmax)
else: else:
self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model, self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model,
config.cutoffs, div_val=config.div_val) config.cutoffs, div_val=config.div_val)
self.init_weights() self.init_weights()
self.tie_weights() self.tie_weights()
......
...@@ -140,7 +140,7 @@ class PreTrainedModel(nn.Module): ...@@ -140,7 +140,7 @@ class PreTrainedModel(nn.Module):
Arguments: Arguments:
new_num_tokens: (`optional`) int: new_num_tokens: (`optional`) int:
New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model. If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
Return: ``torch.nn.Embeddings`` Return: ``torch.nn.Embeddings``
...@@ -434,7 +434,10 @@ class PoolerStartLogits(nn.Module): ...@@ -434,7 +434,10 @@ class PoolerStartLogits(nn.Module):
x = self.dense(hidden_states).squeeze(-1) x = self.dense(hidden_states).squeeze(-1)
if p_mask is not None: if p_mask is not None:
x = x * (1 - p_mask) - 1e30 * p_mask if next(self.parameters()).dtype == torch.float16:
x = x * (1 - p_mask) - 65500 * p_mask
else:
x = x * (1 - p_mask) - 1e30 * p_mask
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment