Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
ed6d28b1
"sgl-router/src/git@developer.sourcefind.cn:change/sglang.git" did not exist on "963175d5c06a053399b2cf361295372ff1158e46"
Commit
ed6d28b1
authored
Oct 10, 2021
by
mshoeybi
Browse files
merged main
parents
8c119d80
1ec6b0e9
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
449 additions
and
321 deletions
+449
-321
megatron/p2p_communication.py
megatron/p2p_communication.py
+45
-38
megatron/package_info.py
megatron/package_info.py
+0
-30
megatron/schedules.py
megatron/schedules.py
+181
-25
megatron/text_generation_server.py
megatron/text_generation_server.py
+49
-23
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+72
-81
megatron/training.py
megatron/training.py
+40
-10
pretrain_bert.py
pretrain_bert.py
+4
-3
pretrain_gpt.py
pretrain_gpt.py
+4
-3
pretrain_ict.py
pretrain_ict.py
+2
-0
pretrain_t5.py
pretrain_t5.py
+47
-9
pretrain_vit.py
pretrain_vit.py
+2
-0
requirements.txt
requirements.txt
+0
-5
setup.py
setup.py
+0
-91
tools/text_generation_cli.py
tools/text_generation_cli.py
+3
-3
No files found.
megatron/p2p_communication.py
View file @
ed6d28b1
...
@@ -22,8 +22,8 @@ from megatron import mpu
...
@@ -22,8 +22,8 @@ from megatron import mpu
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
use_ring_exchange
=
False
,
tensor_shape
=
None
,
tensor_shape
,
override_scatter_gather_tensors_in_pipelin
e
=
False
,
use_ring_exchang
e
=
False
,
dtype_
=
None
):
dtype_
=
None
):
"""Communicate tensors between stages. Used as helper method in other
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
communication methods that are used in megatron/schedules.py.
...
@@ -37,16 +37,13 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -37,16 +37,13 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
previous rank.
previous rank.
recv_next: boolean for whether tensor should be received from
recv_next: boolean for whether tensor should be received from
next rank.
next rank.
tensor_shape: shape of tensor to receive (this method assumes that all
tensors sent and received in a single function call are
the same shape).
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used.
API should be used.
tensor_shape: optional, use when the input sequence contains less
dtype_: optional, this is used when the tensor that needs to be
tokens than the default sequence length
communicated is different from args.params_dtype.
override_scatter_gather_tensors_in_pipeline: optional, this is used
when tensor_shape is
provided to overwide
scatter gather tensors
dtype_: optional, this is used when tensor_shape is provied and what
is the type of tensor_shape
Returns:
Returns:
(tensor_recv_prev, tensor_recv_next)
(tensor_recv_prev, tensor_recv_next)
"""
"""
...
@@ -56,12 +53,22 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -56,12 +53,22 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# if needed.
# if needed.
tensor_recv_prev
=
None
tensor_recv_prev
=
None
tensor_recv_next
=
None
tensor_recv_next
=
None
# Some legacy inference code doesn't set the tensor shape, do so now
# for the normal values for gpt/bert. This could be removed if inference
# code is changed to provide tensor_shape.
if
tensor_shape
is
None
:
if
tensor_shape
is
None
:
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
if
not
override_scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
:
override_scatter_gather_tensors_in_pipeline
=
False
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
//
\
if
args
.
scatter_gather_tensors_in_pipeline
:
mpu
.
get_tensor_model_parallel_world_size
()
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
if
tensor_chunk_shape
%
mpu
.
get_tensor_model_parallel_world_size
()
==
0
:
tensor_chunk_shape
=
tensor_chunk_shape
//
\
mpu
.
get_tensor_model_parallel_world_size
()
else
:
tensor_chunk_shape
=
tensor_shape
override_scatter_gather_tensors_in_pipeline
=
True
else
:
else
:
tensor_chunk_shape
=
tensor_shape
tensor_chunk_shape
=
tensor_shape
dtype
=
args
.
params_dtype
dtype
=
args
.
params_dtype
...
@@ -143,9 +150,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -143,9 +150,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
return
tensor_recv_prev
,
tensor_recv_next
return
tensor_recv_prev
,
tensor_recv_next
def
recv_forward
(
tensor_shape
=
None
,
def
recv_forward
(
tensor_shape
=
None
,
dtype_
=
None
,
timers
=
None
):
override_scatter_gather_tensors_in_pipeline
=
False
,
dtype_
=
None
,
timers
=
None
):
"""Receive tensor from previous rank in pipeline (forward receive)."""
"""Receive tensor from previous rank in pipeline (forward receive)."""
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
...
@@ -159,15 +164,13 @@ def recv_forward(tensor_shape=None,
...
@@ -159,15 +164,13 @@ def recv_forward(tensor_shape=None,
recv_prev
=
True
,
recv_prev
=
True
,
recv_next
=
False
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
override_scatter_gather_tensors_in_pipeline
=
\
override_scatter_gather_tensors_in_pipeline
,
dtype_
=
dtype_
)
dtype_
=
dtype_
)
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-recv'
).
stop
()
timers
(
'forward-recv'
).
stop
()
return
input_tensor
return
input_tensor
def
recv_backward
(
timers
=
None
):
def
recv_backward
(
tensor_shape
=
None
,
timers
=
None
):
"""Receive tensor from next rank in pipeline (backward receive)."""
"""Receive tensor from next rank in pipeline (backward receive)."""
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
output_tensor_grad
=
None
...
@@ -178,15 +181,14 @@ def recv_backward(timers=None):
...
@@ -178,15 +181,14 @@ def recv_backward(timers=None):
tensor_send_next
=
None
,
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_prev
=
False
,
recv_next
=
True
)
recv_next
=
True
,
tensor_shape
=
tensor_shape
)
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'backward-recv'
).
stop
()
timers
(
'backward-recv'
).
stop
()
return
output_tensor_grad
return
output_tensor_grad
def
send_forward
(
output_tensor
,
timers
=
None
,
def
send_forward
(
output_tensor
,
tensor_shape
=
None
,
dtype_
=
None
,
timers
=
None
):
override_scatter_gather_tensors_in_pipeline
=
False
,
dtype_
=
None
):
"""Send tensor to next rank in pipeline (forward send)."""
"""Send tensor to next rank in pipeline (forward send)."""
if
not
mpu
.
is_pipeline_last_stage
():
if
not
mpu
.
is_pipeline_last_stage
():
...
@@ -197,14 +199,13 @@ def send_forward(output_tensor, timers=None,
...
@@ -197,14 +199,13 @@ def send_forward(output_tensor, timers=None,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_prev
=
False
,
recv_next
=
False
,
recv_next
=
False
,
override_scatter_gather_tensors_in_pipeline
=
\
tensor_shape
=
tensor_shape
,
override_scatter_gather_tensors_in_pipeline
,
dtype_
=
dtype_
)
dtype_
=
dtype_
)
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-send'
).
stop
()
timers
(
'forward-send'
).
stop
()
def
send_backward
(
input_tensor_grad
,
timers
=
None
):
def
send_backward
(
input_tensor_grad
,
tensor_shape
=
None
,
timers
=
None
):
"""Send tensor to previous rank in pipeline (backward send)."""
"""Send tensor to previous rank in pipeline (backward send)."""
if
not
mpu
.
is_pipeline_first_stage
():
if
not
mpu
.
is_pipeline_first_stage
():
if
timers
is
not
None
:
if
timers
is
not
None
:
...
@@ -213,12 +214,13 @@ def send_backward(input_tensor_grad, timers=None):
...
@@ -213,12 +214,13 @@ def send_backward(input_tensor_grad, timers=None):
tensor_send_next
=
None
,
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_prev
=
False
,
recv_next
=
False
)
recv_next
=
False
,
tensor_shape
=
tensor_shape
)
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'backward-send'
).
stop
()
timers
(
'backward-send'
).
stop
()
def
send_forward_recv_backward
(
output_tensor
,
timers
=
None
):
def
send_forward_recv_backward
(
output_tensor
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched send and recv with next rank in pipeline."""
"""Batched send and recv with next rank in pipeline."""
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
output_tensor_grad
=
None
...
@@ -229,13 +231,14 @@ def send_forward_recv_backward(output_tensor, timers=None):
...
@@ -229,13 +231,14 @@ def send_forward_recv_backward(output_tensor, timers=None):
tensor_send_next
=
output_tensor
,
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_prev
=
False
,
recv_next
=
True
)
recv_next
=
True
,
tensor_shape
=
tensor_shape
)
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-send-backward-recv'
).
stop
()
timers
(
'forward-send-backward-recv'
).
stop
()
return
output_tensor_grad
return
output_tensor_grad
def
send_backward_recv_forward
(
input_tensor_grad
,
timers
=
None
):
def
send_backward_recv_forward
(
input_tensor_grad
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched send and recv with previous rank in pipeline."""
"""Batched send and recv with previous rank in pipeline."""
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
input_tensor
=
None
input_tensor
=
None
...
@@ -246,13 +249,14 @@ def send_backward_recv_forward(input_tensor_grad, timers=None):
...
@@ -246,13 +249,14 @@ def send_backward_recv_forward(input_tensor_grad, timers=None):
tensor_send_next
=
None
,
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
True
,
recv_prev
=
True
,
recv_next
=
False
)
recv_next
=
False
,
tensor_shape
=
tensor_shape
)
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'backward-send-forward-recv'
).
stop
()
timers
(
'backward-send-forward-recv'
).
stop
()
return
input_tensor
return
input_tensor
def
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
timers
=
None
):
def
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched recv from previous rank and send to next rank in pipeline."""
"""Batched recv from previous rank and send to next rank in pipeline."""
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-send-forward-recv'
).
start
()
timers
(
'forward-send-forward-recv'
).
start
()
...
@@ -260,13 +264,14 @@ def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
...
@@ -260,13 +264,14 @@ def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
tensor_send_next
=
output_tensor
,
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
recv_prev
,
recv_prev
=
recv_prev
,
recv_next
=
False
)
recv_next
=
False
,
tensor_shape
=
tensor_shape
)
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-send-forward-recv'
).
stop
()
timers
(
'forward-send-forward-recv'
).
stop
()
return
input_tensor
return
input_tensor
def
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
,
timers
=
None
):
def
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched recv from next rank and send to previous rank in pipeline."""
"""Batched recv from next rank and send to previous rank in pipeline."""
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'backward-send-backward-recv'
).
start
()
timers
(
'backward-send-backward-recv'
).
start
()
...
@@ -274,7 +279,8 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
...
@@ -274,7 +279,8 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
tensor_send_next
=
None
,
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_prev
=
False
,
recv_next
=
recv_next
)
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
)
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'backward-send-backward-recv'
).
stop
()
timers
(
'backward-send-backward-recv'
).
stop
()
return
output_tensor_grad
return
output_tensor_grad
...
@@ -282,7 +288,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
...
@@ -282,7 +288,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
def
send_forward_backward_recv_forward_backward
(
def
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
,
output_tensor
,
input_tensor_grad
,
recv_prev
,
recv_next
,
timers
=
None
):
recv_next
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched send and recv with previous and next ranks in pipeline."""
"""Batched send and recv with previous and next ranks in pipeline."""
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-backward-send-forward-backward-recv'
).
start
()
timers
(
'forward-backward-send-forward-backward-recv'
).
start
()
...
@@ -290,7 +296,8 @@ def send_forward_backward_recv_forward_backward(
...
@@ -290,7 +296,8 @@ def send_forward_backward_recv_forward_backward(
tensor_send_next
=
output_tensor
,
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
input_tensor_grad
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
)
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
)
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-backward-send-forward-backward-recv'
).
stop
()
timers
(
'forward-backward-send-forward-backward-recv'
).
stop
()
return
input_tensor
,
output_tensor_grad
return
input_tensor
,
output_tensor_grad
megatron/package_info.py
deleted
100644 → 0
View file @
8c119d80
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
MAJOR
=
1
MINOR
=
1.5
# Use the following formatting: (major, minor)
VERSION
=
(
MAJOR
,
MINOR
)
__version__
=
'.'
.
join
(
map
(
str
,
VERSION
))
__package_name__
=
'megatron-lm'
__contact_names__
=
'NVIDIA INC'
__url__
=
'https://github.com/NVIDIA/Megatron-LM'
__download_url__
=
'https://github.com/NVIDIA/Megatron-LM/releases'
__description__
=
'Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism.'
__license__
=
'See https://github.com/NVIDIA/Megatron-LM/blob/master/LICENSE'
__keywords__
=
'deep learning, Megatron, gpu, NLP, nvidia, pytorch, torch, language'
megatron/schedules.py
View file @
ed6d28b1
...
@@ -25,6 +25,8 @@ from megatron import p2p_communication
...
@@ -25,6 +25,8 @@ from megatron import p2p_communication
from
megatron.utils
import
unwrap_model
from
megatron.utils
import
unwrap_model
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
from
megatron.model
import
Float16Module
from
megatron.model
import
ModelType
def
get_forward_backward_func
():
def
get_forward_backward_func
():
args
=
get_args
()
args
=
get_args
()
...
@@ -48,11 +50,18 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
...
@@ -48,11 +50,18 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
passed-in input_tensor is used.
passed-in input_tensor is used.
Returns output tensor."""
Returns output tensor."""
args
=
get_args
()
timers
=
get_timers
()
timers
=
get_timers
()
timers
(
'forward-compute'
).
start
()
timers
(
'forward-compute'
).
start
()
unwrapped_model
=
unwrap_model
(
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
unwrap_output_tensor
=
False
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
unwrap_output_tensor
=
True
unwrapped_model
.
set_input_tensor
(
input_tensor
)
unwrapped_model
.
set_input_tensor
(
input_tensor
)
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
)
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
)
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
...
@@ -62,7 +71,15 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
...
@@ -62,7 +71,15 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
losses_reduced
.
append
(
loss_reduced
)
losses_reduced
.
append
(
loss_reduced
)
timers
(
'forward-compute'
).
stop
()
timers
(
'forward-compute'
).
stop
()
return
output_tensor
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
if
mpu
.
is_pipeline_stage_after_split
()
and
\
args
.
model_type
==
ModelType
.
encoder_and_decoder
:
return
[
output_tensor
,
input_tensor
[
-
1
]]
if
unwrap_output_tensor
:
return
output_tensor
return
[
output_tensor
]
def
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
):
def
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
):
...
@@ -73,24 +90,53 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
...
@@ -73,24 +90,53 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
Returns gradient of loss with respect to input tensor (None if first
Returns gradient of loss with respect to input tensor (None if first
stage)."""
stage)."""
# NOTE: This code currently can handle at most one skip connection. It
# needs to be modified slightly to support arbitrary numbers of skip
# connections.
args
=
get_args
()
args
=
get_args
()
timers
=
get_timers
()
timers
=
get_timers
()
timers
(
'backward-compute'
).
start
()
timers
(
'backward-compute'
).
start
()
# Retain the grad on the input_tensor.
# Retain the grad on the input_tensor.
if
input_tensor
is
not
None
:
unwrap_input_tensor_grad
=
False
input_tensor
.
retain_grad
()
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
unwrap_input_tensor_grad
=
True
for
x
in
input_tensor
:
if
x
is
not
None
:
x
.
retain_grad
()
if
not
isinstance
(
output_tensor
,
list
):
output_tensor
=
[
output_tensor
]
if
not
isinstance
(
output_tensor_grad
,
list
):
output_tensor_grad
=
[
output_tensor_grad
]
# Backward pass.
# Backward pass.
if
output_tensor_grad
is
None
:
if
output_tensor_grad
[
0
]
is
None
:
output_tensor
=
optimizer
.
scale_loss
(
output_tensor
)
output_tensor
=
optimizer
.
scale_loss
(
output_tensor
[
0
]
)
torch
.
autograd
.
backward
(
output_tensor
,
grad_tensors
=
output_tensor_grad
)
torch
.
autograd
.
backward
(
output_tensor
[
0
]
,
grad_tensors
=
output_tensor_grad
[
0
]
)
# Collect the grad of the input_tensor.
# Collect the grad of the input_tensor.
input_tensor_grad
=
None
input_tensor_grad
=
[
None
]
if
input_tensor
is
not
None
:
if
input_tensor
is
not
None
:
input_tensor_grad
=
input_tensor
.
grad
input_tensor_grad
=
[]
for
x
in
input_tensor
:
if
x
is
None
:
input_tensor_grad
.
append
(
None
)
else
:
input_tensor_grad
.
append
(
x
.
grad
)
# Handle single skip connection if it exists (encoder_hidden_state in
# model with encoder and decoder).
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
and
\
mpu
.
is_pipeline_stage_after_split
()
and
\
args
.
model_type
==
ModelType
.
encoder_and_decoder
:
if
output_tensor_grad
[
1
]
is
not
None
:
input_tensor_grad
[
-
1
].
add_
(
output_tensor_grad
[
1
])
if
unwrap_input_tensor_grad
:
input_tensor_grad
=
input_tensor_grad
[
0
]
timers
(
'backward-compute'
).
stop
()
timers
(
'backward-compute'
).
stop
()
...
@@ -153,6 +199,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
...
@@ -153,6 +199,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
pipeline_parallel_size
=
mpu
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel_size
=
mpu
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
pipeline_parallel_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
args
=
get_args
()
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
# Compute number of warmup and remaining microbatches.
# Compute number of warmup and remaining microbatches.
num_model_chunks
=
len
(
model
)
num_model_chunks
=
len
(
model
)
num_microbatches
=
get_num_microbatches
()
*
num_model_chunks
num_microbatches
=
get_num_microbatches
()
*
num_model_chunks
...
@@ -237,7 +286,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
...
@@ -237,7 +286,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
# Run warmup forward passes.
# Run warmup forward passes.
mpu
.
set_virtual_pipeline_model_parallel_rank
(
0
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
0
)
input_tensors
[
0
].
append
(
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
timers
=
timers
))
p2p_communication
.
recv_forward
(
tensor_shape
,
timers
=
timers
))
for
k
in
range
(
num_warmup_microbatches
):
for
k
in
range
(
num_warmup_microbatches
):
output_tensor
=
forward_step_helper
(
k
)
output_tensor
=
forward_step_helper
(
k
)
...
@@ -266,12 +315,15 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
...
@@ -266,12 +315,15 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
p2p_communication
.
send_forward_backward_recv_forward_backward
(
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
timers
=
timers
)
timers
=
timers
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
else
:
else
:
input_tensor
=
\
input_tensor
=
\
p2p_communication
.
send_forward_recv_forward
(
p2p_communication
.
send_forward_recv_forward
(
output_tensor
,
recv_prev
=
recv_prev
,
timers
=
timers
)
output_tensor
,
recv_prev
=
recv_prev
,
tensor_shape
=
tensor_shape
,
timers
=
timers
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
# Run 1F1B in steady state.
# Run 1F1B in steady state.
...
@@ -335,7 +387,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
...
@@ -335,7 +387,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
p2p_communication
.
send_forward_backward_recv_forward_backward
(
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
timers
=
timers
)
tensor_shape
=
tensor_shape
,
timers
=
timers
)
# Put input_tensor and output_tensor_grad in data structures in the
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
# right location.
...
@@ -349,7 +401,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
...
@@ -349,7 +401,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
if
not
forward_only
:
if
not
forward_only
:
if
all_warmup_microbatches
:
if
all_warmup_microbatches
:
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
p2p_communication
.
recv_backward
(
timers
=
timers
))
p2p_communication
.
recv_backward
(
tensor_shape
,
timers
=
timers
))
for
k
in
range
(
num_microbatches_remaining
,
num_microbatches
):
for
k
in
range
(
num_microbatches_remaining
,
num_microbatches
):
input_tensor_grad
=
backward_step_helper
(
k
)
input_tensor_grad
=
backward_step_helper
(
k
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
...
@@ -361,11 +413,107 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
...
@@ -361,11 +413,107 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
recv_next
=
False
recv_next
=
False
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
p2p_communication
.
send_backward_recv_backward
(
p2p_communication
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
=
recv_next
,
timers
=
timers
))
input_tensor_grad
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
timers
=
timers
))
return
losses_reduced
return
losses_reduced
def
get_tensor_shapes
(
rank
,
model_type
):
# Determine right tensor sizes (based on position of rank with respect to split
# rank) and model size.
# Send two tensors if model is T5 and rank is in decoder stage:
# first tensor is decoder (pre-transpose),
# second tensor is encoder (post-transpose).
# If model is T5 and rank is at the boundary:
# send one tensor (post-transpose from encoder).
# Otherwise, send one tensor (pre-transpose).
args
=
get_args
()
tensor_shapes
=
[]
if
model_type
==
ModelType
.
encoder_and_decoder
:
if
mpu
.
is_pipeline_stage_before_split
(
rank
):
# If next rank is after split, then need transpose for encoder_hidden_state.
if
mpu
.
is_pipeline_stage_before_split
(
rank
+
1
):
tensor_shapes
.
append
((
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
else
:
tensor_shapes
.
append
((
args
.
micro_batch_size
,
args
.
seq_length
,
args
.
hidden_size
))
else
:
tensor_shapes
.
append
((
args
.
decoder_seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
tensor_shapes
.
append
((
args
.
micro_batch_size
,
args
.
seq_length
,
args
.
hidden_size
))
else
:
tensor_shapes
.
append
((
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
return
tensor_shapes
def
recv_forward
(
tensor_shapes
,
timers
):
input_tensors
=
[]
for
tensor_shape
in
tensor_shapes
:
if
tensor_shape
is
None
:
input_tensors
.
append
(
None
)
else
:
input_tensors
.
append
(
p2p_communication
.
recv_forward
(
tensor_shape
,
timers
=
timers
))
return
input_tensors
def
recv_backward
(
tensor_shapes
,
timers
):
output_tensor_grads
=
[]
for
tensor_shape
in
tensor_shapes
:
if
tensor_shape
is
None
:
output_tensor_grads
.
append
(
None
)
else
:
output_tensor_grads
.
append
(
p2p_communication
.
recv_backward
(
tensor_shape
,
timers
=
timers
))
return
output_tensor_grads
def
send_forward
(
output_tensors
,
tensor_shapes
,
timers
):
if
not
isinstance
(
output_tensors
,
list
):
output_tensors
=
[
output_tensors
]
for
(
output_tensor
,
tensor_shape
)
in
zip
(
output_tensors
,
tensor_shapes
):
if
tensor_shape
is
None
:
continue
p2p_communication
.
send_forward
(
output_tensor
,
tensor_shape
,
timers
=
timers
)
def
send_backward
(
input_tensor_grads
,
tensor_shapes
,
timers
):
if
not
isinstance
(
input_tensor_grads
,
list
):
input_tensor_grads
=
[
input_tensor_grads
]
for
(
input_tensor_grad
,
tensor_shape
)
in
zip
(
input_tensor_grads
,
tensor_shapes
):
if
tensor_shape
is
None
:
continue
p2p_communication
.
send_backward
(
input_tensor_grad
,
tensor_shape
,
timers
=
timers
)
def
send_forward_recv_backward
(
output_tensors
,
tensor_shapes
,
timers
):
if
not
isinstance
(
output_tensors
,
list
):
output_tensors
=
[
output_tensors
]
output_tensor_grads
=
[]
for
(
output_tensor
,
tensor_shape
)
in
zip
(
output_tensors
,
tensor_shapes
):
if
tensor_shape
is
None
:
output_tensor_grads
.
append
(
None
)
continue
output_tensor_grad
=
p2p_communication
.
send_forward_recv_backward
(
output_tensor
,
tensor_shape
,
timers
=
timers
)
output_tensor_grads
.
append
(
output_tensor_grad
)
return
output_tensor_grads
def
send_backward_recv_forward
(
input_tensor_grads
,
tensor_shapes
,
timers
):
if
not
isinstance
(
input_tensor_grads
,
list
):
input_tensor_grads
=
[
input_tensor_grads
]
input_tensors
=
[]
for
(
input_tensor_grad
,
tensor_shape
)
in
zip
(
input_tensor_grads
,
tensor_shapes
):
if
tensor_shape
is
None
:
input_tensors
.
append
(
None
)
continue
input_tensor
=
p2p_communication
.
send_backward_recv_forward
(
input_tensor_grad
,
tensor_shape
,
timers
=
timers
)
input_tensors
.
append
(
input_tensor
)
return
input_tensors
def
forward_backward_pipelining_without_interleaving
(
forward_step_func
,
data_iterator
,
def
forward_backward_pipelining_without_interleaving
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
model
,
optimizer
,
timers
,
forward_only
):
forward_only
):
...
@@ -389,6 +537,13 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
...
@@ -389,6 +537,13 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
num_microbatches_remaining
=
\
num_microbatches_remaining
=
\
num_microbatches
-
num_warmup_microbatches
num_microbatches
-
num_warmup_microbatches
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
model_type
=
unwrapped_model
.
model_type
rank
=
mpu
.
get_pipeline_model_parallel_rank
()
recv_tensor_shapes
=
get_tensor_shapes
(
rank
-
1
,
model_type
)
send_tensor_shapes
=
get_tensor_shapes
(
rank
,
model_type
)
# Input, output tensors only need to be saved when doing backward passes
# Input, output tensors only need to be saved when doing backward passes
input_tensors
=
None
input_tensors
=
None
output_tensors
=
None
output_tensors
=
None
...
@@ -399,10 +554,10 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
...
@@ -399,10 +554,10 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
# Run warmup forward passes.
# Run warmup forward passes.
for
i
in
range
(
num_warmup_microbatches
):
for
i
in
range
(
num_warmup_microbatches
):
input_tensor
=
p2p_communication
.
recv_forward
(
timers
=
timers
)
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
timers
=
timers
)
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
input_tensor
,
losses_reduced
)
p2p_communication
.
send_forward
(
output_tensor
,
timers
=
timers
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
if
not
forward_only
:
if
not
forward_only
:
input_tensors
.
append
(
input_tensor
)
input_tensors
.
append
(
input_tensor
)
...
@@ -412,7 +567,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
...
@@ -412,7 +567,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
# If all microbatches are run in warmup / cooldown phase, then no need to
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
# receive this tensor here.
if
num_microbatches_remaining
>
0
:
if
num_microbatches_remaining
>
0
:
input_tensor
=
p2p_communication
.
recv_forward
(
timers
=
timers
)
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
timers
=
timers
)
# Run 1F1B in steady state.
# Run 1F1B in steady state.
for
i
in
range
(
num_microbatches_remaining
):
for
i
in
range
(
num_microbatches_remaining
):
...
@@ -421,15 +576,16 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
...
@@ -421,15 +576,16 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
input_tensor
,
losses_reduced
)
if
forward_only
:
if
forward_only
:
p2p_communication
.
send_forward
(
output_tensor
,
timers
=
timers
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
if
not
last_iteration
:
if
not
last_iteration
:
input_tensor
=
p2p_communication
.
recv_forward
(
timers
=
timers
)
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
timers
=
timers
)
else
:
else
:
output_tensor_grad
=
\
output_tensor_grad
=
\
p2p_communication
.
send_forward_recv_backward
(
output_tensor
,
send_forward_recv_backward
(
output_tensor
,
timers
=
timers
)
send_tensor_shapes
,
timers
=
timers
)
# Add input_tensor and output_tensor to end of list.
# Add input_tensor and output_tensor to end of list.
input_tensors
.
append
(
input_tensor
)
input_tensors
.
append
(
input_tensor
)
...
@@ -446,11 +602,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
...
@@ -446,11 +602,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if
last_iteration
:
if
last_iteration
:
input_tensor
=
None
input_tensor
=
None
p2p_communication
.
send_backward
(
input_tensor_grad
,
timers
=
timers
)
send_backward
(
input_tensor_grad
,
recv_tensor_shapes
,
timers
=
timers
)
else
:
else
:
input_tensor
=
\
input_tensor
=
\
p2p_communication
.
send_backward_recv_forward
(
send_backward_recv_forward
(
input_tensor_grad
,
timers
=
timers
)
input_tensor_grad
,
recv_tensor_shapes
,
timers
=
timers
)
# Run cooldown backward passes.
# Run cooldown backward passes.
if
not
forward_only
:
if
not
forward_only
:
...
@@ -458,12 +614,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
...
@@ -458,12 +614,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
input_tensor
=
input_tensors
.
pop
(
0
)
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor_grad
=
p2p_communication
.
recv_backward
(
timers
=
timers
)
output_tensor_grad
=
recv_backward
(
send_tensor_shapes
,
timers
=
timers
)
input_tensor_grad
=
\
input_tensor_grad
=
\
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
output_tensor_grad
)
p2p_communication
.
send_backward
(
input_tensor_grad
,
timers
=
timers
)
send_backward
(
input_tensor_grad
,
recv_tensor_shapes
,
timers
=
timers
)
return
losses_reduced
return
losses_reduced
megatron/text_generation_server.py
View file @
ed6d28b1
...
@@ -39,10 +39,19 @@ class MegatronGenerate(Resource):
...
@@ -39,10 +39,19 @@ class MegatronGenerate(Resource):
print
(
"request IP: "
+
str
(
request
.
remote_addr
))
print
(
"request IP: "
+
str
(
request
.
remote_addr
))
print
(
json
.
dumps
(
request
.
get_json
()),
flush
=
True
)
print
(
json
.
dumps
(
request
.
get_json
()),
flush
=
True
)
print
(
"current time: "
,
datetime
.
datetime
.
now
())
print
(
"current time: "
,
datetime
.
datetime
.
now
())
if
not
"prompts"
in
request
.
get_json
():
return
"prompts argument required"
,
400
sentences
=
request
.
get_json
()[
"sentences"
]
if
"max_len"
in
request
.
get_json
():
if
len
(
sentences
)
>
128
:
return
"max_len is no longer used. Replace with tokens_to_generate"
,
400
return
"Maximum number of sentences is 128"
,
400
if
"sentences"
in
request
.
get_json
():
return
"sentences is no longer used. Replace with prompts"
,
400
prompts
=
request
.
get_json
()[
"prompts"
]
if
len
(
prompts
)
>
128
:
return
"Maximum number of prompts is 128"
,
400
tokens_to_generate
=
64
# Choosing hopefully sane default. Full sequence is slow
tokens_to_generate
=
64
# Choosing hopefully sane default. Full sequence is slow
if
"tokens_to_generate"
in
request
.
get_json
():
if
"tokens_to_generate"
in
request
.
get_json
():
...
@@ -52,18 +61,35 @@ class MegatronGenerate(Resource):
...
@@ -52,18 +61,35 @@ class MegatronGenerate(Resource):
if
tokens_to_generate
<
1
:
if
tokens_to_generate
<
1
:
return
"tokens_to_generate must be an integer greater than 0"
return
"tokens_to_generate must be an integer greater than 0"
all_
probs
=
False
log
probs
=
False
if
"
all_
probs"
in
request
.
get_json
():
if
"
log
probs"
in
request
.
get_json
():
all_
probs
=
request
.
get_json
()[
"
all_
probs"
]
log
probs
=
request
.
get_json
()[
"
log
probs"
]
if
not
isinstance
(
all_
probs
,
bool
):
if
not
isinstance
(
log
probs
,
bool
):
return
"
all_
probs must be a boolean value"
return
"
log
probs must be a boolean value"
temperature
=
args
.
temperature
temperature
=
args
.
temperature
if
"temperature"
in
request
.
get_json
():
if
"temperature"
in
request
.
get_json
():
temperature
=
request
.
get_json
()[
"temperature"
]
temperature
=
request
.
get_json
()[
"temperature"
]
if
not
isinstance
(
temperature
,
float
)
or
not
\
if
not
(
type
(
temperature
)
==
int
or
type
(
temperature
)
==
float
):
0.0
<
temperature
<=
100.0
:
return
"temperature must be a positive number less than or equal to 100.0"
return
"temperature must be a positive float less than or equal to 100.0"
if
not
(
0.0
<
temperature
<=
100.0
):
return
"temperature must be a positive number less than or equal to 100.0"
top_k
=
args
.
top_k
if
"top_k"
in
request
.
get_json
():
top_k
=
request
.
get_json
()[
"top_k"
]
if
not
(
type
(
top_k
)
==
int
):
return
"top_k must be an integer equal to or greater than 0 and less than or equal to 1000"
if
not
(
0
<
top_k
<=
1000
):
return
"top_k must be equal to or greater than 0 and less than or equal to 1000"
top_p
=
args
.
top_p
if
"top_p"
in
request
.
get_json
():
top_p
=
request
.
get_json
()[
"top_p"
]
if
not
(
type
(
top_p
)
==
float
):
return
"top_p must be a positive float less than or equal to 1.0"
if
not
(
0
<
top_p
<=
1.0
):
return
"top_p must be less than or equal to 1.0"
add_BOS
=
False
add_BOS
=
False
if
"add_BOS"
in
request
.
get_json
():
if
"add_BOS"
in
request
.
get_json
():
...
@@ -73,24 +99,24 @@ class MegatronGenerate(Resource):
...
@@ -73,24 +99,24 @@ class MegatronGenerate(Resource):
with
lock
:
# Need to get lock to keep multiple threads from hitting code
with
lock
:
# Need to get lock to keep multiple threads from hitting code
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
resp
_sentences
,
resp_sentences_seg
,
output_logits
,
full_logits
,
tokens
=
generate
(
self
.
model
,
sentences
,
tokens_to_generate
,
all_probs
,
temperature
,
add_BOS
)
resp
onse
,
response_seg
,
response_logprobs
=
generate
(
self
.
model
,
prompts
,
if
all_probs
:
tokens_to_generate
,
return
jsonify
({
"sentences"
:
resp_sentence
s
,
logprob
s
,
"segments"
:
resp_sentences_seg
,
temperature
,
"logits"
:
output_logits
,
top_k
,
"all_logits"
:
full_logits
,
top_p
,
"tokens"
:
tokens
})
add_BOS
)
return
jsonify
({
"
sentences"
:
resp_sentences
,
return
jsonify
({
"
text"
:
response
,
"segments"
:
resp
_sentences
_seg
,
"segments"
:
resp
onse
_seg
,
"log
its"
:
output_logit
s
})
"log
probs"
:
response_logprob
s
})
class
MegatronServer
(
object
):
class
MegatronServer
(
object
):
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
):
self
.
app
=
Flask
(
__name__
,
static_url_path
=
''
)
self
.
app
=
Flask
(
__name__
,
static_url_path
=
''
)
api
=
Api
(
self
.
app
)
api
=
Api
(
self
.
app
)
api
.
add_resource
(
MegatronGenerate
,
'/
generate
'
,
resource_class_args
=
[
model
])
api
.
add_resource
(
MegatronGenerate
,
'/
api
'
,
resource_class_args
=
[
model
])
def
run
(
self
,
url
):
def
run
(
self
,
url
):
self
.
app
.
run
(
url
,
threaded
=
True
,
debug
=
False
)
self
.
app
.
run
(
url
,
threaded
=
True
,
debug
=
False
)
megatron/text_generation_utils.py
View file @
ed6d28b1
...
@@ -108,13 +108,13 @@ def tokenize_batch(sentences, max_len, add_BOS):
...
@@ -108,13 +108,13 @@ def tokenize_batch(sentences, max_len, add_BOS):
context_length_tensor
=
torch
.
cuda
.
LongTensor
(
context_lengths
)
context_length_tensor
=
torch
.
cuda
.
LongTensor
(
context_lengths
)
return
context_tokens_tensor
,
context_length_tensor
return
context_tokens_tensor
,
context_length_tensor
def
send_generate_info
(
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
all_
probs
):
def
send_generate_info
(
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
log
probs
,
temperature
,
top_k
,
top_p
):
"""
"""
Needs to be synced up with receive_generate_info
Needs to be synced up with receive_generate_info
"""
"""
# Send the sizes of the tensors
# Send the sizes of the tensors
input_info
=
[
context_tokens_tensor
.
size
(
0
),
context_tokens_tensor
.
size
(
1
),
tokens_to_generate
,
all_
probs
]
input_info
=
[
context_tokens_tensor
.
size
(
0
),
context_tokens_tensor
.
size
(
1
),
tokens_to_generate
,
log
probs
,
temperature
,
top_k
,
top_p
]
input_info_tensor
=
torch
.
cuda
.
Long
Tensor
(
input_info
)
input_info_tensor
=
torch
.
cuda
.
Float
Tensor
(
input_info
)
torch
.
distributed
.
broadcast
(
input_info_tensor
,
0
)
torch
.
distributed
.
broadcast
(
input_info_tensor
,
0
)
# Send variables to all ranks
# Send variables to all ranks
...
@@ -125,12 +125,15 @@ def receive_generate_info():
...
@@ -125,12 +125,15 @@ def receive_generate_info():
"""
"""
Needs to be synced up with send_generate_info
Needs to be synced up with send_generate_info
"""
"""
input_info_tensor
=
torch
.
empty
(
4
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
input_info_tensor
=
torch
.
empty
(
7
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
torch
.
distributed
.
broadcast
(
input_info_tensor
,
0
)
torch
.
distributed
.
broadcast
(
input_info_tensor
,
0
)
batch_size
=
input_info_tensor
[
0
].
item
()
batch_size
=
int
(
input_info_tensor
[
0
].
item
())
seq_len
=
input_info_tensor
[
1
].
item
()
seq_len
=
int
(
input_info_tensor
[
1
].
item
())
tokens_to_generate
=
input_info_tensor
[
2
].
item
()
tokens_to_generate
=
int
(
input_info_tensor
[
2
].
item
())
all_probs
=
input_info_tensor
[
3
].
item
()
logprobs
=
bool
(
input_info_tensor
[
3
].
item
())
temperature
=
float
(
input_info_tensor
[
4
].
item
())
top_k
=
int
(
input_info_tensor
[
5
].
item
())
top_p
=
float
(
input_info_tensor
[
6
].
item
())
context_length_tensor
=
torch
.
empty
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
context_length_tensor
=
torch
.
empty
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
context_tokens_tensor
=
torch
.
empty
(
batch_size
,
seq_len
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
context_tokens_tensor
=
torch
.
empty
(
batch_size
,
seq_len
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
...
@@ -139,65 +142,52 @@ def receive_generate_info():
...
@@ -139,65 +142,52 @@ def receive_generate_info():
torch
.
distributed
.
broadcast
(
context_length_tensor
,
0
)
torch
.
distributed
.
broadcast
(
context_length_tensor
,
0
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
0
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
0
)
return
context_length_tensor
,
context_tokens_tensor
,
tokens_to_generate
,
all_
probs
return
context_length_tensor
,
context_tokens_tensor
,
tokens_to_generate
,
log
probs
,
temperature
,
top_k
,
top_p
def
synced_generate
(
model
,
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
all_
probs
,
temperature
):
def
synced_generate
(
model
,
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
log
probs
,
temperature
,
top_k
,
top_p
):
context_length
=
context_length_tensor
.
min
().
item
()
context_length
=
context_length_tensor
.
min
().
item
()
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
)
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
)
batch_token_iterator
=
sample_sequence_batch
(
model
,
context_tokens_tensor
,
batch_token_iterator
=
sample_sequence_batch
(
model
,
context_tokens_tensor
,
context_length_tensor
,
context_length_tensor
,
attention_mask
,
position_ids
,
attention_mask
,
position_ids
,
tokens_to_generate
,
tokens_to_generate
,
all_probs
,
logprobs
,
temperature
=
temperature
)
temperature
,
for
tokens
,
lengths
,
output_logits
,
full_logits
in
batch_token_iterator
:
top_k
,
top_p
)
for
tokens
,
lengths
,
output_logits
in
batch_token_iterator
:
context_length
+=
1
context_length
+=
1
if
mpu
.
is_pipeline_last_stage
():
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
print
(
'last rank output size {} {} |
\n
'
.
format
(
output_logits
.
size
(
0
),
output_logits
.
size
(
1
)))
torch
.
distributed
.
broadcast
(
output_logits
,
src
,
group
)
if
all_probs
:
print
(
'last rank full size {} {} |
\n
'
.
format
(
full_logits
.
size
(
0
),
full_logits
.
size
(
1
),
full_logits
.
size
(
2
)))
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
torch
.
distributed
.
broadcast
(
full_logits
,
src
,
group
)
else
:
if
logprobs
:
if
mpu
.
is_pipeline_
fir
st_stage
():
if
mpu
.
is_pipeline_
la
st_stage
():
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
group
=
mpu
.
get_embedding_group
()
output_logits
=
torch
.
empty
(
tokens
.
size
(
0
),
context_length
-
1
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cuda"
))
print
(
'first rank output size {} {} |
\n
'
.
format
(
output_logits
.
size
(
0
),
output_logits
.
size
(
1
)))
torch
.
distributed
.
broadcast
(
output_logits
,
src
,
group
)
torch
.
distributed
.
broadcast
(
output_logits
,
src
,
group
)
if
all_probs
:
else
:
args
=
get_args
()
if
mpu
.
is_pipeline_first_stage
()
:
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
group
=
mpu
.
get_embedding_group
()
full_logits
=
torch
.
empty
(
tokens
.
size
(
0
),
context_length
-
1
,
args
.
padded_vocab_size
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cuda"
))
output_logits
=
torch
.
empty
(
tokens
.
size
(
0
),
context_length
-
1
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cuda"
))
print
(
'first rank full size {} {} |
\n
'
.
format
(
full_logits
.
size
(
0
),
torch
.
distributed
.
broadcast
(
output_logits
,
src
,
group
)
full_logits
.
size
(
1
),
full_logits
.
size
(
2
)))
torch
.
distributed
.
broadcast
(
full_logits
,
src
,
group
)
if
tokens
is
not
None
:
if
tokens
is
not
None
:
return
tokens
[:,
:
context_length
],
output_logits
,
full_logits
return
tokens
[:,
:
context_length
],
output_logits
def
generate
(
model
,
sentences
=
None
,
tokens_to_generate
=
0
,
all_
probs
=
False
,
temperature
=
1.0
,
add_BOS
=
False
):
def
generate
(
model
,
sentences
=
None
,
tokens_to_generate
=
0
,
log
probs
=
False
,
temperature
=
1.0
,
top_k
=
0
,
top_p
=
0.0
,
add_BOS
=
False
):
model
.
eval
()
model
.
eval
()
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
context_tokens_tensor
,
context_length_tensor
=
tokenize_batch
(
sentences
,
tokens_to_generate
,
add_BOS
)
context_tokens_tensor
,
context_length_tensor
=
tokenize_batch
(
sentences
,
tokens_to_generate
,
add_BOS
)
send_generate_info
(
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
all_
probs
)
send_generate_info
(
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
log
probs
,
temperature
,
top_k
,
top_p
)
else
:
else
:
context_length_tensor
,
context_tokens_tensor
,
tokens_to_generate
,
all_
probs
=
receive_generate_info
()
context_length_tensor
,
context_tokens_tensor
,
tokens_to_generate
,
log
probs
,
temperature
,
top_k
,
top_p
=
receive_generate_info
()
output
=
synced_generate
(
model
,
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
all_
probs
,
temperature
)
output
=
synced_generate
(
model
,
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
log
probs
,
temperature
,
top_k
,
top_p
)
if
output
is
not
None
:
if
output
is
not
None
:
decode_tokens
,
output_logits
,
full_logits
=
output
decode_tokens
,
output_logits
=
output
args
=
get_args
()
args
=
get_args
()
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
...
@@ -205,7 +195,8 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe
...
@@ -205,7 +195,8 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe
resp_sentences_seg
=
[]
resp_sentences_seg
=
[]
decode_tokens
=
decode_tokens
.
cpu
().
numpy
().
tolist
()
decode_tokens
=
decode_tokens
.
cpu
().
numpy
().
tolist
()
for
decode_token
in
decode_tokens
:
for
i
,
decode_token
in
enumerate
(
decode_tokens
):
resp_sentences
.
append
(
tokenizer
.
detokenize
(
decode_token
))
resp_sentences
.
append
(
tokenizer
.
detokenize
(
decode_token
))
words
=
[]
words
=
[]
for
token
in
decode_token
:
for
token
in
decode_token
:
...
@@ -213,12 +204,10 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe
...
@@ -213,12 +204,10 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe
word
=
bytearray
([
tokenizer
.
tokenizer
.
byte_decoder
[
c
]
for
c
in
word
]).
decode
(
'utf-8'
,
errors
=
'replace'
)
word
=
bytearray
([
tokenizer
.
tokenizer
.
byte_decoder
[
c
]
for
c
in
word
]).
decode
(
'utf-8'
,
errors
=
'replace'
)
words
.
append
(
word
)
words
.
append
(
word
)
resp_sentences_seg
.
append
(
words
)
resp_sentences_seg
.
append
(
words
)
output_logits
=
output_logits
.
cpu
().
numpy
().
tolist
()
if
logprobs
:
if
all_probs
:
output_logits
=
output_logits
.
cpu
().
numpy
().
tolist
()
full_logits
=
full_logits
.
cpu
().
numpy
()
#.tolist()
return
resp_sentences
,
resp_sentences_seg
,
output_logits
return
resp_sentences
,
resp_sentences_seg
,
output_logits
,
full_logits
,
decode_tokens
def
generate_samples_eval
(
model
,
context
,
max_gen_length
,
eos_token_id
):
def
generate_samples_eval
(
model
,
context
,
max_gen_length
,
eos_token_id
):
"""
"""
...
@@ -268,9 +257,17 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
...
@@ -268,9 +257,17 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
return
output_tensor
return
output_tensor
def
sample_sequence_batch
(
model
,
context_tokens
,
context_lengths
,
def
sample_sequence_batch
(
model
,
attention_mask
,
position_ids
,
context_tokens
,
tokens_to_generate
,
all_probs
=
False
,
type_ids
=
None
,
temperature
=
None
):
context_lengths
,
attention_mask
,
position_ids
,
tokens_to_generate
,
logprobs
,
temperature
,
top_k
,
top_p
,
type_ids
=
None
):
args
=
get_args
()
args
=
get_args
()
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
...
@@ -340,8 +337,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -340,8 +337,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
else
:
else
:
logits
=
logits
.
float
()
logits
=
logits
.
float
()
logits
/=
temperature
logits
/=
temperature
logits
=
top_k_logits
(
logits
,
top_k
=
args
.
top_k
,
logits
=
top_k_logits
(
logits
,
top_k
=
top_k
,
top_p
=
args
.
top_p
)
top_p
=
top_p
)
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
prev
=
torch
.
multinomial
(
log_probs
,
num_samples
=
1
).
view
(
-
1
)
prev
=
torch
.
multinomial
(
log_probs
,
num_samples
=
1
).
view
(
-
1
)
started
=
context_lengths
<=
context_length
started
=
context_lengths
<=
context_length
...
@@ -353,22 +350,19 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -353,22 +350,19 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
new_tokens
=
switch
(
new_tokens
=
switch
(
tokens
[:,
context_length
].
view
(
-
1
),
prev
,
started
)
tokens
[:,
context_length
].
view
(
-
1
),
prev
,
started
)
tokens
[:,
context_length
]
=
new_tokens
tokens
[:,
context_length
]
=
new_tokens
if
output_logits
is
None
:
if
logprobs
:
output_context
=
F
.
log_softmax
(
output
[:,
:
context_length
,
:],
2
)
if
output_logits
is
None
:
indices
=
torch
.
unsqueeze
(
tokens
[:,
1
:
context_length
+
1
],
2
)
output_context
=
F
.
log_softmax
(
output
[:,
:
context_length
,
:],
2
)
output_logits
=
torch
.
gather
(
output_context
,
2
,
indices
).
squeeze
(
2
)
indices
=
torch
.
unsqueeze
(
tokens
[:,
1
:
context_length
+
1
],
2
)
if
all_probs
:
output_logits
=
torch
.
gather
(
output_context
,
2
,
indices
).
squeeze
(
2
)
full_logits
=
output_context
else
:
else
:
output_context
=
F
.
log_softmax
(
output
,
2
)
output_context
=
F
.
log_softmax
(
output
,
2
)
indices
=
torch
.
unsqueeze
(
new_tokens
,
1
).
unsqueeze
(
2
)
indices
=
torch
.
unsqueeze
(
new_tokens
,
1
).
unsqueeze
(
2
)
new_output_logits
=
torch
.
gather
(
output_context
,
2
,
indices
).
squeeze
(
2
)
new_output_logits
=
torch
.
gather
(
output_context
,
2
,
indices
).
squeeze
(
2
)
# TODO(rprenger) we're copying output_logits every time. Should pre-allocate
# TODO(rprenger) we're copying output_logits every time. Should pre-allocate
output_logits
=
torch
.
cat
([
output_logits
,
new_output_logits
],
1
)
output_logits
=
torch
.
cat
([
output_logits
,
new_output_logits
],
1
)
if
all_probs
:
full_logits
=
torch
.
cat
([
full_logits
,
output_context
],
1
)
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
group
=
mpu
.
get_embedding_group
()
...
@@ -383,10 +377,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -383,10 +377,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
torch
.
distributed
.
broadcast
(
done
,
src
,
group
)
torch
.
distributed
.
broadcast
(
done
,
src
,
group
)
if
all_probs
:
yield
tokens
,
lengths
,
output_logits
yield
tokens
,
lengths
,
output_logits
,
full_logits
else
:
yield
tokens
,
lengths
,
output_logits
,
None
else
:
else
:
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
...
@@ -395,9 +386,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -395,9 +386,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
new_tokens
=
torch
.
empty_like
(
tokens
[:,
context_length
])
new_tokens
=
torch
.
empty_like
(
tokens
[:,
context_length
])
torch
.
distributed
.
broadcast
(
new_tokens
,
src
,
group
)
torch
.
distributed
.
broadcast
(
new_tokens
,
src
,
group
)
tokens
[:,
context_length
]
=
new_tokens
tokens
[:,
context_length
]
=
new_tokens
yield
tokens
,
None
,
None
,
None
yield
tokens
,
None
,
None
else
:
else
:
yield
None
,
None
,
None
,
None
yield
None
,
None
,
None
done
=
torch
.
cuda
.
ByteTensor
([
0
])
done
=
torch
.
cuda
.
ByteTensor
([
0
])
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
...
...
megatron/training.py
View file @
ed6d28b1
...
@@ -38,6 +38,7 @@ from megatron import print_rank_last
...
@@ -38,6 +38,7 @@ from megatron import print_rank_last
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.model
import
Float16Module
from
megatron.model
import
Float16Module
from
megatron.model
import
ModelType
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.initialize
import
write_args_to_tensorboard
...
@@ -61,6 +62,7 @@ def print_datetime(string):
...
@@ -61,6 +62,7 @@ def print_datetime(string):
def
pretrain
(
train_valid_test_dataset_provider
,
def
pretrain
(
train_valid_test_dataset_provider
,
model_provider
,
model_provider
,
model_type
,
forward_step_func
,
forward_step_func
,
extra_args_provider
=
None
,
extra_args_provider
=
None
,
args_defaults
=
{}):
args_defaults
=
{}):
...
@@ -77,6 +79,7 @@ def pretrain(train_valid_test_dataset_provider,
...
@@ -77,6 +79,7 @@ def pretrain(train_valid_test_dataset_provider,
train/valid/test dataset and returns `train, valid, test` datasets.
train/valid/test dataset and returns `train, valid, test` datasets.
model_provider: a function that returns a vanilla version of the
model_provider: a function that returns a vanilla version of the
model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
model_type: an enum that specifies the type of model being trained.
forward_step_func: a function that takes a `data iterator` and `model`,
forward_step_func: a function that takes a `data iterator` and `model`,
and returns a `loss` scalar with a dictionary with key:values being
and returns a `loss` scalar with a dictionary with key:values being
the info we would like to monitor during training, for example
the info we would like to monitor during training, for example
...
@@ -109,7 +112,8 @@ def pretrain(train_valid_test_dataset_provider,
...
@@ -109,7 +112,8 @@ def pretrain(train_valid_test_dataset_provider,
# Model, optimizer, and learning rate.
# Model, optimizer, and learning rate.
timers
(
'model-and-optimizer-setup'
).
start
()
timers
(
'model-and-optimizer-setup'
).
start
()
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
)
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
,
model_type
)
timers
(
'model-and-optimizer-setup'
).
stop
()
timers
(
'model-and-optimizer-setup'
).
stop
()
print_datetime
(
'after model, optimizer, and learning rate '
print_datetime
(
'after model, optimizer, and learning rate '
'scheduler are built'
)
'scheduler are built'
)
...
@@ -189,13 +193,16 @@ def update_train_iters(args):
...
@@ -189,13 +193,16 @@ def update_train_iters(args):
print_rank_0
(
'setting training iterations to {}'
.
format
(
args
.
train_iters
))
print_rank_0
(
'setting training iterations to {}'
.
format
(
args
.
train_iters
))
def
get_model
(
model_provider_func
,
wrap_with_ddp
=
True
):
def
get_model
(
model_provider_func
,
model_type
=
ModelType
.
encoder_or_decoder
,
wrap_with_ddp
=
True
):
"""Build the model."""
"""Build the model."""
args
=
get_args
()
args
=
get_args
()
args
.
model_type
=
model_type
# Build model.
# Build model.
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
and
\
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
and
\
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
model_type
!=
ModelType
.
encoder_and_decoder
,
\
"Interleaved schedule not supported for model with both encoder and decoder"
model
=
[]
model
=
[]
for
i
in
range
(
args
.
virtual_pipeline_model_parallel_size
):
for
i
in
range
(
args
.
virtual_pipeline_model_parallel_size
):
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
...
@@ -206,14 +213,36 @@ def get_model(model_provider_func, wrap_with_ddp=True):
...
@@ -206,14 +213,36 @@ def get_model(model_provider_func, wrap_with_ddp=True):
pre_process
=
pre_process
,
pre_process
=
pre_process
,
post_process
=
post_process
post_process
=
post_process
)
)
this_model
.
model_type
=
model_type
model
.
append
(
this_model
)
model
.
append
(
this_model
)
else
:
else
:
pre_process
=
mpu
.
is_pipeline_first_stage
()
pre_process
=
mpu
.
is_pipeline_first_stage
()
post_process
=
mpu
.
is_pipeline_last_stage
()
post_process
=
mpu
.
is_pipeline_last_stage
()
model
=
model_provider_func
(
add_encoder
=
True
pre_process
=
pre_process
,
add_decoder
=
True
post_process
=
post_process
if
model_type
==
ModelType
.
encoder_and_decoder
:
)
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
assert
args
.
pipeline_model_parallel_split_rank
is
not
None
,
\
"Split rank needs to be specified for model with both encoder and decoder"
rank
=
mpu
.
get_pipeline_model_parallel_rank
()
split_rank
=
args
.
pipeline_model_parallel_split_rank
world_size
=
mpu
.
get_pipeline_model_parallel_world_size
()
pre_process
=
rank
==
0
or
rank
==
split_rank
post_process
=
(
rank
==
(
split_rank
-
1
))
or
(
rank
==
(
world_size
-
1
))
add_encoder
=
mpu
.
is_pipeline_stage_before_split
()
add_decoder
=
mpu
.
is_pipeline_stage_after_split
()
model
=
model_provider_func
(
pre_process
=
pre_process
,
post_process
=
post_process
,
add_encoder
=
add_encoder
,
add_decoder
=
add_decoder
)
else
:
model
=
model_provider_func
(
pre_process
=
pre_process
,
post_process
=
post_process
)
model
.
model_type
=
model_type
if
not
isinstance
(
model
,
list
):
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
model
=
[
model
]
...
@@ -306,11 +335,11 @@ def get_learning_rate_scheduler(optimizer):
...
@@ -306,11 +335,11 @@ def get_learning_rate_scheduler(optimizer):
return
lr_scheduler
return
lr_scheduler
def
setup_model_and_optimizer
(
model_provider_func
):
def
setup_model_and_optimizer
(
model_provider_func
,
model_type
):
"""Setup model and optimizer."""
"""Setup model and optimizer."""
args
=
get_args
()
args
=
get_args
()
model
=
get_model
(
model_provider_func
)
model
=
get_model
(
model_provider_func
,
model_type
)
unwrapped_model
=
unwrap_model
(
model
,
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
(
torchDDP
,
LocalDDP
,
Float16Module
))
...
@@ -379,13 +408,14 @@ def train_step(forward_step_func, data_iterator,
...
@@ -379,13 +408,14 @@ def train_step(forward_step_func, data_iterator,
# This should only run for models that support pipelined model parallelism
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
# (BERT and GPT-2).
timers
(
'backward-embedding-all-reduce'
).
start
()
timers
(
'backward-embedding-all-reduce'
).
start
()
if
(
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
or
if
mpu
.
is_rank_in_embedding_group
(
ignore_virtual
=
True
)
and
\
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
))
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
0
]
unwrapped_model
=
model
[
0
]
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
-
1
]
unwrapped_model
=
model
[
-
1
]
else
:
# We do not support the interleaved schedule for T5 yet.
unwrapped_model
=
model
[
0
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
...
...
pretrain_bert.py
View file @
ed6d28b1
...
@@ -25,7 +25,7 @@ from megatron import print_rank_0
...
@@ -25,7 +25,7 @@ from megatron import print_rank_0
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.model
import
BertModel
from
megatron.model
import
BertModel
,
ModelType
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
...
@@ -143,5 +143,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
...
@@ -143,5 +143,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
ModelType
.
encoder_or_decoder
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
pretrain_gpt.py
View file @
ed6d28b1
...
@@ -23,7 +23,7 @@ from megatron import get_timers
...
@@ -23,7 +23,7 @@ from megatron import get_timers
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.data.gpt_dataset
import
build_train_valid_test_datasets
from
megatron.data.gpt_dataset
import
build_train_valid_test_datasets
from
megatron.model
import
GPTModel
from
megatron.model
import
GPTModel
,
ModelType
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
...
@@ -121,5 +121,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
...
@@ -121,5 +121,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
})
ModelType
.
encoder_or_decoder
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
})
pretrain_ict.py
View file @
ed6d28b1
...
@@ -28,6 +28,7 @@ from megatron import get_timers
...
@@ -28,6 +28,7 @@ from megatron import get_timers
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.data.biencoder_dataset_utils
import
get_ict_batch
from
megatron.data.biencoder_dataset_utils
import
get_ict_batch
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.model
import
ModelType
from
megatron.model.biencoder_model
import
biencoder_model_provider
from
megatron.model.biencoder_model
import
biencoder_model_provider
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
...
@@ -174,5 +175,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
...
@@ -174,5 +175,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
pretrain
(
train_valid_test_datasets_provider
,
pretrain_ict_model_provider
,
pretrain_ict_model_provider
,
ModelType
.
encoder_or_decoder
,
forward_step
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
pretrain_t5.py
View file @
ed6d28b1
...
@@ -26,18 +26,58 @@ from megatron import (
...
@@ -26,18 +26,58 @@ from megatron import (
print_rank_0
print_rank_0
)
)
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.model
import
T5Model
from
megatron.model
import
T5Model
,
ModelType
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""
Pipeline parallelism for T5
===========================
T5 is a model architecture with both encoder and decoder blocks.
Consequently, pipeline parallelism is implemented slightly differently
compared to architectures like GPT and BERT.
In particular, when pipeline_model_parallel_world_size > 1, each stage
either executes an encoder block or a decoder block. The
--pipeline-model-parallel-split-rank argument controls the rank at which
the split happens: all ranks lower than this argument execute the
encoder block, and all ranks equal to or higher than this argument value
execute the decoder block.
In the encoder section of the model, only one tensor is sent downstream:
the intermediate encoder_hidden_state. In the decoder section of the
model, two tensors are sent downstream in the forward pass: the fully
computed encoder_hidden_state, and the intermediate decoder_hidden_state.
In particular, these are the shapes of the tensors sent between
different workers:
If rank is in decoder section:
intermediate decoder_hidden_state (pre-transpose),
complete encoder_hidden_state (post-transpose).
If rank is at boundary between encoder and decoder sections:
complete encoder_hidden_state (post-transpose).
If rank is in encoder section:
intermediate encoder_hidden_state (pre-transpose).
Additionally, we have code in the backward_step function in schedules.py
to accumulate the encoder_hidden_state gradient across skip connections
(encoder_hidden_state fed in as input to each layer in the decoder).
"""
def
model_provider
(
pre_process
=
True
,
post_process
=
True
,
add_encoder
=
True
,
add_decoder
=
True
):
"""Build the model."""
"""Build the model."""
assert
pre_process
and
post_process
,
"T5 doesn't yet support pipelining"
print_rank_0
(
'building T5 model ...'
)
print_rank_0
(
'building T5 model ...'
)
model
=
T5Model
(
num_tokentypes
=
0
,
model
=
T5Model
(
num_tokentypes
=
0
,
parallel_output
=
True
)
parallel_output
=
True
,
pre_process
=
pre_process
,
post_process
=
post_process
,
add_encoder
=
add_encoder
,
add_decoder
=
add_decoder
)
return
model
return
model
...
@@ -70,9 +110,7 @@ def get_batch(data_iterator):
...
@@ -70,9 +110,7 @@ def get_batch(data_iterator):
def
loss_func
(
loss_mask
,
output_tensor
):
def
loss_func
(
loss_mask
,
output_tensor
):
lm_loss_
,
_
=
output_tensor
lm_loss_
=
output_tensor
.
float
()
lm_loss_
=
lm_loss_
.
float
()
lm_loss
=
torch
.
sum
(
lm_loss
=
torch
.
sum
(
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
...
@@ -130,5 +168,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
...
@@ -130,5 +168,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
ModelType
.
encoder_and_decoder
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
pretrain_vit.py
View file @
ed6d28b1
...
@@ -20,6 +20,7 @@ import torch.nn.functional as F
...
@@ -20,6 +20,7 @@ import torch.nn.functional as F
from
functools
import
partial
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
,
mpu
,
print_rank_0
from
megatron
import
get_args
,
get_timers
,
mpu
,
print_rank_0
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.model
import
ModelType
from
megatron.model.vit_model
import
VitModel
from
megatron.model.vit_model
import
VitModel
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
...
@@ -92,6 +93,7 @@ if __name__ == "__main__":
...
@@ -92,6 +93,7 @@ if __name__ == "__main__":
pretrain
(
pretrain
(
train_valid_test_datasets_provider
,
train_valid_test_datasets_provider
,
model_provider
,
model_provider
,
ModelType
.
encoder_or_decoder
,
forward_step
,
forward_step
,
args_defaults
=
{
'dataloader_type'
:
'cyclic'
}
args_defaults
=
{
'dataloader_type'
:
'cyclic'
}
)
)
requirements.txt
deleted
100644 → 0
View file @
8c119d80
pybind11
torch
six
regex
numpy
setup.py
deleted
100644 → 0
View file @
8c119d80
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Setup for pip package."""
import
os
import
sys
import
setuptools
if
sys
.
version_info
<
(
3
,):
raise
Exception
(
"Python 2 is not supported by Megatron."
)
from
megatron.package_info
import
(
__description__
,
__contact_names__
,
__url__
,
__download_url__
,
__keywords__
,
__license__
,
__package_name__
,
__version__
,
)
with
open
(
"README.md"
,
"r"
)
as
fh
:
long_description
=
fh
.
read
()
###############################################################################
# Dependency Loading #
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
def
req_file
(
filename
):
with
open
(
filename
)
as
f
:
content
=
f
.
readlines
()
return
[
x
.
strip
()
for
x
in
content
]
install_requires
=
req_file
(
"requirements.txt"
)
setuptools
.
setup
(
name
=
__package_name__
,
# Versions should comply with PEP440. For a discussion on single-sourcing
# the version across setup.py and the project code, see
# https://packaging.python.org/en/latest/single_source_version.html
version
=
__version__
,
description
=
__description__
,
long_description
=
long_description
,
long_description_content_type
=
"text/markdown"
,
# The project's main homepage.
url
=
__url__
,
author
=
__contact_names__
,
maintainer
=
__contact_names__
,
# The licence under which the project is released
license
=
__license__
,
classifiers
=
[
'Intended Audience :: Developers'
,
'Intended Audience :: Science/Research'
,
'Intended Audience :: Information Technology'
,
# Indicate what your project relates to
'Topic :: Scientific/Engineering :: Artificial Intelligence'
,
'Topic :: Software Development :: Libraries :: Python Modules'
,
# Supported python versions
'Programming Language :: Python :: 3.6'
,
'Programming Language :: Python :: 3.7'
,
'Programming Language :: Python :: 3.8'
,
# Additional Setting
'Environment :: Console'
,
'Natural Language :: English'
,
'Operating System :: OS Independent'
,
],
python_requires
=
'>=3.6'
,
packages
=
setuptools
.
find_packages
(),
install_requires
=
install_requires
,
# Add in any packaged data.
include_package_data
=
True
,
zip_safe
=
False
,
# PyPI package information.
keywords
=
__keywords__
)
tools/text_generation_cli.py
View file @
ed6d28b1
...
@@ -25,10 +25,10 @@ if __name__ == "__main__":
...
@@ -25,10 +25,10 @@ if __name__ == "__main__":
url
=
sys
.
argv
[
1
]
url
=
sys
.
argv
[
1
]
while
True
:
while
True
:
sentence
=
raw_input
(
"Enter prompt: "
)
sentence
=
raw_input
(
"Enter prompt: "
)
max_len
=
int
(
input
(
"Enter number tokens
output
: "
))
tokens_to_generate
=
int
(
input
(
"Enter number
of
tokens
to generate
: "
))
data
=
json
.
dumps
({
"
sentence
s"
:
[
sentence
],
"
max_len"
:
max_len
})
data
=
json
.
dumps
({
"
prompt
s"
:
[
sentence
],
"
tokens_to_generate"
:
tokens_to_generate
})
req
=
PutRequest
(
url
,
data
,
{
'Content-Type'
:
'application/json'
})
req
=
PutRequest
(
url
,
data
,
{
'Content-Type'
:
'application/json'
})
response
=
urllib2
.
urlopen
(
req
)
response
=
urllib2
.
urlopen
(
req
)
resp_sentences
=
json
.
load
(
response
)
resp_sentences
=
json
.
load
(
response
)
print
(
"Megatron Response: "
)
print
(
"Megatron Response: "
)
print
(
resp_sentences
[
"
sentences
"
][
0
])
print
(
resp_sentences
[
"
text
"
][
0
])
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment