Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
ff2f0a05
Commit
ff2f0a05
authored
Sep 29, 2021
by
mshoeybi
Browse files
further refactoring
parent
390ddef8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
71 additions
and
22 deletions
+71
-22
megatron/inference/forward_step.py
megatron/inference/forward_step.py
+66
-8
megatron/inference/generation.py
megatron/inference/generation.py
+5
-14
No files found.
megatron/inference/forward_step.py
View file @
ff2f0a05
...
...
@@ -15,15 +15,20 @@
"""Forward step utilities."""
import
torch
from
collections.abc
import
Iterable
from
enum
import
Enum
from
megatron.p2p_communication
import
recv_forward
,
send_forward
from
megatron
import
get_args
class
ForwardStepTypes
(
Enum
):
NO_PIPELINING
=
1
class
InferenceParams
:
def
__init__
(
self
,
micro_batch_size_list
,
max_sequence_len
):
assert
isinstance
(
micro_batch_size_list
,
list
)
...
...
@@ -31,10 +36,67 @@ class InferenceParams:
self
.
micro_batch_size_list
=
micro_batch_size_list
self
.
max_sequence_len
=
max_sequence_len
self
.
allocate_key_value_memory
=
Fals
e
self
.
allocate_key_value_memory
=
Tru
e
self
.
micro_batch_size_index
=
0
class
InferenceForwardStep
:
def
__init__
(
self
,
model
,
batch_size
,
max_sequence_len
):
if
isinstance
(
model
,
Iterable
):
for
this_model
in
model
:
this_model
.
eval
()
else
:
model
.
eval
()
self
.
model
=
model
self
.
inference_params
=
InferenceParams
([
batch_size
],
max_sequence_len
)
self
.
forward_step_type
=
ForwardStepTypes
.
NO_PIPELINING
def
__call__
(
self
,
tokens
,
position_ids
,
attention_mask
):
if
self
.
forward_step_type
==
ForwardStepTypes
.
NO_PIPELINING
:
return
self
.
_forward_step_no_pipelining
(
tokens
,
position_ids
,
attention_mask
)
raise
Exception
(
'unknown forward step type {}'
.
format
(
self
.
forward_step_type
))
def
_forward_step_no_pipelining
(
self
,
tokens
,
position_ids
,
attention_mask
):
# Need to tell p2p_communicate functions the correct size.
args
=
get_args
()
orig_seq_length
=
args
.
seq_length
args
.
seq_length
=
tokens
.
shape
[
1
]
assert
args
.
seq_length
<=
self
.
inference_params
.
max_sequence_len
args
.
micro_batch_size
=
tokens
.
shape
[
0
]
assert
self
.
inference_params
.
micro_batch_size_list
[
0
]
==
tokens
.
shape
[
0
]
assert
self
.
inference_params
.
micro_batch_size_index
==
0
# Receive from previous stage.
input_tensor
=
recv_forward
()
# Forward pass through the model.
self
.
model
.
set_input_tensor
(
input_tensor
)
output_tensor
=
self
.
model
(
tokens
,
position_ids
,
attention_mask
,
inference_params
=
self
.
inference_params
)
# Send output to the next stage.
send_forward
(
output_tensor
)
# Reset the sequence lenght to whatwever it was before.
args
.
seq_length
=
orig_seq_length
# Make sure we do not allocate context memory anymore.
if
self
.
inference_params
.
allocate_key_value_memory
:
self
.
inference_params
.
allocate_key_value_memory
=
False
return
output_tensor
def
forward_step
(
model
,
tokens
,
position_ids
,
attention_mask
,
inference_params
):
# Hidden size changes when not using recompute, need to tell p2p_communicate
...
...
@@ -56,7 +118,3 @@ def forward_step(model, tokens, position_ids, attention_mask, inference_params):
args
.
seq_length
=
orig_seq_length
return
output_tensor
megatron/inference/generation.py
View file @
ff2f0a05
...
...
@@ -15,7 +15,6 @@
"""Generation utilities."""
import
torch
import
torch.nn.functional
as
F
...
...
@@ -25,7 +24,7 @@ from .communication import (
copy_from_last_to_first_pipeline_stage
,
broadcast_from_last_pipeline_stage
,
broadcast_from_last_to_first_pipeline_stage
)
from
.forward_step
import
forward_step
,
InferenceParams
from
.forward_step
import
InferenceForwardStep
from
.sampling
import
sample
...
...
@@ -66,6 +65,9 @@ def generate_tokens_probs_and_return_on_first_stage(
max_sequence_length
=
tokens
.
size
(
1
)
max_sequence_length
=
min
(
max_sequence_length
,
args
.
max_position_embeddings
)
# forward step.
forward_step
=
InferenceForwardStep
(
model
,
batch_size
,
max_sequence_length
)
# Added termination_id to support the case that we want to terminate the
# generation once that id is generated.
if
hasattr
(
args
,
'eos_id'
):
...
...
@@ -109,20 +111,10 @@ def generate_tokens_probs_and_return_on_first_stage(
attention_mask
,
position_ids
=
_build_attention_mask_and_position_ids
(
tokens
)
# Set inference params
inference_params
=
InferenceParams
([
batch_size
],
max_sequence_length
)
model
.
eval
()
with
torch
.
no_grad
():
prev_context_length
=
0
for
context_length
in
range
(
min_prompt_length
,
max_sequence_length
):
# If we are starting from scratch, allocate memory for the entire
# context, otherwise set this to false so the memory is not
# reallocated.
inference_params
.
allocate_key_value_memory
=
\
(
prev_context_length
==
0
)
# Pick the slice that we need to pass through the network.
tokens2use
=
tokens
[:,
prev_context_length
:
context_length
]
positions2use
=
position_ids
[:,
prev_context_length
:
context_length
]
...
...
@@ -130,8 +122,7 @@ def generate_tokens_probs_and_return_on_first_stage(
...,
prev_context_length
:
context_length
,
:
context_length
]
# logits will be meanigful only in the last pipeline stage.
logits
=
forward_step
(
model
,
tokens2use
,
positions2use
,
attention_mask2use
,
inference_params
)
logits
=
forward_step
(
tokens2use
,
positions2use
,
attention_mask2use
)
if
mpu
.
is_pipeline_last_stage
():
# Always the last stage should have an output.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment