Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
370f567e
Unverified
Commit
370f567e
authored
Mar 14, 2022
by
Jiarui Fang
Committed by
GitHub
Mar 14, 2022
Browse files
[zero] new interface for ShardedOptimv2 (#406)
parent
a9c27be4
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
51 additions
and
35 deletions
+51
-35
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+31
-6
tests/components_to_test/bert.py
tests/components_to_test/bert.py
+1
-4
tests/components_to_test/nested_model.py
tests/components_to_test/nested_model.py
+1
-4
tests/components_to_test/repeated_computed_layer.py
tests/components_to_test/repeated_computed_layer.py
+1
-4
tests/components_to_test/resnet.py
tests/components_to_test/resnet.py
+1
-4
tests/test_engine/test_engine.py
tests/test_engine/test_engine.py
+3
-3
tests/test_trainer/test_trainer_with_non_pipe_schedule.py
tests/test_trainer/test_trainer_with_non_pipe_schedule.py
+2
-2
tests/test_zero_data_parallel/test_sharded_optim_v2.py
tests/test_zero_data_parallel/test_sharded_optim_v2.py
+7
-5
tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py
...zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py
+4
-3
No files found.
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
370f567e
from
enum
import
Enum
from
typing
import
Dict
,
Optional
from
typing
import
Callable
,
Dict
,
Optional
,
Union
import
torch
import
torch.distributed
as
dist
...
...
@@ -15,7 +15,7 @@ from torch import Tensor
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.optim
import
Optimizer
from
typing
import
Type
,
Any
from
._utils
import
has_inf_or_nan
...
...
@@ -27,8 +27,8 @@ class OptimState(Enum):
class
ShardedOptimizerV2
(
ColossalaiOptimizer
):
def
__init__
(
self
,
optimizer
:
Optimizer
,
sharded_model
:
ShardedModelV2
,
optimizer_class
:
Type
[
Optimizer
],
shard_strategy
:
BaseShardStrategy
,
cpu_offload
:
bool
=
False
,
initial_scale
:
float
=
2
**
32
,
...
...
@@ -39,9 +39,34 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
hysteresis
:
float
=
2
,
max_scale
:
int
=
2
**
32
,
dp_process_group
:
Optional
[
ProcessGroup
]
=
None
,
mp_process_group
:
Optional
[
ProcessGroup
]
=
None
)
->
None
:
mp_process_group
:
Optional
[
ProcessGroup
]
=
None
,
**
defaults
:
Any
)
->
None
:
"""
:param sharded_model: A sharded model initialized by class ShardedModelV2
:type sharded_model: sharded_model
:param optimizer_class: A type of Optimizer
:type optimizer_class: Type[Optimizer]
:param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters.
:type shard_strategy: BaseShardStrategy
:param cpu_offload: is offloading the optimizer states to CPU.
:type cpu_offload: bool
:param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters.
:type shard_strategy: BaseShardStrategy
:**defaults: any trailing arguments, which are forwarded to the local optimizer.
:type defaults: dict()
"""
assert
isinstance
(
sharded_model
,
ShardedModelV2
),
'model must be wrapped with ShardedModel'
super
().
__init__
(
optimizer
)
self
.
_optim_defaults
=
defaults
# initialize the M, V as zeros tensors and initialize param fp32 from sharded_model.parameters()
self
.
optimizer
=
optimizer_class
(
sharded_model
.
parameters
(),
**
self
.
_optim_defaults
)
super
().
__init__
(
self
.
optimizer
)
self
.
shard_strategy
=
shard_strategy
self
.
model
:
ShardedModelV2
=
sharded_model
if
cpu_offload
and
not
sharded_model
.
cpu_offload
:
...
...
@@ -65,7 +90,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Store fp32 param shards
self
.
master_params
:
Dict
[
Parameter
,
Tensor
]
=
{}
for
group
in
optimizer
.
param_groups
:
for
group
in
self
.
optimizer
.
param_groups
:
for
p
in
group
[
'params'
]:
assert
hasattr
(
p
,
'col_attr'
),
'The parameter must be wrapped with ShardedParam'
is_param_sharded
=
p
.
col_attr
.
data
.
is_sharded
...
...
tests/components_to_test/bert.py
View file @
370f567e
...
...
@@ -74,8 +74,5 @@ def get_training_components():
sequence_length
=
sequence_length
,
is_distrbuted
=
True
)
def
get_optim
(
model
):
return
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
criterion
=
None
return
bert_model_builder
,
trainloader
,
testloader
,
get_opti
m
,
criterion
return
bert_model_builder
,
trainloader
,
testloader
,
torch
.
optim
.
Ada
m
,
criterion
tests/components_to_test/nested_model.py
View file @
370f567e
...
...
@@ -49,8 +49,5 @@ def get_training_components():
trainloader
=
DummyDataLoader
()
testloader
=
DummyDataLoader
()
def
optim_builder
(
model
):
return
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
return
model_builder
,
trainloader
,
testloader
,
optim_builder
,
criterion
return
model_builder
,
trainloader
,
testloader
,
torch
.
optim
.
Adam
,
criterion
tests/components_to_test/repeated_computed_layer.py
View file @
370f567e
...
...
@@ -43,8 +43,5 @@ def get_training_components():
trainloader
=
DummyDataLoader
()
testloader
=
DummyDataLoader
()
def
optim_builder
(
model
):
return
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
return
model_builder
,
trainloader
,
testloader
,
optim_builder
,
criterion
return
model_builder
,
trainloader
,
testloader
,
torch
.
optim
.
Adam
,
criterion
tests/components_to_test/resnet.py
View file @
370f567e
...
...
@@ -29,8 +29,5 @@ def get_resnet_training_components():
trainloader
=
get_cifar10_dataloader
(
train
=
True
)
testloader
=
get_cifar10_dataloader
(
train
=
False
)
def
optim_builder
(
model
):
return
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
return
model_builder
,
trainloader
,
testloader
,
optim_builder
,
criterion
return
model_builder
,
trainloader
,
testloader
,
torch
.
optim
.
Adam
,
criterion
tests/test_engine/test_engine.py
View file @
370f567e
...
...
@@ -19,11 +19,11 @@ def run_train():
# FIXME: test bert
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
_
,
optimizer_
builder
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
_
,
optimizer_
class
,
criterion
=
get_components_func
()
model
=
model_builder
(
checkpoint
=
False
)
engine
,
train_dataloader
,
*
args
=
colossalai
.
initialize
(
model
=
model
,
optimizer
=
optimizer_
builder
(
model
),
optimizer
=
optimizer_
class
(
model
.
parameters
(),
lr
=
1e-3
),
criterion
=
criterion
,
train_dataloader
=
train_dataloader
)
...
...
@@ -84,7 +84,7 @@ def run_engine(rank, world_size, port):
@
pytest
.
mark
.
dist
def
test_engine
():
world_size
=
4
world_size
=
2
run_func
=
partial
(
run_engine
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
tests/test_trainer/test_trainer_with_non_pipe_schedule.py
View file @
370f567e
...
...
@@ -25,9 +25,9 @@ def run_trainer_no_pipeline(rank, world_size, port):
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'nested_model'
]
for
name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_
builder
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_
class
,
criterion
=
get_components_func
()
model
=
model_builder
()
optimizer
=
optimizer_
builder
(
model
)
optimizer
=
optimizer_
class
(
model
.
parameters
(),
lr
=
1e-3
)
engine
,
train_dataloader
,
*
_
=
colossalai
.
initialize
(
model
=
model
,
optimizer
=
optimizer
,
criterion
=
criterion
,
...
...
tests/test_zero_data_parallel/test_sharded_optim_v2.py
View file @
370f567e
...
...
@@ -44,19 +44,21 @@ def run_dist(rank, world_size, port, cpu_offload, shard_strategy):
shard_strategy
=
shard_strategy
()
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
model
,
train_dataloader
,
test_dataloader
,
optimizer
_class
,
criterion
=
get_components_func
()
model
=
model
(
checkpoint
=
True
).
cuda
()
zero_model
=
ShardedModelV2
(
copy
.
deepcopy
(
model
),
shard_strategy
,
offload_config
=
dict
(
device
=
'cpu'
)
if
cpu_offload
else
None
)
if
dist
.
get_world_size
()
>
1
:
model
=
DDP
(
model
)
optim
=
Adam
(
model
.
parameters
(),
lr
=
1e-3
)
sharded_optim
=
ShardedOptimizerV2
(
Adam
(
zero_model
.
parameters
(),
lr
=
1e-3
),
zero_model
,
lr
=
1e-3
optim
=
optimizer_class
(
model
.
parameters
(),
lr
=
lr
)
sharded_optim
=
ShardedOptimizerV2
(
zero_model
,
optimizer_class
,
shard_strategy
,
cpu_offload
=
cpu_offload
,
initial_scale
=
2
**
5
)
initial_scale
=
2
**
5
,
lr
=
lr
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
break
...
...
tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py
View file @
370f567e
...
...
@@ -59,11 +59,12 @@ def run_dist(rank, world_size, port, shard_strategy):
if
dist
.
get_world_size
()
>
1
:
model
=
DDP
(
model
)
optim
=
Adam
(
model
.
parameters
(),
lr
=
1e-3
)
sharded_optim
=
ShardedOptimizerV2
(
CPUAdam
(
zero_model
.
parameters
(),
lr
=
1e-3
)
,
zero_model
,
sharded_optim
=
ShardedOptimizerV2
(
zero_model
,
CPUAdam
,
shard_strategy
,
initial_scale
=
2
**
5
,
cpu_offload
=
True
)
cpu_offload
=
True
,
lr
=
1e-3
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
break
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment