Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
2cdecc9f
Unverified
Commit
2cdecc9f
authored
Dec 29, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 29, 2022
Browse files
[example] make palm + GeminiDPP work (#2227)
parent
63cc7717
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
58 deletions
+41
-58
examples/language/palm/palm_pytorch/palm_pytorch.py
examples/language/palm/palm_pytorch/palm_pytorch.py
+3
-7
examples/language/palm/run.sh
examples/language/palm/run.sh
+1
-1
examples/language/palm/train.py
examples/language/palm/train.py
+37
-50
No files found.
examples/language/palm/palm_pytorch/palm_pytorch.py
View file @
2cdecc9f
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
einops
import
rearrange
from
torch
import
einsum
,
nn
,
matmul
from
torch
import
einsum
,
matmul
,
nn
# normalization
# normalization
# they use layernorm without bias, something that pytorch does not offer
# they use layernorm without bias, something that pytorch does not offer
...
@@ -86,8 +86,6 @@ def FeedForward(dim, mult=4):
...
@@ -86,8 +86,6 @@ def FeedForward(dim, mult=4):
# attention
# attention
class
Attention
(
nn
.
Module
):
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_head
=
64
,
heads
=
8
):
def
__init__
(
self
,
dim
,
dim_head
=
64
,
heads
=
8
):
...
@@ -142,8 +140,6 @@ class Attention(nn.Module):
...
@@ -142,8 +140,6 @@ class Attention(nn.Module):
q
,
k
,
v
=
(
self
.
to_q
(
x
),
*
self
.
to_kv
(
x
).
chunk
(
2
,
dim
=-
1
))
q
,
k
,
v
=
(
self
.
to_q
(
x
),
*
self
.
to_kv
(
x
).
chunk
(
2
,
dim
=-
1
))
# split heads
# split heads
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
# they found no performance loss past a certain scale, and more efficient decoding obviously
# they found no performance loss past a certain scale, and more efficient decoding obviously
...
@@ -165,7 +161,7 @@ class Attention(nn.Module):
...
@@ -165,7 +161,7 @@ class Attention(nn.Module):
# similarity
# similarity
#sim = einsum("b h i d, b j d -> b h i j", q, k)
#sim = einsum("b h i d, b j d -> b h i j", q, k)
sim
=
matmul
(
q
.
reshape
(
b
,
h
*
i
,
d
),
k
.
transpose
(
1
,
2
))
sim
=
matmul
(
q
.
reshape
(
b
,
h
*
i
,
d
),
k
.
transpose
(
1
,
2
))
sim
=
sim
.
reshape
(
b
,
h
,
i
,
j
)
sim
=
sim
.
reshape
(
b
,
h
,
i
,
j
)
# causal mask
# causal mask
...
@@ -183,7 +179,7 @@ class Attention(nn.Module):
...
@@ -183,7 +179,7 @@ class Attention(nn.Module):
# aggregate values
# aggregate values
#out = einsum("b h i j, b j d -> b h i d", attn, v)
#out = einsum("b h i j, b j d -> b h i d", attn, v)
out
=
matmul
(
attn
.
reshape
(
b_
,
h_
*
i_
,
j_
),
v
)
out
=
matmul
(
attn
.
reshape
(
b_
,
h_
*
i_
,
j_
),
v
)
out
=
out
.
reshape
(
b_
,
h_
,
i_
,
d_
)
out
=
out
.
reshape
(
b_
,
h_
,
i_
,
d_
)
# merge heads
# merge heads
...
...
examples/language/palm/run.sh
View file @
2cdecc9f
env
OMP_NUM_THREADS
=
12 torchrun
--nproc_per_node
8
--master_port
29501 train.py
--config
palm_config.py
env
OMP_NUM_THREADS
=
12 torchrun
--nproc_per_node
4
--master_port
29501 train.py
--config
palm_config.py
\ No newline at end of file
examples/language/palm/train.py
View file @
2cdecc9f
...
@@ -5,38 +5,36 @@ import numpy as np
...
@@ -5,38 +5,36 @@ import numpy as np
import
torch
import
torch
import
torch.optim
as
optim
import
torch.optim
as
optim
import
tqdm
import
tqdm
from
packaging
import
version
from
palm_pytorch
import
PaLM
from
palm_pytorch
import
PaLM
from
palm_pytorch.autoregressive_wrapper
import
AutoregressiveWrapper
from
palm_pytorch.autoregressive_wrapper
import
AutoregressiveWrapper
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
torch.utils.data
import
DataLoader
,
Dataset
from
torch.utils.data
import
DataLoader
,
Dataset
from
packaging
import
version
import
colossalai
import
colossalai
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.nn.optimizer.gemini_optimizer
import
GeminiAdamOptimizer
from
colossalai.nn.parallel
import
GeminiDDP
,
ZeroDDP
from
colossalai.tensor
import
ColoParameter
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ReplicaSpec
,
ShardSpec
from
colossalai.tensor
import
ColoParameter
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ReplicaSpec
,
ShardSpec
from
colossalai.utils
import
MultiTimer
,
get_current_device
from
colossalai.utils
import
MultiTimer
,
get_current_device
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.nn.optimizer.gemini_optimizer
import
GeminiAdamOptimizer
from
colossalai.nn.parallel
import
GeminiDDP
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
# constants
# constants
NUM_BATCHES
=
int
(
1e5
)
NUM_BATCHES
=
int
(
20
)
BATCH_SIZE
=
4
BATCH_SIZE
=
4
GRADIENT_ACCUMULATE_EVERY
=
4
GRADIENT_ACCUMULATE_EVERY
=
1
LEARNING_RATE
=
2e-4
LEARNING_RATE
=
2e-4
VALIDATE_EVERY
=
100
VALIDATE_EVERY
=
100
GENERATE_EVERY
=
500
GENERATE_EVERY
=
500
GENERATE_LENGTH
=
512
GENERATE_LENGTH
=
512
SEQ_LEN
=
1024
SEQ_LEN
=
1024
TPDEGREE
=
2
TPDEGREE
=
1
USE_SHARD_INIT
=
False
USE_SHARD_INIT
=
False
placement
=
'cpu'
placement
=
'cpu'
# helpers
# helpers
def
cycle
(
loader
):
def
cycle
(
loader
):
while
True
:
while
True
:
for
data
in
loader
:
for
data
in
loader
:
...
@@ -50,6 +48,7 @@ def decode_token(token):
...
@@ -50,6 +48,7 @@ def decode_token(token):
def
decode_tokens
(
tokens
):
def
decode_tokens
(
tokens
):
return
""
.
join
(
list
(
map
(
decode_token
,
tokens
)))
return
""
.
join
(
list
(
map
(
decode_token
,
tokens
)))
# Gemini + ZeRO DDP
# Gemini + ZeRO DDP
def
gemini_zero_dpp
(
model
:
torch
.
nn
.
Module
,
pg
:
ProcessGroup
,
placememt_policy
:
str
=
"auto"
):
def
gemini_zero_dpp
(
model
:
torch
.
nn
.
Module
,
pg
:
ProcessGroup
,
placememt_policy
:
str
=
"auto"
):
cai_version
=
colossalai
.
__version__
cai_version
=
colossalai
.
__version__
...
@@ -72,7 +71,8 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
...
@@ -72,7 +71,8 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
else
:
else
:
raise
NotImplemented
(
f
"CAI version
{
cai_version
}
is not supported"
)
raise
NotImplemented
(
f
"CAI version
{
cai_version
}
is not supported"
)
return
model
return
model
# instantiate GPT-like decoder model
# instantiate GPT-like decoder model
parser
=
colossalai
.
get_default_parser
()
parser
=
colossalai
.
get_default_parser
()
...
@@ -80,24 +80,15 @@ args = parser.parse_args()
...
@@ -80,24 +80,15 @@ args = parser.parse_args()
disable_existing_loggers
()
disable_existing_loggers
()
colossalai
.
launch_from_torch
(
config
=
args
.
config
,
seed
=
42
)
colossalai
.
launch_from_torch
(
config
=
args
.
config
,
seed
=
42
)
# instantiate GPT-like decoder model
# instantiate GPT-like decoder model
default_pg
=
ProcessGroup
(
tp_degree
=
TPDEGREE
)
default_pg
=
ProcessGroup
(
tp_degree
=
TPDEGREE
)
default_dist_spec
=
ShardSpec
([
-
1
],
[
TPDEGREE
])
if
USE_SHARD_INIT
else
None
default_dist_spec
=
ShardSpec
([
-
1
],
[
TPDEGREE
])
if
USE_SHARD_INIT
else
None
ctx
=
ColoInitContext
(
device
=
'cpu'
,
default_dist_spec
=
default_dist_spec
,
default_pg
=
default_pg
)
ctx
=
ColoInitContext
(
device
=
'cpu'
,
default_dist_spec
=
default_dist_spec
,
default_pg
=
default_pg
)
with
ctx
:
model
=
PaLM
(
num_tokens
=
256
,
dim
=
512
,
depth
=
8
)
model
=
AutoregressiveWrapper
(
model
,
max_seq_len
=
SEQ_LEN
)
model
.
cuda
()
# prepare enwik8 data
# model = PaLM(num_tokens=256, dim=512, depth=8)
with
ctx
:
model
=
PaLM
(
num_tokens
=
256
,
dim
=
512
,
depth
=
8
)
# model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
model
=
AutoregressiveWrapper
(
model
,
max_seq_len
=
SEQ_LEN
)
# model.cuda()
with
gzip
.
open
(
"./data/enwik8.gz"
)
as
file
:
with
gzip
.
open
(
"./data/enwik8.gz"
)
as
file
:
X
=
np
.
fromstring
(
file
.
read
(
int
(
95e6
)),
dtype
=
np
.
uint8
)
X
=
np
.
fromstring
(
file
.
read
(
int
(
95e6
)),
dtype
=
np
.
uint8
)
...
@@ -129,46 +120,42 @@ val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))
...
@@ -129,46 +120,42 @@ val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))
#tensor_parallelize(model, pg)
#tensor_parallelize(model, pg)
pg
=
default_pg
pg
=
default_pg
# model = GeminiDDP(model,
# device=get_current_device(),
# placement_policy="auto",
# pin_memory=True,
# search_range_mb=32)
model
=
gemini_zero_dpp
(
model
,
pg
,
placement
)
model
=
gemini_zero_dpp
(
model
,
pg
,
placement
)
#optimizer
#optimizer
optimizer
=
GeminiAdamOptimizer
(
model
,
lr
=
1e-7
,
initial_scale
=
2
**
5
)
optimizer
=
GeminiAdamOptimizer
(
model
,
lr
=
1e-7
,
initial_scale
=
2
**
5
)
#optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# training
# training
model
.
train
()
for
i
in
tqdm
.
tqdm
(
range
(
NUM_BATCHES
),
mininterval
=
10.0
,
desc
=
"training"
):
for
i
in
tqdm
.
tqdm
(
range
(
NUM_BATCHES
),
mininterval
=
10.0
,
desc
=
"training"
):
model
.
train
()
for
__
in
range
(
GRADIENT_ACCUMULATE_EVERY
):
optimizer
.
zero_grad
()
loss
=
model
(
next
(
train_loader
))
loss
.
backward
()
loss
=
model
(
next
(
train_loader
))
# loss.backward()
optimizer
.
backward
(
loss
)
print
(
f
"training loss:
{
loss
.
item
()
}
"
)
print
(
f
"training loss:
{
loss
.
item
()
}
"
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
0.5
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
0.5
)
# optim.step()
# optim.step()
# optim.zero_grad()
# optim.zero_grad()
optimizer
.
step
()
optimizer
.
step
()
optimizer
.
zero_grad
()
if
i
%
VALIDATE_EVERY
==
0
:
# TODO
model
.
eval
()
# if i % VALIDATE_EVERY == 0:
with
torch
.
no_grad
():
# model.eval()
loss
=
model
(
next
(
val_loader
))
# with torch.no_grad():
print
(
f
"validation loss:
{
loss
.
item
()
}
"
)
# loss = model(next(val_loader))
# print(f"validation loss: {loss.item()}")
if
i
%
GENERATE_EVERY
==
0
:
model
.
eval
()
# if i % GENERATE_EVERY == 0:
inp
=
random
.
choice
(
val_dataset
)[:
-
1
]
# model.eval()
prime
=
decode_tokens
(
inp
)
# inp = random.choice(val_dataset)[:-1]
print
(
f
"%s
\n\n
%s"
,
(
prime
,
"*"
*
100
))
# prime = decode_tokens(inp)
# print(f"%s \n\n %s", (prime, "*" * 100))
sample
=
model
.
generate
(
inp
[
None
,
...],
GENERATE_LENGTH
)
output_str
=
decode_tokens
(
sample
[
0
])
# sample = model.generate(inp[None, ...], GENERATE_LENGTH)
print
(
output_str
)
# output_str = decode_tokens(sample[0])
# print(output_str)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment