Commit 54282071 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'vijay/diff_query_key_lengths' into 'main'

support for different query key sequence lengths

See merge request ADLR/megatron-lm!151
parents 56e16cba d763f54a
......@@ -191,28 +191,28 @@ class ParallelSelfAttention(MegatronModule):
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
# hidden_states: [s, b, h]
# hidden_states: [sq, b, h]
# =====================
# Query, Key, and Value
# =====================
# Attention heads [s, b, hp] --> [s, b, hp * 3]
# Attention heads [sq, b, hp] --> [sq, b, hp * 3]
mixed_x_layer, _ = self.query_key_value(hidden_states)
checkpoint_version = get_checkpoint_version()
if checkpoint_version is not None and \
checkpoint_version == 0:
# [s, b, 3 * hp] --> [s, b, hp * 3]
# [sq, b, 3 * hp] --> [sq, b, hp * 3]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer)
# [s, b, hp * 3] --> [s, b, np, hn, 3]
# [sq, b, hp * 3] --> [sq, b, np, hn, 3]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, 3)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [s, b, np, hn, 3] --> 3 [s, b, np, hn]
# [sq, b, np, hn, 3] --> 3 [sq, b, np, hn]
query_layer = mixed_x_layer[:,:,:,:,0]
key_layer = mixed_x_layer[:,:,:,:,1]
value_layer = mixed_x_layer[:,:,:,:,2]
......@@ -235,19 +235,19 @@ class ParallelSelfAttention(MegatronModule):
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, s, s]
# [b, np, sq, sk]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))
# [s, b, np, hn] -> [s, b * np, hn]
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, s, s]
# preallocting result tensor: [b * np, sq, sk]
matmul_result = torch.empty(
output_size[0]*output_size[1],
output_size[2],
......@@ -255,18 +255,18 @@ class ParallelSelfAttention(MegatronModule):
dtype=query_layer.dtype,
device=torch.cuda.current_device())
# Raw attention scores. [b * np, s, s]
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(matmul_result,
query_layer.transpose(0, 1), # [b * np, s, hn]
key_layer.transpose(0,1).transpose(1, 2), #[b * np, hn, s]
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0,1).transpose(1, 2), #[b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
# change view to [b, np, s, s]
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ==================================================
# Update attention mask for inference. [b, np, s, s]
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================
if get_key_value:
......@@ -287,7 +287,7 @@ class ParallelSelfAttention(MegatronModule):
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, s, s]
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores,
attention_mask)
......@@ -298,43 +298,43 @@ class ParallelSelfAttention(MegatronModule):
# =========================
# Context layer. [s, b, hp]
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [s, b, np, hn] --> [b, np, s, hn]
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, s, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
value_layer.size(0),
query_layer.size(0),
value_layer.size(3))
# change view [s, b * np, hn]
value_layer = value_layer.view(output_size[2],
# change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0),
output_size[0] * output_size[1], -1)
# change view [b * np, s, s]
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)
# matmul: [b * np, s, hn]
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0,1))
# change view [b, np, s, hn]
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, s, hn] --> [s, b, np, hn]
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [s, b, np, hn] --> [s, b, hp]
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
# =================
# Output. [s, b, h]
# Output. [sq, b, h]
# =================
output, bias = self.dense(context_layer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment