Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
2827f418
"...deps/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "2153ee81759d45e27ff6846c0195b8e8029c2529"
Unverified
Commit
2827f418
authored
Dec 20, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 20, 2022
Browse files
[Gemini] GeminiDPP convert to PyTorch Module. (#2151)
parent
bdef9dfd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
76 additions
and
1 deletion
+76
-1
colossalai/nn/parallel/utils.py
colossalai/nn/parallel/utils.py
+28
-0
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+0
-1
tests/test_gemini/update/test_convert_torch_module.py
tests/test_gemini/update/test_convert_torch_module.py
+48
-0
No files found.
colossalai/nn/parallel/utils.py
View file @
2827f418
...
@@ -2,6 +2,7 @@ import torch
...
@@ -2,6 +2,7 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
colossalai.gemini.chunk
import
Chunk
from
colossalai.gemini.chunk
import
Chunk
from
colossalai.tensor
import
ColoTensor
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
...
@@ -19,3 +20,30 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk):
...
@@ -19,3 +20,30 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk):
dist
.
all_gather
(
tensor_list
=
gather_list
,
tensor
=
shard_temp
,
group
=
chunk
.
torch_pg
)
dist
.
all_gather
(
tensor_list
=
gather_list
,
tensor
=
shard_temp
,
group
=
chunk
.
torch_pg
)
return
total_temp
return
total_temp
def
_add_param
(
model
,
name
,
param
):
name_list
=
name
.
split
(
'.'
)
module
=
model
.
_modules
[
name_list
[
0
]]
for
i
in
range
(
1
,
len
(
name_list
)
-
1
):
module
=
module
.
_modules
[
name_list
[
i
]]
module
.
_parameters
[
name_list
[
-
1
]]
=
param
def
convert_to_torch_module
(
gemini_ddp_model
)
->
torch
.
nn
.
Module
:
"""convert_to_torch_module
Args:
gemini_ddp_model (GeminiDDP): a gemini ddp model
Returns:
torch.nn.Module: a torch model contains the params of gemini_ddp_model
"""
module
=
gemini_ddp_model
.
module
for
n
,
p
in
module
.
named_parameters
():
if
isinstance
(
p
,
ColoTensor
):
p
.
to_replicate_
()
_add_param
(
module
,
n
,
p
.
data
)
return
module
colossalai/tensor/colo_tensor.py
View file @
2827f418
...
@@ -103,7 +103,6 @@ class ColoTensor(torch.Tensor):
...
@@ -103,7 +103,6 @@ class ColoTensor(torch.Tensor):
self
.
process_group
=
spec
.
pg
self
.
process_group
=
spec
.
pg
self
.
_type
=
TensorType
.
NONMODEL
self
.
_type
=
TensorType
.
NONMODEL
self
.
_graph_node
=
None
def
has_compute_spec
(
self
)
->
bool
:
def
has_compute_spec
(
self
)
->
bool
:
return
self
.
compute_spec
is
not
None
return
self
.
compute_spec
is
not
None
...
...
tests/test_gemini/update/test_convert_torch_module.py
0 → 100644
View file @
2827f418
from
functools
import
partial
import
pytest
import
torch.multiprocessing
as
mp
import
colossalai
from
colossalai.nn.parallel.utils
import
convert_to_torch_module
from
colossalai.tensor
import
ColoTensor
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
tests.components_to_test.registry
import
non_distributed_component_funcs
@
parameterize
(
'model_name'
,
[
'resnet18'
,
'bert'
])
def
run_convert_torch_module
(
model_name
:
str
):
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
_
,
_
,
_
,
_
=
get_components_func
()
with
ColoInitContext
(
device
=
'cpu'
):
model
=
model_builder
(
checkpoint
=
False
)
from
colossalai.nn.parallel
import
GeminiDDP
model
=
GeminiDDP
(
model
,
device
=
get_current_device
(),
placement_policy
=
'auto'
,
pin_memory
=
True
)
pytorch_model
=
convert_to_torch_module
(
model
)
for
n
,
p
in
pytorch_model
.
named_parameters
():
assert
not
isinstance
(
p
,
ColoTensor
)
def
run_dist
(
rank
,
world_size
,
port
):
config
=
{}
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_convert_torch_module
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_convert_torch_module
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_convert_torch_module
(
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment