Unverified Commit c965d302 authored by bofeng huang's avatar bofeng huang Committed by GitHub
Browse files

Fix comments for `_merge_heads` (#24855)

* Fix comments

* Fix comments
parent e4a52b6a
......@@ -253,7 +253,7 @@ class BloomAttention(nn.Module):
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
"""
Merge heads together over the last dimenstion
Merge heads together over the last dimension
Args:
x (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
......@@ -344,7 +344,7 @@ class BloomAttention(nn.Module):
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
# change view [batch_size, num_heads, q_length, head_dim]
# change view [batch_size, q_length, num_heads * head_dim]
context_layer = self._merge_heads(context_layer)
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
......
......@@ -255,7 +255,7 @@ class FalconAttention(nn.Module):
# Copied from transformers.models.bloom.modeling_bloom.BloomAttention._merge_heads
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
"""
Merge heads together over the last dimenstion
Merge heads together over the last dimension
Args:
x (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
......@@ -384,7 +384,7 @@ class FalconAttention(nn.Module):
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)
# change view [batch_size, num_heads, q_length, head_dim]
# change view [batch_size, q_length, num_heads * head_dim]
context_layer = self._merge_heads(context_layer)
output_tensor = self.dense(context_layer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment