Commit f867365a authored by Jie Zhu's avatar Jie Zhu Committed by Frank Lee
Browse files

bug fix: pass hook_list to engine (#273)

* bug fix: pass hook_list to engine

* change parameter name
parent 5a560a06
...@@ -27,6 +27,7 @@ from colossalai.utils import (accumulate_gradient, get_current_device, ...@@ -27,6 +27,7 @@ from colossalai.utils import (accumulate_gradient, get_current_device,
is_using_ddp, is_using_pp, is_using_sequence, is_using_ddp, is_using_pp, is_using_sequence,
sync_model_param) sync_model_param)
from colossalai.zero import convert_to_zero, ShardedOptimizer from colossalai.zero import convert_to_zero, ShardedOptimizer
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
def get_default_parser(): def get_default_parser():
...@@ -228,6 +229,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]], ...@@ -228,6 +229,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
train_dataloader: Optional[Union[Iterable, List[Iterable]]] = None, train_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None, test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
lr_scheduler: _LRScheduler = None, lr_scheduler: _LRScheduler = None,
ophooks: List[BaseOpHook] = [],
verbose: bool = True verbose: bool = True
) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]: ) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
"""Core function to wrap the essential training components with our functionality based on the config which is """Core function to wrap the essential training components with our functionality based on the config which is
...@@ -412,7 +414,8 @@ def initialize(model: Union[nn.Module, List[nn.Module]], ...@@ -412,7 +414,8 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
optimizer=optimizer, optimizer=optimizer,
criterion=criterion, criterion=criterion,
gradient_handlers=gradient_handlers, gradient_handlers=gradient_handlers,
clip_grad_norm=clip_grad_norm clip_grad_norm=clip_grad_norm,
ophook_list=ophooks
) )
return engine, train_dataloader, test_dataloader, lr_scheduler return engine, train_dataloader, test_dataloader, lr_scheduler
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment