Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
370f567e
Unverified
Commit
370f567e
authored
Mar 14, 2022
by
Jiarui Fang
Committed by
GitHub
Mar 14, 2022
Browse files
[zero] new interface for ShardedOptimv2 (#406)
parent
a9c27be4
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
51 additions
and
35 deletions
+51
-35
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+31
-6
tests/components_to_test/bert.py
tests/components_to_test/bert.py
+1
-4
tests/components_to_test/nested_model.py
tests/components_to_test/nested_model.py
+1
-4
tests/components_to_test/repeated_computed_layer.py
tests/components_to_test/repeated_computed_layer.py
+1
-4
tests/components_to_test/resnet.py
tests/components_to_test/resnet.py
+1
-4
tests/test_engine/test_engine.py
tests/test_engine/test_engine.py
+3
-3
tests/test_trainer/test_trainer_with_non_pipe_schedule.py
tests/test_trainer/test_trainer_with_non_pipe_schedule.py
+2
-2
tests/test_zero_data_parallel/test_sharded_optim_v2.py
tests/test_zero_data_parallel/test_sharded_optim_v2.py
+7
-5
tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py
...zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py
+4
-3
No files found.
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
370f567e
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Dict
,
Optional
from
typing
import
Callable
,
Dict
,
Optional
,
Union
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -15,7 +15,7 @@ from torch import Tensor
...
@@ -15,7 +15,7 @@ from torch import Tensor
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
typing
import
Type
,
Any
from
._utils
import
has_inf_or_nan
from
._utils
import
has_inf_or_nan
...
@@ -27,8 +27,8 @@ class OptimState(Enum):
...
@@ -27,8 +27,8 @@ class OptimState(Enum):
class
ShardedOptimizerV2
(
ColossalaiOptimizer
):
class
ShardedOptimizerV2
(
ColossalaiOptimizer
):
def
__init__
(
self
,
def
__init__
(
self
,
optimizer
:
Optimizer
,
sharded_model
:
ShardedModelV2
,
sharded_model
:
ShardedModelV2
,
optimizer_class
:
Type
[
Optimizer
],
shard_strategy
:
BaseShardStrategy
,
shard_strategy
:
BaseShardStrategy
,
cpu_offload
:
bool
=
False
,
cpu_offload
:
bool
=
False
,
initial_scale
:
float
=
2
**
32
,
initial_scale
:
float
=
2
**
32
,
...
@@ -39,9 +39,34 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -39,9 +39,34 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
hysteresis
:
float
=
2
,
hysteresis
:
float
=
2
,
max_scale
:
int
=
2
**
32
,
max_scale
:
int
=
2
**
32
,
dp_process_group
:
Optional
[
ProcessGroup
]
=
None
,
dp_process_group
:
Optional
[
ProcessGroup
]
=
None
,
mp_process_group
:
Optional
[
ProcessGroup
]
=
None
)
->
None
:
mp_process_group
:
Optional
[
ProcessGroup
]
=
None
,
**
defaults
:
Any
)
->
None
:
"""
:param sharded_model: A sharded model initialized by class ShardedModelV2
:type sharded_model: sharded_model
:param optimizer_class: A type of Optimizer
:type optimizer_class: Type[Optimizer]
:param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters.
:type shard_strategy: BaseShardStrategy
:param cpu_offload: is offloading the optimizer states to CPU.
:type cpu_offload: bool
:param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters.
:type shard_strategy: BaseShardStrategy
:**defaults: any trailing arguments, which are forwarded to the local optimizer.
:type defaults: dict()
"""
assert
isinstance
(
sharded_model
,
ShardedModelV2
),
'model must be wrapped with ShardedModel'
assert
isinstance
(
sharded_model
,
ShardedModelV2
),
'model must be wrapped with ShardedModel'
super
().
__init__
(
optimizer
)
self
.
_optim_defaults
=
defaults
# initialize the M, V as zeros tensors and initialize param fp32 from sharded_model.parameters()
self
.
optimizer
=
optimizer_class
(
sharded_model
.
parameters
(),
**
self
.
_optim_defaults
)
super
().
__init__
(
self
.
optimizer
)
self
.
shard_strategy
=
shard_strategy
self
.
shard_strategy
=
shard_strategy
self
.
model
:
ShardedModelV2
=
sharded_model
self
.
model
:
ShardedModelV2
=
sharded_model
if
cpu_offload
and
not
sharded_model
.
cpu_offload
:
if
cpu_offload
and
not
sharded_model
.
cpu_offload
:
...
@@ -65,7 +90,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -65,7 +90,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Store fp32 param shards
# Store fp32 param shards
self
.
master_params
:
Dict
[
Parameter
,
Tensor
]
=
{}
self
.
master_params
:
Dict
[
Parameter
,
Tensor
]
=
{}
for
group
in
optimizer
.
param_groups
:
for
group
in
self
.
optimizer
.
param_groups
:
for
p
in
group
[
'params'
]:
for
p
in
group
[
'params'
]:
assert
hasattr
(
p
,
'col_attr'
),
'The parameter must be wrapped with ShardedParam'
assert
hasattr
(
p
,
'col_attr'
),
'The parameter must be wrapped with ShardedParam'
is_param_sharded
=
p
.
col_attr
.
data
.
is_sharded
is_param_sharded
=
p
.
col_attr
.
data
.
is_sharded
...
...
tests/components_to_test/bert.py
View file @
370f567e
...
@@ -74,8 +74,5 @@ def get_training_components():
...
@@ -74,8 +74,5 @@ def get_training_components():
sequence_length
=
sequence_length
,
sequence_length
=
sequence_length
,
is_distrbuted
=
True
)
is_distrbuted
=
True
)
def
get_optim
(
model
):
return
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
criterion
=
None
criterion
=
None
return
bert_model_builder
,
trainloader
,
testloader
,
get_opti
m
,
criterion
return
bert_model_builder
,
trainloader
,
testloader
,
torch
.
optim
.
Ada
m
,
criterion
tests/components_to_test/nested_model.py
View file @
370f567e
...
@@ -49,8 +49,5 @@ def get_training_components():
...
@@ -49,8 +49,5 @@ def get_training_components():
trainloader
=
DummyDataLoader
()
trainloader
=
DummyDataLoader
()
testloader
=
DummyDataLoader
()
testloader
=
DummyDataLoader
()
def
optim_builder
(
model
):
return
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
return
model_builder
,
trainloader
,
testloader
,
optim_builder
,
criterion
return
model_builder
,
trainloader
,
testloader
,
torch
.
optim
.
Adam
,
criterion
tests/components_to_test/repeated_computed_layer.py
View file @
370f567e
...
@@ -43,8 +43,5 @@ def get_training_components():
...
@@ -43,8 +43,5 @@ def get_training_components():
trainloader
=
DummyDataLoader
()
trainloader
=
DummyDataLoader
()
testloader
=
DummyDataLoader
()
testloader
=
DummyDataLoader
()
def
optim_builder
(
model
):
return
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
return
model_builder
,
trainloader
,
testloader
,
optim_builder
,
criterion
return
model_builder
,
trainloader
,
testloader
,
torch
.
optim
.
Adam
,
criterion
tests/components_to_test/resnet.py
View file @
370f567e
...
@@ -29,8 +29,5 @@ def get_resnet_training_components():
...
@@ -29,8 +29,5 @@ def get_resnet_training_components():
trainloader
=
get_cifar10_dataloader
(
train
=
True
)
trainloader
=
get_cifar10_dataloader
(
train
=
True
)
testloader
=
get_cifar10_dataloader
(
train
=
False
)
testloader
=
get_cifar10_dataloader
(
train
=
False
)
def
optim_builder
(
model
):
return
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
return
model_builder
,
trainloader
,
testloader
,
optim_builder
,
criterion
return
model_builder
,
trainloader
,
testloader
,
torch
.
optim
.
Adam
,
criterion
tests/test_engine/test_engine.py
View file @
370f567e
...
@@ -19,11 +19,11 @@ def run_train():
...
@@ -19,11 +19,11 @@ def run_train():
# FIXME: test bert
# FIXME: test bert
for
model_name
in
test_models
:
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
_
,
optimizer_
builder
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
_
,
optimizer_
class
,
criterion
=
get_components_func
()
model
=
model_builder
(
checkpoint
=
False
)
model
=
model_builder
(
checkpoint
=
False
)
engine
,
train_dataloader
,
*
args
=
colossalai
.
initialize
(
model
=
model
,
engine
,
train_dataloader
,
*
args
=
colossalai
.
initialize
(
model
=
model
,
optimizer
=
optimizer_
builder
(
model
),
optimizer
=
optimizer_
class
(
model
.
parameters
(),
lr
=
1e-3
),
criterion
=
criterion
,
criterion
=
criterion
,
train_dataloader
=
train_dataloader
)
train_dataloader
=
train_dataloader
)
...
@@ -84,7 +84,7 @@ def run_engine(rank, world_size, port):
...
@@ -84,7 +84,7 @@ def run_engine(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
def
test_engine
():
def
test_engine
():
world_size
=
4
world_size
=
2
run_func
=
partial
(
run_engine
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_engine
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
tests/test_trainer/test_trainer_with_non_pipe_schedule.py
View file @
370f567e
...
@@ -25,9 +25,9 @@ def run_trainer_no_pipeline(rank, world_size, port):
...
@@ -25,9 +25,9 @@ def run_trainer_no_pipeline(rank, world_size, port):
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'nested_model'
]
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'nested_model'
]
for
name
in
test_models
:
for
name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_
builder
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_
class
,
criterion
=
get_components_func
()
model
=
model_builder
()
model
=
model_builder
()
optimizer
=
optimizer_
builder
(
model
)
optimizer
=
optimizer_
class
(
model
.
parameters
(),
lr
=
1e-3
)
engine
,
train_dataloader
,
*
_
=
colossalai
.
initialize
(
model
=
model
,
engine
,
train_dataloader
,
*
_
=
colossalai
.
initialize
(
model
=
model
,
optimizer
=
optimizer
,
optimizer
=
optimizer
,
criterion
=
criterion
,
criterion
=
criterion
,
...
...
tests/test_zero_data_parallel/test_sharded_optim_v2.py
View file @
370f567e
...
@@ -44,19 +44,21 @@ def run_dist(rank, world_size, port, cpu_offload, shard_strategy):
...
@@ -44,19 +44,21 @@ def run_dist(rank, world_size, port, cpu_offload, shard_strategy):
shard_strategy
=
shard_strategy
()
shard_strategy
=
shard_strategy
()
for
model_name
in
test_models
:
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
model
,
train_dataloader
,
test_dataloader
,
optimizer
_class
,
criterion
=
get_components_func
()
model
=
model
(
checkpoint
=
True
).
cuda
()
model
=
model
(
checkpoint
=
True
).
cuda
()
zero_model
=
ShardedModelV2
(
copy
.
deepcopy
(
model
),
zero_model
=
ShardedModelV2
(
copy
.
deepcopy
(
model
),
shard_strategy
,
shard_strategy
,
offload_config
=
dict
(
device
=
'cpu'
)
if
cpu_offload
else
None
)
offload_config
=
dict
(
device
=
'cpu'
)
if
cpu_offload
else
None
)
if
dist
.
get_world_size
()
>
1
:
if
dist
.
get_world_size
()
>
1
:
model
=
DDP
(
model
)
model
=
DDP
(
model
)
optim
=
Adam
(
model
.
parameters
(),
lr
=
1e-3
)
lr
=
1e-3
sharded_optim
=
ShardedOptimizerV2
(
Adam
(
zero_model
.
parameters
(),
lr
=
1e-3
),
optim
=
optimizer_class
(
model
.
parameters
(),
lr
=
lr
)
zero_model
,
sharded_optim
=
ShardedOptimizerV2
(
zero_model
,
optimizer_class
,
shard_strategy
,
shard_strategy
,
cpu_offload
=
cpu_offload
,
cpu_offload
=
cpu_offload
,
initial_scale
=
2
**
5
)
initial_scale
=
2
**
5
,
lr
=
lr
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
if
i
>
2
:
break
break
...
...
tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py
View file @
370f567e
...
@@ -59,11 +59,12 @@ def run_dist(rank, world_size, port, shard_strategy):
...
@@ -59,11 +59,12 @@ def run_dist(rank, world_size, port, shard_strategy):
if
dist
.
get_world_size
()
>
1
:
if
dist
.
get_world_size
()
>
1
:
model
=
DDP
(
model
)
model
=
DDP
(
model
)
optim
=
Adam
(
model
.
parameters
(),
lr
=
1e-3
)
optim
=
Adam
(
model
.
parameters
(),
lr
=
1e-3
)
sharded_optim
=
ShardedOptimizerV2
(
CPUAdam
(
zero_model
.
parameters
(),
lr
=
1e-3
)
,
sharded_optim
=
ShardedOptimizerV2
(
zero_model
,
zero_model
,
CPUAdam
,
shard_strategy
,
shard_strategy
,
initial_scale
=
2
**
5
,
initial_scale
=
2
**
5
,
cpu_offload
=
True
)
cpu_offload
=
True
,
lr
=
1e-3
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
if
i
>
2
:
break
break
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment