"docs/source/vscode:/vscode.git/clone" did not exist on "a09fe140c1c059baf05c4f97e5b4e83c719608db"
Unverified Commit 96e7ee72 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #1740 from huggingface/fix-ctrl-past

Fix CTRL past
parents 3c28a2da 8da47b07
...@@ -63,7 +63,8 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N ...@@ -63,7 +63,8 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N
scaled_attention_logits = matmul_qk / np.sqrt(dk) scaled_attention_logits = matmul_qk / np.sqrt(dk)
if mask is not None: if mask is not None:
scaled_attention_logits += (mask * -1e4) nd, ns = scaled_attention_logits.size(-2), scaled_attention_logits.size(-1)
scaled_attention_logits += (mask[ns-nd:ns, :ns] * -1e4)
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask # Apply the attention mask
...@@ -373,7 +374,7 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -373,7 +374,7 @@ class CTRLModel(CTRLPreTrainedModel):
inputs_embeds = self.w(input_ids) inputs_embeds = self.w(input_ids)
# inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded # inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len = input_shape[-1] seq_len = input_shape[-1]
mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(inputs_embeds.device) mask = torch.triu(torch.ones(seq_len + past_length, seq_len + past_length), 1).to(inputs_embeds.device)
inputs_embeds *= np.sqrt(self.d_model_size) inputs_embeds *= np.sqrt(self.d_model_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment