Commit c552cefa authored by Jianghai's avatar Jianghai Committed by Hongxin Liu
Browse files

[pipeline]add pipeline policy and bert forward (#4130)

* add pipeline policy and bert forward to be done

* add bertmodel pipeline forward and make tests

* add Bert_Policy and test for policy

* update formatting

* update formatting

* update the code

* fix bugs

* fix name confilt
parent 5c897ddb
...@@ -21,7 +21,7 @@ def check_stage_manager(): ...@@ -21,7 +21,7 @@ def check_stage_manager():
1: [0, 1], 1: [0, 1],
2: [2, 3], 2: [2, 3],
3: [2, 3], 3: [2, 3],
} }
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM) stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank() rank = dist.get_rank()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment