Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
efc750b6
Commit
efc750b6
authored
Oct 29, 2021
by
mshoeybi
Browse files
made model stateless with respect to inference
parent
d33460df
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
29 deletions
+17
-29
megatron/model/transformer.py
megatron/model/transformer.py
+16
-23
megatron/text_generation/forward_step.py
megatron/text_generation/forward_step.py
+1
-6
No files found.
megatron/model/transformer.py
View file @
efc750b6
...
...
@@ -179,10 +179,6 @@ class ParallelAttention(MegatronModule):
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
# Inference key-value memory
self
.
inference_key_memory
=
None
self
.
inference_value_memory
=
None
def
_allocate_memory
(
self
,
inference_max_sequence_len
,
batch_size
):
return
torch
.
empty
(
...
...
@@ -203,19 +199,18 @@ class ParallelAttention(MegatronModule):
# Pre-allocate memory for key-values for inference.
# =================================================
if
inference_params
:
if
inference_params
.
allocate_
key_value_memory
:
if
self
.
layer_number
not
in
inference_params
.
key_value_memory
_dict
:
inf_max_seq_len
=
inference_params
.
max_sequence_len
inf_max_batch_size
=
inference_params
.
max_batch_size
self
.
inference_key_memory
=
self
.
_allocate_memory
(
inference_key_memory
=
self
.
_allocate_memory
(
inf_max_seq_len
,
inf_max_batch_size
)
self
.
inference_value_memory
=
self
.
_allocate_memory
(
inference_value_memory
=
self
.
_allocate_memory
(
inf_max_seq_len
,
inf_max_batch_size
)
# This is added for safety. In case inference_params
# is not provided, make sure there is no potential memory left
# from previous inference.
else
:
self
.
inference_key_memory
=
None
self
.
inference_value_memory
=
None
inference_params
.
key_value_memory_dict
[
self
.
layer_number
]
=
(
inference_key_memory
,
inference_value_memory
)
else
:
inference_key_memory
,
inference_value_memory
=
\
inference_params
.
key_value_memory_dict
[
self
.
layer_number
]
# =====================
...
...
@@ -266,20 +261,18 @@ class ParallelAttention(MegatronModule):
if
inference_params
:
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
key_layer
.
size
(
1
)
assert
batch_end
<=
self
.
inference_key_memory
.
size
(
1
)
assert
batch_end
<=
inference_key_memory
.
size
(
1
)
sequence_start
=
inference_params
.
sequence_len_offset
sequence_end
=
sequence_start
+
key_layer
.
size
(
0
)
assert
sequence_end
<=
self
.
inference_key_memory
.
size
(
0
)
assert
sequence_end
<=
inference_key_memory
.
size
(
0
)
# Copy key and values.
self
.
inference_key_memory
[
sequence_start
:
sequence_end
,
batch_start
:
batch_end
,
...]
=
key_layer
self
.
inference_value_memory
[
sequence_start
:
sequence_end
,
batch_start
:
batch_end
,
...]
=
value_layer
key_layer
=
self
.
inference_key_memory
[
inference_key_memory
[
sequence_start
:
sequence_end
,
batch_start
:
batch_end
,
...]
=
key_layer
inference_value_memory
[
sequence_start
:
sequence_end
,
batch_start
:
batch_end
,
...]
=
value_layer
key_layer
=
inference_key_memory
[
:
sequence_end
,
batch_start
:
batch_end
,
...]
value_layer
=
self
.
inference_value_memory
[
value_layer
=
inference_value_memory
[
:
sequence_end
,
batch_start
:
batch_end
,
...]
...
...
megatron/text_generation/forward_step.py
View file @
efc750b6
...
...
@@ -40,7 +40,7 @@ class InferenceParams:
self
.
max_batch_size
=
max_batch_size
self
.
sequence_len_offset
=
0
self
.
batch_size_offset
=
0
self
.
allocate_
key_value_memory
=
True
self
.
key_value_memory
_dict
=
{}
...
...
@@ -132,11 +132,6 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask,
# Send output to the next stage.
send_to_next_pipeline_rank
(
output_tensor
)
# Make sure we do not allocate context memory anymore.
if
inference_params
.
allocate_key_value_memory
:
inference_params
.
allocate_key_value_memory
=
False
return
output_tensor
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment