Unverified Commit 660eed91 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

[pipeline] set optimizer to optional in execute_pipeline (#4630)

* set optimizer to optional in execute_pipeline

* arrange device and mixed precision in booster init

* fix execute_pipeline in booster.py
parent c3d5fa3b
...@@ -49,7 +49,9 @@ class Booster: ...@@ -49,7 +49,9 @@ class Booster:
``` ```
Args: Args:
device (str or torch.device): The device to run the training. Default: 'cuda'. device (str or torch.device): The device to run the training. Default: None.
If plugin is not used or plugin doesn't control the device,
this argument will be set as training device ('cuda' will be used if argument is None).
mixed_precision (str or MixedPrecision): The mixed precision to run the training. Default: None. mixed_precision (str or MixedPrecision): The mixed precision to run the training. Default: None.
If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'. If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'.
'fp16' would use PyTorch AMP while `fp16_apex` would use Nvidia Apex. 'fp16' would use PyTorch AMP while `fp16_apex` would use Nvidia Apex.
...@@ -57,7 +59,7 @@ class Booster: ...@@ -57,7 +59,7 @@ class Booster:
""" """
def __init__(self, def __init__(self,
device: str = 'cuda', device: Optional[str] = None,
mixed_precision: Union[MixedPrecision, str] = None, mixed_precision: Union[MixedPrecision, str] = None,
plugin: Optional[Plugin] = None) -> None: plugin: Optional[Plugin] = None) -> None:
if plugin is not None: if plugin is not None:
...@@ -68,13 +70,16 @@ class Booster: ...@@ -68,13 +70,16 @@ class Booster:
# set accelerator # set accelerator
if self.plugin and self.plugin.control_device(): if self.plugin and self.plugin.control_device():
self.accelerator = None self.accelerator = None
warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.') if device is not None:
warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
else: else:
device = device or 'cuda'
self.accelerator = Accelerator(device) self.accelerator = Accelerator(device)
# set precision # set precision
if self.plugin and self.plugin.control_precision(): if self.plugin and self.plugin.control_precision():
warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.') if mixed_precision is not None:
warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
self.mixed_precision = None self.mixed_precision = None
elif mixed_precision is None: elif mixed_precision is None:
self.mixed_precision = None self.mixed_precision = None
...@@ -146,7 +151,7 @@ class Booster: ...@@ -146,7 +151,7 @@ class Booster:
data_iter: Iterator, data_iter: Iterator,
model: nn.Module, model: nn.Module,
criterion: Callable[[Any, Any], torch.Tensor], criterion: Callable[[Any, Any], torch.Tensor],
optimizer: Optimizer, optimizer: Optional[Optimizer] = None,
return_loss: bool = True, return_loss: bool = True,
return_outputs: bool = False) -> dict: return_outputs: bool = False) -> dict:
# run pipeline forward backward pass # run pipeline forward backward pass
......
...@@ -443,15 +443,15 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -443,15 +443,15 @@ class HybridParallelPlugin(PipelinePluginBase):
data_iter: Iterator, data_iter: Iterator,
model: HybridParallelModule, model: HybridParallelModule,
criterion: Callable[[Any, Any], torch.Tensor], criterion: Callable[[Any, Any], torch.Tensor],
optimizer: Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, optimizer: Optional[Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer,
HybridParallelZeroOptimizer], HybridParallelZeroOptimizer]] = None,
return_loss: bool = True, return_loss: bool = True,
return_outputs: bool = False) -> dict: return_outputs: bool = False) -> dict:
assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled' assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled'
# return loss or outputs if needed # return loss or outputs if needed
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
with ctx: with ctx:
outputs = self.schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss, outputs = self.schedule.forward_backward_step(model, data_iter, criterion, optimizer, return_loss,
return_outputs) return_outputs)
model.sync_shared_params() model.sync_shared_params()
if isinstance(optimizer, HybridParallelZeroOptimizer): if isinstance(optimizer, HybridParallelZeroOptimizer):
......
from abc import abstractmethod from abc import abstractmethod
from typing import Any, Callable, Iterator from typing import Any, Callable, Iterator, Optional
import torch import torch
...@@ -15,7 +15,7 @@ class PipelinePluginBase(Plugin): ...@@ -15,7 +15,7 @@ class PipelinePluginBase(Plugin):
data_iter: Iterator, data_iter: Iterator,
model: ModelWrapper, model: ModelWrapper,
criterion: Callable[[Any, Any], torch.Tensor], criterion: Callable[[Any, Any], torch.Tensor],
optimizer: OptimizerWrapper, optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = True, return_loss: bool = True,
return_outputs: bool = False) -> dict: return_outputs: bool = False) -> dict:
pass pass
from typing import Any, Callable, Iterable from typing import Any, Callable, Iterable, Optional
from torch import Tensor from torch import Tensor
from torch.nn import Module from torch.nn import Module
...@@ -14,18 +14,18 @@ class PipelineSchedule: ...@@ -14,18 +14,18 @@ class PipelineSchedule:
def forward_backward_step(self, def forward_backward_step(self,
model: Module, model: Module,
optimizer: OptimizerWrapper,
data_iter: Iterable, data_iter: Iterable,
criterion: Callable[[Any, Any], Tensor], criterion: Callable[[Any, Any], Tensor],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False, return_loss: bool = False,
return_outputs: bool = False) -> dict: return_outputs: bool = False) -> dict:
"""Forward and backward step for pipeline training. """Forward and backward step for pipeline training.
Args: Args:
model (Module): Model to be trained. model (Module): Model to be trained.
optimizer (OptimizerWrapper): Optimizer to be used.
data_iter (Iterable): Data iterator. data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss. return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs. return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
......
...@@ -237,18 +237,18 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -237,18 +237,18 @@ class InterleavedSchedule(PipelineSchedule):
def forward_backward_step(self, def forward_backward_step(self,
model_chunk: Module, model_chunk: Module,
optimizer: OptimizerWrapper,
data_iter: Iterable, data_iter: Iterable,
criterion: Callable[..., Any], criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False, return_loss: bool = False,
return_outputs: bool = False) -> dict: return_outputs: bool = False) -> dict:
"""Runs interleaved 1F1B schedule, with communication between pipeline stages. """Runs interleaved 1F1B schedule, with communication between pipeline stages.
Args: Args:
model_chunk (List[Module]): Model Chunk to be trained. model_chunk (List[Module]): Model Chunk to be trained.
optimizer (OptimizerWrapper): Optimizer to be used.
data_iter (Iterable): Data iterator. data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss. return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs. return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
...@@ -256,6 +256,8 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -256,6 +256,8 @@ class InterleavedSchedule(PipelineSchedule):
dict: A dict with keys: 'loss' and 'outputs'. dict: A dict with keys: 'loss' and 'outputs'.
""" """
forward_only = not torch.is_grad_enabled() forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert forward_only, "Optimizer should be passed when doing backward."
self.load_batch(data_iter) self.load_batch(data_iter)
num_model_chunks = len(model_chunk) num_model_chunks = len(model_chunk)
......
...@@ -210,18 +210,18 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -210,18 +210,18 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
def forward_backward_step(self, def forward_backward_step(self,
model: Module, model: Module,
optimizer: OptimizerWrapper,
data_iter: Iterable, data_iter: Iterable,
criterion: Callable[..., Any], criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False, return_loss: bool = False,
return_outputs: bool = False) -> dict: return_outputs: bool = False) -> dict:
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages. """Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Args: Args:
model (Module): Model to be trained. model (Module): Model to be trained.
optimizer (OptimizerWrapper): Optimizer to be used.
data_iter (Iterable): Data iterator. data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss. return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs. return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
...@@ -229,6 +229,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -229,6 +229,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
dict: A dict with keys: 'loss' and 'outputs'. dict: A dict with keys: 'loss' and 'outputs'.
""" """
forward_only = not torch.is_grad_enabled() forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert forward_only, "Optimizer should be passed when doing backward."
self.load_batch(data_iter) self.load_batch(data_iter)
......
...@@ -46,7 +46,6 @@ def move_to_cuda(batch): ...@@ -46,7 +46,6 @@ def move_to_cuda(batch):
@torch.no_grad() @torch.no_grad()
def evaluate_model( def evaluate_model(
model: nn.Module, model: nn.Module,
optimizer,
criterion, criterion,
test_dataloader: Union[DataLoader, List[DataLoader]], test_dataloader: Union[DataLoader, List[DataLoader]],
num_labels: int, num_labels: int,
...@@ -71,12 +70,7 @@ def evaluate_model( ...@@ -71,12 +70,7 @@ def evaluate_model(
current_rank = dist.get_rank() current_rank = dist.get_rank()
#TODO pass dataloader to execute_pipeline directly #TODO pass dataloader to execute_pipeline directly
batch = iter([batch]) batch = iter([batch])
outputs = booster.execute_pipeline(batch, outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=True)
model,
criterion,
optimizer,
return_loss=True,
return_outputs=True)
if booster.plugin.stage_manager.is_last_stage(): if booster.plugin.stage_manager.is_last_stage():
val_loss = outputs["loss"] val_loss = outputs["loss"]
...@@ -304,7 +298,7 @@ def main(): ...@@ -304,7 +298,7 @@ def main():
for epoch in range(NUM_EPOCHS): for epoch in range(NUM_EPOCHS):
train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
results = evaluate_model(model, optimizer, _criterion, test_dataloader, data_builder.num_labels, args.task, results = evaluate_model(model, _criterion, test_dataloader, data_builder.num_labels, args.task,
data_builder.eval_splits, booster, coordinator) data_builder.eval_splits, booster, coordinator)
if coordinator.is_master(): if coordinator.is_master():
......
...@@ -110,9 +110,9 @@ def examine_pp(num_micro_batches): ...@@ -110,9 +110,9 @@ def examine_pp(num_micro_batches):
torch_loss.backward() torch_loss.backward()
pp_ret = schedule.forward_backward_step(sharded_model, pp_ret = schedule.forward_backward_step(sharded_model,
pp_optimizer,
iter(input_list), iter(input_list),
criterion, criterion,
pp_optimizer,
return_loss=True, return_loss=True,
return_outputs=True) return_outputs=True)
......
...@@ -90,9 +90,9 @@ def examine_pp(): ...@@ -90,9 +90,9 @@ def examine_pp():
torch_loss.backward() torch_loss.backward()
pp_ret = schedule.forward_backward_step(sharded_model, pp_ret = schedule.forward_backward_step(sharded_model,
pp_optimizer,
iter(input_list), iter(input_list),
criterion, criterion,
pp_optimizer,
return_loss=True, return_loss=True,
return_outputs=True) return_outputs=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment