Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a9b8402d
Unverified
Commit
a9b8402d
authored
Mar 20, 2023
by
Frank Lee
Committed by
GitHub
Mar 20, 2023
Browse files
[booster] added the accelerator implementation (#3159)
parent
1ad3a636
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
72 additions
and
5 deletions
+72
-5
colossalai/booster/accelerator.py
colossalai/booster/accelerator.py
+44
-4
colossalai/booster/booster.py
colossalai/booster/booster.py
+14
-1
tests/test_booster/test_accelerator.py
tests/test_booster/test_accelerator.py
+13
-0
tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
...est_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
+1
-0
No files found.
colossalai/booster/accelerator.py
View file @
a9b8402d
...
@@ -3,12 +3,52 @@ import torch.nn as nn
...
@@ -3,12 +3,52 @@ import torch.nn as nn
__all__
=
[
'Accelerator'
]
__all__
=
[
'Accelerator'
]
_supported_devices
=
[
'cpu'
,
'cuda'
,
# To be supported
# 'xpu',
# 'npu',
# 'tpu',
]
class
Accelerator
:
class
Accelerator
:
"""
Accelerator is an abstraction for the hardware device that is used to run the model.
Args:
device (str): The device to be used. Currently only support 'cpu' and 'gpu'.
"""
def
__init__
(
self
,
device
:
torch
.
device
):
def
__init__
(
self
,
device
:
str
):
self
.
device
=
device
self
.
device
=
device
def
setup_model
(
self
,
model
:
nn
.
Module
)
->
nn
.
Module
:
assert
self
.
device
in
_supported_devices
,
f
"Device
{
self
.
device
}
is not supported yet, supported devices include
{
_supported_devices
}
"
# TODO: implement this method
pass
def
bind
(
self
):
"""
Set the default device for the current process.
"""
if
self
.
device
==
'cpu'
:
pass
elif
self
.
device
==
'cuda'
:
# TODO(FrankLeeeee): use global environment to check if it is a dist job
# if is_distributed:
# local_rank = EnvTable().get_local_rank()
# torch.cuda.set_device(torch.device(f'cuda:{local_rank}'))
torch
.
cuda
.
set_device
(
torch
.
device
(
'cuda'
))
pass
else
:
raise
ValueError
(
f
"Device
{
self
.
device
}
is not supported yet"
)
def
configure_model
(
self
,
model
:
nn
.
Module
)
->
nn
.
Module
:
"""
Move the model to the device.
Args:
model (nn.Module): The model to be moved.
"""
model
=
model
.
to
(
torch
.
device
(
self
.
device
))
return
model
colossalai/booster/booster.py
View file @
a9b8402d
...
@@ -8,6 +8,7 @@ from torch.optim import Optimizer
...
@@ -8,6 +8,7 @@ from torch.optim import Optimizer
from
torch.optim.lr_scheduler
import
_LRScheduler
as
LRScheduler
from
torch.optim.lr_scheduler
import
_LRScheduler
as
LRScheduler
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
.accelerator
import
Accelerator
from
.mixed_precision
import
MixedPrecision
,
mixed_precision_factory
from
.mixed_precision
import
MixedPrecision
,
mixed_precision_factory
from
.plugin
import
Plugin
from
.plugin
import
Plugin
...
@@ -51,9 +52,16 @@ class Booster:
...
@@ -51,9 +52,16 @@ class Booster:
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
device
:
Union
[
str
,
torch
.
device
]
=
'cuda'
,
device
:
str
=
'cuda'
,
mixed_precision
:
Union
[
MixedPrecision
,
str
]
=
None
,
mixed_precision
:
Union
[
MixedPrecision
,
str
]
=
None
,
plugin
:
Optional
[
Plugin
]
=
None
)
->
None
:
plugin
:
Optional
[
Plugin
]
=
None
)
->
None
:
# TODO(FrankLeeeee): add plugin control logic
# if self.plugin is not None and self.plugin.control_accelerator:
# ...
# create acclerator
self
.
acceleartor
=
Accelerator
(
device
)
self
.
acceleartor
.
set_default_device
()
# validate and set precision
# validate and set precision
if
isinstance
(
MixedPrecision
,
str
):
if
isinstance
(
MixedPrecision
,
str
):
# the user will take the default arguments for amp training
# the user will take the default arguments for amp training
...
@@ -78,6 +86,11 @@ class Booster:
...
@@ -78,6 +86,11 @@ class Booster:
lr_scheduler (LRScheduler): The lr_scheduler to be boosted.
lr_scheduler (LRScheduler): The lr_scheduler to be boosted.
dataloader (DataLoader): The dataloader to be boosted.
dataloader (DataLoader): The dataloader to be boosted.
"""
"""
# TODO(FrankLeeeee): add plugin control logic
# if self.plugin is not None and self.plugin.control_accelerator:
# ...
model
=
self
.
acceleartor
.
configure_model
(
model
)
# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
# TODO(lsg): Add plugin control logic
# TODO(lsg): Add plugin control logic
# e.g.
# e.g.
...
...
tests/test_booster/test_accelerator.py
0 → 100644
View file @
a9b8402d
import
pytest
import
torch.nn
as
nn
from
torchvision.models
import
resnet18
from
colossalai.booster.accelerator
import
Accelerator
@
pytest
.
mark
.
parametrize
(
'device'
,
[
'cpu'
,
'cuda'
])
def
test_accelerator
(
device
):
acceleartor
=
Accelerator
(
device
)
model
=
nn
.
Linear
(
8
,
8
)
model
=
acceleartor
.
configure_model
(
model
)
assert
next
(
model
.
parameters
()).
device
.
type
==
device
tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
View file @
a9b8402d
...
@@ -56,6 +56,7 @@ def test_torchrec_dlrm_models():
...
@@ -56,6 +56,7 @@ def test_torchrec_dlrm_models():
data
=
data_gen_fn
()
data
=
data_gen_fn
()
# dlrm_interactionarch is not supported
# dlrm_interactionarch is not supported
# TODO(FrankLeeeee): support this model
if
name
==
'dlrm_interactionarch'
:
if
name
==
'dlrm_interactionarch'
:
continue
continue
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment