Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
63cc7717
Unverified
Commit
63cc7717
authored
Dec 29, 2022
by
ZijianYY
Committed by
GitHub
Dec 29, 2022
Browse files
[example] Palm adding gemini, still has bugs (#2221)
parent
7010e181
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
82 additions
and
8 deletions
+82
-8
examples/language/palm/palm_config.py
examples/language/palm/palm_config.py
+6
-0
examples/language/palm/palm_pytorch/palm_pytorch.py
examples/language/palm/palm_pytorch/palm_pytorch.py
+3
-1
examples/language/palm/run.sh
examples/language/palm/run.sh
+1
-0
examples/language/palm/train.py
examples/language/palm/train.py
+72
-7
No files found.
examples/language/palm/palm_config.py
0 → 100644
View file @
63cc7717
SEQ_LENGTH
=
1024
BATCH_SIZE
=
4
NUM_EPOCHS
=
4
TPDEGREE
=
2
USE_SHARD_INIT
=
False
placement
=
'cpu'
\ No newline at end of file
examples/language/palm/palm_pytorch/palm_pytorch.py
View file @
63cc7717
...
@@ -47,7 +47,9 @@ class RotaryEmbedding(nn.Module):
...
@@ -47,7 +47,9 @@ class RotaryEmbedding(nn.Module):
def
forward
(
self
,
max_seq_len
,
*
,
device
):
def
forward
(
self
,
max_seq_len
,
*
,
device
):
seq
=
torch
.
arange
(
max_seq_len
,
device
=
device
)
seq
=
torch
.
arange
(
max_seq_len
,
device
=
device
)
#freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq)
#freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq)
freqs
=
torch
.
outer
(
seq
.
type_as
(
self
.
inv_freq
),
self
.
inv_freq
)
#freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
i
,
j
=
len
(
seq
.
type_as
(
self
.
inv_freq
)),
len
(
self
.
inv_freq
)
freqs
=
matmul
(
seq
.
type_as
(
self
.
inv_freq
).
reshape
(
i
,
1
),
self
.
inv_freq
.
reshape
(
1
,
j
))
return
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
return
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
...
...
examples/language/palm/run.sh
0 → 100644
View file @
63cc7717
env
OMP_NUM_THREADS
=
12 torchrun
--nproc_per_node
8
--master_port
29501 train.py
--config
palm_config.py
\ No newline at end of file
examples/language/palm/train.py
View file @
63cc7717
...
@@ -9,6 +9,16 @@ from palm_pytorch import PaLM
...
@@ -9,6 +9,16 @@ from palm_pytorch import PaLM
from
palm_pytorch.autoregressive_wrapper
import
AutoregressiveWrapper
from
palm_pytorch.autoregressive_wrapper
import
AutoregressiveWrapper
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
torch.utils.data
import
DataLoader
,
Dataset
from
torch.utils.data
import
DataLoader
,
Dataset
from
packaging
import
version
import
colossalai
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
ColoParameter
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ReplicaSpec
,
ShardSpec
from
colossalai.utils
import
MultiTimer
,
get_current_device
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.nn.optimizer.gemini_optimizer
import
GeminiAdamOptimizer
from
colossalai.nn.parallel
import
GeminiDDP
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
# constants
# constants
...
@@ -20,6 +30,9 @@ VALIDATE_EVERY = 100
...
@@ -20,6 +30,9 @@ VALIDATE_EVERY = 100
GENERATE_EVERY
=
500
GENERATE_EVERY
=
500
GENERATE_LENGTH
=
512
GENERATE_LENGTH
=
512
SEQ_LEN
=
1024
SEQ_LEN
=
1024
TPDEGREE
=
2
USE_SHARD_INIT
=
False
placement
=
'cpu'
# helpers
# helpers
...
@@ -37,16 +50,55 @@ def decode_token(token):
...
@@ -37,16 +50,55 @@ def decode_token(token):
def
decode_tokens
(
tokens
):
def
decode_tokens
(
tokens
):
return
""
.
join
(
list
(
map
(
decode_token
,
tokens
)))
return
""
.
join
(
list
(
map
(
decode_token
,
tokens
)))
# Gemini + ZeRO DDP
def
gemini_zero_dpp
(
model
:
torch
.
nn
.
Module
,
pg
:
ProcessGroup
,
placememt_policy
:
str
=
"auto"
):
cai_version
=
colossalai
.
__version__
if
version
.
parse
(
cai_version
)
>
version
.
parse
(
"0.1.10"
):
from
colossalai.nn.parallel
import
GeminiDDP
model
=
GeminiDDP
(
model
,
device
=
get_current_device
(),
placement_policy
=
placememt_policy
,
pin_memory
=
True
,
search_range_mb
=
32
)
elif
version
.
parse
(
cai_version
)
<=
version
.
parse
(
"0.1.10"
)
and
version
.
parse
(
cai_version
)
>=
version
.
parse
(
"0.1.9"
):
from
colossalai.gemini
import
ChunkManager
,
GeminiManager
chunk_size
=
ChunkManager
.
search_chunk_size
(
model
,
64
*
1024
**
2
,
32
)
gemini_manager
=
GeminiManager
(
placememt_policy
,
chunk_manager
)
chunk_manager
=
ChunkManager
(
chunk_size
,
pg
,
enable_distributed_storage
=
True
,
init_device
=
GeminiManager
.
get_default_device
(
placememt_policy
))
model
=
ZeroDDP
(
model
,
gemini_manager
)
else
:
raise
NotImplemented
(
f
"CAI version
{
cai_version
}
is not supported"
)
return
model
# instantiate GPT-like decoder model
parser
=
colossalai
.
get_default_parser
()
args
=
parser
.
parse_args
()
disable_existing_loggers
()
colossalai
.
launch_from_torch
(
config
=
args
.
config
,
seed
=
42
)
# instantiate GPT-like decoder model
# instantiate GPT-like decoder model
model
=
PaLM
(
num_tokens
=
256
,
dim
=
512
,
depth
=
8
)
default_pg
=
ProcessGroup
(
tp_degree
=
TPDEGREE
)
default_dist_spec
=
ShardSpec
([
-
1
],
[
TPDEGREE
])
if
USE_SHARD_INIT
else
None
ctx
=
ColoInitContext
(
device
=
'cpu'
,
default_dist_spec
=
default_dist_spec
,
default_pg
=
default_pg
)
with
ctx
:
model
=
PaLM
(
num_tokens
=
256
,
dim
=
512
,
depth
=
8
)
model
=
AutoregressiveWrapper
(
model
,
max_seq_len
=
SEQ_LEN
)
model
.
cuda
()
model
=
AutoregressiveWrapper
(
model
,
max_seq_len
=
2048
)
model
.
cuda
()
# prepare enwik8 data
# prepare enwik8 data
# model = PaLM(num_tokens=256, dim=512, depth=8)
# model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
# model.cuda()
with
gzip
.
open
(
"./data/enwik8.gz"
)
as
file
:
with
gzip
.
open
(
"./data/enwik8.gz"
)
as
file
:
X
=
np
.
fromstring
(
file
.
read
(
int
(
95e6
)),
dtype
=
np
.
uint8
)
X
=
np
.
fromstring
(
file
.
read
(
int
(
95e6
)),
dtype
=
np
.
uint8
)
trX
,
vaX
=
np
.
split
(
X
,
[
int
(
90e6
)])
trX
,
vaX
=
np
.
split
(
X
,
[
int
(
90e6
)])
...
@@ -74,9 +126,20 @@ val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
...
@@ -74,9 +126,20 @@ val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader
=
cycle
(
DataLoader
(
train_dataset
,
batch_size
=
BATCH_SIZE
))
train_loader
=
cycle
(
DataLoader
(
train_dataset
,
batch_size
=
BATCH_SIZE
))
val_loader
=
cycle
(
DataLoader
(
val_dataset
,
batch_size
=
BATCH_SIZE
))
val_loader
=
cycle
(
DataLoader
(
val_dataset
,
batch_size
=
BATCH_SIZE
))
# optimizer
#tensor_parallelize(model, pg)
pg
=
default_pg
# model = GeminiDDP(model,
# device=get_current_device(),
# placement_policy="auto",
# pin_memory=True,
# search_range_mb=32)
model
=
gemini_zero_dpp
(
model
,
pg
,
placement
)
#optimizer
optim
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
LEARNING_RATE
)
optimizer
=
GeminiAdamOptimizer
(
model
,
lr
=
1e-7
,
initial_scale
=
2
**
5
)
#optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# training
# training
...
@@ -89,8 +152,10 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
...
@@ -89,8 +152,10 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
print
(
f
"training loss:
{
loss
.
item
()
}
"
)
print
(
f
"training loss:
{
loss
.
item
()
}
"
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
0.5
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
0.5
)
optim
.
step
()
# optim.step()
optim
.
zero_grad
()
# optim.zero_grad()
optimizer
.
step
()
optimizer
.
zero_grad
()
if
i
%
VALIDATE_EVERY
==
0
:
if
i
%
VALIDATE_EVERY
==
0
:
model
.
eval
()
model
.
eval
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment