Commit ea128da5 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

t5_pipeline_fix

parent 3ae12a47
......@@ -177,7 +177,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]]
if pipeline_model_parallel_split_rank_ is not None and \
pipeline_model_parallel_split_rank_ not in embedding_ranks:
ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks:
embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank_],
ranks[-1]]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment