Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a0064407
Unverified
Commit
a0064407
authored
Jun 03, 2022
by
Jiarui Fang
Committed by
GitHub
Jun 03, 2022
Browse files
reorgnize colotensor directory (#1062)
* reorgnize colotensor directory * polish code
parent
3d10be33
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
49 additions
and
27 deletions
+49
-27
colossalai/tensor/utils.py
colossalai/tensor/utils.py
+1
-3
colossalai/utils/model/colo_init_context.py
colossalai/utils/model/colo_init_context.py
+5
-3
tests/test_tensor/test_hybrid_device.py
tests/test_tensor/test_hybrid_device.py
+13
-5
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+2
-1
tests/test_tensor/test_module_spec.py
tests/test_tensor/test_module_spec.py
+28
-15
No files found.
colossalai/tensor/utils.py
View file @
a0064407
import
torch
import
torch
from
colossalai.tensor.colo_tensor
import
ColoTensor
from
typing
import
Iterator
,
Tuple
,
Union
from
typing
import
Iterator
,
Tuple
,
Union
import
torch.nn
as
nn
import
torch.nn
as
nn
from
colossalai.tensor
import
ColoTensor
from
colossalai.tensor
.colo_tensor
import
ColoTensor
# The function is credited to PyTorch Team
# The function is credited to PyTorch Team
...
...
colossalai/utils/model/colo_init_context.py
View file @
a0064407
from
.utils
import
InsertPostInitMethodToModuleSubClasses
from
.utils
import
InsertPostInitMethodToModuleSubClasses
import
torch
import
torch
from
colossalai.tensor
import
ColoTensor
,
ColoParameter
,
register_colo_module
,
init_colo_module
,
\
from
colossalai.tensor
import
ColoTensor
,
ColoParameter
from
colossalai.nn
import
register_colo_module
,
init_colo_module
,
\
ColoLinear
,
ColoEmbedding
ColoLinear
,
ColoEmbedding
import
types
from
torch
import
nn
from
torch
import
nn
from
typing
import
Iterator
,
Tuple
,
Union
,
Optional
from
typing
import
Iterator
,
Tuple
,
Union
# find named_params includes replica
# find named_params includes replica
...
@@ -24,6 +25,7 @@ def _named_params_with_replica(
...
@@ -24,6 +25,7 @@ def _named_params_with_replica(
name
=
mod_prefix
+
(
'.'
if
mod_prefix
else
''
)
+
name
name
=
mod_prefix
+
(
'.'
if
mod_prefix
else
''
)
+
name
yield
name
,
val
yield
name
,
val
def
ColoModulize
(
module
):
def
ColoModulize
(
module
):
"""
"""
Replacing the parameters() and named_parameters() with our customized ones
Replacing the parameters() and named_parameters() with our customized ones
...
...
tests/test_tensor/test_hybrid_device.py
View file @
a0064407
from
colossalai.utils
import
free_port
,
ColoInitContext
,
get_current_device
from
colossalai.utils
import
free_port
,
ColoInitContext
,
get_current_device
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ParallelAction
,
init_colo_module
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ParallelAction
from
functools
import
partial
from
functools
import
partial
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context
import
ParallelMode
from
colossalai.context
import
ParallelMode
from
colossalai.nn
import
init_colo_module
from
colossalai.nn.parallel
import
ColoDDP
from
colossalai.nn.parallel
import
ColoDDP
import
colossalai
import
colossalai
...
@@ -11,12 +14,14 @@ import torch
...
@@ -11,12 +14,14 @@ import torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
pytest
import
pytest
class
Net
(
torch
.
nn
.
Module
):
class
Net
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
super
(
Net
,
self
).
__init__
()
self
.
embed
=
torch
.
nn
.
Embedding
(
20
,
4
)
self
.
embed
=
torch
.
nn
.
Embedding
(
20
,
4
)
self
.
proj
=
torch
.
nn
.
Linear
(
4
,
8
)
self
.
proj
=
torch
.
nn
.
Linear
(
4
,
8
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
# move input to cpu and restore output
# move input to cpu and restore output
current_dev
=
x
.
device
current_dev
=
x
.
device
...
@@ -27,6 +32,7 @@ class Net(torch.nn.Module):
...
@@ -27,6 +32,7 @@ class Net(torch.nn.Module):
x
=
self
.
proj
(
x
)
x
=
self
.
proj
(
x
)
return
x
return
x
def
run_hybrid_device
(
use_ddp
):
def
run_hybrid_device
(
use_ddp
):
with
ColoInitContext
(
device
=
get_current_device
()):
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
Net
()
model
=
Net
()
...
@@ -36,7 +42,6 @@ def run_hybrid_device(use_ddp):
...
@@ -36,7 +42,6 @@ def run_hybrid_device(use_ddp):
model
=
ColoDDP
(
model
)
model
=
ColoDDP
(
model
)
real_model
=
model
.
module
real_model
=
model
.
module
print
(
f
'embedding weight size:
{
real_model
.
embed
.
weight
.
size
()
}
| device:
{
real_model
.
embed
.
weight
.
device
}
'
)
print
(
f
'embedding weight size:
{
real_model
.
embed
.
weight
.
size
()
}
| device:
{
real_model
.
embed
.
weight
.
device
}
'
)
#print(f'linear weight size: {real_model.proj.weight.size()} | device: {real_model.proj.weight.device}')
#print(f'linear weight size: {real_model.proj.weight.size()} | device: {real_model.proj.weight.device}')
parallel_action
=
ParallelAction
(
ComputePattern
.
TP1D
)
parallel_action
=
ParallelAction
(
ComputePattern
.
TP1D
)
...
@@ -49,11 +54,12 @@ def run_hybrid_device(use_ddp):
...
@@ -49,11 +54,12 @@ def run_hybrid_device(use_ddp):
print
(
f
'embedding weight size:
{
real_model
.
embed
.
weight
.
size
()
}
| new device:
{
real_model
.
embed
.
weight
.
device
}
'
)
print
(
f
'embedding weight size:
{
real_model
.
embed
.
weight
.
size
()
}
| new device:
{
real_model
.
embed
.
weight
.
device
}
'
)
#print(f'linear weight size: {real_model.proj.weight.size()} | new device: {real_model.proj.weight.device}')
#print(f'linear weight size: {real_model.proj.weight.size()} | new device: {real_model.proj.weight.device}')
data
=
torch
.
randint
(
low
=
0
,
high
=
20
,
size
=
(
16
,),
device
=
get_current_device
())
data
=
torch
.
randint
(
low
=
0
,
high
=
20
,
size
=
(
16
,),
device
=
get_current_device
())
out
=
model
(
data
)
out
=
model
(
data
)
out
.
sum
().
backward
()
out
.
sum
().
backward
()
def
run_dist
(
rank
,
world_size
,
port
,
use_ddp
):
def
run_dist
(
rank
,
world_size
,
port
,
use_ddp
):
if
use_ddp
and
world_size
==
1
:
if
use_ddp
and
world_size
==
1
:
return
return
...
@@ -62,6 +68,7 @@ def run_dist(rank, world_size, port, use_ddp):
...
@@ -62,6 +68,7 @@ def run_dist(rank, world_size, port, use_ddp):
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_hybrid_device
(
use_ddp
)
run_hybrid_device
(
use_ddp
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
'use_ddp'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'use_ddp'
,
[
False
,
True
])
...
@@ -71,5 +78,6 @@ def _test_hybrid_device(world_size, use_ddp):
...
@@ -71,5 +78,6 @@ def _test_hybrid_device(world_size, use_ddp):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
use_ddp
=
use_ddp
)
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
use_ddp
=
use_ddp
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
_test_hybrid_device
(
1
,
False
)
_test_hybrid_device
(
1
,
False
)
\ No newline at end of file
tests/test_tensor/test_model.py
View file @
a0064407
...
@@ -10,9 +10,10 @@ from colossalai.utils.cuda import get_current_device
...
@@ -10,9 +10,10 @@ from colossalai.utils.cuda import get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils
import
ColoInitContext
from
colossalai.utils
import
ColoInitContext
from
colossalai.tensor
import
distspec
,
named_params_with_colotensor
,
TensorSpec
,
ComputePattern
,
\
from
colossalai.tensor
import
distspec
,
named_params_with_colotensor
,
TensorSpec
,
ComputePattern
,
\
ParallelAction
,
ColoTensor
,
ColoOptimizer
,
DistSpecManager
ParallelAction
,
ColoTensor
,
DistSpecManager
from
colossalai.context
import
ParallelMode
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.nn.optimizer
import
ColoOptimizer
from
functools
import
partial
from
functools
import
partial
from
_utils
import
set_seed
from
_utils
import
set_seed
...
...
tests/test_tensor/test_module_spec.py
View file @
a0064407
from
copy
import
copy
from
copy
import
copy
from
colossalai.utils.cuda
import
get_current_device
import
pytest
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
import
torch
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.tensor
import
ColoTensor
,
distspec
from
functools
import
partial
from
functools
import
partial
import
colossalai
import
pytest
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
torch.nn.functional
as
F
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ParallelAction
from
colossalai.nn
import
init_colo_module
,
check_colo_module
from
_utils
import
tensor_equal
,
tensor_shard_equal
,
set_seed
import
colossalai
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.tensor
import
distspec
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ParallelAction
,
DistSpecManager
,
register_colo_module
,
init_colo_module
,
check_colo_module
from
_utils
import
tensor_equal
,
tensor_shard_equal
,
set_seed
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
def
run_model_with_spec
(
mode
,
model_name
):
def
run_model_with_spec
(
mode
,
model_name
):
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
...
@@ -27,7 +31,7 @@ def run_model_with_spec(mode, model_name):
...
@@ -27,7 +31,7 @@ def run_model_with_spec(mode, model_name):
set_seed
(
1
)
set_seed
(
1
)
with
ColoInitContext
(
device
=
get_current_device
()):
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
(
checkpoint
=
False
)
model
=
model_builder
(
checkpoint
=
False
)
if
rank
==
0
:
if
rank
==
0
:
model_seq
=
model_builder
(
checkpoint
=
False
)
model_seq
=
model_builder
(
checkpoint
=
False
)
model_seq
=
model_seq
.
cuda
()
model_seq
=
model_seq
.
cuda
()
...
@@ -103,15 +107,16 @@ def run_model_with_spec(mode, model_name):
...
@@ -103,15 +107,16 @@ def run_model_with_spec(mode, model_name):
if
i
>
3
:
if
i
>
3
:
break
break
def
run_linear_with_spec
(
mode
):
def
run_linear_with_spec
(
mode
):
with
ColoInitContext
(
device
=
get_current_device
()):
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
torch
.
nn
.
Linear
(
4
,
8
)
model
=
torch
.
nn
.
Linear
(
4
,
8
)
model_handy
=
copy
(
model
)
model_handy
=
copy
(
model
)
parallel_action
=
ParallelAction
(
ComputePattern
.
TP1D
)
parallel_action
=
ParallelAction
(
ComputePattern
.
TP1D
)
init_colo_module
(
model
,
parallel_action
,
recursive
=
True
,
mode
=
mode
)
init_colo_module
(
model
,
parallel_action
,
recursive
=
True
,
mode
=
mode
)
x
=
torch
.
rand
(
2
,
4
).
cuda
()
x
=
torch
.
rand
(
2
,
4
).
cuda
()
out
=
model
(
x
)
out
=
model
(
x
)
colo_out
=
model_handy
(
x
)
colo_out
=
model_handy
(
x
)
...
@@ -122,6 +127,7 @@ def run_linear_with_spec(mode):
...
@@ -122,6 +127,7 @@ def run_linear_with_spec(mode):
assert
tensor_shard_equal
(
model
.
weight
.
grad
,
model_handy
.
weight
.
grad
)
assert
tensor_shard_equal
(
model
.
weight
.
grad
,
model_handy
.
weight
.
grad
)
assert
tensor_shard_equal
(
model
.
bias
.
grad
,
model_handy
.
bias
.
grad
)
assert
tensor_shard_equal
(
model
.
bias
.
grad
,
model_handy
.
bias
.
grad
)
def
run_check_shared_param
():
def
run_check_shared_param
():
from
transformers
import
BertForMaskedLM
,
BertConfig
from
transformers
import
BertForMaskedLM
,
BertConfig
hidden_dim
=
8
hidden_dim
=
8
...
@@ -157,12 +163,14 @@ def run_check_shared_param():
...
@@ -157,12 +163,14 @@ def run_check_shared_param():
except
Exception
as
e
:
except
Exception
as
e
:
assert
'incorrectly sharded'
in
str
(
e
)
assert
'incorrectly sharded'
in
str
(
e
)
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_linear_with_spec
(
'col'
)
run_linear_with_spec
(
'col'
)
run_linear_with_spec
(
'row'
)
run_linear_with_spec
(
'row'
)
def
run_dist_model
(
rank
,
world_size
,
port
):
def
run_dist_model
(
rank
,
world_size
,
port
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
...
@@ -170,11 +178,13 @@ def run_dist_model(rank, world_size, port):
...
@@ -170,11 +178,13 @@ def run_dist_model(rank, world_size, port):
run_model_with_spec
(
'col'
,
model_name
)
run_model_with_spec
(
'col'
,
model_name
)
run_model_with_spec
(
'row'
,
model_name
)
run_model_with_spec
(
'row'
,
model_name
)
def
run_dist_check
(
rank
,
world_size
,
port
):
def
run_dist_check
(
rank
,
world_size
,
port
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_check_shared_param
()
run_check_shared_param
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
...
@@ -182,6 +192,7 @@ def test_module_linear_1d(world_size):
...
@@ -182,6 +192,7 @@ def test_module_linear_1d(world_size):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
...
@@ -189,6 +200,7 @@ def test_module_model(world_size):
...
@@ -189,6 +200,7 @@ def test_module_model(world_size):
run_func
=
partial
(
run_dist_model
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_dist_model
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
...
@@ -196,5 +208,6 @@ def test_module_check(world_size):
...
@@ -196,5 +208,6 @@ def test_module_check(world_size):
run_func
=
partial
(
run_dist_check
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_dist_check
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_module_check
(
2
)
test_module_check
(
2
)
\ No newline at end of file
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment