Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e64a05b3
Commit
e64a05b3
authored
Jan 16, 2023
by
jiaruifang
Browse files
polish code
parent
9cba38b4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
16 deletions
+22
-16
examples/language/palm/train.py
examples/language/palm/train.py
+22
-16
No files found.
examples/language/palm/train.py
View file @
e64a05b3
import
gzip
import
random
from
time
import
time
from
functools
import
partial
from
time
import
time
import
numpy
as
np
import
torch
import
torch.optim
as
optim
import
torch.nn
as
nn
import
torch.optim
as
optim
import
tqdm
from
packaging
import
version
from
palm_pytorch
import
PaLM
from
palm_pytorch.autoregressive_wrapper
import
AutoregressiveWrapper
from
torch.nn
import
functional
as
F
from
torch.utils.data
import
DataLoader
,
Dataset
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.nn.optimizer.gemini_optimizer
import
GeminiAdamOptimizer
from
colossalai.nn.parallel
import
GeminiDDP
,
ZeroDDP
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.tensor
import
ColoParameter
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ReplicaSpec
,
ShardSpec
from
colossalai.utils
import
MultiTimer
,
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
...
...
@@ -69,6 +69,7 @@ def parse_args():
args
=
parser
.
parse_args
()
return
args
# helpers
def
cycle
(
loader
):
while
True
:
...
...
@@ -79,12 +80,15 @@ def cycle(loader):
def
decode_token
(
token
):
return
str
(
chr
(
max
(
32
,
token
)))
def
get_tflops
(
model_numel
,
batch_size
,
seq_len
,
step_time
):
return
model_numel
*
batch_size
*
seq_len
*
8
/
1e12
/
(
step_time
+
1e-12
)
def
decode_tokens
(
tokens
):
return
""
.
join
(
list
(
map
(
decode_token
,
tokens
)))
def
get_model_size
(
model
:
nn
.
Module
):
total_numel
=
0
for
module
in
model
.
modules
():
...
...
@@ -92,6 +96,7 @@ def get_model_size(model: nn.Module):
total_numel
+=
p
.
numel
()
return
total_numel
# Gemini + ZeRO DDP
def
gemini_zero_dpp
(
model
:
torch
.
nn
.
Module
,
pg
:
ProcessGroup
,
placememt_policy
:
str
=
"auto"
):
cai_version
=
colossalai
.
__version__
...
...
@@ -115,6 +120,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
raise
NotImplemented
(
f
"CAI version
{
cai_version
}
is not supported"
)
return
model
## Parameter Sharding Strategies for Tensor Parallelism
def
split_param_single_dim_tp1d
(
dim
:
int
,
param
:
ColoParameter
,
pg
:
ProcessGroup
):
spec
=
(
ShardSpec
([
dim
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
...
...
@@ -128,6 +134,7 @@ def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
def
split_param_col_tp1d
(
param
:
ColoParameter
,
pg
:
ProcessGroup
):
split_param_single_dim_tp1d
(
-
1
,
param
,
pg
)
# Tensor Parallel
def
tensor_parallelize
(
model
:
torch
.
nn
.
Module
,
pg
:
ProcessGroup
):
"""tensor_parallelize
...
...
@@ -216,7 +223,7 @@ else:
model
.
cuda
()
optim
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
LEARNING_RATE
)
# model is shared after TP
# model is shared after TP
numel
=
get_model_size
(
model
)
get_tflops_func
=
partial
(
get_tflops
,
numel
,
args
.
batch_size
,
SEQ_LEN
)
...
...
@@ -266,13 +273,12 @@ tflops_list.sort()
median_index
=
((
NUM_BATCHES
-
WARMUP_BATCHES
)
>>
1
)
+
WARMUP_BATCHES
logger
.
info
(
f
"Median TFLOPS is
{
tflops_list
[
median_index
]:.
3
f
}
"
)
# TODO
# if i % VALIDATE_EVERY == 0:
# model.eval()
# with torch.no_grad():
# loss = model(next(val_loader))
# print(f"validation loss: {loss.item()}")
# TODO
# if i % VALIDATE_EVERY == 0:
# model.eval()
# with torch.no_grad():
# loss = model(next(val_loader))
# print(f"validation loss: {loss.item()}")
# if i % GENERATE_EVERY == 0:
# model.eval()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment