Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
6a318816
Commit
6a318816
authored
Mar 09, 2022
by
Frank Lee
Browse files
set criterion as optional in colossalai initialize (#336)
parent
3213554c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
47 additions
and
40 deletions
+47
-40
colossalai/amp/torch_amp/__init__.py
colossalai/amp/torch_amp/__init__.py
+9
-6
colossalai/engine/_base_engine.py
colossalai/engine/_base_engine.py
+13
-7
colossalai/initialize.py
colossalai/initialize.py
+25
-27
No files found.
colossalai/amp/torch_amp/__init__.py
View file @
6a318816
...
...
@@ -3,12 +3,13 @@ from torch.optim import Optimizer
from
torch.nn.modules.loss
import
_Loss
from
colossalai.context
import
Config
from
.torch_amp
import
TorchAMPOptimizer
,
TorchAMPModel
,
TorchAMPLoss
from
typing
import
Optional
def
convert_to_torch_amp
(
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
criterion
:
_Loss
,
amp_config
:
Config
):
criterion
:
Optional
[
_Loss
]
=
None
,
amp_config
:
Optional
[
Config
]
=
None
):
"""A helper function to wrap training components with Torch AMP modules
:param model: your model object
...
...
@@ -16,15 +17,17 @@ def convert_to_torch_amp(model: nn.Module,
:param optimizer: your optimizer object
:type optimizer: :class:`torch.optim.Optimzer`
:param criterion: your loss function object
:type criterion: :class:`torch.nn.modules.loss._Loss`
:type criterion: :class:`torch.nn.modules.loss._Loss`
, optional
:param amp_config: configuration for different amp modes
:type amp_config: :class:`colossalai.context.Config` or dict
:type amp_config: :class:`colossalai.context.Config` or dict, optional
:return: (model, optimizer, criterion)
:rtype: Tuple
"""
model
=
TorchAMPModel
(
model
)
if
amp_config
is
None
:
amp_config
=
dict
()
optimizer
=
TorchAMPOptimizer
(
optimizer
,
**
amp_config
)
if
criterion
:
criterion
=
TorchAMPLoss
(
criterion
)
return
model
,
optimizer
,
criterion
...
...
colossalai/engine/_base_engine.py
View file @
6a318816
...
...
@@ -9,6 +9,8 @@ from torch.optim import Optimizer
from
colossalai.logging
import
get_dist_logger
from
torch
import
Tensor
from
colossalai.engine.ophooks
import
register_ophooks_recursively
,
BaseOpHook
from
typing
import
Optional
from
colossalai.engine.gradient_handler
import
BaseGradientHandler
class
Engine
:
...
...
@@ -21,9 +23,9 @@ class Engine:
:param optimizer: Optimizer for updating the parameters
:type optimizer: ``torch.optim.Optimizer``
:param criterion: Loss function for calculating loss
:type criterion: ``torch.nn.modules.loss._Loss``
:type criterion: ``torch.nn.modules.loss._Loss``
, optional
:param gradient_handlers: A list of gradient handler used in backward
:type gradient_handlers: list
:type gradient_handlers:
a
list
of ``BaseGradientHandler``, optional
:param clip_grad_norm: The norm of gradient clipping
:type clip_grad_norm: float, optional
:param ophook_list: List of ophook
...
...
@@ -31,13 +33,14 @@ class Engine:
:param verbose: whether to display log info
:type verbose: bool
"""
def
__init__
(
self
,
model
:
Module
,
optimizer
:
Optimizer
,
criterion
:
_Loss
,
gradient_handlers
:
List
=
None
,
criterion
:
Optional
[
_Loss
]
=
None
,
gradient_handlers
:
Optional
[
List
[
BaseGradientHandler
]]
=
None
,
clip_grad_norm
:
float
=
0.0
,
ophook_list
:
List
[
BaseOpHook
]
=
[]
,
ophook_list
:
Optional
[
List
[
BaseOpHook
]
]
=
None
,
verbose
:
bool
=
True
):
self
.
_model
=
model
self
.
_optimizer
=
optimizer
...
...
@@ -55,6 +58,9 @@ class Engine:
else
:
self
.
_gradient_handlers
=
[]
if
ophook_list
is
None
:
self
.
_ophook_list
=
[]
else
:
self
.
_ophook_list
=
ophook_list
register_ophooks_recursively
(
self
.
_model
,
self
.
_ophook_list
)
...
...
colossalai/initialize.py
View file @
6a318816
...
...
@@ -27,7 +27,7 @@ from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
from
colossalai.utils
import
(
accumulate_gradient
,
get_current_device
,
is_using_ddp
,
is_using_pp
,
is_using_sequence
,
sync_model_param
)
from
colossalai.zero
import
convert_to_zero
,
ShardedOptimizer
from
colossalai.engine.ophooks
import
register_ophooks_recursively
,
BaseOpHook
from
colossalai.engine.ophooks
import
BaseOpHook
def
get_default_parser
():
...
...
@@ -216,15 +216,14 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
verbose
=
verbose
)
def
initialize
(
model
:
Union
[
nn
.
Module
,
List
[
nn
.
Module
]],
optimizer
:
Union
[
Optimizer
,
List
[
Optimizer
]],
criterion
:
Union
[
_Loss
,
List
[
_Loss
]],
train_dataloader
:
Optional
[
Union
[
Iterable
,
List
[
Iterable
]]]
=
None
,
test_dataloader
:
Optional
[
Union
[
Iterable
,
List
[
Iterable
]]]
=
None
,
lr_scheduler
:
_LRScheduler
=
None
,
ophooks
:
List
[
BaseOpHook
]
=
[],
verbose
:
bool
=
True
)
->
Tuple
[
Engine
,
DataLoader
,
DataLoader
,
_LRScheduler
]:
def
initialize
(
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
criterion
:
Optional
[
_Loss
]
=
None
,
train_dataloader
:
Optional
[
Iterable
]
=
None
,
test_dataloader
:
Optional
[
Iterable
]
=
None
,
lr_scheduler
:
Optional
[
_LRScheduler
]
=
None
,
ophooks
:
Optional
[
List
[
BaseOpHook
]]
=
None
,
verbose
:
bool
=
True
)
->
Tuple
[
Engine
,
DataLoader
,
DataLoader
,
_LRScheduler
]:
"""Core function to wrap the essential training components with our functionality based on the config which is
loaded into gpc.config.
...
...
@@ -233,12 +232,12 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
:param optimizer: Your optimizer instance
:type optimizer: :class:`torch.optim.optimizer.Optimizer`
:param criterion: Your criterion instance
:type criterion: :class:`torch.nn.modules.loss._Loss`
:type criterion: :class:`torch.nn.modules.loss._Loss`
, optional
:param train_dataloader: Dataloader for training
:type train_dataloader: :class:`torch.utils.data.DataLoader`, optional
:param test_dataloader: Dataloader for testing
:type test_dataloader: :class:`torch.utils.data.DataLoader`, optional
:param lr_scheduler: Your lr scheduler instance
:param lr_scheduler: Your lr scheduler instance
, optional
:type lr_scheduler: :class:`torch.nn.lr_scheduler._LRScheduler`, optional
:param verbose: Whether to print logs
:type verbose: bool, optional
...
...
@@ -399,20 +398,19 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
# gradient accumulation
grad_accum_size
=
gpc
.
config
.
get
(
'gradient_accumulation'
,
None
)
if
grad_accum_size
is
not
None
:
optimizer
,
train_dataloader
,
gradient_handlers
,
lr_scheduler
=
accumulate_gradient
(
model
=
model
,
optimizer
,
train_dataloader
,
gradient_handlers
,
lr_scheduler
=
accumulate_gradient
(
model
=
model
,
optimizer
=
optimizer
,
dataloader
=
train_dataloader
,
accumulate_size
=
grad_accum_size
,
gradient_handlers
=
gradient_handlers
,
lr_scheduler
=
lr_scheduler
)
engine
=
Engine
(
model
=
model
,
engine
=
Engine
(
model
=
model
,
optimizer
=
optimizer
,
criterion
=
criterion
,
gradient_handlers
=
gradient_handlers
,
clip_grad_norm
=
clip_grad_norm
,
ophook_list
=
ophooks
)
ophook_list
=
ophooks
)
return
engine
,
train_dataloader
,
test_dataloader
,
lr_scheduler
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment