Unverified Commit bb0a668f authored by Wenhao Chen's avatar Wenhao Chen Committed by GitHub
Browse files

[hotfix] set return_outputs=False in examples and polish code (#5404)

* fix: simplify merge_batch

* fix: use return_outputs=False to eliminate extra memory consumption

* feat: add return_outputs warning

* style: remove `return_outputs=False` as it is the default value
parent 5fcd7795
...@@ -238,7 +238,6 @@ def main(): ...@@ -238,7 +238,6 @@ def main():
lambda x, y: x.loss, lambda x, y: x.loss,
optimizer, optimizer,
return_loss=True, return_loss=True,
return_outputs=True,
) )
# Backward and optimize # Backward and optimize
if is_pp_last_stage: if is_pp_last_stage:
......
...@@ -1183,6 +1183,9 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -1183,6 +1183,9 @@ class HybridParallelPlugin(PipelinePluginBase):
) -> dict: ) -> dict:
assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled" assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
if return_outputs:
warnings.warn("return_outputs may lead to significant extra memory consumption.")
# Create a context for gradient synchronization based on the optimizer type. # Create a context for gradient synchronization based on the optimizer type.
# If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync(). # If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync().
# This is to avoid redundant gradient reduction in pipeline parallelism (multiple microbatch values should be reduced once), # This is to avoid redundant gradient reduction in pipeline parallelism (multiple microbatch values should be reduced once),
......
...@@ -7,7 +7,7 @@ from torch.nn import Module ...@@ -7,7 +7,7 @@ from torch.nn import Module
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
...@@ -327,9 +327,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -327,9 +327,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.send_forward(output_obj) self.send_forward(output_obj)
if outputs is not None: if outputs is not None:
if isinstance(model, ModelWrapper): outputs = merge_batch(outputs)
model = model.unwrap()
outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0))
return {"loss": accum_loss, "outputs": outputs} return {"loss": accum_loss, "outputs": outputs}
def run_forward_backward( def run_forward_backward(
...@@ -412,9 +410,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -412,9 +410,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
if outputs is not None: if outputs is not None:
if isinstance(model, ModelWrapper): outputs = merge_batch(outputs)
model = model.unwrap()
outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0))
return {"loss": accum_loss, "outputs": outputs} return {"loss": accum_loss, "outputs": outputs}
def forward_backward_step( def forward_backward_step(
......
...@@ -178,7 +178,7 @@ def train_epoch( ...@@ -178,7 +178,7 @@ def train_epoch(
for _ in pbar: for _ in pbar:
if use_pipeline: if use_pipeline:
outputs = booster.execute_pipeline( outputs = booster.execute_pipeline(
train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True train_dataloader_iter, model, _criterion, optimizer, return_loss=True
) )
# Backward and optimize # Backward and optimize
if is_pp_last_stage: if is_pp_last_stage:
......
...@@ -231,7 +231,7 @@ def run_forward_backward( ...@@ -231,7 +231,7 @@ def run_forward_backward(
if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1: if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1:
# run pipeline forward backward when enabling pp in hybrid parallel plugin # run pipeline forward backward when enabling pp in hybrid parallel plugin
output_dict = booster.execute_pipeline( output_dict = booster.execute_pipeline(
data_iter, model, criterion, optimizer, return_loss=True, return_outputs=True data_iter, model, criterion, optimizer, return_loss=True
) )
loss, outputs = output_dict["loss"], output_dict["outputs"] loss, outputs = output_dict["loss"], output_dict["outputs"]
else: else:
......
...@@ -198,8 +198,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: ...@@ -198,8 +198,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion:
model, model,
_criterion, _criterion,
optimizer, optimizer,
return_loss=True, return_loss=True)
return_outputs=True)
# Backward and optimize # Backward and optimize
if is_pp_last_stage: if is_pp_last_stage:
loss = outputs['loss'] loss = outputs['loss']
......
...@@ -271,7 +271,7 @@ However, if pipeline parallel is enabled, there are several usages different fro ...@@ -271,7 +271,7 @@ However, if pipeline parallel is enabled, there are several usages different fro
3. Do forward and backward passing through calling `Booster.execute_pipeline` method: 3. Do forward and backward passing through calling `Booster.execute_pipeline` method:
```python ```python
outputs = booster.execute_pipeline( outputs = booster.execute_pipeline(
train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True train_dataloader_iter, model, _criterion, optimizer, return_loss=True
) )
``` ```
Backward passing has been completed by this method, so there is no need to call `loss.backward()` after executing this method. Backward passing has been completed by this method, so there is no need to call `loss.backward()` after executing this method.
......
...@@ -175,7 +175,7 @@ def train_epoch( ...@@ -175,7 +175,7 @@ def train_epoch(
for _ in pbar: for _ in pbar:
if use_pipeline: if use_pipeline:
outputs = booster.execute_pipeline( outputs = booster.execute_pipeline(
train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True train_dataloader_iter, model, _criterion, optimizer, return_loss=True
) )
# Backward and optimize # Backward and optimize
if is_pp_last_stage: if is_pp_last_stage:
......
...@@ -234,7 +234,7 @@ def run_forward_backward( ...@@ -234,7 +234,7 @@ def run_forward_backward(
if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1: if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1:
# run pipeline forward backward when enabling pp in hybrid parallel plugin # run pipeline forward backward when enabling pp in hybrid parallel plugin
output_dict = booster.execute_pipeline( output_dict = booster.execute_pipeline(
data_iter, model, criterion, optimizer, return_loss=True, return_outputs=True data_iter, model, criterion, optimizer, return_loss=True
) )
loss, outputs = output_dict["loss"], output_dict["outputs"] loss, outputs = output_dict["loss"], output_dict["outputs"]
else: else:
......
...@@ -193,8 +193,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: ...@@ -193,8 +193,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion:
model, model,
_criterion, _criterion,
optimizer, optimizer,
return_loss=True, return_loss=True)
return_outputs=True)
# Backward and optimize # Backward and optimize
if is_pp_last_stage: if is_pp_last_stage:
loss = outputs['loss'] loss = outputs['loss']
......
...@@ -264,7 +264,7 @@ elif args.plugin == "hybrid_parallel": ...@@ -264,7 +264,7 @@ elif args.plugin == "hybrid_parallel":
3. 通过调用`Booster.execute_pipeline` 方法来执行前向和后向传递: 3. 通过调用`Booster.execute_pipeline` 方法来执行前向和后向传递:
```python ```python
outputs = booster.execute_pipeline( outputs = booster.execute_pipeline(
train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True train_dataloader_iter, model, _criterion, optimizer, return_loss=True
) )
``` ```
该方法会自动执行后向传递,所以在执行该方法后不需要再调用 `loss.backward()`方法。 该方法会自动执行后向传递,所以在执行该方法后不需要再调用 `loss.backward()`方法。
......
...@@ -120,7 +120,7 @@ def main(): ...@@ -120,7 +120,7 @@ def main():
# run pipeline forward backward # run pipeline forward backward
batch = iter([batch]) batch = iter([batch])
outputs = booster.execute_pipeline( outputs = booster.execute_pipeline(
batch, model, criterion, optimizer, return_loss=True, return_outputs=True batch, model, criterion, optimizer, return_loss=True
) )
else: else:
outputs = model(**batch) outputs = model(**batch)
......
...@@ -148,7 +148,7 @@ def train_epoch( ...@@ -148,7 +148,7 @@ def train_epoch(
for _ in pbar: for _ in pbar:
if use_pipeline: if use_pipeline:
outputs = booster.execute_pipeline( outputs = booster.execute_pipeline(
train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True train_dataloader_iter, model, _criterion, optimizer, return_loss=True
) )
# Backward and optimize # Backward and optimize
if is_pp_last_device: if is_pp_last_device:
......
...@@ -145,7 +145,7 @@ def train_epoch( ...@@ -145,7 +145,7 @@ def train_epoch(
for _ in pbar: for _ in pbar:
if use_pipeline: if use_pipeline:
outputs = booster.execute_pipeline( outputs = booster.execute_pipeline(
train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True train_dataloader_iter, model, _criterion, optimizer, return_loss=True
) )
# Backward and optimize # Backward and optimize
if is_pp_last_stage: if is_pp_last_stage:
......
...@@ -271,7 +271,7 @@ def main(): ...@@ -271,7 +271,7 @@ def main():
for step in pbar: for step in pbar:
if use_pipeline: if use_pipeline:
outputs = booster.execute_pipeline( outputs = booster.execute_pipeline(
dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True dataloader_iter, model, _criterion, optimizer, return_loss=True
) )
loss = outputs["loss"] loss = outputs["loss"]
else: else:
......
...@@ -185,7 +185,7 @@ def main(): ...@@ -185,7 +185,7 @@ def main():
microbatch_size=1, microbatch_size=1,
enable_jit_fused=False, enable_jit_fused=False,
zero_stage=0, zero_stage=0,
precision="fp32", precision=args.mixed_precision,
initial_scale=1, initial_scale=1,
) )
else: else:
...@@ -286,7 +286,7 @@ def main(): ...@@ -286,7 +286,7 @@ def main():
for step in pbar: for step in pbar:
if use_pipeline: if use_pipeline:
outputs = booster.execute_pipeline( outputs = booster.execute_pipeline(
dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True dataloader_iter, model, _criterion, optimizer, return_loss=True
) )
loss = outputs["loss"] loss = outputs["loss"]
else: else:
......
...@@ -270,7 +270,6 @@ def main(): ...@@ -270,7 +270,6 @@ def main():
lambda x, y: x.loss, lambda x, y: x.loss,
optimizer, optimizer,
return_loss=True, return_loss=True,
return_outputs=True,
) )
# Backward and optimize # Backward and optimize
if is_pp_last_stage: if is_pp_last_stage:
......
...@@ -340,7 +340,6 @@ def main(): ...@@ -340,7 +340,6 @@ def main():
lambda x, y: x.loss, lambda x, y: x.loss,
optimizer, optimizer,
return_loss=True, return_loss=True,
return_outputs=True,
) )
# Backward and optimize # Backward and optimize
if is_pp_last_stage: if is_pp_last_stage:
......
...@@ -42,7 +42,7 @@ def train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, dataloader, b ...@@ -42,7 +42,7 @@ def train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, dataloader, b
for _ in pbar: for _ in pbar:
if use_pipeline: if use_pipeline:
outputs = booster.execute_pipeline( outputs = booster.execute_pipeline(
dataloader, model, _criterion, optimizer, return_loss=True, return_outputs=True dataloader, model, _criterion, optimizer, return_loss=True
) )
# Backward and optimize # Backward and optimize
if is_pp_last_stage: if is_pp_last_stage:
......
...@@ -74,7 +74,7 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[ ...@@ -74,7 +74,7 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[
loss = criterion(outputs[output_key]) loss = criterion(outputs[output_key])
return loss return loss
booster.execute_pipeline(data_iter, model, _criterion, optimizer, return_loss=True, return_outputs=False) booster.execute_pipeline(data_iter, model, _criterion, optimizer, return_loss=True)
optimizer.step() optimizer.step()
except Exception as e: except Exception as e:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment