Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
8f160844
Commit
8f160844
authored
Oct 01, 2021
by
mshoeybi
Browse files
simple pipelining works
parent
dee8707e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
123 additions
and
48 deletions
+123
-48
megatron/inference/forward_step.py
megatron/inference/forward_step.py
+114
-41
megatron/inference/generation.py
megatron/inference/generation.py
+3
-2
megatron/model/transformer.py
megatron/model/transformer.py
+6
-5
No files found.
megatron/inference/forward_step.py
View file @
8f160844
...
...
@@ -15,15 +15,33 @@
"""Forward step utilities."""
from
abc
import
ABC
from
abc
import
abstractmethod
from
collections.abc
import
Iterable
from
enum
import
Enum
from
megatron.p2p_communication
import
recv_forward
,
send_forward
from
megatron
import
get_args
import
torch
from
megatron
import
(
get_args
,
mpu
)
from
megatron.p2p_communication
import
(
recv_forward
,
send_forward
)
class
ForwardStepTypes
(
Enum
):
NO_PIPELINING
=
1
def
forward_step_provider
(
model
,
batch_size
,
micro_batch_size
,
max_sequence_len
):
args
=
get_args
()
if
args
.
pipeline_model_parallel_size
==
1
or
micro_batch_size
>=
batch_size
:
return
NoPipeliningForwardStep
(
model
,
batch_size
,
max_sequence_len
)
return
SimplePipeliningForwardStep
(
model
,
batch_size
,
micro_batch_size
,
max_sequence_len
)
...
...
@@ -37,12 +55,12 @@ class InferenceParams:
self
.
micro_batch_size_list
=
micro_batch_size_list
self
.
max_sequence_len
=
max_sequence_len
self
.
allocate_key_value_memory
=
True
self
.
micro_batch_
size_
index
=
0
self
.
micro_batch_index
=
0
class
Inference
ForwardStep
:
class
ForwardStep
Base
(
ABC
)
:
def
__init__
(
self
,
model
,
batch_size
,
max_sequence_len
):
def
__init__
(
self
,
model
):
if
isinstance
(
model
,
Iterable
):
for
this_model
in
model
:
...
...
@@ -51,21 +69,100 @@ class InferenceForwardStep:
model
.
eval
()
self
.
model
=
model
self
.
inference_params
=
InferenceParams
([
batch_size
],
max_sequence_len
)
self
.
forward_step_type
=
ForwardStepTypes
.
NO_PIPELINING
@
abstractmethod
def
__call__
(
self
,
tokens
,
position_ids
,
attention_mask
):
pass
class
SimplePipeliningForwardStep
(
ForwardStepBase
):
def
__init__
(
self
,
model
,
batch_size
,
micro_batch_size
,
max_sequence_len
):
super
().
__init__
(
model
)
self
.
batch_size
=
batch_size
# Divide the batch dimension into micro batches.
self
.
num_micro_batches
,
last_chunk
=
divmod
(
batch_size
,
micro_batch_size
)
self
.
micro_batch_size_list
=
[]
self
.
batch_dim_start_index
=
[
0
]
for
i
in
range
(
self
.
num_micro_batches
):
self
.
micro_batch_size_list
.
append
(
micro_batch_size
)
self
.
batch_dim_start_index
.
append
((
i
+
1
)
*
micro_batch_size
)
if
last_chunk
>
0
:
self
.
num_micro_batches
+=
1
self
.
micro_batch_size_list
.
append
(
last_chunk
)
self
.
batch_dim_start_index
.
append
(
batch_size
)
self
.
inference_params
=
InferenceParams
(
self
.
micro_batch_size_list
,
max_sequence_len
)
def
__call__
(
self
,
tokens
,
position_ids
,
attention_mask
):
if
self
.
forward_step_type
==
ForwardStepTypes
.
NO_PIPELINING
:
return
self
.
_forward_step_no_pipelining
(
tokens
,
position_ids
,
attention_mask
)
# Need to tell p2p_communicate functions the correct size.
args
=
get_args
()
orig_seq_length
=
args
.
seq_length
args
.
seq_length
=
tokens
.
size
(
1
)
assert
args
.
seq_length
<=
self
.
inference_params
.
max_sequence_len
# Preallocate memory for output logits.
logits
=
None
if
mpu
.
is_pipeline_last_stage
():
logits
=
torch
.
empty
(
tokens
.
size
(
0
),
tokens
.
size
(
1
),
args
.
padded_vocab_size
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
# Pileline using micro batches.
for
micro_batch_index
in
range
(
self
.
num_micro_batches
):
# Set micro-batch size and index.
self
.
inference_params
.
micro_batch_index
=
micro_batch_index
args
.
micro_batch_size
=
self
.
micro_batch_size_list
[
micro_batch_index
]
# Slice among the batch dimenion.
start
=
self
.
batch_dim_start_index
[
micro_batch_index
]
end
=
self
.
batch_dim_start_index
[
micro_batch_index
+
1
]
tokens2use
=
tokens
[
start
:
end
,
...]
position_ids2use
=
position_ids
[
start
:
end
,
...]
# Receive from previous stage.
input_tensor
=
recv_forward
()
# Forward pass through the model.
self
.
model
.
set_input_tensor
(
input_tensor
)
output_tensor
=
self
.
model
(
tokens2use
,
position_ids2use
,
attention_mask
,
inference_params
=
self
.
inference_params
)
# Send output to the next stage.
send_forward
(
output_tensor
)
# Reset the sequence lenght to whatwever it was before.
# Make sure we do not allocate context memory anymore.
if
self
.
inference_params
.
allocate_key_value_memory
:
self
.
inference_params
.
allocate_key_value_memory
=
False
if
mpu
.
is_pipeline_last_stage
():
logits
[
start
:
end
,
...]
=
output_tensor
# Adjust the sequence length back to whatever it was before.
args
.
seq_length
=
orig_seq_length
return
logits
class
NoPipeliningForwardStep
(
ForwardStepBase
):
raise
Exception
(
'unknown forward step type {}'
.
format
(
self
.
forward_step_type
)
)
def
__init__
(
self
,
model
,
batch_size
,
max_sequence_len
):
super
().
__init__
(
model
)
self
.
inference_params
=
InferenceParams
([
batch_size
],
max_sequence_len
)
def
_forward_step_no_pipelining
(
self
,
tokens
,
position_ids
,
attention_mask
):
def
__call__
(
self
,
tokens
,
position_ids
,
attention_mask
):
# Need to tell p2p_communicate functions the correct size.
args
=
get_args
()
...
...
@@ -74,7 +171,7 @@ class InferenceForwardStep:
assert
args
.
seq_length
<=
self
.
inference_params
.
max_sequence_len
args
.
micro_batch_size
=
tokens
.
shape
[
0
]
assert
self
.
inference_params
.
micro_batch_size_list
[
0
]
==
tokens
.
shape
[
0
]
assert
self
.
inference_params
.
micro_batch_
size_
index
==
0
assert
self
.
inference_params
.
micro_batch_index
==
0
# Receive from previous stage.
input_tensor
=
recv_forward
()
...
...
@@ -94,27 +191,3 @@ class InferenceForwardStep:
self
.
inference_params
.
allocate_key_value_memory
=
False
return
output_tensor
def
forward_step
(
model
,
tokens
,
position_ids
,
attention_mask
,
inference_params
):
# Hidden size changes when not using recompute, need to tell p2p_communicate
# functions the correct size
args
=
get_args
()
orig_seq_length
=
args
.
seq_length
args
.
seq_length
=
tokens
.
shape
[
1
]
args
.
micro_batch_size
=
tokens
.
shape
[
0
]
input_tensor
=
recv_forward
()
# Forward pass through the model.
model
.
set_input_tensor
(
input_tensor
)
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
inference_params
=
inference_params
)
send_forward
(
output_tensor
)
args
.
seq_length
=
orig_seq_length
return
output_tensor
megatron/inference/generation.py
View file @
8f160844
...
...
@@ -24,7 +24,7 @@ from .communication import (
copy_from_last_to_first_pipeline_stage
,
broadcast_from_last_pipeline_stage
,
broadcast_from_last_to_first_pipeline_stage
)
from
.forward_step
import
InferenceF
orward
S
tep
from
.forward_step
import
f
orward
_s
tep
_provider
from
.sampling
import
sample
...
...
@@ -66,7 +66,8 @@ def generate_tokens_probs_and_return_on_first_stage(
max_sequence_length
=
min
(
max_sequence_length
,
args
.
max_position_embeddings
)
# forward step.
forward_step
=
InferenceForwardStep
(
model
,
batch_size
,
max_sequence_length
)
forward_step
=
forward_step_provider
(
model
,
batch_size
,
4
,
max_sequence_length
)
# Added termination_id to support the case that we want to terminate the
# generation once that id is generated.
...
...
megatron/model/transformer.py
View file @
8f160844
...
...
@@ -269,18 +269,19 @@ class ParallelAttention(MegatronModule):
# ==================================
if
inference_params
:
inf_batch_index
=
inference_params
.
micro_batch_
size_
index
inf_batch_index
=
inference_params
.
micro_batch_index
assert
key_layer
.
size
(
1
)
==
\
inference_params
.
micro_batch_size_list
[
inf_batch_index
]
# Adjust the range variables.
start
=
self
.
inference_current_sequence_len_list
[
inf_batch_index
]
end
=
start
+
key_layer
.
size
(
0
)
assert
end
<=
inference_params
.
max_sequence_len
self
.
inference_current_sequence_len_list
[
inf_batch_index
]
=
end
# Copy key and values.
self
.
inference_key_memory_list
[
inf_batch_index
][
start
:
end
,
...]
=
\
key_layer
self
.
inference_value_memory_list
[
inf_batch_index
][
start
:
end
,
...]
=
\
value_layer
self
.
inference_key_memory_list
[
inf_batch_index
][
start
:
end
,
...]
\
=
key_layer
self
.
inference_value_memory_list
[
inf_batch_index
][
start
:
end
,
...]
\
=
value_layer
key_layer
=
\
self
.
inference_key_memory_list
[
inf_batch_index
][:
end
,
...]
value_layer
=
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment