Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1aaeb596
Unverified
Commit
1aaeb596
authored
Jan 06, 2023
by
Jiarui Fang
Committed by
GitHub
Jan 06, 2023
Browse files
[example] gpt, shard init on all processes (#2366)
parent
1f8ab6f1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
12 deletions
+18
-12
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+4
-4
examples/language/gpt/gemini/train_gpt_demo.py
examples/language/gpt/gemini/train_gpt_demo.py
+14
-8
No files found.
colossalai/tensor/colo_tensor.py
View file @
1aaeb596
...
@@ -117,7 +117,7 @@ class ColoTensor(torch.Tensor):
...
@@ -117,7 +117,7 @@ class ColoTensor(torch.Tensor):
def
set_process_group
(
self
,
pg
:
ProcessGroup
):
def
set_process_group
(
self
,
pg
:
ProcessGroup
):
"""set_process_group
"""set_process_group
change the pg of the ColoTensor. Note that the valid use cases is limited.
change the pg of the ColoTensor. Note that the valid use cases is limited.
Only existing pg is DP and dist spec is REPLICaTE is valid
.
It works for the target pg is DP and TP only and current dist spec of the Tensor is Replica
.
Args:
Args:
pg (ProcessGroup): target pg
pg (ProcessGroup): target pg
...
@@ -127,10 +127,10 @@ class ColoTensor(torch.Tensor):
...
@@ -127,10 +127,10 @@ class ColoTensor(torch.Tensor):
# if the new pg is the same as the old pg, just returns
# if the new pg is the same as the old pg, just returns
if
self
.
process_group
==
pg
:
if
self
.
process_group
==
pg
:
return
return
assert
self
.
process_group
.
tp_world_size
()
==
1
,
\
assert
self
.
process_group
.
tp_world_size
()
==
1
or
self
.
process_group
.
dp_world_size
()
==
1
,
\
"Can not set_process_group on a ColoTensor whose process_group
has tp
world group"
"Can not set_process_group on a ColoTensor whose process_group
is both tp > 1 and
world group
> 1
"
assert
self
.
dist_spec
.
placement
.
value
==
'r'
,
\
assert
self
.
dist_spec
.
placement
.
value
==
'r'
,
\
"Can not set_process_group on a ColoTensor whose dist spec is not R
EPLICATE
"
"Can not set_process_group on a ColoTensor whose dist spec is not R
eplica
"
self
.
process_group
=
pg
self
.
process_group
=
pg
...
...
examples/language/gpt/gemini/train_gpt_demo.py
View file @
1aaeb596
...
@@ -148,10 +148,16 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
...
@@ -148,10 +148,16 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
"""
"""
for
mn
,
module
in
model
.
named_modules
():
for
mn
,
module
in
model
.
named_modules
():
for
pn
,
param
in
module
.
named_parameters
(
recurse
=
False
):
for
pn
,
param
in
module
.
named_parameters
(
recurse
=
False
):
# NOTE() a param maybe shared by t
o
w modules
# NOTE() a param maybe shared by tw
o
modules
if
hasattr
(
param
,
'visited'
):
if
hasattr
(
param
,
'visited'
):
continue
continue
# if shard init, then convert param to replica and use the dp-only ProcessGroup
param
:
ColoParameter
=
param
param
.
set_dist_spec
(
ReplicaSpec
())
param
.
set_dist_spec
(
ReplicaSpec
())
param
.
set_process_group
(
pg
)
# shard it w.r.t tp pattern
if
'mlp.c_fc'
in
mn
:
if
'mlp.c_fc'
in
mn
:
if
'weight'
in
pn
or
'bias'
in
pn
:
if
'weight'
in
pn
or
'bias'
in
pn
:
split_param_col_tp1d
(
param
,
pg
)
# colmn slice
split_param_col_tp1d
(
param
,
pg
)
# colmn slice
...
@@ -170,7 +176,6 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
...
@@ -170,7 +176,6 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
split_param_col_tp1d
(
param
,
pg
)
# colmn slice
split_param_col_tp1d
(
param
,
pg
)
# colmn slice
else
:
else
:
param
.
set_dist_spec
(
ReplicaSpec
())
param
.
set_dist_spec
(
ReplicaSpec
())
param
.
visited
=
True
param
.
visited
=
True
...
@@ -248,27 +253,28 @@ def main():
...
@@ -248,27 +253,28 @@ def main():
torch
.
manual_seed
(
123
)
torch
.
manual_seed
(
123
)
if
args
.
distplan
==
"colossalai"
:
if
args
.
distplan
==
"colossalai"
:
# all param must use the same process group.
# all param must use the same process group.
default_pg
=
ProcessGroup
(
tp_degree
=
args
.
tp_degree
)
world_size
=
torch
.
distributed
.
get_world_size
()
default_dist_spec
=
ShardSpec
([
-
1
],
[
args
.
tp_degree
])
if
args
.
shardinit
else
None
shard_pg
=
ProcessGroup
(
tp_degree
=
world_size
)
default_dist_spec
=
ShardSpec
([
-
1
],
[
world_size
])
if
args
.
shardinit
else
None
# build GPT model
# build GPT model
if
version
.
parse
(
CAI_VERSION
)
>
version
.
parse
(
"0.1.10"
):
if
version
.
parse
(
CAI_VERSION
)
>
version
.
parse
(
"0.1.10"
):
with
ColoInitContext
(
device
=
get_current_device
(),
with
ColoInitContext
(
device
=
get_current_device
(),
dtype
=
torch
.
half
,
dtype
=
torch
.
half
,
default_dist_spec
=
default_dist_spec
,
default_dist_spec
=
default_dist_spec
,
default_pg
=
default
_pg
):
default_pg
=
shard
_pg
):
model
=
model_builder
(
args
.
model_type
)(
checkpoint
=
True
)
model
=
model_builder
(
args
.
model_type
)(
checkpoint
=
True
)
else
:
else
:
with
ColoInitContext
(
device
=
get_current_device
()):
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
(
args
.
model_type
)(
checkpoint
=
True
)
model
=
model_builder
(
args
.
model_type
)(
checkpoint
=
True
)
pg
=
default_pg
tp_
pg
=
ProcessGroup
(
tp_degree
=
args
.
tp_degree
)
# Tensor Parallelism (TP)
# Tensor Parallelism (TP)
tensor_parallelize
(
model
,
pg
)
tensor_parallelize
(
model
,
tp_
pg
)
# build a Gemini model and a highly optimized cpu optimizer
# build a Gemini model and a highly optimized cpu optimizer
# Gemini + ZeRO DP, Note it must be used after TP
# Gemini + ZeRO DP, Note it must be used after TP
model
,
optimizer
=
build_gemini
(
model
,
pg
,
args
.
placement
)
model
,
optimizer
=
build_gemini
(
model
,
tp_
pg
,
args
.
placement
)
logger
.
info
(
get_mem_info
(
prefix
=
'After init optim, '
),
ranks
=
[
0
])
logger
.
info
(
get_mem_info
(
prefix
=
'After init optim, '
),
ranks
=
[
0
])
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment