Commit d763f54a authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

support for different query key sequence lengths

parent 56e16cba
...@@ -191,28 +191,28 @@ class ParallelSelfAttention(MegatronModule): ...@@ -191,28 +191,28 @@ class ParallelSelfAttention(MegatronModule):
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False): get_key_value=False):
# hidden_states: [s, b, h] # hidden_states: [sq, b, h]
# ===================== # =====================
# Query, Key, and Value # Query, Key, and Value
# ===================== # =====================
# Attention heads [s, b, hp] --> [s, b, hp * 3] # Attention heads [sq, b, hp] --> [sq, b, hp * 3]
mixed_x_layer, _ = self.query_key_value(hidden_states) mixed_x_layer, _ = self.query_key_value(hidden_states)
checkpoint_version = get_checkpoint_version() checkpoint_version = get_checkpoint_version()
if checkpoint_version is not None and \ if checkpoint_version is not None and \
checkpoint_version == 0: checkpoint_version == 0:
# [s, b, 3 * hp] --> [s, b, hp * 3] # [sq, b, 3 * hp] --> [sq, b, hp * 3]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer) mixed_x_layer = self._transpose_last_dim(mixed_x_layer)
# [s, b, hp * 3] --> [s, b, np, hn, 3] # [sq, b, hp * 3] --> [sq, b, np, hn, 3]
new_tensor_shape = mixed_x_layer.size()[:-1] + \ new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition, (self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, 3) self.hidden_size_per_attention_head, 3)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [s, b, np, hn, 3] --> 3 [s, b, np, hn] # [sq, b, np, hn, 3] --> 3 [sq, b, np, hn]
query_layer = mixed_x_layer[:,:,:,:,0] query_layer = mixed_x_layer[:,:,:,:,0]
key_layer = mixed_x_layer[:,:,:,:,1] key_layer = mixed_x_layer[:,:,:,:,1]
value_layer = mixed_x_layer[:,:,:,:,2] value_layer = mixed_x_layer[:,:,:,:,2]
...@@ -235,19 +235,19 @@ class ParallelSelfAttention(MegatronModule): ...@@ -235,19 +235,19 @@ class ParallelSelfAttention(MegatronModule):
# Raw attention scores. [b, np, s, s] # Raw attention scores. [b, np, s, s]
# =================================== # ===================================
# [b, np, s, s] # [b, np, sq, sk]
output_size = (query_layer.size(1), output_size = (query_layer.size(1),
query_layer.size(2), query_layer.size(2),
query_layer.size(0), query_layer.size(0),
key_layer.size(0)) key_layer.size(0))
# [s, b, np, hn] -> [s, b * np, hn] # [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2], query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1) output_size[0] * output_size[1], -1)
key_layer = key_layer.view(output_size[3], key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1) output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, s, s] # preallocting result tensor: [b * np, sq, sk]
matmul_result = torch.empty( matmul_result = torch.empty(
output_size[0]*output_size[1], output_size[0]*output_size[1],
output_size[2], output_size[2],
...@@ -255,18 +255,18 @@ class ParallelSelfAttention(MegatronModule): ...@@ -255,18 +255,18 @@ class ParallelSelfAttention(MegatronModule):
dtype=query_layer.dtype, dtype=query_layer.dtype,
device=torch.cuda.current_device()) device=torch.cuda.current_device())
# Raw attention scores. [b * np, s, s] # Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(matmul_result, matmul_result = torch.baddbmm(matmul_result,
query_layer.transpose(0, 1), # [b * np, s, hn] query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0,1).transpose(1, 2), #[b * np, hn, s] key_layer.transpose(0,1).transpose(1, 2), #[b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor)) beta=0.0, alpha=(1.0/self.norm_factor))
# change view to [b, np, s, s] # change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size) attention_scores = matmul_result.view(*output_size)
# ================================================== # ==================================================
# Update attention mask for inference. [b, np, s, s] # Update attention mask for inference. [b, np, sq, sk]
# ================================================== # ==================================================
if get_key_value: if get_key_value:
...@@ -287,7 +287,7 @@ class ParallelSelfAttention(MegatronModule): ...@@ -287,7 +287,7 @@ class ParallelSelfAttention(MegatronModule):
# Attention probs and dropout # Attention probs and dropout
# =========================== # ===========================
# attention scores and attention mask [b, np, s, s] # attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores, attention_probs = self.scale_mask_softmax(attention_scores,
attention_mask) attention_mask)
...@@ -298,43 +298,43 @@ class ParallelSelfAttention(MegatronModule): ...@@ -298,43 +298,43 @@ class ParallelSelfAttention(MegatronModule):
# ========================= # =========================
# Context layer. [s, b, hp] # Context layer. [sq, b, hp]
# ========================= # =========================
# value_layer -> context layer. # value_layer -> context layer.
# [s, b, np, hn] --> [b, np, s, hn] # [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, s, hn] # context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1), output_size = (value_layer.size(1),
value_layer.size(2), value_layer.size(2),
value_layer.size(0), query_layer.size(0),
value_layer.size(3)) value_layer.size(3))
# change view [s, b * np, hn] # change view [sk, b * np, hn]
value_layer = value_layer.view(output_size[2], value_layer = value_layer.view(value_layer.size(0),
output_size[0] * output_size[1], -1) output_size[0] * output_size[1], -1)
# change view [b * np, s, s] # change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1], attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1) output_size[2], -1)
# matmul: [b * np, s, hn] # matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0,1)) context_layer = torch.bmm(attention_probs, value_layer.transpose(0,1))
# change view [b, np, s, hn] # change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size) context_layer = context_layer.view(*output_size)
# [b, np, s, hn] --> [s, b, np, hn] # [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous() context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [s, b, np, hn] --> [s, b, hp] # [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \ new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,) (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(*new_context_layer_shape)
# ================= # =================
# Output. [s, b, h] # Output. [sq, b, h]
# ================= # =================
output, bias = self.dense(context_layer) output, bias = self.dense(context_layer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment