Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a7539b0f
"vscode:/vscode.git/clone" did not exist on "2a8022ca5b0b011191fcd394518344217b762c46"
Commit
a7539b0f
authored
Oct 07, 2021
by
mshoeybi
Browse files
pipelining works
parent
8f160844
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
155 additions
and
158 deletions
+155
-158
megatron/inference/forward_step.py
megatron/inference/forward_step.py
+126
-122
megatron/inference/generation.py
megatron/inference/generation.py
+4
-5
megatron/model/transformer.py
megatron/model/transformer.py
+25
-31
No files found.
megatron/inference/forward_step.py
View file @
a7539b0f
...
@@ -15,8 +15,6 @@
...
@@ -15,8 +15,6 @@
"""Forward step utilities."""
"""Forward step utilities."""
from
abc
import
ABC
from
abc
import
abstractmethod
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
import
torch
import
torch
...
@@ -24,44 +22,27 @@ import torch
...
@@ -24,44 +22,27 @@ import torch
from
megatron
import
(
from
megatron
import
(
get_args
,
get_args
,
mpu
)
mpu
)
from
megatron.p2p_communication
import
(
recv_forward
,
send_forward
)
def
forward_step_provider
(
model
,
batch_size
,
micro_batch_size
,
max_sequence_len
):
args
=
get_args
()
if
args
.
pipeline_model_parallel_size
==
1
or
micro_batch_size
>=
batch_size
:
return
NoPipeliningForwardStep
(
model
,
batch_size
,
max_sequence_len
)
return
SimplePipeliningForwardStep
(
model
,
batch_size
,
micro_batch_size
,
max_sequence_len
)
class
InferenceParams
:
class
InferenceParams
:
def
__init__
(
self
,
micro_batch_size_list
,
max_sequence_len
):
assert
isinstance
(
micro_batch_size_list
,
list
)
def
__init__
(
self
,
max_batch_size
,
max_sequence_len
):
assert
max_sequence_len
>
0
self
.
micro_batch_size_list
=
micro_batch_size_list
self
.
max_sequence_len
=
max_sequence_len
self
.
max_sequence_len
=
max_sequence_len
self
.
max_batch_size
=
max_batch_size
self
.
sequence_len_offset
=
0
self
.
batch_size_offset
=
0
self
.
allocate_key_value_memory
=
True
self
.
allocate_key_value_memory
=
True
self
.
micro_batch_index
=
0
class
ForwardStepBase
(
ABC
):
def
__init__
(
self
,
model
):
class
ForwardStep
:
def
__init__
(
self
,
model
,
max_batch_size
,
max_sequence_len
):
# Make sure model is in eval mode.
if
isinstance
(
model
,
Iterable
):
if
isinstance
(
model
,
Iterable
):
for
this_model
in
model
:
for
this_model
in
model
:
this_model
.
eval
()
this_model
.
eval
()
...
@@ -69,125 +50,148 @@ class ForwardStepBase(ABC):
...
@@ -69,125 +50,148 @@ class ForwardStepBase(ABC):
model
.
eval
()
model
.
eval
()
self
.
model
=
model
self
.
model
=
model
@
abstractmethod
self
.
constant
=
512
def
__call__
(
self
,
tokens
,
position_ids
,
attention_mask
):
pass
# Initialize inference parameters.
self
.
inference_params
=
InferenceParams
(
max_batch_size
,
max_sequence_len
)
class
SimplePipeliningForwardStep
(
ForwardStepBase
):
def
__call__
(
self
,
tokens
,
position_ids
,
attention_mask
):
if
tokens
.
size
(
0
)
*
tokens
.
size
(
1
)
>=
self
.
constant
:
micro_batch_size
=
max
(
1
,
self
.
constant
//
tokens
.
size
(
1
))
return
_with_pipelining_forward_step
(
self
.
model
,
tokens
,
position_ids
,
attention_mask
,
self
.
inference_params
,
micro_batch_size
)
else
:
return
_no_pipelining_forward_step
(
self
.
model
,
tokens
,
position_ids
,
attention_mask
,
self
.
inference_params
)
def
__init__
(
self
,
model
,
batch_size
,
micro_batch_size
,
max_sequence_len
):
super
().
__init__
(
model
)
self
.
batch_size
=
batch_size
# Divide the batch dimension into micro batches.
self
.
num_micro_batches
,
last_chunk
=
divmod
(
batch_size
,
micro_batch_size
)
self
.
micro_batch_size_list
=
[]
self
.
batch_dim_start_index
=
[
0
]
for
i
in
range
(
self
.
num_micro_batches
):
self
.
micro_batch_size_list
.
append
(
micro_batch_size
)
self
.
batch_dim_start_index
.
append
((
i
+
1
)
*
micro_batch_size
)
if
last_chunk
>
0
:
self
.
num_micro_batches
+=
1
self
.
micro_batch_size_list
.
append
(
last_chunk
)
self
.
batch_dim_start_index
.
append
(
batch_size
)
self
.
inference_params
=
InferenceParams
(
self
.
micro_batch_size_list
,
def
_get_recv_buffer_dtype
(
args
):
max_sequence_len
)
"""Receive happens between the layers."""
if
args
.
fp32_residual_connection
:
return
torch
.
float
return
args
.
params_dtype
def
__call__
(
self
,
tokens
,
position_ids
,
attention_mask
):
# Need to tell p2p_communicate functions the correct size.
def
_allocate_recv_buffer
(
batch_size
,
sequence_length
):
"""Receive happens between the layers with size [s, b, h]."""
if
mpu
.
is_pipeline_first_stage
():
return
None
args
=
get_args
()
args
=
get_args
()
orig_seq_length
=
args
.
seq_length
recv_size
=
(
sequence_length
,
batch_size
,
args
.
hidden_size
)
args
.
seq_length
=
tokens
.
size
(
1
)
return
torch
.
empty
(
recv_size
,
assert
args
.
seq_length
<=
self
.
inference_params
.
max_sequence_len
dtype
=
_get_recv_buffer_dtype
(
args
),
# Preallocate memory for output logits.
logits
=
None
if
mpu
.
is_pipeline_last_stage
():
logits
=
torch
.
empty
(
tokens
.
size
(
0
),
tokens
.
size
(
1
),
args
.
padded_vocab_size
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
device
=
torch
.
cuda
.
current_device
())
# Pileline using micro batches.
for
micro_batch_index
in
range
(
self
.
num_micro_batches
):
# Set micro-batch size and index.
def
_forward_step_helper
(
model
,
tokens
,
position_ids
,
attention_mask
,
self
.
inference_params
.
micro_batch_index
=
micro_batch_index
inference_params
,
recv_buffer
=
None
):
args
.
micro_batch_size
=
self
.
micro_batch_size_list
[
"""Single forward step. Update the allocate memory flag so
micro_batch_index
]
only the first time the memory is allocated."""
# Slice among the batch dimenion.
batch_size
=
tokens
.
size
(
0
)
start
=
self
.
batch_dim_start_index
[
micro_batch_index
]
sequence_length
=
tokens
.
size
(
1
)
end
=
self
.
batch_dim_start_index
[
micro_batch_index
+
1
]
if
recv_buffer
is
None
:
tokens2use
=
tokens
[
start
:
end
,
...]
recv_buffer
=
_allocate_recv_buffer
(
batch_size
,
sequence_length
)
position_ids2use
=
position_ids
[
start
:
end
,
...]
# Receive from previous stage.
# Receive from previous stage.
input_tensor
=
recv_forward
()
if
not
mpu
.
is_pipeline_first_stage
():
torch
.
distributed
.
recv
(
recv_buffer
,
src
=
mpu
.
get_pipeline_model_parallel_prev_rank
())
# Forward pass through the model.
# Forward pass through the model.
self
.
model
.
set_input_tensor
(
input_tensor
)
model
.
set_input_tensor
(
recv_buffer
)
output_tensor
=
self
.
model
(
tokens2use
,
position_ids2use
,
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
attention_mask
,
inference_params
=
inference_params
)
inference_params
=
self
.
inference_params
)
# Send output to the next stage.
# Send output to the next stage.
send_forward
(
output_tensor
)
if
not
mpu
.
is_pipeline_last_stage
():
torch
.
distributed
.
send
(
output_tensor
,
mpu
.
get_pipeline_model_parallel_next_rank
())
# Reset the sequence lenght to whatwever it was before.
# Make sure we do not allocate context memory anymore.
# Make sure we do not allocate context memory anymore.
if
self
.
inference_params
.
allocate_key_value_memory
:
if
inference_params
.
allocate_key_value_memory
:
self
.
inference_params
.
allocate_key_value_memory
=
False
inference_params
.
allocate_key_value_memory
=
False
if
mpu
.
is_pipeline_last_stage
():
logits
[
start
:
end
,
...]
=
output_tensor
# Adjust the sequence length back to whatever it was before.
return
output_tensor
args
.
seq_length
=
orig_seq_length
return
logits
def
_no_pipelining_forward_step
(
model
,
tokens
,
position_ids
,
attention_mask
,
inference_params
,
recv_buffer
=
None
):
# Run a simple forward pass.
output_tensor
=
_forward_step_helper
(
model
,
tokens
,
position_ids
,
attention_mask
,
inference_params
,
recv_buffer
=
recv_buffer
)
# Update the sequence length offset.
inference_params
.
sequence_len_offset
+=
tokens
.
size
(
1
)
logits
=
None
if
mpu
.
is_pipeline_last_stage
():
logits
=
output_tensor
class
NoPipeliningForwardStep
(
ForwardStepBase
):
return
logits
def
__init__
(
self
,
model
,
batch_size
,
max_sequence_len
):
super
().
__init__
(
model
)
self
.
inference_params
=
InferenceParams
([
batch_size
],
max_sequence_len
)
def
_with_pipelining_forward_step
(
model
,
tokens
,
position_ids
,
attention_mask
,
inference_params
,
micro_batch_size
):
sequence_length
=
tokens
.
size
(
1
)
batch_size
=
tokens
.
size
(
0
)
def
__call__
(
self
,
tokens
,
position_ids
,
attention_mask
):
# Divide the batch dimension into micro batches.
num_micro_batches
,
last_chunk
=
divmod
(
batch_size
,
micro_batch_size
)
if
last_chunk
>
0
:
num_micro_batches
+=
1
# Need to tell p2p_communicate functions the correct size.
# Preallocate memory for output logits.
logits
=
None
if
mpu
.
is_pipeline_last_stage
():
args
=
get_args
()
args
=
get_args
()
orig_seq_length
=
args
.
seq_length
logits
=
torch
.
empty
(
args
.
seq_length
=
tokens
.
shape
[
1
]
(
batch_size
,
sequence_length
,
args
.
padded_vocab_size
),
assert
args
.
seq_length
<=
self
.
inference_params
.
max_sequence_len
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
args
.
micro_batch_size
=
tokens
.
shape
[
0
]
assert
self
.
inference_params
.
micro_batch_size_list
[
0
]
==
tokens
.
shape
[
0
]
assert
self
.
inference_params
.
micro_batch_index
==
0
#
Receive from previous stage
.
#
Preallocate recv buffer
.
input_tensor
=
recv_forward
(
)
recv_buffer
=
_allocate_recv_buffer
(
micro_batch_size
,
sequence_length
)
# Forward pass through the model.
for
micro_batch_index
in
range
(
num_micro_batches
):
self
.
model
.
set_input_tensor
(
input_tensor
)
# Slice among the batch dimenion.
output_tensor
=
self
.
model
(
tokens
,
position_ids
,
attention_mask
,
start
=
micro_batch_index
*
micro_batch_size
inference_params
=
self
.
inference_params
)
end
=
min
(
start
+
micro_batch_size
,
batch_size
)
this_micro_batch_size
=
end
-
start
tokens2use
=
tokens
[
start
:
end
,
...]
position_ids2use
=
position_ids
[
start
:
end
,
...]
# Send output to the next stage.
# Run a simple forward pass.
send_forward
(
output_tensor
)
if
this_micro_batch_size
!=
micro_batch_size
:
recv_buffer
=
None
output
=
_forward_step_helper
(
model
,
tokens2use
,
position_ids2use
,
attention_mask
,
inference_params
,
recv_buffer
=
recv_buffer
)
# Reset the sequence lenght to whatwever it was before.
# Adjust the batch size offset to account for the micro-batch.
args
.
seq_length
=
orig_seq_length
inference_params
.
batch_size_offset
+=
this_micro_batch_size
# Make sure we do not allocate context memory anymore.
if
self
.
inference_params
.
allocate_key_value_memory
:
self
.
inference_params
.
allocate_key_value_memory
=
False
return
output_tensor
# Copy logits.
if
mpu
.
is_pipeline_last_stage
():
logits
[
start
:
end
,
...]
=
output
# Once we are done with all the micro-batches, we can
# adjust the sequence length offset.
inference_params
.
sequence_len_offset
+=
sequence_length
# and reset the batch size offset
inference_params
.
batch_size_offset
=
0
return
logits
megatron/inference/generation.py
View file @
a7539b0f
...
@@ -24,7 +24,7 @@ from .communication import (
...
@@ -24,7 +24,7 @@ from .communication import (
copy_from_last_to_first_pipeline_stage
,
copy_from_last_to_first_pipeline_stage
,
broadcast_from_last_pipeline_stage
,
broadcast_from_last_pipeline_stage
,
broadcast_from_last_to_first_pipeline_stage
)
broadcast_from_last_to_first_pipeline_stage
)
from
.forward_step
import
f
orward
_s
tep
_provider
from
.forward_step
import
F
orward
S
tep
from
.sampling
import
sample
from
.sampling
import
sample
...
@@ -66,8 +66,7 @@ def generate_tokens_probs_and_return_on_first_stage(
...
@@ -66,8 +66,7 @@ def generate_tokens_probs_and_return_on_first_stage(
max_sequence_length
=
min
(
max_sequence_length
,
args
.
max_position_embeddings
)
max_sequence_length
=
min
(
max_sequence_length
,
args
.
max_position_embeddings
)
# forward step.
# forward step.
forward_step
=
forward_step_provider
(
model
,
batch_size
,
4
,
forward_step
=
ForwardStep
(
model
,
batch_size
,
max_sequence_length
)
max_sequence_length
)
# Added termination_id to support the case that we want to terminate the
# Added termination_id to support the case that we want to terminate the
# generation once that id is generated.
# generation once that id is generated.
...
@@ -190,8 +189,8 @@ def generate_tokens_probs_and_return_on_first_stage(
...
@@ -190,8 +189,8 @@ def generate_tokens_probs_and_return_on_first_stage(
done
=
torch
.
all
(
is_generation_done
)
done
=
torch
.
all
(
is_generation_done
)
done
=
broadcast_from_last_pipeline_stage
(
1
,
torch
.
uint8
,
done
=
broadcast_from_last_pipeline_stage
(
1
,
torch
.
uint8
,
tensor
=
done
)
tensor
=
done
)
if
done
:
#
if done:
break
#
break
# ===================================================
# ===================================================
# Update the length of based on max generated length.
# Update the length of based on max generated length.
...
...
megatron/model/transformer.py
View file @
a7539b0f
...
@@ -180,9 +180,8 @@ class ParallelAttention(MegatronModule):
...
@@ -180,9 +180,8 @@ class ParallelAttention(MegatronModule):
skip_bias_add
=
True
)
skip_bias_add
=
True
)
# Inference key-value memory
# Inference key-value memory
self
.
inference_key_memory_list
=
None
self
.
inference_key_memory
=
None
self
.
inference_value_memory_list
=
None
self
.
inference_value_memory
=
None
self
.
inference_current_sequence_len_list
=
None
def
_allocate_memory
(
self
,
inference_max_sequence_len
,
batch_size
):
def
_allocate_memory
(
self
,
inference_max_sequence_len
,
batch_size
):
...
@@ -206,22 +205,17 @@ class ParallelAttention(MegatronModule):
...
@@ -206,22 +205,17 @@ class ParallelAttention(MegatronModule):
if
inference_params
:
if
inference_params
:
if
inference_params
.
allocate_key_value_memory
:
if
inference_params
.
allocate_key_value_memory
:
inf_max_seq_len
=
inference_params
.
max_sequence_len
inf_max_seq_len
=
inference_params
.
max_sequence_len
inf_batch_sizes
=
inference_params
.
micro_batch_size_list
inf_max_batch_size
=
inference_params
.
max_batch_size
self
.
inference_key_memory_list
=
[
self
.
inference_key_memory
=
self
.
_allocate_memory
(
self
.
_allocate_memory
(
inf_max_seq_len
,
inf_batch_size
)
inf_max_seq_len
,
inf_max_batch_size
)
for
inf_batch_size
in
inf_batch_sizes
]
self
.
inference_value_memory
=
self
.
_allocate_memory
(
self
.
inference_value_memory_list
=
[
inf_max_seq_len
,
inf_max_batch_size
)
self
.
_allocate_memory
(
inf_max_seq_len
,
inf_batch_size
)
for
inf_batch_size
in
inf_batch_sizes
]
self
.
inference_current_sequence_len_list
=
[
0
for
_
in
inf_batch_sizes
]
# This is added for safety. In case inference_params
# This is added for safety. In case inference_params
# is not provided, make sure there is no potential memory left
# is not provided, make sure there is no potential memory left
# from previous inference.
# from previous inference.
else
:
else
:
self
.
inference_key_memory_list
=
None
self
.
inference_value_memory
=
None
self
.
inference_value_memory_list
=
None
self
.
inference_current_sequence_len
=
None
self
.
inference_current_sequence_len_list
=
None
# =====================
# =====================
# Query, Key, and Value
# Query, Key, and Value
...
@@ -269,23 +263,23 @@ class ParallelAttention(MegatronModule):
...
@@ -269,23 +263,23 @@ class ParallelAttention(MegatronModule):
# ==================================
# ==================================
if
inference_params
:
if
inference_params
:
inf_batch_index
=
inference_params
.
micro_batch_index
batch_start
=
inference_params
.
batch_size_offset
assert
key_layer
.
size
(
1
)
==
\
batch_end
=
batch_start
+
key_layer
.
size
(
1
)
inference_params
.
micro_batch_size_list
[
inf_batch_index
]
assert
batch_end
<=
self
.
inference_key_memory
.
size
(
1
)
# Adjust the range variables.
sequence_start
=
inference_params
.
sequence_len_offset
start
=
self
.
inference_current_sequence_len_list
[
inf_batch_index
]
sequence_end
=
sequence_start
+
key_layer
.
size
(
0
)
end
=
start
+
key_layer
.
size
(
0
)
assert
sequence_end
<=
self
.
inference_key_memory
.
size
(
0
)
assert
end
<=
inference_params
.
max_sequence_len
self
.
inference_current_sequence_len_list
[
inf_batch_index
]
=
end
# Copy key and values.
# Copy key and values.
self
.
inference_key_memory_list
[
inf_batch_index
][
start
:
end
,
...]
\
self
.
inference_key_memory
[
sequence_start
:
sequence_end
,
=
key_layer
batch_start
:
batch_end
,
self
.
inference_value_memory_list
[
inf_batch_index
][
start
:
end
,
...]
\
...]
=
key_layer
=
value_layer
self
.
inference_value_memory
[
sequence_start
:
sequence_end
,
key_layer
=
\
batch_start
:
batch_end
,
self
.
inference_key_memory_list
[
inf_batch_index
][:
end
,
...]
...]
=
value_layer
value_layer
=
\
key_layer
=
self
.
inference_key_memory
[
self
.
inference_value_memory_list
[
inf_batch_index
][:
end
,
...]
:
sequence_end
,
batch_start
:
batch_end
,
...]
value_layer
=
self
.
inference_value_memory
[
:
sequence_end
,
batch_start
:
batch_end
,
...]
# ===================================
# ===================================
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment