"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "299ac0a300fe029efe2468e552132ad87e50c033"
Unverified Commit 6b242812 authored by Prajjwal Bhargava's avatar Prajjwal Bhargava Committed by GitHub
Browse files

fix typo in comments (#6838)

parent 7351ef83
...@@ -803,8 +803,8 @@ class BertModel(BertPreTrainedModel): ...@@ -803,8 +803,8 @@ class BertModel(BertPreTrainedModel):
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
# If a 2D ou 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None: if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment