Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
fe0f7970
"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "e094933da1d0a574eda105ab6ec0f171d8ddaebb"
Unverified
Commit
fe0f7970
authored
Jan 10, 2023
by
ZijianYY
Committed by
GitHub
Jan 10, 2023
Browse files
[examples] adding tflops to PaLM (#2365)
parent
93f62dd1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
41 additions
and
8 deletions
+41
-8
examples/language/palm/train.py
examples/language/palm/train.py
+41
-8
No files found.
examples/language/palm/train.py
View file @
fe0f7970
import
gzip
import
gzip
import
random
import
random
from
time
import
time
from
functools
import
partial
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.optim
as
optim
import
torch.optim
as
optim
import
torch.nn
as
nn
import
tqdm
import
tqdm
from
packaging
import
version
from
packaging
import
version
from
palm_pytorch
import
PaLM
from
palm_pytorch
import
PaLM
...
@@ -21,7 +23,8 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
...
@@ -21,7 +23,8 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
# constants
# constants
NUM_BATCHES
=
int
(
1000
)
NUM_BATCHES
=
int
(
100
)
WARMUP_BATCHES
=
1
GRADIENT_ACCUMULATE_EVERY
=
1
GRADIENT_ACCUMULATE_EVERY
=
1
LEARNING_RATE
=
2e-4
LEARNING_RATE
=
2e-4
VALIDATE_EVERY
=
100
VALIDATE_EVERY
=
100
...
@@ -76,10 +79,18 @@ def cycle(loader):
...
@@ -76,10 +79,18 @@ def cycle(loader):
def
decode_token
(
token
):
def
decode_token
(
token
):
return
str
(
chr
(
max
(
32
,
token
)))
return
str
(
chr
(
max
(
32
,
token
)))
def
get_tflops
(
model_numel
,
batch_size
,
seq_len
,
step_time
):
return
model_numel
*
batch_size
*
seq_len
*
8
/
1e12
/
(
step_time
+
1e-12
)
def
decode_tokens
(
tokens
):
def
decode_tokens
(
tokens
):
return
""
.
join
(
list
(
map
(
decode_token
,
tokens
)))
return
""
.
join
(
list
(
map
(
decode_token
,
tokens
)))
def
get_model_size
(
model
:
nn
.
Module
):
total_numel
=
0
for
module
in
model
.
modules
():
for
p
in
module
.
parameters
(
recurse
=
False
):
total_numel
+=
p
.
numel
()
return
total_numel
# Gemini + ZeRO DDP
# Gemini + ZeRO DDP
def
gemini_zero_dpp
(
model
:
torch
.
nn
.
Module
,
pg
:
ProcessGroup
,
placememt_policy
:
str
=
"auto"
):
def
gemini_zero_dpp
(
model
:
torch
.
nn
.
Module
,
pg
:
ProcessGroup
,
placememt_policy
:
str
=
"auto"
):
...
@@ -143,7 +154,6 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
...
@@ -143,7 +154,6 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
split_param_row_tp1d
(
param
,
pg
)
# row slice
split_param_row_tp1d
(
param
,
pg
)
# row slice
else
:
else
:
param
.
set_dist_spec
(
ReplicaSpec
())
param
.
set_dist_spec
(
ReplicaSpec
())
param
.
visited
=
True
param
.
visited
=
True
...
@@ -152,6 +162,7 @@ if args.distplan not in ["colossalai", "pytorch"]:
...
@@ -152,6 +162,7 @@ if args.distplan not in ["colossalai", "pytorch"]:
raise
TypeError
(
f
"
{
args
.
distplan
}
is error"
)
raise
TypeError
(
f
"
{
args
.
distplan
}
is error"
)
disable_existing_loggers
()
disable_existing_loggers
()
colossalai
.
launch_from_torch
(
config
=
{})
colossalai
.
launch_from_torch
(
config
=
{})
logger
=
get_dist_logger
()
with
gzip
.
open
(
"./data/enwik8.gz"
)
as
file
:
with
gzip
.
open
(
"./data/enwik8.gz"
)
as
file
:
X
=
np
.
fromstring
(
file
.
read
(
int
(
95e6
)),
dtype
=
np
.
uint8
)
X
=
np
.
fromstring
(
file
.
read
(
int
(
95e6
)),
dtype
=
np
.
uint8
)
...
@@ -188,7 +199,7 @@ if args.distplan == "colossalai":
...
@@ -188,7 +199,7 @@ if args.distplan == "colossalai":
ctx
=
ColoInitContext
(
device
=
'cpu'
,
default_dist_spec
=
default_dist_spec
,
default_pg
=
default_pg
)
ctx
=
ColoInitContext
(
device
=
'cpu'
,
default_dist_spec
=
default_dist_spec
,
default_pg
=
default_pg
)
with
ctx
:
with
ctx
:
model
=
PaLM
(
num_tokens
=
256
,
dim
=
512
,
depth
=
8
)
model
=
PaLM
(
num_tokens
=
50304
,
dim
=
4096
,
depth
=
64
)
model
=
AutoregressiveWrapper
(
model
,
max_seq_len
=
SEQ_LEN
)
model
=
AutoregressiveWrapper
(
model
,
max_seq_len
=
SEQ_LEN
)
pg
=
default_pg
pg
=
default_pg
...
@@ -205,25 +216,42 @@ else:
...
@@ -205,25 +216,42 @@ else:
model
.
cuda
()
model
.
cuda
()
optim
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
LEARNING_RATE
)
optim
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
LEARNING_RATE
)
# model is shared after TP
numel
=
get_model_size
(
model
)
get_tflops_func
=
partial
(
get_tflops
,
numel
,
args
.
batch_size
,
SEQ_LEN
)
# training
# training
model
.
train
()
model
.
train
()
tflops_list
=
[]
for
i
in
tqdm
.
tqdm
(
range
(
NUM_BATCHES
),
mininterval
=
10.0
,
desc
=
"training"
):
for
i
in
tqdm
.
tqdm
(
range
(
NUM_BATCHES
),
mininterval
=
10.0
,
desc
=
"training"
):
if
args
.
distplan
==
"colossalai"
:
if
args
.
distplan
==
"colossalai"
:
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
start
=
time
()
loss
=
model
(
next
(
train_loader
))
loss
=
model
(
next
(
train_loader
))
fwd_end
=
time
()
fwd_time
=
fwd_end
-
start
# loss.backward()
# loss.backward()
optimizer
.
backward
(
loss
)
optimizer
.
backward
(
loss
)
bwd_end
=
time
()
bwd_time
=
bwd_end
-
fwd_end
print
(
f
"training loss:
{
loss
.
item
()
}
"
)
#
print(f"training loss: {loss.item()}")
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
0.5
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
0.5
)
# optim.step()
# optim.step()
# optim.zero_grad()
# optim.zero_grad()
optimizer
.
step
()
optimizer
.
step
()
optim_time
=
time
()
-
bwd_end
step_time
=
time
()
-
start
step_tflops
=
get_tflops_func
(
step_time
)
logger
.
info
(
f
"[
{
i
+
1
}
/
{
NUM_BATCHES
}
] Loss:
{
loss
.
item
():.
3
f
}
, Step time:
{
step_time
:.
3
f
}
s, TFLOPS:
{
get_tflops_func
(
step_time
):.
3
f
}
, FWD time:
{
fwd_time
:.
3
f
}
s, BWD time:
{
bwd_time
:.
3
f
}
s, OPTIM time:
{
optim_time
:.
3
f
}
s"
,
ranks
=
[
0
],
)
if
i
>=
WARMUP_BATCHES
:
tflops_list
.
append
(
step_tflops
)
else
:
else
:
for
__
in
range
(
GRADIENT_ACCUMULATE_EVERY
):
for
__
in
range
(
GRADIENT_ACCUMULATE_EVERY
):
loss
=
model
(
next
(
train_loader
))
loss
=
model
(
next
(
train_loader
))
...
@@ -233,6 +261,11 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
...
@@ -233,6 +261,11 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
0.5
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
0.5
)
optim
.
step
()
optim
.
step
()
optim
.
zero_grad
()
optim
.
zero_grad
()
tflops_list
.
sort
()
median_index
=
((
NUM_BATCHES
-
WARMUP_BATCHES
)
>>
1
)
+
WARMUP_BATCHES
logger
.
info
(
f
"Median TFLOPS is
{
tflops_list
[
median_index
]:.
3
f
}
"
)
# TODO
# TODO
# if i % VALIDATE_EVERY == 0:
# if i % VALIDATE_EVERY == 0:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment