Unverified Commit c6ea6501 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[tutorial] fixed pipeline bug for sequence parallel (#1943)

parent e52f9d91
......@@ -35,6 +35,17 @@ def parse_args():
return parser.parse_args()
def pipeline_data_process_func(stage_output, micro_batch_data):
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = micro_batch_data
if gpc.is_first_rank(ParallelMode.PIPELINE):
data = (tokens, padding_mask, types, lm_labels)
label = (loss_mask, sentence_order)
else:
data = (stage_output, padding_mask, types, lm_labels)
label = (loss_mask, sentence_order)
return data, label
def main():
# initialize
args = parse_args()
......@@ -155,6 +166,7 @@ def main():
if use_pipeline:
train_data_iter = SequenceParallelDataIterator(trainloader)
valid_data_iter = SequenceParallelDataIterator(validloader)
engine.schedule.data_process_func = pipeline_data_process_func
logger.info("start training")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment