Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e58cc441
Commit
e58cc441
authored
Jan 18, 2023
by
jiaruifang
Browse files
polish code and fix dataloader bugs
parent
a4b75b78
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
65 deletions
+35
-65
examples/language/gpt/titans/dataset/webtext.py
examples/language/gpt/titans/dataset/webtext.py
+23
-19
examples/language/gpt/titans/run.sh
examples/language/gpt/titans/run.sh
+2
-1
examples/language/gpt/titans/train_gpt.py
examples/language/gpt/titans/train_gpt.py
+10
-45
No files found.
examples/language/gpt/titans/dataset/webtext.py
View file @
e58cc441
import
json
import
json
import
os
import
os
from
typing
import
Optional
import
torch
import
torch
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
...
@@ -11,26 +12,29 @@ from colossalai.registry import DATASETS
...
@@ -11,26 +12,29 @@ from colossalai.registry import DATASETS
@
DATASETS
.
register_module
@
DATASETS
.
register_module
class
WebtextDataset
(
Dataset
):
class
WebtextDataset
(
Dataset
):
def
__init__
(
self
,
path
,
seq_len
=
1024
)
->
None
:
def
__init__
(
self
,
path
:
Optional
[
str
]
=
None
,
seq_len
=
1024
)
->
None
:
super
().
__init__
()
super
().
__init__
()
root
=
os
.
path
.
dirname
(
path
)
if
path
is
not
None
:
encoded_data_cache_path
=
os
.
path
.
join
(
root
,
f
'gpt_webtext_
{
seq_len
}
.pt'
)
root
=
os
.
path
.
dirname
(
path
)
if
os
.
path
.
isfile
(
encoded_data_cache_path
):
encoded_data_cache_path
=
os
.
path
.
join
(
root
,
f
'gpt_webtext_
{
seq_len
}
.pt'
)
seq_len_
,
data
,
attention_mask
=
torch
.
load
(
encoded_data_cache_path
)
if
os
.
path
.
isfile
(
encoded_data_cache_path
):
if
seq_len_
==
seq_len
:
seq_len_
,
data
,
attention_mask
=
torch
.
load
(
encoded_data_cache_path
)
self
.
data
=
data
if
seq_len_
==
seq_len
:
self
.
attention_mask
=
attention_mask
self
.
data
=
data
return
self
.
attention_mask
=
attention_mask
raw_data
=
[]
return
with
open
(
path
)
as
f
:
raw_data
=
[]
for
line
in
f
.
readlines
():
with
open
(
path
)
as
f
:
raw_data
.
append
(
json
.
loads
(
line
)[
'text'
])
for
line
in
f
.
readlines
():
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
'gpt2'
)
raw_data
.
append
(
json
.
loads
(
line
)[
'text'
])
tokenizer
.
pad_token
=
tokenizer
.
unk_token
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
'gpt2'
)
encoded_data
=
tokenizer
(
raw_data
,
padding
=
True
,
truncation
=
True
,
max_length
=
seq_len
,
return_tensors
=
'pt'
)
tokenizer
.
pad_token
=
tokenizer
.
unk_token
self
.
data
=
encoded_data
[
'input_ids'
]
encoded_data
=
tokenizer
(
raw_data
,
padding
=
True
,
truncation
=
True
,
max_length
=
seq_len
,
return_tensors
=
'pt'
)
self
.
attention_mask
=
encoded_data
[
'attention_mask'
]
self
.
data
=
encoded_data
[
'input_ids'
]
torch
.
save
((
seq_len
,
self
.
data
,
self
.
attention_mask
),
encoded_data_cache_path
)
self
.
attention_mask
=
encoded_data
[
'attention_mask'
]
else
:
self
.
data
=
torch
.
randint
(
0
,
50257
,
(
10240
,
seq_len
))
self
.
attention_mask
=
torch
.
ones_like
(
self
.
data
)
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
data
)
return
len
(
self
.
data
)
...
...
examples/language/gpt/titans/run.sh
View file @
e58cc441
export
DATA
=
/data/scratch/gpt_data/small-gpt-dataset.json
export
DATA
=
/data/scratch/gpt_data/small-gpt-dataset.json
colossalai run
--nproc_per_node
=
4 train_gpt.py
--config
./configs/gpt2_small_zero3_pp1d.py
--from_torch
DUMMY_DATA
=
--use_dummy_dataset
colossalai run
--nproc_per_node
=
2 train_gpt.py
--config
./configs/gpt2_small_zero3_pp1d.py
--from_torch
$DUMMY_DATA
examples/language/gpt/titans/train_gpt.py
View file @
e58cc441
...
@@ -3,6 +3,7 @@ import os
...
@@ -3,6 +3,7 @@ import os
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
dataset.webtext
import
WebtextDataset
from
titans.model.gpt
import
GPTLMLoss
from
titans.model.gpt
import
GPTLMLoss
import
colossalai
import
colossalai
...
@@ -39,52 +40,16 @@ def main():
...
@@ -39,52 +40,16 @@ def main():
colossalai
.
launch_from_slurm
(
config
=
args
.
config
,
host
=
args
.
host
,
port
=
29500
,
seed
=
42
)
colossalai
.
launch_from_slurm
(
config
=
args
.
config
,
host
=
args
.
host
,
port
=
29500
,
seed
=
42
)
logger
=
get_dist_logger
()
logger
=
get_dist_logger
()
if
not
args
.
use_dummy_dataset
:
data_path
=
None
if
args
.
use_dummy_dataset
else
os
.
environ
[
'DATA'
]
data_path
=
os
.
environ
[
'DATA'
]
logger
.
info
(
f
'Build data loader from path
{
data_path
}
'
,
ranks
=
[
0
])
logger
.
info
(
f
'Build data loader from path
{
data_path
}
'
,
ranks
=
[
0
])
from
dataset.webtext
import
WebtextDataset
train_ds
=
WebtextDataset
(
os
.
environ
[
'DATA'
],
seq_len
=
gpc
.
config
.
SEQ_LEN
)
train_dataloader
=
utils
.
get_dataloader
(
train_ds
,
seed
=
42
,
batch_size
=
gpc
.
config
.
BATCH_SIZE
,
pin_memory
=
True
,
shuffle
=
True
,
drop_last
=
True
)
else
:
# build a dummy train_dataloader
logger
.
info
(
'Build data loader using dummy data'
,
ranks
=
[
0
])
def
get_data
(
batch_size
,
seq_len
,
vocab_size
):
input_ids
=
torch
.
randint
(
0
,
vocab_size
,
(
batch_size
,
seq_len
),
device
=
torch
.
cuda
.
current_device
())
attention_mask
=
torch
.
ones_like
(
input_ids
)
return
input_ids
,
attention_mask
# 10 iterations
input_ids
,
attn_mask
=
get_data
(
gpc
.
config
.
BATCH_SIZE
*
10
,
gpc
.
config
.
SEQ_LEN
,
VOCAB_SIZE
)
from
torch.utils.data
import
DataLoader
,
Dataset
class
TextSamplerDataset
(
Dataset
):
def
__init__
(
self
,
data
,
seq_len
):
super
().
__init__
()
self
.
data
=
data
self
.
seq_len
=
seq_len
def
__getitem__
(
self
,
index
):
rand_start
=
torch
.
randint
(
0
,
self
.
data
.
size
(
0
)
-
self
.
seq_len
,
(
1
,))
full_seq
=
self
.
data
[
rand_start
:
rand_start
+
self
.
seq_len
+
1
].
long
()
return
full_seq
.
cuda
()
def
__len__
(
self
):
return
self
.
data
.
size
(
0
)
//
self
.
seq_len
def
cycle
(
loader
):
while
True
:
for
data
in
loader
:
yield
data
train_dataset
=
TextSamplerDataset
(
input_ids
,
gpc
.
config
.
SEQ_LEN
)
train_ds
=
WebtextDataset
(
path
=
data_path
,
seq_len
=
gpc
.
config
.
SEQ_LEN
)
train_dataloader
=
DataLoader
(
train_dataset
,
batch_size
=
gpc
.
config
.
BATCH_SIZE
)
train_dataloader
=
utils
.
get_dataloader
(
train_ds
,
seed
=
42
,
batch_size
=
gpc
.
config
.
BATCH_SIZE
,
pin_memory
=
True
,
shuffle
=
True
,
drop_last
=
True
)
logger
.
info
(
'Build model'
,
ranks
=
[
0
])
logger
.
info
(
'Build model'
,
ranks
=
[
0
])
use_pipeline
=
is_using_pp
()
use_pipeline
=
is_using_pp
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment