Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f7fd592b
Unverified
Commit
f7fd592b
authored
Jan 05, 2023
by
ZijianYY
Committed by
GitHub
Jan 05, 2023
Browse files
[examples]adding tp to PaLM (#2319)
parent
9c9246c0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
43 additions
and
1 deletion
+43
-1
examples/language/palm/train.py
examples/language/palm/train.py
+43
-1
No files found.
examples/language/palm/train.py
View file @
f7fd592b
...
@@ -104,6 +104,48 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
...
@@ -104,6 +104,48 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
raise
NotImplemented
(
f
"CAI version
{
cai_version
}
is not supported"
)
raise
NotImplemented
(
f
"CAI version
{
cai_version
}
is not supported"
)
return
model
return
model
## Parameter Sharding Strategies for Tensor Parallelism
def
split_param_single_dim_tp1d
(
dim
:
int
,
param
:
ColoParameter
,
pg
:
ProcessGroup
):
spec
=
(
ShardSpec
([
dim
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
param
.
set_tensor_spec
(
*
spec
)
def
split_param_row_tp1d
(
param
:
ColoParameter
,
pg
:
ProcessGroup
):
split_param_single_dim_tp1d
(
0
,
param
,
pg
)
def
split_param_col_tp1d
(
param
:
ColoParameter
,
pg
:
ProcessGroup
):
split_param_single_dim_tp1d
(
-
1
,
param
,
pg
)
# Tensor Parallel
def
tensor_parallelize
(
model
:
torch
.
nn
.
Module
,
pg
:
ProcessGroup
):
"""tensor_parallelize
Sharding the Model Parameters.
Args:
model (torch.nn.Module): a torch module to be sharded
"""
for
mn
,
module
in
model
.
named_modules
():
for
pn
,
param
in
module
.
named_parameters
(
recurse
=
False
):
if
hasattr
(
param
,
'visited'
):
continue
param
.
set_dist_spec
(
ReplicaSpec
())
if
'net.0'
in
mn
:
split_param_col_tp1d
(
param
,
pg
)
# colmn slice
elif
'to_q'
in
mn
:
split_param_col_tp1d
(
param
,
pg
)
# colmn slice
elif
'to_kv'
in
mn
:
split_param_row_tp1d
(
param
,
pg
)
# row slice
elif
'to_out'
in
mn
:
split_param_row_tp1d
(
param
,
pg
)
# row slice
elif
'1.1'
in
mn
:
split_param_col_tp1d
(
param
,
pg
)
# colmn slice
elif
'1.2'
in
mn
:
split_param_row_tp1d
(
param
,
pg
)
# row slice
else
:
param
.
set_dist_spec
(
ReplicaSpec
())
param
.
visited
=
True
args
=
parse_args
()
args
=
parse_args
()
if
args
.
distplan
not
in
[
"colossalai"
,
"pytorch"
]:
if
args
.
distplan
not
in
[
"colossalai"
,
"pytorch"
]:
...
@@ -150,7 +192,7 @@ if args.distplan == "colossalai":
...
@@ -150,7 +192,7 @@ if args.distplan == "colossalai":
model
=
AutoregressiveWrapper
(
model
,
max_seq_len
=
SEQ_LEN
)
model
=
AutoregressiveWrapper
(
model
,
max_seq_len
=
SEQ_LEN
)
pg
=
default_pg
pg
=
default_pg
#
tensor_parallelize(model, pg)
tensor_parallelize
(
model
,
pg
)
model
=
gemini_zero_dpp
(
model
,
pg
,
args
.
placement
)
model
=
gemini_zero_dpp
(
model
,
pg
,
args
.
placement
)
#optimizer
#optimizer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment