Unverified Commit 0f5f5dd5 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

fixed gpt attention mask in pipeline (#430)

parent f9c762df
...@@ -51,18 +51,6 @@ class GPTEmbedding(nn.Module): ...@@ -51,18 +51,6 @@ class GPTEmbedding(nn.Module):
x = x + self.tokentype_embeddings(tokentype_ids) x = x + self.tokentype_embeddings(tokentype_ids)
x = self.dropout(x) x = self.dropout(x)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# Adapted from huggingface
if attention_mask is not None:
batch_size = input_ids.shape[0]
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = col_nn.partition_batch(attention_mask)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = attention_mask.to(dtype=x.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
return x, attention_mask return x, attention_mask
...@@ -355,6 +343,21 @@ class PipelineGPT(nn.Module): ...@@ -355,6 +343,21 @@ class PipelineGPT(nn.Module):
if self.first: if self.first:
x, attention_mask = self.embed(input_ids, attention_mask) x, attention_mask = self.embed(input_ids, attention_mask)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# Adapted from huggingface
if attention_mask is not None:
if self.first:
batch_size = input_ids.shape[0]
else:
batch_size = x.shape[0]
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = col_nn.partition_batch(attention_mask)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = attention_mask.to(dtype=x.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
for block in self.blocks: for block in self.blocks:
x, attention_mask = block(x, attention_mask) x, attention_mask = block(x, attention_mask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment