Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e64a05b3
Commit
e64a05b3
authored
Jan 16, 2023
by
jiaruifang
Browse files
polish code
parent
9cba38b4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
16 deletions
+22
-16
examples/language/palm/train.py
examples/language/palm/train.py
+22
-16
No files found.
examples/language/palm/train.py
View file @
e64a05b3
import
gzip
import
gzip
import
random
import
random
from
time
import
time
from
functools
import
partial
from
functools
import
partial
from
time
import
time
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.optim
as
optim
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.optim
as
optim
import
tqdm
import
tqdm
from
packaging
import
version
from
packaging
import
version
from
palm_pytorch
import
PaLM
from
palm_pytorch
import
PaLM
from
palm_pytorch.autoregressive_wrapper
import
AutoregressiveWrapper
from
palm_pytorch.autoregressive_wrapper
import
AutoregressiveWrapper
from
torch.nn
import
functional
as
F
from
torch.utils.data
import
DataLoader
,
Dataset
from
torch.utils.data
import
DataLoader
,
Dataset
import
colossalai
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.nn.optimizer.gemini_optimizer
import
GeminiAdamOptimizer
from
colossalai.nn.optimizer.gemini_optimizer
import
GeminiAdamOptimizer
from
colossalai.nn.parallel
import
GeminiDDP
,
ZeroDDP
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.tensor
import
ColoParameter
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ReplicaSpec
,
ShardSpec
from
colossalai.tensor
import
ColoParameter
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ReplicaSpec
,
ShardSpec
from
colossalai.utils
import
MultiTimer
,
get_current_device
from
colossalai.utils
import
MultiTimer
,
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
...
@@ -69,6 +69,7 @@ def parse_args():
...
@@ -69,6 +69,7 @@ def parse_args():
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
return
args
return
args
# helpers
# helpers
def
cycle
(
loader
):
def
cycle
(
loader
):
while
True
:
while
True
:
...
@@ -79,12 +80,15 @@ def cycle(loader):
...
@@ -79,12 +80,15 @@ def cycle(loader):
def
decode_token
(
token
):
def
decode_token
(
token
):
return
str
(
chr
(
max
(
32
,
token
)))
return
str
(
chr
(
max
(
32
,
token
)))
def
get_tflops
(
model_numel
,
batch_size
,
seq_len
,
step_time
):
def
get_tflops
(
model_numel
,
batch_size
,
seq_len
,
step_time
):
return
model_numel
*
batch_size
*
seq_len
*
8
/
1e12
/
(
step_time
+
1e-12
)
return
model_numel
*
batch_size
*
seq_len
*
8
/
1e12
/
(
step_time
+
1e-12
)
def
decode_tokens
(
tokens
):
def
decode_tokens
(
tokens
):
return
""
.
join
(
list
(
map
(
decode_token
,
tokens
)))
return
""
.
join
(
list
(
map
(
decode_token
,
tokens
)))
def
get_model_size
(
model
:
nn
.
Module
):
def
get_model_size
(
model
:
nn
.
Module
):
total_numel
=
0
total_numel
=
0
for
module
in
model
.
modules
():
for
module
in
model
.
modules
():
...
@@ -92,6 +96,7 @@ def get_model_size(model: nn.Module):
...
@@ -92,6 +96,7 @@ def get_model_size(model: nn.Module):
total_numel
+=
p
.
numel
()
total_numel
+=
p
.
numel
()
return
total_numel
return
total_numel
# Gemini + ZeRO DDP
# Gemini + ZeRO DDP
def
gemini_zero_dpp
(
model
:
torch
.
nn
.
Module
,
pg
:
ProcessGroup
,
placememt_policy
:
str
=
"auto"
):
def
gemini_zero_dpp
(
model
:
torch
.
nn
.
Module
,
pg
:
ProcessGroup
,
placememt_policy
:
str
=
"auto"
):
cai_version
=
colossalai
.
__version__
cai_version
=
colossalai
.
__version__
...
@@ -115,6 +120,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
...
@@ -115,6 +120,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
raise
NotImplemented
(
f
"CAI version
{
cai_version
}
is not supported"
)
raise
NotImplemented
(
f
"CAI version
{
cai_version
}
is not supported"
)
return
model
return
model
## Parameter Sharding Strategies for Tensor Parallelism
## Parameter Sharding Strategies for Tensor Parallelism
def
split_param_single_dim_tp1d
(
dim
:
int
,
param
:
ColoParameter
,
pg
:
ProcessGroup
):
def
split_param_single_dim_tp1d
(
dim
:
int
,
param
:
ColoParameter
,
pg
:
ProcessGroup
):
spec
=
(
ShardSpec
([
dim
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
spec
=
(
ShardSpec
([
dim
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
...
@@ -128,6 +134,7 @@ def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
...
@@ -128,6 +134,7 @@ def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
def
split_param_col_tp1d
(
param
:
ColoParameter
,
pg
:
ProcessGroup
):
def
split_param_col_tp1d
(
param
:
ColoParameter
,
pg
:
ProcessGroup
):
split_param_single_dim_tp1d
(
-
1
,
param
,
pg
)
split_param_single_dim_tp1d
(
-
1
,
param
,
pg
)
# Tensor Parallel
# Tensor Parallel
def
tensor_parallelize
(
model
:
torch
.
nn
.
Module
,
pg
:
ProcessGroup
):
def
tensor_parallelize
(
model
:
torch
.
nn
.
Module
,
pg
:
ProcessGroup
):
"""tensor_parallelize
"""tensor_parallelize
...
@@ -159,7 +166,7 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
...
@@ -159,7 +166,7 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
args
=
parse_args
()
args
=
parse_args
()
if
args
.
distplan
not
in
[
"colossalai"
,
"pytorch"
]:
if
args
.
distplan
not
in
[
"colossalai"
,
"pytorch"
]:
raise
TypeError
(
f
"
{
args
.
distplan
}
is error"
)
raise
TypeError
(
f
"
{
args
.
distplan
}
is error"
)
disable_existing_loggers
()
disable_existing_loggers
()
colossalai
.
launch_from_torch
(
config
=
{})
colossalai
.
launch_from_torch
(
config
=
{})
logger
=
get_dist_logger
()
logger
=
get_dist_logger
()
...
@@ -216,7 +223,7 @@ else:
...
@@ -216,7 +223,7 @@ else:
model
.
cuda
()
model
.
cuda
()
optim
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
LEARNING_RATE
)
optim
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
LEARNING_RATE
)
# model is shared after TP
# model is shared after TP
numel
=
get_model_size
(
model
)
numel
=
get_model_size
(
model
)
get_tflops_func
=
partial
(
get_tflops
,
numel
,
args
.
batch_size
,
SEQ_LEN
)
get_tflops_func
=
partial
(
get_tflops
,
numel
,
args
.
batch_size
,
SEQ_LEN
)
...
@@ -251,7 +258,7 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
...
@@ -251,7 +258,7 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
)
)
if
i
>=
WARMUP_BATCHES
:
if
i
>=
WARMUP_BATCHES
:
tflops_list
.
append
(
step_tflops
)
tflops_list
.
append
(
step_tflops
)
else
:
else
:
for
__
in
range
(
GRADIENT_ACCUMULATE_EVERY
):
for
__
in
range
(
GRADIENT_ACCUMULATE_EVERY
):
loss
=
model
(
next
(
train_loader
))
loss
=
model
(
next
(
train_loader
))
...
@@ -261,18 +268,17 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
...
@@ -261,18 +268,17 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
0.5
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
0.5
)
optim
.
step
()
optim
.
step
()
optim
.
zero_grad
()
optim
.
zero_grad
()
tflops_list
.
sort
()
tflops_list
.
sort
()
median_index
=
((
NUM_BATCHES
-
WARMUP_BATCHES
)
>>
1
)
+
WARMUP_BATCHES
median_index
=
((
NUM_BATCHES
-
WARMUP_BATCHES
)
>>
1
)
+
WARMUP_BATCHES
logger
.
info
(
f
"Median TFLOPS is
{
tflops_list
[
median_index
]:.
3
f
}
"
)
logger
.
info
(
f
"Median TFLOPS is
{
tflops_list
[
median_index
]:.
3
f
}
"
)
# TODO
# TODO
# if i % VALIDATE_EVERY == 0:
# if i % VALIDATE_EVERY == 0:
# model.eval()
# model.eval()
# with torch.no_grad():
# with torch.no_grad():
# loss = model(next(val_loader))
# loss = model(next(val_loader))
# print(f"validation loss: {loss.item()}")
# print(f"validation loss: {loss.item()}")
# if i % GENERATE_EVERY == 0:
# if i % GENERATE_EVERY == 0:
# model.eval()
# model.eval()
...
@@ -282,4 +288,4 @@ logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
...
@@ -282,4 +288,4 @@ logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
# sample = model.generate(inp[None, ...], GENERATE_LENGTH)
# sample = model.generate(inp[None, ...], GENERATE_LENGTH)
# output_str = decode_tokens(sample[0])
# output_str = decode_tokens(sample[0])
# print(output_str)
# print(output_str)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment