Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
d763f54a
Commit
d763f54a
authored
Oct 08, 2020
by
Vijay Korthikanti
Browse files
support for different query key sequence lengths
parent
56e16cba
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
27 deletions
+27
-27
megatron/model/transformer.py
megatron/model/transformer.py
+27
-27
No files found.
megatron/model/transformer.py
View file @
d763f54a
...
@@ -191,28 +191,28 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -191,28 +191,28 @@ class ParallelSelfAttention(MegatronModule):
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
get_key_value
=
False
):
# hidden_states: [s, b, h]
# hidden_states: [s
q
, b, h]
# =====================
# =====================
# Query, Key, and Value
# Query, Key, and Value
# =====================
# =====================
# Attention heads [s, b, hp] --> [s, b, hp * 3]
# Attention heads [s
q
, b, hp] --> [s
q
, b, hp * 3]
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
checkpoint_version
=
get_checkpoint_version
()
checkpoint_version
=
get_checkpoint_version
()
if
checkpoint_version
is
not
None
and
\
if
checkpoint_version
is
not
None
and
\
checkpoint_version
==
0
:
checkpoint_version
==
0
:
# [s, b, 3 * hp] --> [s, b, hp * 3]
# [s
q
, b, 3 * hp] --> [s
q
, b, hp * 3]
mixed_x_layer
=
self
.
_transpose_last_dim
(
mixed_x_layer
)
mixed_x_layer
=
self
.
_transpose_last_dim
(
mixed_x_layer
)
# [s, b, hp * 3] --> [s, b, np, hn, 3]
# [s
q
, b, hp * 3] --> [s
q
, b, np, hn, 3]
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
\
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
(
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
,
3
)
self
.
hidden_size_per_attention_head
,
3
)
mixed_x_layer
=
mixed_x_layer
.
view
(
*
new_tensor_shape
)
mixed_x_layer
=
mixed_x_layer
.
view
(
*
new_tensor_shape
)
# [s, b, np, hn, 3] --> 3 [s, b, np, hn]
# [s
q
, b, np, hn, 3] --> 3 [s
q
, b, np, hn]
query_layer
=
mixed_x_layer
[:,:,:,:,
0
]
query_layer
=
mixed_x_layer
[:,:,:,:,
0
]
key_layer
=
mixed_x_layer
[:,:,:,:,
1
]
key_layer
=
mixed_x_layer
[:,:,:,:,
1
]
value_layer
=
mixed_x_layer
[:,:,:,:,
2
]
value_layer
=
mixed_x_layer
[:,:,:,:,
2
]
...
@@ -235,19 +235,19 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -235,19 +235,19 @@ class ParallelSelfAttention(MegatronModule):
# Raw attention scores. [b, np, s, s]
# Raw attention scores. [b, np, s, s]
# ===================================
# ===================================
# [b, np, s, s]
# [b, np, s
q
, s
k
]
output_size
=
(
query_layer
.
size
(
1
),
output_size
=
(
query_layer
.
size
(
1
),
query_layer
.
size
(
2
),
query_layer
.
size
(
2
),
query_layer
.
size
(
0
),
query_layer
.
size
(
0
),
key_layer
.
size
(
0
))
key_layer
.
size
(
0
))
# [s, b, np, hn] -> [s, b * np, hn]
# [s
q
, b, np, hn] -> [s
q
, b * np, hn]
query_layer
=
query_layer
.
view
(
output_size
[
2
],
query_layer
=
query_layer
.
view
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
output_size
[
0
]
*
output_size
[
1
],
-
1
)
key_layer
=
key_layer
.
view
(
output_size
[
3
],
key_layer
=
key_layer
.
view
(
output_size
[
3
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# preallocting result tensor: [b * np, s, s]
# preallocting result tensor: [b * np, s
q
, s
k
]
matmul_result
=
torch
.
empty
(
matmul_result
=
torch
.
empty
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
2
],
...
@@ -255,18 +255,18 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -255,18 +255,18 @@ class ParallelSelfAttention(MegatronModule):
dtype
=
query_layer
.
dtype
,
dtype
=
query_layer
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
device
=
torch
.
cuda
.
current_device
())
# Raw attention scores. [b * np, s, s]
# Raw attention scores. [b * np, s
q
, s
k
]
matmul_result
=
torch
.
baddbmm
(
matmul_result
,
matmul_result
=
torch
.
baddbmm
(
matmul_result
,
query_layer
.
transpose
(
0
,
1
),
# [b * np, s, hn]
query_layer
.
transpose
(
0
,
1
),
# [b * np, s
q
, hn]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
#[b * np, hn, s]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
#[b * np, hn, s
k
]
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
))
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
))
# change view to [b, np, s, s]
# change view to [b, np, s
q
, s
k
]
attention_scores
=
matmul_result
.
view
(
*
output_size
)
attention_scores
=
matmul_result
.
view
(
*
output_size
)
# ==================================================
# ==================================================
# Update attention mask for inference. [b, np, s, s]
# Update attention mask for inference. [b, np, s
q
, s
k
]
# ==================================================
# ==================================================
if
get_key_value
:
if
get_key_value
:
...
@@ -287,7 +287,7 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -287,7 +287,7 @@ class ParallelSelfAttention(MegatronModule):
# Attention probs and dropout
# Attention probs and dropout
# ===========================
# ===========================
# attention scores and attention mask [b, np, s, s]
# attention scores and attention mask [b, np, s
q
, s
k
]
attention_probs
=
self
.
scale_mask_softmax
(
attention_scores
,
attention_probs
=
self
.
scale_mask_softmax
(
attention_scores
,
attention_mask
)
attention_mask
)
...
@@ -298,43 +298,43 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -298,43 +298,43 @@ class ParallelSelfAttention(MegatronModule):
# =========================
# =========================
# Context layer. [s, b, hp]
# Context layer. [s
q
, b, hp]
# =========================
# =========================
# value_layer -> context layer.
# value_layer -> context layer.
# [s, b, np, hn] --> [b, np, s, hn]
# [s
k
, b, np, hn] --> [b, np, s
q
, hn]
# context layer shape: [b, np, s, hn]
# context layer shape: [b, np, s
q
, hn]
output_size
=
(
value_layer
.
size
(
1
),
output_size
=
(
value_layer
.
size
(
1
),
value_layer
.
size
(
2
),
value_layer
.
size
(
2
),
value
_layer
.
size
(
0
),
query
_layer
.
size
(
0
),
value_layer
.
size
(
3
))
value_layer
.
size
(
3
))
# change view [s, b * np, hn]
# change view [s
k
, b * np, hn]
value_layer
=
value_layer
.
view
(
output_
size
[
2
]
,
value_layer
=
value_layer
.
view
(
value_layer
.
size
(
0
)
,
output_size
[
0
]
*
output_size
[
1
],
-
1
)
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# change view [b * np, s, s]
# change view [b * np, s
q
, s
k
]
attention_probs
=
attention_probs
.
view
(
output_size
[
0
]
*
output_size
[
1
],
attention_probs
=
attention_probs
.
view
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
-
1
)
output_size
[
2
],
-
1
)
# matmul: [b * np, s, hn]
# matmul: [b * np, s
q
, hn]
context_layer
=
torch
.
bmm
(
attention_probs
,
value_layer
.
transpose
(
0
,
1
))
context_layer
=
torch
.
bmm
(
attention_probs
,
value_layer
.
transpose
(
0
,
1
))
# change view [b, np, s, hn]
# change view [b, np, s
q
, hn]
context_layer
=
context_layer
.
view
(
*
output_size
)
context_layer
=
context_layer
.
view
(
*
output_size
)
# [b, np, s, hn] --> [s, b, np, hn]
# [b, np, s
q
, hn] --> [s
q
, b, np, hn]
context_layer
=
context_layer
.
permute
(
2
,
0
,
1
,
3
).
contiguous
()
context_layer
=
context_layer
.
permute
(
2
,
0
,
1
,
3
).
contiguous
()
# [s, b, np, hn] --> [s, b, hp]
# [s
q
, b, np, hn] --> [s
q
, b, hp]
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
\
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
\
(
self
.
hidden_size_per_partition
,)
(
self
.
hidden_size_per_partition
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
# =================
# =================
# Output. [s, b, h]
# Output. [s
q
, b, h]
# =================
# =================
output
,
bias
=
self
.
dense
(
context_layer
)
output
,
bias
=
self
.
dense
(
context_layer
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment