"...git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "7d8e0338a40da196436bef866ebdc43d5ed1c677"
Unverified Commit 725af3ee authored by Wenhao Chen's avatar Wenhao Chen Committed by GitHub
Browse files

[booster] make optimizer argument optional for boost (#3993)

* feat: make optimizer optional in Booster.boost

* test: skip unet test if diffusers version > 0.10.2
parent c9cff7e7
...@@ -97,10 +97,10 @@ class Booster: ...@@ -97,10 +97,10 @@ class Booster:
def boost( def boost(
self, self,
model: nn.Module, model: nn.Module,
optimizer: Optimizer, optimizer: Optional[Optimizer] = None,
criterion: Callable = None, criterion: Optional[Callable] = None,
dataloader: DataLoader = None, dataloader: Optional[DataLoader] = None,
lr_scheduler: LRScheduler = None, lr_scheduler: Optional[LRScheduler] = None,
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: ) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
""" """
Boost the model, optimizer, criterion, lr_scheduler, and dataloader. Boost the model, optimizer, criterion, lr_scheduler, and dataloader.
......
...@@ -115,10 +115,12 @@ class FP16TorchMixedPrecision(MixedPrecision): ...@@ -115,10 +115,12 @@ class FP16TorchMixedPrecision(MixedPrecision):
def configure(self, def configure(self,
model: nn.Module, model: nn.Module,
optimizer: Optimizer, optimizer: Optional[Optimizer] = None,
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]: criterion: Optional[Callable] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
model = TorchAMPModule(model) model = TorchAMPModule(model)
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs) if optimizer is not None:
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
if criterion is not None: if criterion is not None:
criterion = TorchAMPModule(criterion) criterion = TorchAMPModule(criterion)
return model, optimizer, criterion return model, optimizer, criterion
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, Tuple from typing import Callable, Optional, Tuple
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
...@@ -15,7 +15,8 @@ class MixedPrecision(ABC): ...@@ -15,7 +15,8 @@ class MixedPrecision(ABC):
@abstractmethod @abstractmethod
def configure(self, def configure(self,
model: nn.Module, model: nn.Module,
optimizer: Optimizer, optimizer: Optional[Optimizer] = None,
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]: criterion: Optional[Callable] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
# TODO: implement this method # TODO: implement this method
pass pass
...@@ -274,11 +274,11 @@ class GeminiPlugin(DPPluginBase): ...@@ -274,11 +274,11 @@ class GeminiPlugin(DPPluginBase):
def configure( def configure(
self, self,
model: nn.Module, model: nn.Module,
optimizer: Optimizer, optimizer: Optional[Optimizer] = None,
criterion: Callable = None, criterion: Optional[Callable] = None,
dataloader: DataLoader = None, dataloader: Optional[DataLoader] = None,
lr_scheduler: LRScheduler = None, lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
# convert model to sync bn # convert model to sync bn
...@@ -293,8 +293,12 @@ class GeminiPlugin(DPPluginBase): ...@@ -293,8 +293,12 @@ class GeminiPlugin(DPPluginBase):
# wrap the model with Gemini # wrap the model with Gemini
model = GeminiModel(model, self.gemini_config, self.verbose) model = GeminiModel(model, self.gemini_config, self.verbose)
if not isinstance(optimizer, OptimizerWrapper): if optimizer is not None and \
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs, not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(model.unwrap(),
optimizer,
self.zero_optim_config,
self.optim_kwargs,
self.verbose) self.verbose)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler
......
...@@ -197,17 +197,21 @@ class LowLevelZeroPlugin(DPPluginBase): ...@@ -197,17 +197,21 @@ class LowLevelZeroPlugin(DPPluginBase):
def configure( def configure(
self, self,
model: nn.Module, model: nn.Module,
optimizer: Optimizer, optimizer: Optional[Optimizer] = None,
criterion: Callable = None, criterion: Optional[Callable] = None,
dataloader: DataLoader = None, dataloader: Optional[DataLoader] = None,
lr_scheduler: LRScheduler = None, lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.stage, self.precision) model = LowLevelZeroModel(model, self.stage, self.precision)
if not isinstance(optimizer, OptimizerWrapper): if optimizer is not None and \
optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs, not isinstance(optimizer, OptimizerWrapper):
optimizer = LowLevelZeroOptimizer(model.unwrap(),
optimizer,
self.zero_optim_config,
self.optim_kwargs,
self.verbose) self.verbose)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, Iterator, List, Tuple, Union from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
...@@ -38,11 +38,11 @@ class Plugin(ABC): ...@@ -38,11 +38,11 @@ class Plugin(ABC):
def configure( def configure(
self, self,
model: nn.Module, model: nn.Module,
optimizer: Optimizer, optimizer: Optional[Optimizer] = None,
criterion: Callable = None, criterion: Optional[Callable] = None,
dataloader: DataLoader = None, dataloader: Optional[DataLoader] = None,
lr_scheduler: LRScheduler = None, lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# implement this method # implement this method
pass pass
......
...@@ -138,11 +138,11 @@ class TorchDDPPlugin(DPPluginBase): ...@@ -138,11 +138,11 @@ class TorchDDPPlugin(DPPluginBase):
def configure( def configure(
self, self,
model: nn.Module, model: nn.Module,
optimizer: Optimizer, optimizer: Optional[Optimizer] = None,
criterion: Callable = None, criterion: Optional[Callable] = None,
dataloader: DataLoader = None, dataloader: Optional[DataLoader] = None,
lr_scheduler: LRScheduler = None, lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# cast model to cuda # cast model to cuda
model = model.cuda() model = model.cuda()
...@@ -152,7 +152,8 @@ class TorchDDPPlugin(DPPluginBase): ...@@ -152,7 +152,8 @@ class TorchDDPPlugin(DPPluginBase):
# wrap the model with PyTorch DDP # wrap the model with PyTorch DDP
model = TorchDDPModel(model, **self.ddp_kwargs) model = TorchDDPModel(model, **self.ddp_kwargs)
if not isinstance(optimizer, OptimizerWrapper): if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer) optimizer = OptimizerWrapper(optimizer)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler
......
...@@ -195,23 +195,24 @@ class TorchFSDPPlugin(DPPluginBase): ...@@ -195,23 +195,24 @@ class TorchFSDPPlugin(DPPluginBase):
def configure( def configure(
self, self,
model: nn.Module, model: nn.Module,
optimizer: Optimizer, optimizer: Optional[Optimizer] = None,
criterion: Callable = None, criterion: Optional[Callable] = None,
dataloader: DataLoader = None, dataloader: Optional[DataLoader] = None,
lr_scheduler: LRScheduler = None, lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# wrap the model with PyTorch FSDP # wrap the model with PyTorch FSDP
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs) fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
if len(optimizer.param_groups) > 1: if optimizer is not None:
warnings.warn( if len(optimizer.param_groups) > 1:
'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.' warnings.warn(
) 'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.'
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults) )
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
if not isinstance(optimizer, FSDPOptimizerWrapper): if not isinstance(optimizer, FSDPOptimizerWrapper):
optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model) optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model)
return fsdp_model, optimizer, criterion, dataloader, lr_scheduler return fsdp_model, optimizer, criterion, dataloader, lr_scheduler
......
...@@ -4,12 +4,15 @@ import pytest ...@@ -4,12 +4,15 @@ import pytest
import torch import torch
try: try:
from diffusers import UNet2DModel import diffusers
MODELS = [UNet2DModel] MODELS = [diffusers.UNet2DModel]
HAS_REPO = True HAS_REPO = True
from packaging import version
SKIP_UNET_TEST = version.parse(diffusers.__version__) > version.parse("0.10.2")
except: except:
MODELS = [] MODELS = []
HAS_REPO = False HAS_REPO = False
SKIP_UNET_TEST = False
from test_autochunk_diffuser_utils import run_test from test_autochunk_diffuser_utils import run_test
...@@ -32,6 +35,10 @@ def get_data(shape: tuple) -> Tuple[List, List]: ...@@ -32,6 +35,10 @@ def get_data(shape: tuple) -> Tuple[List, List]:
return meta_args, concrete_args return meta_args, concrete_args
@pytest.mark.skipif(
SKIP_UNET_TEST,
reason="diffusers version > 0.10.2",
)
@pytest.mark.skipif( @pytest.mark.skipif(
not (AUTOCHUNK_AVAILABLE and HAS_REPO), not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0", reason="torch version is lower than 1.12.0",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment