"...git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "19e1a5cf16ead982eb8818cd69e41b06a5d23b20"
Unverified Commit 0ed7042f authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[pipeline] refactor pipeline (#679)

* refactor pipeline---put runtime schedule into engine.

* add type hint for schedule Optional[BaseSchedule]

* preprocess schedule during engine initializing

* infer pipeline schedule params from config
parent eace6938
from ._base_schedule import BaseSchedule
from ._pipeline_schedule import PipelineSchedule, InterleavedPipelineSchedule
from ._pipeline_schedule import PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
from ._non_pipeline_schedule import NonPipelineSchedule
__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule']
__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule', 'get_tensor_shape']
......@@ -16,6 +16,29 @@ from colossalai.zero.sharded_model import ShardedModelV2
from ._base_schedule import BaseSchedule
def get_tensor_shape():
if hasattr(gpc.config, 'TENSOR_SHAPE'):
return gpc.config.TENSOR_SHAPE
if not gpc.is_initialized(ParallelMode.PIPELINE):
return None
if hasattr(gpc.config, 'SEQ_LENGTH') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'HIDDEN_SIZE'):
if gpc.is_initialized(ParallelMode.DATA):
dp_size = gpc.get_world_size(ParallelMode.DATA)
else:
dp_size = 1
if gpc.is_initialized(ParallelMode.SEQUENCE):
seq_size = gpc.get_world_size(ParallelMode.SEQUENCE)
else:
seq_size = 1
tensor_shape = (gpc.config.SEQ_LENGTH // seq_size,
gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES,
gpc.config.HIDDEN_SIZE)
return tensor_shape
else:
return None
def pack_return_tensors(return_tensors):
output, label = tuple(zip(*return_tensors))
......
......@@ -20,7 +20,7 @@ from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.engine import Engine
......@@ -391,14 +391,18 @@ def initialize(model: nn.Module,
# initialize schedule for engine
if is_using_pp():
tensor_shape = getattr(gpc.config, 'TENSOR_SHAPE', None)
tensor_shape = get_tensor_shape()
use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks')
if gpc.is_initialized(ParallelMode.PARALLEL_1D):
scatter_gather = True
else:
scatter_gather = False
if use_interleaved:
schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
gpc.config.model.num_chunks, tensor_shape=tensor_shape, scatter_gather_tensors=True)
gpc.config.model.num_chunks, tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather)
else:
schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
tensor_shape=tensor_shape, scatter_gather_tensors=True)
tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather)
else:
schedule = NonPipelineSchedule()
......
......@@ -312,7 +312,7 @@ class AccuracyHook(MetricHook):
def after_test_iter(self, trainer, logits, targets, *args):
if self._is_stage_to_compute:
batch_size = trainer.schedule.batch_size
batch_size = trainer.engine.schedule.batch_size
self.metric.update(logits, targets, batch_size)
......@@ -392,7 +392,7 @@ class ThroughputHook(MetricHook):
def after_train_iter(self, trainer, *args):
if self._is_stage_to_compute:
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time())
self.metric.update(trainer.engine.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time())
def before_test(self, trainer):
if self._is_stage_to_compute:
......@@ -400,4 +400,4 @@ class ThroughputHook(MetricHook):
def after_test_iter(self, trainer, *args):
if self._is_stage_to_compute:
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())
self.metric.update(trainer.engine.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment