Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cd5a0d56
Unverified
Commit
cd5a0d56
authored
Nov 08, 2022
by
Jiarui Fang
Committed by
GitHub
Nov 08, 2022
Browse files
[Gemini] make gemini usage simple (#1821)
parent
99870726
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
49 additions
and
21 deletions
+49
-21
colossalai/nn/parallel/__init__.py
colossalai/nn/parallel/__init__.py
+2
-1
colossalai/nn/parallel/data_parallel.py
colossalai/nn/parallel/data_parallel.py
+4
-13
colossalai/nn/parallel/gemini_parallel.py
colossalai/nn/parallel/gemini_parallel.py
+39
-0
examples/language/opt/run_clm.py
examples/language/opt/run_clm.py
+4
-7
No files found.
colossalai/nn/parallel/__init__.py
View file @
cd5a0d56
from
.data_parallel
import
ColoDDP
,
ZeroDDP
from
.gemini_parallel
import
GeminiDDP
__all__
=
[
'ColoDDP'
,
'ZeroDDP'
]
__all__
=
[
'ColoDDP'
,
'ZeroDDP'
,
'GeminiDDP'
]
colossalai/nn/parallel/data_parallel.py
View file @
cd5a0d56
...
...
@@ -188,25 +188,16 @@ class ColoDDP(torch.nn.Module):
class
ZeroDDP
(
ColoDDP
):
"""ZeRO-DP for ColoTensor. Nested ZeroDDP is not supported now.
We can configure chunk and gemini via ChunkManager and GeminiManager respectively.
"""ZeRO DDP for ColoTensor.
Warning: Nested ZeroDDP is not supported now.
It is designed to be used with ChunkManager and GeminiManager.
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
Example:
>>> model = torch.nn.Linear(20, 1)
>>> placement_policy = 'cuda'
>>> chunk_size = ChunkManager.search_chunk_size(model, search_range, n_grids) if use_chunk else None
>>> chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero, init_device=GeminiManager.get_default_device(placement_policy))
>>> gemini_manager = GeminiManager(placement_policy, chunk_manager)
>>> model = ZeroDDP(model, gemini_manager)
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
Args:
module (torch.nn.Module): Module to apply ZeRO-DP.
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
For more details, see the API reference of ``GeminiManager``.
pin_memory (bool): Chunks on CPU Memory use pin-memory.
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. Defaults to False.
"""
...
...
colossalai/nn/parallel/gemini_parallel.py
0 → 100644
View file @
cd5a0d56
import
torch
from
colossalai.gemini.chunk
import
init_chunk_manager
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
.data_parallel
import
ZeroDDP
class
GeminiDDP
(
ZeroDDP
):
def
__init__
(
self
,
module
:
torch
.
nn
.
Module
,
device
:
torch
.
device
,
placement_policy
:
str
=
"cpu"
,
pin_memory
:
bool
=
False
,
force_outputs_fp32
:
bool
=
False
,
search_range_mb
:
int
=
32
)
->
None
:
"""
A torch.Module warpper using ZeRODPP and Genimi.
ZeRO is for parallel. Gemini is for memory management.
Example:
model is initialized under the context of ColoInitContext
>>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda")
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
Args:
module (torch.nn.Module): the model to be wrapped.
device (torch.device): device to place the model.
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32.
"""
chunk_manager
=
init_chunk_manager
(
model
=
module
,
init_device
=
device
,
search_range_mb
=
search_range_mb
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
,
module
)
super
().
__init__
(
module
,
gemini_manager
,
pin_memory
,
force_outputs_fp32
)
examples/language/opt/run_clm.py
View file @
cd5a0d56
...
...
@@ -24,7 +24,6 @@ https://huggingface.co/models?filter=text-generation
import
math
import
os
import
random
import
time
from
itertools
import
chain
...
...
@@ -43,7 +42,6 @@ import colossalai
import
transformers
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.gemini
import
ChunkManager
,
GeminiManager
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.parallel
import
ZeroDDP
...
...
@@ -380,11 +378,8 @@ def main():
cai_version
=
colossalai
.
__version__
logger
.
info
(
f
'using Colossal-AI version
{
cai_version
}
'
)
if
version
.
parse
(
cai_version
)
>
version
.
parse
(
"0.1.10"
):
from
colossalai.gemini
import
GeminiManager
from
colossalai.gemini.chunk
import
init_chunk_manager
chunk_manager
=
init_chunk_manager
(
model
=
model
,
init_device
=
get_current_device
(),
search_range_mb
=
32
)
gemini_manager
=
GeminiManager
(
PLACEMENT_POLICY
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
from
colossalai.nn.parallel
import
GeminiDDP
model
=
GeminiDDP
(
model
,
device
=
get_current_device
(),
placement_policy
=
PLACEMENT_POLICY
,
pin_memory
=
True
)
elif
version
.
parse
(
cai_version
)
<=
version
.
parse
(
"0.1.10"
)
and
version
.
parse
(
cai_version
)
>=
version
.
parse
(
"0.1.9"
):
from
colossalai.gemini
import
ChunkManager
,
GeminiManager
pg
=
ProcessGroup
()
...
...
@@ -393,6 +388,8 @@ def main():
pg
,
enable_distributed_storage
=
True
,
init_device
=
GeminiManager
.
get_default_device
(
PLACEMENT_POLICY
))
gemini_manager
=
GeminiManager
(
PLACEMENT_POLICY
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
)
logger
.
info
(
f
'
{
model
.
__class__
.
__name__
}
has been created'
,
ranks
=
[
0
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment