Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
9be97c01
Commit
9be97c01
authored
May 23, 2025
by
dongcl
Browse files
update ParallelAttention to core v0.12.0
parent
698bfd4d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
13 deletions
+18
-13
dcu_megatron/legacy/model/transformer.py
dcu_megatron/legacy/model/transformer.py
+18
-13
No files found.
dcu_megatron/legacy/model/transformer.py
View file @
9be97c01
...
...
@@ -5,6 +5,7 @@ from functools import wraps
from
megatron.training
import
get_args
from
megatron.core
import
tensor_parallel
from
megatron.core.utils
import
deprecate_inference_params
from
megatron.legacy.model.enums
import
AttnType
from
megatron.core.models.common.embeddings
import
apply_rotary_pos_emb
from
megatron.legacy.model.module
import
MegatronModule
...
...
@@ -92,20 +93,23 @@ class ParallelAttentionPatch(MegatronModule):
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
# query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head) =>
# query_layer = query_layer.contiguous().view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head)
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
inference_
params
=
None
,
rotary_pos_emb
=
None
):
encoder_output
=
None
,
inference_
context
=
None
,
rotary_pos_emb
=
None
,
*
,
inference_params
=
None
):
# hidden_states: [sq, b, h]
inference_context
=
deprecate_inference_params
(
inference_context
,
inference_params
)
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
is_first_step
=
False
if
inference_
params
:
if
self
.
layer_number
not
in
inference_
params
.
key_value_memory_dict
:
inf_max_seq_len
=
inference_
params
.
max_sequence_length
inf_max_batch_size
=
inference_
params
.
max_batch_size
if
inference_
context
:
if
self
.
layer_number
not
in
inference_
context
.
key_value_memory_dict
:
inf_max_seq_len
=
inference_
context
.
max_sequence_length
inf_max_batch_size
=
inference_
context
.
max_batch_size
inference_key_memory
=
self
.
_allocate_memory
(
inf_max_seq_len
,
inf_max_batch_size
,
self
.
num_query_groups_per_partition
)
...
...
@@ -113,12 +117,12 @@ class ParallelAttentionPatch(MegatronModule):
inf_max_seq_len
,
inf_max_batch_size
,
self
.
num_query_groups_per_partition
)
inference_
params
.
key_value_memory_dict
[
self
.
layer_number
]
=
(
inference_
context
.
key_value_memory_dict
[
self
.
layer_number
]
=
(
inference_key_memory
,
inference_value_memory
)
is_first_step
=
True
else
:
inference_key_memory
,
inference_value_memory
=
\
inference_
params
.
key_value_memory_dict
[
self
.
layer_number
]
inference_
context
.
key_value_memory_dict
[
self
.
layer_number
]
# =====================
# Query, Key, and Value
...
...
@@ -188,13 +192,14 @@ class ParallelAttentionPatch(MegatronModule):
else
:
rotary_pos_emb
=
((
rotary_pos_emb
,)
*
2
)
if
inference_
params
:
batch_start
=
inference_
params
.
batch_size_offset
if
inference_
context
:
batch_start
=
inference_
context
.
batch_size_offset
batch_end
=
batch_start
+
key_layer
.
size
(
1
)
assert
batch_end
<=
inference_key_memory
.
size
(
1
)
sequence_start
=
inference_
params
.
sequence_len_offset
sequence_start
=
inference_
context
.
sequence_len_offset
sequence_end
=
sequence_start
+
key_layer
.
size
(
0
)
assert
sequence_end
<=
inference_key_memory
.
size
(
0
)
assert
sequence_end
<=
inference_key_memory
.
size
(
0
),
(
"Current sequence length is "
"longer than expected maximum sequence length! Increase inference_max_seq_length."
)
# Copy key and values.
inference_key_memory
[
sequence_start
:
sequence_end
,
batch_start
:
batch_end
,
...]
=
key_layer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment