Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
32291dd7
Unverified
Commit
32291dd7
authored
May 26, 2022
by
Ziyue Jiang
Committed by
GitHub
May 26, 2022
Browse files
[Tensor] add module handler for linear (#1021)
* add module spec for linear * polish * polish * polish
parent
ee50497d
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
341 additions
and
2 deletions
+341
-2
colossalai/tensor/__init__.py
colossalai/tensor/__init__.py
+5
-1
colossalai/tensor/module_utils.py
colossalai/tensor/module_utils.py
+92
-0
colossalai/tensor/modules/__init__.py
colossalai/tensor/modules/__init__.py
+2
-0
colossalai/tensor/modules/colo_module.py
colossalai/tensor/modules/colo_module.py
+51
-0
colossalai/tensor/modules/linear.py
colossalai/tensor/modules/linear.py
+39
-0
colossalai/utils/model/colo_init_context.py
colossalai/utils/model/colo_init_context.py
+15
-1
tests/test_tensor/test_module_spec.py
tests/test_tensor/test_module_spec.py
+137
-0
No files found.
colossalai/tensor/__init__.py
View file @
32291dd7
...
@@ -8,8 +8,12 @@ from ._ops import *
...
@@ -8,8 +8,12 @@ from ._ops import *
from
.optim.colo_optimizer
import
ColoOptimizer
from
.optim.colo_optimizer
import
ColoOptimizer
from
.
import
distspec
from
.
import
distspec
from
.dist_spec_mgr
import
DistSpecManager
from
.dist_spec_mgr
import
DistSpecManager
from
.module_utils
import
register_colo_module
,
is_colo_module
,
get_colo_module
,
init_colo_module
,
check_colo_module
from
.modules
import
ColoLinear
__all__
=
[
__all__
=
[
'ColoTensor'
,
'convert_parameter'
,
'colo_op_impl'
,
'ComputePattern'
,
'TensorSpec'
,
'ParallelAction'
,
'ColoTensor'
,
'convert_parameter'
,
'colo_op_impl'
,
'ComputePattern'
,
'TensorSpec'
,
'ParallelAction'
,
'named_params_with_colotensor'
,
'ColoOptimizer'
,
'ColoParameter'
,
'distspec'
,
'DistSpecManager'
'named_params_with_colotensor'
,
'ColoOptimizer'
,
'ColoParameter'
,
'distspec'
,
'DistSpecManager'
,
'register_colo_module'
,
'is_colo_module'
,
'get_colo_module'
,
'init_colo_module'
,
'check_colo_module'
,
'ColoLinear'
]
]
colossalai/tensor/module_utils.py
0 → 100644
View file @
32291dd7
from
typing
import
Dict
from
colossalai.tensor
import
ColoParameter
,
ParallelAction
,
TensorSpec
from
.modules
import
ColoModule
import
torch
_COLOSSAL_MODULES
:
Dict
[
type
,
ColoModule
]
=
{}
def
register_colo_module
(
module_type
:
type
,
colo_module
:
ColoModule
):
global
_COLOSSAL_MODULES
_COLOSSAL_MODULES
[
module_type
]
=
colo_module
def
is_colo_module
(
module
:
torch
.
nn
.
Module
):
global
_COLOSSAL_MODULES
return
type
(
module
)
in
_COLOSSAL_MODULES
def
get_colo_module
(
module
:
torch
.
nn
.
Module
):
global
_COLOSSAL_MODULES
if
is_colo_module
(
module
):
colo_module
=
_COLOSSAL_MODULES
[
type
(
module
)]
colo_module
.
register
()
return
colo_module
else
:
return
None
def
check_colo_module
(
module
:
torch
.
nn
.
Module
,
recursive
=
True
):
if
is_colo_module
(
module
):
colo_module
=
get_colo_module
(
module
)
param_names
=
colo_module
.
get_param_names
()
compute_pattern
=
None
for
param_name
in
param_names
:
param
=
module
.
get_parameter
(
param_name
)
if
not
isinstance
(
param
,
ColoParameter
):
raise
Exception
(
f
'Invalid ColoParameter spec:
{
param
}
in
{
module
}
is not a ColoParameter.'
)
if
param
.
has_spec
():
cur_compute_pattern
=
param
.
spec
.
parallel_action
.
compute_pattern
if
compute_pattern
is
None
:
compute_pattern
=
cur_compute_pattern
else
:
if
cur_compute_pattern
!=
compute_pattern
:
raise
Exception
(
f
'Invalid ColoParameter spec: Params in
{
module
}
have different compute_pattern.'
)
else
:
continue
if
compute_pattern
is
not
None
:
if
not
colo_module
.
has_compute_pattern
(
compute_pattern
):
raise
Exception
(
f
'Invalid ColoParameter spec: ComputePattern
{
compute_pattern
}
in
{
module
}
is not allowed.'
)
match_specs
=
False
allowed_specs
=
colo_module
.
get_dist_specs
(
compute_pattern
)
for
_
,
param_specs
in
allowed_specs
.
items
():
cur_match
=
True
for
param_name
,
dist_spec
in
param_specs
.
items
():
param
=
module
.
get_parameter
(
param_name
)
if
param
.
has_spec
():
if
dist_spec
!=
param
.
spec
.
dist_spec
:
cur_match
=
False
break
else
:
if
dist_spec
is
not
None
:
cur_match
=
False
break
if
cur_match
==
True
:
match_specs
=
True
break
if
match_specs
==
False
:
raise
Exception
(
f
'Invalid ColoParameter spec: Params in
{
module
}
are incorrectly sharded.'
)
if
recursive
==
True
:
for
submodule
in
module
.
children
():
check_colo_module
(
submodule
,
recursive
=
True
)
def
init_colo_module
(
module
:
torch
.
nn
.
Module
,
parallel_action
:
ParallelAction
,
recursive
=
True
,
label
=
'default'
):
compute_pattern
=
parallel_action
.
compute_pattern
if
is_colo_module
(
module
):
# for each param
# set DistSpec and ParallelAction
colo_module
=
get_colo_module
(
module
)
if
not
colo_module
.
has_compute_pattern_with_label
(
compute_pattern
,
label
=
label
):
raise
NotImplementedError
for
param_name
,
dist_spec
in
colo_module
.
get_dist_specs_with_label
(
compute_pattern
,
label
=
label
).
items
():
if
dist_spec
is
None
:
continue
param
=
module
.
get_parameter
(
param_name
)
if
isinstance
(
param
,
ColoParameter
):
spec
=
TensorSpec
(
dist_spec
,
parallel_action
)
param
.
set_spec
(
spec
)
check_colo_module
(
module
,
recursive
=
False
)
if
recursive
==
True
:
for
submodule
in
module
.
children
():
init_colo_module
(
submodule
,
parallel_action
,
recursive
=
True
,
label
=
label
)
\ No newline at end of file
colossalai/tensor/modules/__init__.py
0 → 100644
View file @
32291dd7
from
.colo_module
import
ColoModule
from
.linear
import
ColoLinear
\ No newline at end of file
colossalai/tensor/modules/colo_module.py
0 → 100644
View file @
32291dd7
from
colossalai.tensor.distspec
import
_DistSpec
from
colossalai.tensor
import
ComputePattern
from
typing
import
List
,
Dict
class
ColoModule
(
object
):
def
__init__
(
self
):
self
.
_shard_params
:
List
[
str
]
=
[]
# Example:
# {ComputePattern.TP1D:
# 'default':
# 'weight':
# distspec.shard(xxxxx)
# 'bias':
# distspec.shard(xxxxx)
# 'row': ...
# 'col': ...
# }
self
.
_allowed_patterns
:
Dict
[
ComputePattern
,
Dict
[
str
,
Dict
[
str
,
_DistSpec
]]]
=
{}
def
_register_shard_params
(
self
,
params
:
List
[
str
]):
self
.
_shard_params
=
params
def
_register_allowed_patterns
(
self
,
compute_pattern
:
ComputePattern
,
dist_specs
:
Dict
[
str
,
_DistSpec
],
label
=
'default'
):
assert
list
(
dist_specs
.
keys
()).
sort
()
==
self
.
_shard_params
.
sort
(),
'Every registered param should have dist_spec.'
if
not
compute_pattern
in
self
.
_allowed_patterns
:
self
.
_allowed_patterns
[
compute_pattern
]
=
{}
self
.
_allowed_patterns
[
compute_pattern
][
label
]
=
dist_specs
def
_set_default
(
self
,
compute_pattern
:
ComputePattern
,
target_label
):
self
.
_allowed_patterns
[
compute_pattern
][
'default'
]
=
self
.
_allowed_patterns
[
compute_pattern
][
target_label
]
def
has_compute_pattern
(
self
,
compute_pattern
:
ComputePattern
):
return
compute_pattern
in
self
.
_allowed_patterns
def
get_dist_specs
(
self
,
compute_pattern
:
ComputePattern
):
assert
self
.
has_compute_pattern
(
compute_pattern
)
return
self
.
_allowed_patterns
[
compute_pattern
]
def
has_compute_pattern_with_label
(
self
,
compute_pattern
:
ComputePattern
,
label
=
'default'
):
return
compute_pattern
in
self
.
_allowed_patterns
and
label
in
self
.
_allowed_patterns
[
compute_pattern
]
def
get_dist_specs_with_label
(
self
,
compute_pattern
:
ComputePattern
,
label
=
'default'
):
assert
self
.
has_compute_pattern_with_label
(
compute_pattern
,
label
)
return
self
.
_allowed_patterns
[
compute_pattern
][
label
]
def
get_param_names
(
self
):
return
self
.
_shard_params
def
register
(
self
):
raise
NotImplementedError
\ No newline at end of file
colossalai/tensor/modules/linear.py
0 → 100644
View file @
32291dd7
from
.colo_module
import
ColoModule
from
colossalai.tensor
import
ComputePattern
,
distspec
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context.parallel_mode
import
ParallelMode
class
ColoLinear
(
ColoModule
):
def
__init__
(
self
):
super
(
ColoLinear
,
self
).
__init__
()
self
.
_register_shard_params
([
'weight'
,
'bias'
])
self
.
_register
=
False
def
register
(
self
):
if
self
.
_register
==
False
:
self
.
_set_TP1D
()
self
.
_register
=
True
def
_set_TP1D
(
self
):
# TP1D Row Linear
_compute_pattern
=
ComputePattern
.
TP1D
self
.
_register_allowed_patterns
(
compute_pattern
=
_compute_pattern
,
dist_specs
=
{
'weight'
:
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
'bias'
:
None
},
label
=
'row'
,
)
# TP1D Col Linear
self
.
_register_allowed_patterns
(
compute_pattern
=
_compute_pattern
,
dist_specs
=
{
'weight'
:
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
'bias'
:
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)])
},
label
=
'col'
,
)
self
.
_set_default
(
compute_pattern
=
_compute_pattern
,
target_label
=
'row'
)
colossalai/utils/model/colo_init_context.py
View file @
32291dd7
from
.utils
import
InsertPostInitMethodToModuleSubClasses
from
.utils
import
InsertPostInitMethodToModuleSubClasses
import
torch
import
torch
from
colossalai.tensor
import
ColoTensor
,
ColoParameter
from
colossalai.tensor
import
ColoTensor
,
ColoParameter
,
register_colo_module
,
init_colo_module
,
\
ColoLinear
import
types
import
types
from
torch
import
nn
from
torch
import
nn
...
@@ -101,6 +102,17 @@ def _setattr_with_colotensor(self, name: str, value: Union[torch.Tensor, torch.n
...
@@ -101,6 +102,17 @@ def _setattr_with_colotensor(self, name: str, value: Union[torch.Tensor, torch.n
else
:
else
:
object
.
__setattr__
(
self
,
name
,
value
)
object
.
__setattr__
(
self
,
name
,
value
)
def
_get_parameter_with_colotensor
(
self
,
target
:
str
)
->
Union
[
torch
.
nn
.
Parameter
,
ColoTensor
]:
module_path
,
_
,
param_name
=
target
.
rpartition
(
"."
)
mod
:
torch
.
nn
.
Module
=
self
.
get_submodule
(
module_path
)
if
not
hasattr
(
mod
,
param_name
):
raise
AttributeError
(
mod
.
_get_name
()
+
" has no attribute `"
+
param_name
+
"`"
)
param
=
getattr
(
mod
,
param_name
)
return
param
def
ColoModulize
(
module
):
def
ColoModulize
(
module
):
"""
"""
...
@@ -124,6 +136,8 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -124,6 +136,8 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
torch
.
nn
.
Module
.
__setattr__
=
_setattr_with_colotensor
torch
.
nn
.
Module
.
__setattr__
=
_setattr_with_colotensor
torch
.
nn
.
Module
.
register_parameter
=
_register_parameter_with_colotensor
torch
.
nn
.
Module
.
register_parameter
=
_register_parameter_with_colotensor
torch
.
nn
.
Module
.
get_parameter
=
_get_parameter_with_colotensor
register_colo_module
(
torch
.
nn
.
Linear
,
ColoLinear
())
def
_post_init_method
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
,
**
kwargs
):
def
_post_init_method
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
,
**
kwargs
):
"""
"""
...
...
tests/test_tensor/test_module_spec.py
0 → 100644
View file @
32291dd7
from
copy
import
copy
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
import
torch
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.tensor
import
ColoTensor
,
distspec
from
functools
import
partial
import
colossalai
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn.functional
as
F
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.core
import
global_context
as
gpc
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ParallelAction
,
DistSpecManager
,
register_colo_module
,
init_colo_module
,
ColoLinear
from
_utils
import
tensor_equal
,
tensor_shard_equal
,
set_seed
from
tests.components_to_test.registry
import
non_distributed_component_funcs
def
run_simplenet_with_spec
(
label
):
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'simple_net'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
set_seed
(
1
)
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
(
checkpoint
=
True
)
if
rank
==
0
:
model_seq
=
model_builder
(
checkpoint
=
True
)
model_seq
=
model_seq
.
cuda
()
# Make two models have the same init params
for
p1
,
p2
in
zip
(
model
.
parameters
(),
model_seq
.
parameters
()):
p2
.
data
.
copy_
(
p1
.
data
)
parallel_action
=
ParallelAction
(
ComputePattern
.
TP1D
)
init_colo_module
(
model
,
parallel_action
,
recursive
=
True
,
label
=
label
)
model
=
model
.
cuda
()
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
data
=
data
.
to
(
get_current_device
())
label
=
label
.
to
(
get_current_device
())
torch
.
distributed
.
broadcast
(
data
,
0
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
))
torch
.
distributed
.
broadcast
(
label
,
0
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
))
if
criterion
:
output
=
model
(
data
)
loss
=
criterion
(
output
,
label
)
else
:
output
=
model
(
data
,
label
)
loss
=
output
# For reference
if
rank
==
0
:
if
criterion
:
output_seq
=
model_seq
(
data
)
loss_seq
=
criterion
(
output_seq
,
label
)
else
:
output_seq
=
model_seq
(
data
,
label
)
loss_seq
=
output_seq
if
rank
==
0
:
with
torch
.
no_grad
():
assert
torch
.
allclose
(
loss
,
loss_seq
,
rtol
=
1e-2
)
loss
.
backward
()
if
rank
==
0
:
loss_seq
.
backward
()
with
torch
.
no_grad
():
# check param
for
p1
,
p2
in
zip
(
model
.
parameters
(),
model_seq
.
parameters
()):
if
p1
.
size
()
==
p2
.
size
():
assert
torch
.
allclose
(
p1
,
p2
)
else
:
if
p1
.
size
(
-
1
)
<
p2
.
size
(
-
1
):
# col
world_size
=
p2
.
size
(
-
1
)
//
p1
.
size
(
-
1
)
split_p2
=
torch
.
chunk
(
p2
,
world_size
,
dim
=-
1
)[
0
]
elif
p1
.
size
(
0
)
<
p2
.
size
(
0
):
# row
world_size
=
p2
.
size
(
0
)
//
p1
.
size
(
0
)
split_p2
=
torch
.
chunk
(
p2
,
world_size
,
dim
=
0
)[
0
]
assert
torch
.
allclose
(
p1
,
split_p2
)
if
i
>
3
:
break
def
run_linear_with_spec
(
label
):
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
torch
.
nn
.
Linear
(
4
,
8
)
model_handy
=
copy
(
model
)
parallel_action
=
ParallelAction
(
ComputePattern
.
TP1D
)
init_colo_module
(
model
,
parallel_action
,
recursive
=
True
,
label
=
label
)
x
=
torch
.
rand
(
2
,
4
).
cuda
()
out
=
model
(
x
)
colo_out
=
model_handy
(
x
)
assert
tensor_equal
(
out
,
colo_out
)
grad
=
torch
.
rand_like
(
out
)
out
.
backward
(
grad
)
colo_out
.
backward
(
grad
)
assert
tensor_shard_equal
(
model
.
weight
.
grad
,
model_handy
.
weight
.
grad
)
assert
tensor_shard_equal
(
model
.
bias
.
grad
,
model_handy
.
bias
.
grad
)
def
run_dist
(
rank
,
world_size
,
port
,
func
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
func
(
'col'
)
func
(
'row'
)
func
(
'default'
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_module_linear_1d
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
func
=
run_linear_with_spec
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_module_simplenet
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
func
=
run_simplenet_with_spec
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_module_simplenet
(
4
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment