Unverified Commit 1559c0df authored by ver217's avatar ver217 Committed by GitHub
Browse files

fix attn mask shape of gpt (#472)

parent 3cb3fc27
......@@ -292,7 +292,7 @@ class GPT(nn.Module):
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# Adapted from huggingface
if attention_mask is not None:
batch_size = x.shape[0]
batch_size = input_ids.shape[0]
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = col_nn.partition_batch(attention_mask)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment