Unverified Commit 708404d5 authored by ver217's avatar ver217 Committed by GitHub
Browse files

fix pipeline forward return tensors (#176)

parent 6fb550ac
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from typing import List, Tuple, Union, Callable
import inspect import inspect
import torch.cuda from typing import Callable, List, Tuple, Union
import colossalai.communication as comm import colossalai.communication as comm
import torch.cuda
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.logging import get_dist_logger
from colossalai.utils import switch_virtual_pipeline_parallel_rank
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2, from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3) ZeroRedundancyOptimizer_Level_3)
from colossalai.utils import switch_virtual_pipeline_parallel_rank
from colossalai.logging import get_dist_logger
from ._base_schedule import BaseSchedule from ._base_schedule import BaseSchedule
...@@ -151,7 +152,7 @@ class PipelineSchedule(BaseSchedule): ...@@ -151,7 +152,7 @@ class PipelineSchedule(BaseSchedule):
if gpc.is_last_rank(ParallelMode.PIPELINE): if gpc.is_last_rank(ParallelMode.PIPELINE):
if return_output_label: if return_output_label:
return_tensors.append(tuple((output_tensor, label))) return_tensors.append((output_tensor, label))
if accum_loss is not None: if accum_loss is not None:
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches
accum_loss.add_(loss_reduced.detach()) accum_loss.add_(loss_reduced.detach())
...@@ -414,7 +415,7 @@ class InterleavedPipelineSchedule(PipelineSchedule): ...@@ -414,7 +415,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
if gpc.is_pipeline_last_stage(): if gpc.is_pipeline_last_stage():
if return_output_label: if return_output_label:
return_tensors.append(tuple(output_tensor, label)) return_tensors.append((output_tensor, label))
if accum_loss is not None: if accum_loss is not None:
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches
accum_loss.add_(loss_reduced.detach()) accum_loss.add_(loss_reduced.detach())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment