Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
390ddef8
Commit
390ddef8
authored
Sep 29, 2021
by
mshoeybi
Browse files
added multi-batch inference
parent
b6555b71
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
76 additions
and
72 deletions
+76
-72
megatron/inference/forward_step.py
megatron/inference/forward_step.py
+16
-7
megatron/inference/generation.py
megatron/inference/generation.py
+8
-6
megatron/model/gpt_model.py
megatron/model/gpt_model.py
+2
-5
megatron/model/language_model.py
megatron/model/language_model.py
+3
-6
megatron/model/transformer.py
megatron/model/transformer.py
+47
-48
No files found.
megatron/inference/forward_step.py
View file @
390ddef8
...
...
@@ -22,9 +22,20 @@ from megatron.p2p_communication import recv_forward, send_forward
from
megatron
import
get_args
def
forward_step
(
model
,
tokens
,
position_ids
,
attention_mask
,
set_inference_key_value_memory
=
False
,
inference_max_sequence_len
=
None
):
class
InferenceParams
:
def
__init__
(
self
,
micro_batch_size_list
,
max_sequence_len
):
assert
isinstance
(
micro_batch_size_list
,
list
)
assert
max_sequence_len
>
0
self
.
micro_batch_size_list
=
micro_batch_size_list
self
.
max_sequence_len
=
max_sequence_len
self
.
allocate_key_value_memory
=
False
self
.
micro_batch_size_index
=
0
def
forward_step
(
model
,
tokens
,
position_ids
,
attention_mask
,
inference_params
):
# Hidden size changes when not using recompute, need to tell p2p_communicate
# functions the correct size
...
...
@@ -37,10 +48,8 @@ def forward_step(model, tokens, position_ids, attention_mask,
# Forward pass through the model.
model
.
set_input_tensor
(
input_tensor
)
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
set_inference_key_value_memory
=
set_inference_key_value_memory
,
inference_max_sequence_len
=
inference_max_sequence_len
)
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
inference_params
=
inference_params
)
send_forward
(
output_tensor
)
...
...
megatron/inference/generation.py
View file @
390ddef8
...
...
@@ -25,7 +25,7 @@ from .communication import (
copy_from_last_to_first_pipeline_stage
,
broadcast_from_last_pipeline_stage
,
broadcast_from_last_to_first_pipeline_stage
)
from
.forward_step
import
forward_step
from
.forward_step
import
forward_step
,
InferenceParams
from
.sampling
import
sample
...
...
@@ -109,6 +109,9 @@ def generate_tokens_probs_and_return_on_first_stage(
attention_mask
,
position_ids
=
_build_attention_mask_and_position_ids
(
tokens
)
# Set inference params
inference_params
=
InferenceParams
([
batch_size
],
max_sequence_length
)
model
.
eval
()
with
torch
.
no_grad
():
prev_context_length
=
0
...
...
@@ -117,7 +120,8 @@ def generate_tokens_probs_and_return_on_first_stage(
# If we are starting from scratch, allocate memory for the entire
# context, otherwise set this to false so the memory is not
# reallocated.
set_inference_key_value_memory
=
(
prev_context_length
==
0
)
inference_params
.
allocate_key_value_memory
=
\
(
prev_context_length
==
0
)
# Pick the slice that we need to pass through the network.
tokens2use
=
tokens
[:,
prev_context_length
:
context_length
]
...
...
@@ -126,10 +130,8 @@ def generate_tokens_probs_and_return_on_first_stage(
...,
prev_context_length
:
context_length
,
:
context_length
]
# logits will be meanigful only in the last pipeline stage.
logits
=
forward_step
(
model
,
tokens2use
,
positions2use
,
attention_mask2use
,
set_inference_key_value_memory
=
set_inference_key_value_memory
,
inference_max_sequence_len
=
max_sequence_length
)
logits
=
forward_step
(
model
,
tokens2use
,
positions2use
,
attention_mask2use
,
inference_params
)
if
mpu
.
is_pipeline_last_stage
():
# Always the last stage should have an output.
...
...
megatron/model/gpt_model.py
View file @
390ddef8
...
...
@@ -82,16 +82,13 @@ class GPTModel(MegatronModule):
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
labels
=
None
,
tokentype_ids
=
None
,
set_inference_key_value_memory
=
False
,
inference_max_sequence_len
=
None
):
tokentype_ids
=
None
,
inference_params
=
None
):
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
attention_mask
,
set_inference_key_value_memory
=
set_inference_key_value_memory
,
inference_max_sequence_len
=
inference_max_sequence_len
)
inference_params
=
inference_params
)
if
self
.
post_process
:
return
post_language_model_processing
(
...
...
megatron/model/language_model.py
View file @
390ddef8
...
...
@@ -335,8 +335,7 @@ class TransformerLanguageModel(MegatronModule):
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_attn_mask
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
set_inference_key_value_memory
=
False
,
inference_max_sequence_len
=
None
,
inference_params
=
None
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
...
...
@@ -353,8 +352,7 @@ class TransformerLanguageModel(MegatronModule):
encoder_output
=
self
.
encoder
(
encoder_input
,
enc_attn_mask
,
set_inference_key_value_memory
=
set_inference_key_value_memory
,
inference_max_sequence_len
=
inference_max_sequence_len
)
inference_params
=
inference_params
)
else
:
encoder_output
=
enc_hidden_states
.
to
(
encoder_input
.
dtype
)
...
...
@@ -381,8 +379,7 @@ class TransformerLanguageModel(MegatronModule):
dec_attn_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
set_inference_key_value_memory
=
set_inference_key_value_memory
,
inference_max_sequence_len
=
inference_max_sequence_len
)
inference_params
=
inference_params
)
if
self
.
add_pooler
and
self
.
post_process
:
return
decoder_output
,
encoder_output
,
pooled_output
...
...
megatron/model/transformer.py
View file @
390ddef8
...
...
@@ -180,9 +180,9 @@ class ParallelAttention(MegatronModule):
skip_bias_add
=
True
)
# Inference key-value memory
self
.
inference_key_memory
=
None
self
.
inference_value_memory
=
None
self
.
inference_current_sequence_len
=
0
self
.
inference_key_memory
_list
=
None
self
.
inference_value_memory
_list
=
None
self
.
inference_current_sequence_len
_list
=
None
def
_allocate_memory
(
self
,
inference_max_sequence_len
,
batch_size
):
...
...
@@ -196,35 +196,32 @@ class ParallelAttention(MegatronModule):
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
set_inference_key_value_memory
=
False
,
inference_max_sequence_len
=
None
):
encoder_output
=
None
,
inference_params
=
None
):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if
set_inference_key_value_memory
:
assert
inference_max_sequence_len
and
inference_max_sequence_len
>
0
self
.
inference_key_memory
=
self
.
_allocate_memory
(
inference_max_sequence_len
,
hidden_states
.
size
(
1
))
self
.
inference_value_memory
=
self
.
_allocate_memory
(
inference_max_sequence_len
,
hidden_states
.
size
(
1
))
self
.
inference_current_sequence_len
=
0
# Some consistency check.
if
inference_max_sequence_len
:
assert
self
.
inference_current_sequence_len
<
\
self
.
inference_key_memory
.
size
(
0
)
assert
inference_max_sequence_len
==
\
self
.
inference_key_memory
.
size
(
0
)
# This is added for safety. In case inference_max_sequence_len
if
inference_params
:
if
inference_params
.
allocate_key_value_memory
:
inf_max_seq_len
=
inference_params
.
max_sequence_len
inf_batch_sizes
=
inference_params
.
micro_batch_size_list
self
.
inference_key_memory_list
=
[
self
.
_allocate_memory
(
inf_max_seq_len
,
inf_batch_size
)
for
inf_batch_size
in
inf_batch_sizes
]
self
.
inference_value_memory_list
=
[
self
.
_allocate_memory
(
inf_max_seq_len
,
inf_batch_size
)
for
inf_batch_size
in
inf_batch_sizes
]
self
.
inference_current_sequence_len_list
=
[
0
for
_
in
inf_batch_sizes
]
# This is added for safety. In case inference_params
# is not provided, make sure there is no potential memory left
# from previous inference.
if
not
inference_max_sequence_len
:
self
.
inference_key_memory
=
None
self
.
inference_value_memory
=
None
else
:
self
.
inference_key_memory
_list
=
None
self
.
inference_value_memory
_list
=
None
self
.
inference_current_sequence_len_list
=
None
# =====================
# Query, Key, and Value
...
...
@@ -267,20 +264,27 @@ class ParallelAttention(MegatronModule):
query_layer
=
query_layer
.
view
(
*
new_tensor_shape
)
# ==================================
=================
# Adjust key
, value, and attention mask
for inference
# ==================================
=================
# ==================================
# Adjust key
and value
for inference
# ==================================
if
inference_max_sequence_len
:
if
inference_params
:
inf_batch_index
=
inference_params
.
micro_batch_size_index
assert
key_layer
.
size
(
1
)
==
\
inference_params
.
micro_batch_size_list
[
inf_batch_index
]
# Adjust the range variables.
start
=
self
.
inference_current_sequence_len
self
.
inference_current_sequence_len
+
=
key_layer
.
size
(
0
)
end
=
self
.
inference_current_sequence_len
start
=
self
.
inference_current_sequence_len
_list
[
inf_batch_index
]
end
=
start
+
key_layer
.
size
(
0
)
self
.
inference_current_sequence_len
_list
[
inf_batch_index
]
=
end
# Copy key and values.
self
.
inference_key_memory
[
start
:
end
,
...]
=
key_layer
self
.
inference_value_memory
[
start
:
end
,
...]
=
value_layer
key_layer
=
self
.
inference_key_memory
[:
end
,
...]
value_layer
=
self
.
inference_value_memory
[:
end
,
...]
self
.
inference_key_memory_list
[
inf_batch_index
][
start
:
end
,
...]
=
\
key_layer
self
.
inference_value_memory_list
[
inf_batch_index
][
start
:
end
,
...]
=
\
value_layer
key_layer
=
\
self
.
inference_key_memory_list
[
inf_batch_index
][:
end
,
...]
value_layer
=
\
self
.
inference_value_memory_list
[
inf_batch_index
][:
end
,
...]
# ===================================
...
...
@@ -459,10 +463,8 @@ class ParallelTransformerLayer(MegatronModule):
output_layer_init_method
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
set_inference_key_value_memory
=
False
,
inference_max_sequence_len
=
None
):
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
# hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer.
...
...
@@ -472,8 +474,7 @@ class ParallelTransformerLayer(MegatronModule):
self
.
self_attention
(
layernorm_output
,
attention_mask
,
set_inference_key_value_memory
=
set_inference_key_value_memory
,
inference_max_sequence_len
=
inference_max_sequence_len
)
inference_params
=
inference_params
)
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
...
...
@@ -686,13 +687,11 @@ class ParallelTransformer(MegatronModule):
self
.
input_tensor
=
input_tensor
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
set_inference_key_value_memory
=
False
,
inference_max_sequence_len
=
None
):
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
# Checks.
if
inference_
max_sequence_len
:
if
inference_
params
:
assert
self
.
activations_checkpoint_method
is
None
,
\
'inference does not work with activation checkpointing'
...
...
@@ -724,8 +723,8 @@ class ParallelTransformer(MegatronModule):
attention_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
set_
inference_
key_value_memory
=
set_inference_key_value_memory
,
inference_max_sequence_len
=
inference_max_sequence_len
)
inference_
params
=
inference_params
)
# Final layer norm.
if
self
.
post_process
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment