Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
fae6c92e
Unverified
Commit
fae6c92e
authored
Sep 05, 2023
by
Hongxin Liu
Committed by
GitHub
Sep 05, 2023
Browse files
Merge branch 'main' into feature/shardformer
parents
bd186784
ac178ca5
Changes
113
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
129 additions
and
70 deletions
+129
-70
colossalai/auto_parallel/tensor_shard/node_handler/registry.py
...salai/auto_parallel/tensor_shard/node_handler/registry.py
+1
-1
colossalai/booster/plugin/low_level_zero_plugin.py
colossalai/booster/plugin/low_level_zero_plugin.py
+83
-46
colossalai/context/parallel_context.py
colossalai/context/parallel_context.py
+1
-1
colossalai/context/process_group_initializer/initializer_1d.py
...salai/context/process_group_initializer/initializer_1d.py
+2
-1
colossalai/context/process_group_initializer/initializer_2d.py
...salai/context/process_group_initializer/initializer_2d.py
+1
-1
colossalai/context/process_group_initializer/initializer_2p5d.py
...lai/context/process_group_initializer/initializer_2p5d.py
+2
-1
colossalai/context/process_group_initializer/initializer_3d.py
...salai/context/process_group_initializer/initializer_3d.py
+1
-1
colossalai/context/process_group_initializer/initializer_data.py
...lai/context/process_group_initializer/initializer_data.py
+1
-1
colossalai/context/process_group_initializer/initializer_model.py
...ai/context/process_group_initializer/initializer_model.py
+4
-2
colossalai/context/process_group_initializer/initializer_pipeline.py
...context/process_group_initializer/initializer_pipeline.py
+1
-1
colossalai/context/process_group_initializer/initializer_sequence.py
...context/process_group_initializer/initializer_sequence.py
+1
-1
colossalai/context/process_group_initializer/initializer_tensor.py
...i/context/process_group_initializer/initializer_tensor.py
+3
-2
colossalai/initialize.py
colossalai/initialize.py
+4
-4
colossalai/interface/__init__.py
colossalai/interface/__init__.py
+2
-2
colossalai/interface/model.py
colossalai/interface/model.py
+11
-0
colossalai/legacy/__init__.py
colossalai/legacy/__init__.py
+0
-0
colossalai/legacy/builder/__init__.py
colossalai/legacy/builder/__init__.py
+0
-0
colossalai/legacy/builder/builder.py
colossalai/legacy/builder/builder.py
+2
-2
colossalai/legacy/engine/__init__.py
colossalai/legacy/engine/__init__.py
+0
-0
colossalai/legacy/engine/_base_engine.py
colossalai/legacy/engine/_base_engine.py
+9
-3
No files found.
colossalai/auto_parallel/tensor_shard/node_handler/registry.py
View file @
fae6c92e
class
Registry
:
# TODO: refactor the registry classes used in colossalai.registry, colossalai.fx and here
# TODO: refactor the registry classes used in colossalai.
legacy.
registry, colossalai.fx and here
def
__init__
(
self
,
name
):
self
.
name
=
name
...
...
colossalai/booster/plugin/low_level_zero_plugin.py
View file @
fae6c92e
...
...
@@ -3,6 +3,7 @@ import os
import
warnings
from
functools
import
partial
from
pathlib
import
Path
from
types
import
MethodType
from
typing
import
Callable
,
Iterator
,
List
,
Optional
,
Tuple
,
Union
import
torch
...
...
@@ -25,9 +26,9 @@ from colossalai.checkpoint_io.utils import (
sharded_optimizer_loading_epilogue
,
unwrap_optimizer
,
)
from
colossalai.interface
import
ModelWrapper
,
OptimizerWrapper
from
colossalai.interface
import
AMPModelMixin
,
ModelWrapper
,
OptimizerWrapper
from
colossalai.utils
import
get_current_device
from
colossalai.zero
import
LowLevelZeroOptimizer
,
zero_model_wrapper
,
zero_optim_wrapper
from
colossalai.zero
import
LowLevelZeroOptimizer
from
.dp_plugin_base
import
DPPluginBase
from
.torch_ddp_plugin
import
TorchDDPCheckpointIO
...
...
@@ -44,6 +45,34 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
SUPPORTED_PRECISION
=
[
'fp16'
,
'bf16'
,
'fp32'
]
class
LowLevelZeroModel
(
ModelWrapper
,
AMPModelMixin
):
def
__init__
(
self
,
module
:
nn
.
Module
,
precision
:
str
)
->
None
:
super
().
__init__
(
module
)
self
.
dtype
=
None
if
precision
==
'fp16'
:
self
.
dtype
=
torch
.
float16
elif
precision
==
'bf16'
:
self
.
dtype
=
torch
.
bfloat16
if
self
.
dtype
is
not
None
:
module
=
module
.
to
(
self
.
dtype
)
module
=
module
.
to
(
get_current_device
())
self
.
module
=
module
self
.
convert_fn
=
None
if
self
.
dtype
is
not
None
:
self
.
convert_fn
=
partial
(
_convert_floating_point
,
dtype
=
self
.
dtype
)
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
convert_fn
is
not
None
:
args
=
tree_map
(
self
.
convert_fn
,
args
)
kwargs
=
tree_map
(
self
.
convert_fn
,
kwargs
)
return
super
().
forward
(
*
args
,
**
kwargs
)
def
unwrap
(
self
):
# TODO(ver217): this is a workaround for loading model
return
self
class
LowLevelZeroCheckpointIO
(
TorchDDPCheckpointIO
):
def
save_unsharded_optimizer
(
self
,
optimizer
:
OptimizerWrapper
,
checkpoint
:
str
,
gather_dtensor
:
bool
=
False
):
...
...
@@ -165,30 +194,36 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
sharded_optimizer_loading_epilogue
(
optimizer
)
def
save_unsharded_model
(
self
,
model
:
LowLevelZeroModel
,
checkpoint
:
str
,
gather_dtensor
:
bool
,
use_safetensors
:
bool
):
assert
isinstance
(
model
,
LowLevelZeroModel
)
super
().
save_unsharded_model
(
model
.
module
,
checkpoint
,
gather_dtensor
,
use_safetensors
)
class
LowLevelZeroModel
(
ModelWrapper
):
def
__init__
(
self
,
module
:
nn
.
Module
,
stage
:
int
,
precision
:
str
)
->
None
:
super
().
__init__
(
module
)
self
.
dtype
=
None
if
precision
==
'fp16'
:
self
.
dtype
=
torch
.
float16
elif
precision
==
'bf16'
:
self
.
dtype
=
torch
.
bfloat16
module
=
zero_model_wrapper
(
module
,
zero_stage
=
stage
)
if
self
.
dtype
is
not
None
:
module
=
module
.
to
(
self
.
dtype
)
module
=
module
.
to
(
get_current_device
())
self
.
module
=
module
self
.
convert_fn
=
None
if
self
.
dtype
is
not
None
:
self
.
convert_fn
=
partial
(
_convert_floating_point
,
dtype
=
self
.
dtype
)
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
convert_fn
is
not
None
:
args
=
tree_map
(
self
.
convert_fn
,
args
)
kwargs
=
tree_map
(
self
.
convert_fn
,
kwargs
)
return
super
().
forward
(
*
args
,
**
kwargs
)
def
save_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint_path
:
str
,
gather_dtensor
:
bool
=
True
,
prefix
:
Optional
[
str
]
=
None
,
max_shard_size
:
int
=
1024
,
use_safetensors
:
bool
=
False
):
assert
isinstance
(
model
,
LowLevelZeroModel
)
super
().
save_sharded_model
(
model
.
module
,
checkpoint_path
,
gather_dtensor
,
prefix
,
max_shard_size
,
use_safetensors
)
def
load_unsharded_model
(
self
,
model
:
LowLevelZeroModel
,
checkpoint
:
str
,
strict
:
bool
=
True
):
assert
isinstance
(
model
,
LowLevelZeroModel
)
super
().
load_unsharded_model
(
model
.
module
,
checkpoint
,
strict
)
model
.
update_master_params
()
def
load_sharded_model
(
self
,
model
:
LowLevelZeroModel
,
checkpoint_index_file
:
Path
,
strict
:
bool
=
False
,
use_safetensors
:
bool
=
False
,
load_sub_module
:
bool
=
True
):
assert
isinstance
(
model
,
LowLevelZeroModel
)
super
().
load_sharded_model
(
model
.
module
,
checkpoint_index_file
,
strict
,
use_safetensors
,
load_sub_module
)
model
.
update_master_params
()
class
LowLevelZeroPlugin
(
DPPluginBase
):
...
...
@@ -248,22 +283,24 @@ class LowLevelZeroPlugin(DPPluginBase):
super
().
__init__
()
assert
stage
in
(
1
,
2
),
f
'LowLevelZeroPlugin only supports stage 1/2 training'
assert
precision
in
SUPPORTED_PRECISION
,
f
'LowLevelZeroPlugin only supports
{
SUPPORTED_PRECISION
}
training'
assert
norm_type
==
2.0
,
f
'LowLevelZeroPlugin only supports norm_type=2.0 now'
self
.
stage
=
stage
self
.
precision
=
precision
self
.
zero_optim_config
=
dict
(
reduce_bucket_size
=
reduce_bucket_size_in_m
*
1024
*
1024
,
communication_dtype
=
communication_dtype
,
overlap_communication
=
overlap_communication
,
cpu_offload
=
cpu_offload
)
self
.
optim_kwargs
=
dict
(
initial_scale
=
initial_scale
,
self
.
zero_optim_kwargs
=
dict
(
initial_scale
=
initial_scale
,
growth_factor
=
growth_factor
,
backoff_factor
=
backoff_factor
,
growth_interval
=
growth_interval
,
hysteresis
=
hysteresis
,
min_scale
=
min_scale
,
max_scale
=
max_scale
,
max_norm
=
max_norm
,
norm_type
=
norm_type
)
clip_grad_norm
=
max_norm
,
reduce_bucket_size
=
reduce_bucket_size_in_m
*
1024
*
1024
,
communication_dtype
=
communication_dtype
,
overlap_communication
=
overlap_communication
,
cpu_offload
=
cpu_offload
,
partition_grad
=
(
stage
==
2
),
)
self
.
verbose
=
verbose
# set class name with stage, for better error message
...
...
@@ -294,15 +331,15 @@ class LowLevelZeroPlugin(DPPluginBase):
)
->
Tuple
[
nn
.
Module
,
OptimizerWrapper
,
Callable
,
DataLoader
,
LRScheduler
]:
if
not
isinstance
(
model
,
ModelWrapper
):
model
=
LowLevelZeroModel
(
model
,
self
.
stage
,
self
.
precision
)
model
=
LowLevelZeroModel
(
model
,
self
.
precision
)
if
optimizer
is
not
None
and
\
not
isinstance
(
optimizer
,
OptimizerWrapper
):
optimizer
=
zero_optim_wrapper
(
model
.
unwrap
(),
optimizer
,
optim_config
=
self
.
zero_optim_config
,
**
self
.
optim_kwargs
,
optimizer
:
LowLevelZeroOptimizer
=
LowLevelZeroOptimizer
(
optimizer
,
**
self
.
zero_optim_kwargs
,
verbose
=
self
.
verbose
)
# inject update_master_params
model
.
update_master_params
=
MethodType
(
optimizer
.
update_master_params
,
model
)
return
model
,
optimizer
,
criterion
,
dataloader
,
lr_scheduler
...
...
colossalai/context/parallel_context.py
View file @
fae6c92e
...
...
@@ -15,8 +15,8 @@ from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
from
colossalai.context.config
import
Config
from
colossalai.context.singleton_meta
import
SingletonMeta
from
colossalai.global_variables
import
tensor_parallel_env
as
env
from
colossalai.legacy.registry
import
DIST_GROUP_INITIALIZER
from
colossalai.logging
import
get_dist_logger
from
colossalai.registry
import
DIST_GROUP_INITIALIZER
from
.parallel_mode
import
ParallelMode
from
.random
import
add_seed
,
get_seeds
,
set_mode
...
...
colossalai/context/process_group_initializer/initializer_1d.py
View file @
fae6c92e
...
...
@@ -2,8 +2,9 @@
# -*- encoding: utf-8 -*-
import
torch.distributed
as
dist
from
colossalai.global_variables
import
tensor_parallel_env
as
env
from
colossalai.registry
import
DIST_GROUP_INITIALIZER
from
colossalai.
legacy.
registry
import
DIST_GROUP_INITIALIZER
from
..parallel_mode
import
ParallelMode
from
.process_group_initializer
import
ProcessGroupInitializer
...
...
colossalai/context/process_group_initializer/initializer_2d.py
View file @
fae6c92e
...
...
@@ -3,7 +3,7 @@ import math
import
torch.distributed
as
dist
from
colossalai.global_variables
import
tensor_parallel_env
as
env
from
colossalai.registry
import
DIST_GROUP_INITIALIZER
from
colossalai.
legacy.
registry
import
DIST_GROUP_INITIALIZER
from
..parallel_mode
import
ParallelMode
from
.process_group_initializer
import
ProcessGroupInitializer
...
...
colossalai/context/process_group_initializer/initializer_2p5d.py
View file @
fae6c92e
...
...
@@ -4,9 +4,10 @@
import
math
import
torch.distributed
as
dist
from
colossalai.context
import
Config
from
colossalai.global_variables
import
tensor_parallel_env
as
env
from
colossalai.registry
import
DIST_GROUP_INITIALIZER
from
colossalai.
legacy.
registry
import
DIST_GROUP_INITIALIZER
from
..parallel_mode
import
ParallelMode
from
.process_group_initializer
import
ProcessGroupInitializer
...
...
colossalai/context/process_group_initializer/initializer_3d.py
View file @
fae6c92e
...
...
@@ -6,7 +6,7 @@ import math
import
torch.distributed
as
dist
from
colossalai.global_variables
import
tensor_parallel_env
as
env
from
colossalai.registry
import
DIST_GROUP_INITIALIZER
from
colossalai.
legacy.
registry
import
DIST_GROUP_INITIALIZER
from
..parallel_mode
import
ParallelMode
from
.process_group_initializer
import
ProcessGroupInitializer
...
...
colossalai/context/process_group_initializer/initializer_data.py
View file @
fae6c92e
...
...
@@ -3,7 +3,7 @@
from
torch
import
distributed
as
dist
from
colossalai.registry
import
DIST_GROUP_INITIALIZER
from
colossalai.
legacy.
registry
import
DIST_GROUP_INITIALIZER
from
..parallel_mode
import
ParallelMode
from
.process_group_initializer
import
ProcessGroupInitializer
...
...
colossalai/context/process_group_initializer/initializer_model.py
View file @
fae6c92e
...
...
@@ -2,9 +2,11 @@
# -*- encoding: utf-8 -*-
import
torch.distributed
as
dist
from
colossalai.registry
import
DIST_GROUP_INITIALIZER
from
.process_group_initializer
import
ProcessGroupInitializer
from
colossalai.legacy.registry
import
DIST_GROUP_INITIALIZER
from
..parallel_mode
import
ParallelMode
from
.process_group_initializer
import
ProcessGroupInitializer
@
DIST_GROUP_INITIALIZER
.
register_module
...
...
colossalai/context/process_group_initializer/initializer_pipeline.py
View file @
fae6c92e
...
...
@@ -3,7 +3,7 @@
from
torch
import
distributed
as
dist
from
colossalai.registry
import
DIST_GROUP_INITIALIZER
from
colossalai.
legacy.
registry
import
DIST_GROUP_INITIALIZER
from
..parallel_mode
import
ParallelMode
from
.process_group_initializer
import
ProcessGroupInitializer
...
...
colossalai/context/process_group_initializer/initializer_sequence.py
View file @
fae6c92e
...
...
@@ -2,7 +2,7 @@
# -*- encoding: utf-8 -*-
import
torch.distributed
as
dist
from
colossalai.registry
import
DIST_GROUP_INITIALIZER
from
colossalai.
legacy.
registry
import
DIST_GROUP_INITIALIZER
from
..parallel_mode
import
ParallelMode
from
.initializer_tensor
import
Initializer_Tensor
...
...
colossalai/context/process_group_initializer/initializer_tensor.py
View file @
fae6c92e
...
...
@@ -3,9 +3,10 @@
import
torch.distributed
as
dist
from
colossalai.registry
import
DIST_GROUP_INITIALIZER
from
.process_group_initializer
import
ProcessGroupInitializer
from
colossalai.
legacy.
registry
import
DIST_GROUP_INITIALIZER
from
..parallel_mode
import
ParallelMode
from
.process_group_initializer
import
ProcessGroupInitializer
@
DIST_GROUP_INITIALIZER
.
register_module
...
...
colossalai/initialize.py
View file @
fae6c92e
...
...
@@ -17,13 +17,13 @@ from torch.utils.data import DataLoader
from
colossalai.amp
import
AMP_TYPE
,
convert_to_amp
from
colossalai.amp.naive_amp
import
NaiveAMPModel
from
colossalai.builder.builder
import
build_gradient_handler
from
colossalai.context
import
Config
,
ConfigException
,
ParallelMode
from
colossalai.context.moe_context
import
MOE_CONTEXT
from
colossalai.core
import
global_context
as
gpc
from
colossalai.engine
import
Engine
from
colossalai.engine.gradient_accumulation
import
accumulate_gradient
from
colossalai.engine.schedule
import
(
from
colossalai.legacy.builder.builder
import
build_gradient_handler
from
colossalai.legacy.engine
import
Engine
from
colossalai.legacy.engine.gradient_accumulation
import
accumulate_gradient
from
colossalai.legacy.engine.schedule
import
(
InterleavedPipelineSchedule
,
NonPipelineSchedule
,
PipelineSchedule
,
...
...
colossalai/interface/__init__.py
View file @
fae6c92e
from
.model
import
ModelWrapper
from
.model
import
AMPModelMixin
,
ModelWrapper
from
.optimizer
import
OptimizerWrapper
__all__
=
[
'OptimizerWrapper'
,
'ModelWrapper'
]
__all__
=
[
'OptimizerWrapper'
,
'ModelWrapper'
,
'AMPModelMixin'
]
colossalai/interface/model.py
View file @
fae6c92e
...
...
@@ -23,3 +23,14 @@ class ModelWrapper(nn.Module):
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
module
(
*
args
,
**
kwargs
)
class
AMPModelMixin
:
"""This mixin class defines the interface for AMP training.
"""
def
update_master_params
(
self
):
"""
Update the master parameters for AMP training.
"""
pass
colossalai/legacy/__init__.py
0 → 100644
View file @
fae6c92e
colossalai/builder/__init__.py
→
colossalai/
legacy/
builder/__init__.py
View file @
fae6c92e
File moved
colossalai/builder/builder.py
→
colossalai/
legacy/
builder/builder.py
View file @
fae6c92e
...
...
@@ -3,7 +3,7 @@
import
inspect
from
colossalai.registry
import
*
from
colossalai.
legacy.
registry
import
*
def
build_from_config
(
module
,
config
:
dict
):
...
...
@@ -71,7 +71,7 @@ def build_gradient_handler(config, model, optimizer):
optimizer (:class:`torch.optim.Optimizer`): An optimizer object containing parameters for the gradient handler
Returns:
An object of :class:`colossalai.engine.BaseGradientHandler`
An object of :class:`colossalai.
legacy.
engine.BaseGradientHandler`
"""
config_
=
config
.
copy
()
config_
[
'model'
]
=
model
...
...
colossalai/engine/__init__.py
→
colossalai/
legacy/
engine/__init__.py
View file @
fae6c92e
File moved
colossalai/engine/_base_engine.py
→
colossalai/
legacy/
engine/_base_engine.py
View file @
fae6c92e
...
...
@@ -8,11 +8,17 @@ from torch import Tensor
from
torch.nn
import
Module
from
torch.nn.modules.loss
import
_Loss
from
colossalai.engine.gradient_handler
import
BaseGradientHandler
from
colossalai.engine.schedule
import
BaseSchedule
,
InterleavedPipelineSchedule
,
NonPipelineSchedule
,
PipelineSchedule
from
colossalai.legacy.engine.gradient_handler
import
BaseGradientHandler
from
colossalai.legacy.engine.schedule
import
(
BaseSchedule
,
InterleavedPipelineSchedule
,
NonPipelineSchedule
,
PipelineSchedule
,
)
from
colossalai.logging
import
get_dist_logger
from
colossalai.zero.legacy.gemini
import
BaseOpHook
,
register_ophooks_recursively
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.zero.legacy.gemini
import
BaseOpHook
,
register_ophooks_recursively
class
Engine
:
"""Basic engine class for training and evaluation. It runs a specific process method
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment