Unverified Commit c008d4ad authored by Michelle's avatar Michelle Committed by GitHub
Browse files

[NFC] polish colossalai/engine/schedule/_pipeline_schedule.py code style (#2744)

parent 58abde28
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
import inspect import inspect
from typing import Callable, List, Tuple, Union from typing import Callable, List, Tuple, Union
import colossalai.communication as comm
import torch.cuda import torch.cuda
import colossalai.communication as comm
from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
...@@ -157,6 +158,7 @@ class PipelineSchedule(BaseSchedule): ...@@ -157,6 +158,7 @@ class PipelineSchedule(BaseSchedule):
def pre_processing(self, engine): def pre_processing(self, engine):
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
# TODO: remove this after testing new zero with pipeline parallelism # TODO: remove this after testing new zero with pipeline parallelism
model = engine.model model = engine.model
if isinstance(model, NaiveAMPModel): if isinstance(model, NaiveAMPModel):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment