Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
74aa7d96
Unverified
Commit
74aa7d96
authored
Sep 24, 2023
by
Tong Li
Committed by
GitHub
Sep 24, 2023
Browse files
initial commit: add colossal llama 2 (#4784)
parent
4146f1c0
Changes
19
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
2162 additions
and
2 deletions
+2162
-2
applications/Colossal-LLaMA-2/README.md
applications/Colossal-LLaMA-2/README.md
+377
-0
applications/Colossal-LLaMA-2/colossal_llama2/__init__.py
applications/Colossal-LLaMA-2/colossal_llama2/__init__.py
+2
-0
applications/Colossal-LLaMA-2/colossal_llama2/dataset/__init__.py
...ions/Colossal-LLaMA-2/colossal_llama2/dataset/__init__.py
+2
-0
applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py
...ations/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py
+219
-0
applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py
.../colossal_llama2/dataset/spliced_and_tokenized_dataset.py
+183
-0
applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py
...ions/Colossal-LLaMA-2/colossal_llama2/model/init_model.py
+111
-0
applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py
...ossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py
+98
-0
applications/Colossal-LLaMA-2/colossal_llama2/utils/__init__.py
...ations/Colossal-LLaMA-2/colossal_llama2/utils/__init__.py
+2
-0
applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py
...cations/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py
+88
-0
applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py
...al-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py
+216
-0
applications/Colossal-LLaMA-2/colossal_llama2/utils/froze.py
applications/Colossal-LLaMA-2/colossal_llama2/utils/froze.py
+18
-0
applications/Colossal-LLaMA-2/docs/example.md
applications/Colossal-LLaMA-2/docs/example.md
+245
-0
applications/Colossal-LLaMA-2/hostfile.example
applications/Colossal-LLaMA-2/hostfile.example
+2
-0
applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py
applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py
+153
-0
applications/Colossal-LLaMA-2/requirements.txt
applications/Colossal-LLaMA-2/requirements.txt
+15
-0
applications/Colossal-LLaMA-2/train.example.sh
applications/Colossal-LLaMA-2/train.example.sh
+44
-0
applications/Colossal-LLaMA-2/train.py
applications/Colossal-LLaMA-2/train.py
+383
-0
applications/Colossal-LLaMA-2/version.txt
applications/Colossal-LLaMA-2/version.txt
+1
-0
applications/README.md
applications/README.md
+3
-2
No files found.
applications/Colossal-LLaMA-2/README.md
0 → 100644
View file @
74aa7d96
This diff is collapsed.
Click to expand it.
applications/Colossal-LLaMA-2/colossal_llama2/__init__.py
0 → 100644
View file @
74aa7d96
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
applications/Colossal-LLaMA-2/colossal_llama2/dataset/__init__.py
0 → 100644
View file @
74aa7d96
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py
0 → 100644
View file @
74aa7d96
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import
numpy
as
np
import
os
import
random
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Union
,
Sequence
,
Optional
,
Iterator
,
Callable
import
torch
from
datasets
import
dataset_dict
,
load_from_disk
from
datasets
import
Dataset
as
HFDataset
from
torch.distributed
import
ProcessGroup
from
torch.distributed.distributed_c10d
import
_get_default_group
from
torch.utils.data
import
ConcatDataset
,
Dataset
,
DataLoader
,
DistributedSampler
from
transformers.tokenization_utils
import
PreTrainedTokenizer
import
torch.nn.functional
as
F
DatasetType
=
Union
[
Dataset
,
ConcatDataset
,
dataset_dict
.
Dataset
]
PathType
=
Union
[
str
,
os
.
PathLike
]
def
load_tokenized_dataset
(
dataset_paths
:
Union
[
PathType
,
List
[
PathType
]],
mode
:
str
=
"train"
)
->
Optional
[
DatasetType
]:
"""
Load pre-tokenized dataset.
Each instance of dataset is a dictionary with
`{'input_ids': List[int], 'labels': List[int], sequence: str}` format.
"""
mode_map
=
{
"train"
:
"train"
,
"dev"
:
"validation"
,
"test"
:
"test"
}
assert
mode
in
tuple
(
mode_map
),
f
"Unsupported mode
{
mode
}
, it must be in
{
tuple
(
mode_map
)
}
"
if
isinstance
(
dataset_paths
,
(
str
,
os
.
PathLike
)):
dataset_paths
=
[
dataset_paths
]
datasets
=
[]
# `List[datasets.dataset_dict.Dataset]`
for
ds_path
in
dataset_paths
:
ds_path
=
os
.
path
.
abspath
(
ds_path
)
assert
os
.
path
.
exists
(
ds_path
),
f
"Not existed file path
{
ds_path
}
"
ds_dict
=
load_from_disk
(
dataset_path
=
ds_path
,
keep_in_memory
=
False
)
if
isinstance
(
ds_dict
,
HFDataset
):
datasets
.
append
(
ds_dict
)
else
:
if
mode_map
[
mode
]
in
ds_dict
:
datasets
.
append
(
ds_dict
[
mode_map
[
mode
]])
if
len
(
datasets
)
==
0
:
return
None
if
len
(
datasets
)
==
1
:
return
datasets
.
pop
()
return
ConcatDataset
(
datasets
=
datasets
)
@
dataclass
class
DataCollatorForSupervisedDataset
(
object
):
"""
Collate instances for supervised dataset.
Each instance is a tokenized dictionary with fields
`input_ids`(List[int]), `labels`(List[int]) and `sequence`(str).
"""
tokenizer
:
PreTrainedTokenizer
max_length
:
int
=
4096
ignore_index
:
int
=
-
100
def
__call__
(
self
,
instances
:
Sequence
[
Dict
[
str
,
List
[
int
]]])
->
Dict
[
str
,
torch
.
Tensor
]:
"""
Args:
instances (`Sequence[Dict[str, List[int]]]`):
Mini-batch samples, each sample is stored in an individual dictionary.
Returns:
(`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
`input_ids`: `torch.Tensor` of shape (bsz, max_len);
`attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
`labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`.
"""
assert
isinstance
(
self
.
tokenizer
.
pad_token_id
,
int
)
and
self
.
tokenizer
.
pad_token_id
>=
0
,
(
f
"`
{
self
.
tokenizer
.
__class__
.
__name__
}
.pad_token_id` must be a valid non-negative integer index value, "
f
"but now `
{
self
.
tokenizer
.
pad_token_id
}
`"
)
# `List[torch.Tensor]`
batch_input_ids
=
[
torch
.
LongTensor
(
instance
[
"input_ids"
][:
self
.
max_length
])
if
len
(
instance
[
"input_ids"
])
>
self
.
max_length
else
torch
.
LongTensor
(
instance
[
"input_ids"
])
for
instance
in
instances
]
batch_labels
=
[
torch
.
LongTensor
(
instance
[
"labels"
][:
self
.
max_length
])
if
len
(
instance
[
"labels"
])
>
self
.
max_length
else
torch
.
LongTensor
(
instance
[
"labels"
])
for
instance
in
instances
]
if
self
.
tokenizer
.
padding_side
==
"right"
:
input_ids
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
sequences
=
batch_input_ids
,
batch_first
=
True
,
padding_value
=
self
.
tokenizer
.
pad_token_id
,
)
# (bsz, max_len)
labels
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
sequences
=
batch_labels
,
batch_first
=
True
,
padding_value
=
self
.
ignore_index
,
)
# (bsz, max_len)
# pad to max
to_pad
=
self
.
max_length
-
input_ids
.
size
(
1
)
input_ids
=
F
.
pad
(
input_ids
,
(
0
,
to_pad
),
value
=
self
.
tokenizer
.
pad_token_id
)
labels
=
F
.
pad
(
labels
,
(
0
,
to_pad
),
value
=
self
.
ignore_index
)
elif
self
.
tokenizer
.
padding_side
==
"left"
:
reversed_input_ids
=
[
seq
.
flip
(
dims
=
(
0
,))
for
seq
in
batch_input_ids
]
reversed_input_ids
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
sequences
=
reversed_input_ids
,
batch_first
=
True
,
padding_value
=
self
.
tokenizer
.
pad_token_id
,
)
# (bsz, max_len)
input_ids
=
torch
.
flip
(
reversed_input_ids
,
dims
=
(
1
,))
# (bsz, max_len)
reversed_labels
=
[
seq
.
flip
(
dims
=
(
0
,))
for
seq
in
batch_labels
]
reversed_labels
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
sequences
=
reversed_labels
,
batch_first
=
True
,
padding_value
=
self
.
ignore_index
,
)
# (bsz, max_len)
labels
=
torch
.
flip
(
reversed_labels
,
dims
=
(
1
,))
# (bsz, max_len)
else
:
raise
RuntimeError
(
f
"`
{
self
.
tokenizer
.
__class__
.
__name__
}
.padding_side` can only be `left` or `right`, "
f
"but now `
{
self
.
tokenizer
.
padding_side
}
`"
)
attention_mask
=
input_ids
.
ne
(
self
.
tokenizer
.
pad_token_id
)
# `torch.BoolTensor`, (bsz, max_len)
return
dict
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
labels
=
labels
)
class
StatefulDistributedSampler
(
DistributedSampler
):
"""
Stateful distributed sampler for multi-stage training.
"""
def
__init__
(
self
,
dataset
:
DatasetType
,
num_replicas
:
Optional
[
int
]
=
None
,
rank
:
Optional
[
int
]
=
None
,
shuffle
:
bool
=
True
,
seed
:
int
=
0
,
drop_last
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
dataset
=
dataset
,
num_replicas
=
num_replicas
,
rank
=
rank
,
shuffle
=
shuffle
,
seed
=
seed
,
drop_last
=
drop_last
,
)
self
.
start_index
=
0
def
__iter__
(
self
)
->
Iterator
:
iterator
=
super
().
__iter__
()
indices
=
list
(
iterator
)
indices
=
indices
[
self
.
start_index
:]
return
iter
(
indices
)
def
__len__
(
self
)
->
int
:
return
self
.
num_samples
-
self
.
start_index
def
set_start_index
(
self
,
start_index
:
int
)
->
None
:
self
.
start_index
=
start_index
def
setup_distributed_dataloader
(
dataset
:
DatasetType
,
batch_size
:
int
=
1
,
shuffle
:
bool
=
False
,
seed
:
int
=
1024
,
drop_last
:
bool
=
False
,
pin_memory
:
bool
=
False
,
num_workers
:
int
=
0
,
collate_fn
:
Callable
[[
Sequence
[
Dict
[
str
,
Union
[
str
,
List
[
int
]]]]],
Dict
[
str
,
torch
.
Tensor
]]
=
None
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
**
kwargs
,
)
->
DataLoader
:
"""
Setup dataloader for distributed training.
"""
_kwargs
=
kwargs
.
copy
()
process_group
=
process_group
or
_get_default_group
()
sampler
=
StatefulDistributedSampler
(
dataset
=
dataset
,
num_replicas
=
process_group
.
size
(),
rank
=
process_group
.
rank
(),
shuffle
=
shuffle
,
seed
=
seed
,
drop_last
=
drop_last
,
)
# Deterministic dataloader
def
seed_worker
(
worker_id
:
int
)
->
None
:
worker_seed
=
seed
np
.
random
.
seed
(
worker_seed
)
torch
.
manual_seed
(
worker_seed
)
random
.
seed
(
worker_seed
)
return
DataLoader
(
dataset
=
dataset
,
batch_size
=
batch_size
,
sampler
=
sampler
,
num_workers
=
num_workers
,
collate_fn
=
collate_fn
,
pin_memory
=
pin_memory
,
drop_last
=
drop_last
,
worker_init_fn
=
seed_worker
,
**
_kwargs
,
)
applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py
0 → 100644
View file @
74aa7d96
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Splicing multiple pre-tokenized sequence data points
"""
import
random
import
warnings
from
copy
import
deepcopy
from
datasets
import
dataset_dict
from
typing
import
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Union
,
Tuple
from
torch.utils.data
import
ConcatDataset
,
Dataset
,
IterableDataset
from
transformers.models.llama.tokenization_llama
import
LlamaTokenizer
from
transformers.tokenization_utils
import
PreTrainedTokenizer
IGNORE_INDEX
=
-
100
DSType
=
Union
[
Dataset
,
ConcatDataset
,
dataset_dict
.
Dataset
]
def
supervised_tokenize
(
data_point
:
Dict
[
str
,
str
],
tokenizer
:
LlamaTokenizer
,
ignore_index
:
int
=
None
,
max_length
:
int
=
4096
)
->
Dict
[
str
,
Union
[
int
,
str
,
List
[
int
]]]:
"""
A tokenization function to tokenize an original pretraining data point as following:
{"source": "", "target": "Beijing, the capital of the People's Republic of China, ...", "category": "geography"}
"""
assert
tokenizer
.
add_bos_token
is
False
and
tokenizer
.
add_eos_token
is
False
,
(
"Initially set `tokenizer.add_bos_token` and `tokenizer.add_eos_token` to False, "
"add <bos> and <eos> manually later"
)
if
ignore_index
is
None
:
ignore_index
=
IGNORE_INDEX
source_text
=
data_point
[
"source"
]
# `str`
target_text
=
data_point
[
"target"
]
# `str`
is_null_source
=
len
(
source_text
)
==
0
source_text
=
tokenizer
.
bos_token
+
source_text
target_text
+=
tokenizer
.
eos_token
sequence_text
=
source_text
+
target_text
tokenized
=
tokenizer
([
source_text
,
sequence_text
])[
"input_ids"
]
sequence_input_ids
=
tokenized
[
1
]
sequence_labels
=
deepcopy
(
sequence_input_ids
)
source_length
=
len
(
tokenized
[
0
])
if
not
is_null_source
:
sequence_labels
[:
source_length
]
=
[
ignore_index
for
_
in
range
(
source_length
)]
# sequence truncation.
if
len
(
sequence_input_ids
)
>
max_length
:
sequence_input_ids
=
sequence_input_ids
[:
max_length
]
sequence_labels
=
sequence_labels
[:
max_length
]
return
dict
(
input_ids
=
sequence_input_ids
,
labels
=
sequence_labels
,
seq_length
=
len
(
sequence_input_ids
),
seq_category
=
data_point
[
"category"
],
)
class
ClosedToConstantLengthSplicedDataset
(
IterableDataset
):
"""
Define an iterable dataset that returns a (close to) constant length data point spliced from multiple
original independent (pre-tokenized) data points.
"""
def
__init__
(
self
,
dataset
:
DSType
,
tokenizer
:
PreTrainedTokenizer
,
max_length
:
int
=
4096
,
num_packed_sequences
:
int
=
8
,
fetch_sequence_func
:
Callable
[[
Any
],
Tuple
[
List
[
int
],
List
[
int
]]]
=
None
,
input_ids_field
:
str
=
"input_ids"
,
labels_field
:
str
=
"labels"
,
infinite
:
bool
=
False
,
shuffle
:
bool
=
True
,
error_strict
:
bool
=
False
,
)
->
None
:
self
.
tokenizer
=
tokenizer
self
.
dataset
=
dataset
self
.
max_length
=
max_length
self
.
infinite
=
infinite
self
.
max_buffer_size
=
max_length
*
num_packed_sequences
# e.g., 4096 * 16
self
.
shuffle
=
shuffle
# Callable[[Dict[str, Any]], Tuple[List[int], List[int]]],
# A function that fetch sequence input_ids and labels from the original data point
if
fetch_sequence_func
is
None
:
self
.
fetch_sequence_func
=
lambda
data_point
:
(
data_point
[
input_ids_field
],
data_point
[
labels_field
])
else
:
self
.
fetch_sequence_func
=
fetch_sequence_func
self
.
input_ids_field
=
input_ids_field
self
.
labels_field
=
labels_field
self
.
error_strict
=
error_strict
self
.
current_size
=
0
# `int`, current packed data size.
def
__len__
(
self
)
->
int
:
return
len
(
self
.
dataset
)
def
__iter__
(
self
)
->
Iterable
[
Dict
[
str
,
List
[
int
]]]:
iterator
=
iter
(
self
.
dataset
)
more_data_points
=
True
while
more_data_points
is
True
:
buffer
,
buffer_len
=
[],
0
while
True
:
# ending condition.
if
buffer_len
>=
self
.
max_buffer_size
:
break
try
:
# `Tuple[List[int], List[int]]`
seq_input_ids
,
seq_labels
=
self
.
fetch_sequence_func
(
next
(
iterator
))
buffer
.
append
({
self
.
input_ids_field
:
seq_input_ids
,
self
.
labels_field
:
seq_labels
})
buffer_len
+=
len
(
buffer
[
-
1
][
self
.
input_ids_field
])
except
StopIteration
:
if
self
.
infinite
is
True
:
iterator
=
iter
(
self
.
dataset
)
warnings
.
warn
(
"The dataset reached end and the iterator is reset to the start."
)
else
:
more_data_points
=
False
break
examples
=
[]
# `List[Dict[str, List[int]]]`, save buffered spliced data points.
spliced_input_ids
,
spliced_labels
=
[],
[]
# `List[int]`, `List[int]`
for
i
,
data_point
in
enumerate
(
buffer
):
# TODO(2023-09-18) check errors for each unspliced tokenized data point
seq_input_ids
=
data_point
[
self
.
input_ids_field
]
seq_labels
=
data_point
[
self
.
labels_field
]
# Handle special case:
# If the length of an original data point (i.e., input_ids length of a data point before splicing)
# exceeds `max_length`, truncate it.
if
len
(
seq_input_ids
)
>
self
.
max_length
:
truncated_seq_input_ids
=
seq_input_ids
[:
self
.
max_length
]
truncated_label_ids
=
seq_labels
[:
self
.
max_length
]
if
set
(
truncated_label_ids
)
==
{
IGNORE_INDEX
}:
if
self
.
error_strict
is
True
:
raise
ValueError
(
f
"Find an out-of-bounds length(
{
len
(
seq_input_ids
)
}
) data point "
f
"with all label values as
{
IGNORE_INDEX
}
."
)
else
:
warnings
.
warn
(
f
"Filter an error truncated data point (labels all
{
IGNORE_INDEX
}
)"
)
continue
# Skip the current error data point.
spliced_data_point
=
{
self
.
input_ids_field
:
truncated_seq_input_ids
,
self
.
labels_field
:
truncated_label_ids
,
}
examples
.
append
(
spliced_data_point
)
warnings
.
warn
(
"Find a data point to be truncated."
)
continue
# Pre action judgment.
if
len
(
spliced_input_ids
)
+
len
(
seq_input_ids
)
>
self
.
max_length
:
spliced_data_point
=
{
self
.
input_ids_field
:
spliced_input_ids
,
self
.
labels_field
:
spliced_labels
,
}
# `Dict[str, List[int]]`
# Update.
spliced_input_ids
,
spliced_labels
=
[],
[]
spliced_input_ids
.
extend
(
seq_input_ids
)
spliced_labels
.
extend
(
seq_labels
)
examples
.
append
(
spliced_data_point
)
else
:
spliced_input_ids
.
extend
(
seq_input_ids
)
spliced_labels
.
extend
(
seq_labels
)
# For residual spliced data point at the end of the data set
if
self
.
infinite
is
False
and
more_data_points
is
False
and
len
(
spliced_input_ids
)
>
0
:
examples
.
append
(
{
self
.
input_ids_field
:
spliced_input_ids
,
self
.
labels_field
:
spliced_labels
}
)
if
self
.
shuffle
:
random
.
shuffle
(
examples
)
for
spliced_data_point
in
examples
:
# TODO(2023-09-18): check errors for each spliced tokenized data point.
self
.
current_size
+=
1
yield
spliced_data_point
applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py
0 → 100644
View file @
74aa7d96
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Initialize new model with updated tokenizer by calculating the mean values from original model
"""
import
argparse
import
numpy
as
np
import
torch
from
transformers
import
LlamaTokenizer
,
LlamaForCausalLM
from
colossalai.logging
import
get_dist_logger
logger
=
get_dist_logger
()
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--source_model_and_tokenizer_path"
,
type
=
str
,
required
=
True
,
default
=
None
,
help
=
"Source path of model & tokenizer"
,
)
parser
.
add_argument
(
"--target_tokenizer_path"
,
type
=
str
,
required
=
True
,
default
=
None
,
help
=
"Target tokenizer path"
)
parser
.
add_argument
(
"--target_model_path"
,
type
=
str
,
required
=
True
,
default
=
None
,
help
=
"Target model path"
)
args
=
parser
.
parse_args
()
source_tokenizer
=
LlamaTokenizer
.
from_pretrained
(
args
.
source_model_and_tokenizer_path
)
source_tokenizer
.
add_bos_token
=
False
source_tokenizer
.
add_eos_token
=
False
if
source_tokenizer
.
pad_token
is
None
:
source_tokenizer
.
pad_token
=
source_tokenizer
.
unk_token
source_vocab
=
source_tokenizer
.
get_vocab
()
target_tokenizer
=
LlamaTokenizer
.
from_pretrained
(
args
.
target_tokenizer_path
)
target_tokenizer
.
add_bos_token
=
False
target_tokenizer
.
add_eos_token
=
False
if
target_tokenizer
.
pad_token
is
None
:
target_tokenizer
.
pad_token
=
target_tokenizer
.
unk_token
target_vocab
=
target_tokenizer
.
get_vocab
()
target_inverted_vocab
=
{
v
:
k
for
k
,
v
in
target_vocab
.
items
()}
assert
len
(
target_vocab
)
>
len
(
source_vocab
),
f
"Target vocab size(
{
len
(
target_vocab
)
}
) must be greater than source vocab size(
{
len
(
source_vocab
)
}
)"
gpu_device
=
torch
.
device
(
"cuda:0"
)
cpu_device
=
torch
.
device
(
"cpu"
)
source_model
=
LlamaForCausalLM
.
from_pretrained
(
args
.
source_model_and_tokenizer_path
)
source_model
.
eval
()
source_model
=
source_model
.
to
(
gpu_device
)
source_input_embeddings
=
source_model
.
get_input_embeddings
()
assert
isinstance
(
source_input_embeddings
,
torch
.
nn
.
Embedding
)
assert
source_input_embeddings
.
weight
.
shape
[
0
]
==
len
(
source_vocab
)
source_input_embeddings
.
eval
()
source_output_embeddings
=
source_model
.
get_output_embeddings
()
assert
isinstance
(
source_output_embeddings
,
torch
.
nn
.
Linear
)
assert
source_output_embeddings
.
bias
is
None
assert
source_output_embeddings
.
weight
.
shape
[
0
]
==
len
(
source_vocab
)
source_output_embeddings
.
eval
()
input_embeddings
=
source_input_embeddings
.
weight
.
cpu
().
detach
().
numpy
()
output_embeddings
=
source_output_embeddings
.
weight
.
cpu
().
detach
().
numpy
()
for
i
in
range
(
len
(
source_vocab
),
len
(
target_vocab
)):
if
i
%
500
==
0
:
logger
.
info
(
f
"processing
{
i
}
/
{
len
(
target_vocab
)
}
target tokens"
)
target_token
=
target_inverted_vocab
[
i
]
target_to_source_token_ids
=
torch
.
LongTensor
(
source_tokenizer
([
target_token
])[
"input_ids"
][
0
])
target_to_source_token_ids
=
target_to_source_token_ids
.
to
(
gpu_device
)
target_to_source_input_embedding
=
(
source_input_embeddings
.
weight
[
target_to_source_token_ids
]
.
mean
(
dim
=
0
)
.
unsqueeze
(
dim
=
0
)
.
cpu
()
.
detach
()
.
numpy
()
)
target_to_source_output_embedding
=
(
source_output_embeddings
.
weight
[
target_to_source_token_ids
]
.
mean
(
dim
=
0
)
.
unsqueeze
(
dim
=
0
)
.
cpu
()
.
detach
()
.
numpy
()
)
input_embeddings
=
np
.
concatenate
((
input_embeddings
,
target_to_source_input_embedding
),
axis
=
0
)
output_embeddings
=
np
.
concatenate
((
output_embeddings
,
target_to_source_output_embedding
),
axis
=
0
)
source_model
=
source_model
.
to
(
cpu_device
)
assert
isinstance
(
source_model
,
LlamaForCausalLM
)
# expand
source_model
.
resize_token_embeddings
(
new_num_tokens
=
len
(
target_vocab
))
source_model
.
model
.
embed_tokens
.
weight
.
data
=
torch
.
Tensor
(
input_embeddings
)
source_model
.
lm_head
.
weight
.
data
=
torch
.
Tensor
(
output_embeddings
)
source_model
=
source_model
.
half
()
source_model
.
save_pretrained
(
save_directory
=
args
.
target_model_path
)
if
__name__
==
"__main__"
:
main
()
applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py
0 → 100644
View file @
74aa7d96
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
"""
Initialize new tokenizer for continual pre-training
"""
import
argparse
import
os
import
json
from
typing
import
List
,
Union
from
transformers.models.llama.tokenization_llama
import
LlamaTokenizer
from
sentencepiece
import
sentencepiece_model_pb2
as
sp_pb2_model
from
colossalai.logging
import
get_dist_logger
os
.
environ
[
"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"
]
=
"python"
logger
=
get_dist_logger
()
def
expand_vocab_tokenizer
(
source_tokenizer_dir
:
Union
[
str
,
os
.
PathLike
],
target_tokenizer_dir
:
Union
[
str
,
os
.
PathLike
],
new_tokens
:
List
[
str
]
)
->
None
:
"""Expand tokenizer for continue pre-training."""
if
os
.
path
.
exists
(
target_tokenizer_dir
):
raise
RuntimeError
(
f
"Find existed directory
{
target_tokenizer_dir
}
"
)
source_tokenizer
=
LlamaTokenizer
.
from_pretrained
(
source_tokenizer_dir
)
logger
.
info
(
source_tokenizer
)
source_sp_processor
=
source_tokenizer
.
sp_model
source_spm
=
sp_pb2_model
.
ModelProto
()
source_spm
.
ParseFromString
(
source_sp_processor
.
serialized_model_proto
())
logger
.
info
(
f
"Source tokenizer size:
{
len
(
source_sp_processor
)
}
"
)
# Add new tokens to source tokenizer.
source_spm_tokens
=
set
([
p
.
piece
for
p
in
source_spm
.
pieces
])
for
piece
in
new_tokens
:
assert
isinstance
(
piece
,
str
),
f
"Invalid token(
{
piece
}
) type
{
type
(
piece
)
}
"
if
piece
in
source_spm_tokens
:
# Skip existed token.
continue
new_p
=
sp_pb2_model
.
ModelProto
().
SentencePiece
()
new_p
.
piece
=
piece
new_p
.
score
=
0
source_spm
.
pieces
.
append
(
new_p
)
logger
.
info
(
f
"Expand vocab from
{
len
(
source_spm_tokens
)
}
to
{
len
(
source_spm
.
pieces
)
}
"
)
# Save
os
.
makedirs
(
target_tokenizer_dir
)
target_tokenizer_model_path
=
os
.
path
.
join
(
target_tokenizer_dir
,
"tokenizer.model"
)
with
open
(
file
=
target_tokenizer_model_path
,
mode
=
"wb"
)
as
fp
:
fp
.
write
(
source_spm
.
SerializeToString
())
target_tokenizer
=
LlamaTokenizer
(
vocab_file
=
target_tokenizer_model_path
)
target_tokenizer
.
save_pretrained
(
save_directory
=
target_tokenizer_dir
)
logger
.
info
(
f
"Successfully save expand tokenizer to
{
target_tokenizer_dir
}
"
)
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--source_tokenizer_dir"
,
type
=
str
,
required
=
True
,
default
=
None
,
help
=
"Source tokenizer directory"
)
parser
.
add_argument
(
"--target_tokenizer_dir"
,
type
=
str
,
required
=
True
,
default
=
None
,
help
=
"Target tokenizer directory"
)
parser
.
add_argument
(
"--expand_tokens_file"
,
type
=
str
,
required
=
True
,
default
=
None
,
help
=
"Path of the file containing tokens to be extended"
,
)
args
=
parser
.
parse_args
()
expand_tokens
=
[]
with
open
(
file
=
args
.
expand_tokens_file
,
mode
=
"r"
,
encoding
=
"utf-8"
)
as
fp_reader
:
for
line
in
fp_reader
:
item
=
json
.
loads
(
line
)
# e.g., {"piece": "你好"}
token
=
item
[
"piece"
]
if
token
in
expand_tokens
:
continue
expand_tokens
.
append
(
token
)
expand_tokens
.
sort
(
key
=
lambda
t
:
len
(
t
),
reverse
=
False
)
expand_vocab_tokenizer
(
source_tokenizer_dir
=
args
.
source_tokenizer_dir
,
target_tokenizer_dir
=
args
.
target_tokenizer_dir
,
new_tokens
=
expand_tokens
,
)
if
__name__
==
"__main__"
:
main
()
applications/Colossal-LLaMA-2/colossal_llama2/utils/__init__.py
0 → 100644
View file @
74aa7d96
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py
0 → 100644
View file @
74aa7d96
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Helper functions for IO
"""
import
json
import
os
from
typing
import
Any
,
Dict
,
Tuple
,
Union
import
torch
from
torch.optim.optimizer
import
Optimizer
from
torch.optim.lr_scheduler
import
_LRScheduler
from
colossalai.booster
import
Booster
from
colossalai.cluster
import
DistCoordinator
def
load_json
(
file_path
:
Union
[
str
,
os
.
PathLike
])
->
Dict
[
str
,
Any
]:
"""
Load file in JSON format
"""
with
open
(
file
=
file_path
,
mode
=
"r"
,
encoding
=
"utf-8"
)
as
fp
:
return
json
.
load
(
fp
)
def
save_json
(
data
:
Dict
[
str
,
Any
],
file_path
:
Union
[
str
,
os
.
PathLike
])
->
None
:
"""
Save as JSON format
"""
with
open
(
file
=
file_path
,
mode
=
"w"
,
encoding
=
"utf-8"
)
as
fp
:
json
.
dump
(
data
,
fp
=
fp
,
ensure_ascii
=
False
,
indent
=
4
)
def
save_checkpoint
(
save_dir
:
Union
[
str
,
os
.
PathLike
],
booster
:
Booster
,
model
:
torch
.
nn
.
Module
,
optimizer
:
Optimizer
,
lr_scheduler
:
_LRScheduler
,
epoch
:
int
,
step
:
int
,
batch_size
:
int
,
coordinator
:
DistCoordinator
,
)
->
None
:
"""
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
"""
save_dir
=
os
.
path
.
join
(
save_dir
,
f
"epoch-
{
epoch
}
_step-
{
step
}
"
)
os
.
makedirs
(
os
.
path
.
join
(
save_dir
,
"modeling"
),
exist_ok
=
True
)
booster
.
save_model
(
model
,
os
.
path
.
join
(
save_dir
,
"modeling"
),
shard
=
True
)
booster
.
save_optimizer
(
optimizer
,
os
.
path
.
join
(
save_dir
,
"optimizer"
),
shard
=
True
)
booster
.
save_lr_scheduler
(
lr_scheduler
,
os
.
path
.
join
(
save_dir
,
"lr_scheduler"
))
running_states
=
{
"epoch"
:
epoch
,
"step"
:
step
,
"sample_start_index"
:
step
*
batch_size
,
}
if
coordinator
.
is_master
():
save_json
(
running_states
,
os
.
path
.
join
(
save_dir
,
"running_states.json"
))
def
load_checkpoint
(
load_dir
:
Union
[
str
,
os
.
PathLike
],
booster
:
Booster
,
model
:
torch
.
nn
.
Module
,
optimizer
:
Optimizer
,
lr_scheduler
:
_LRScheduler
,
)
->
Tuple
[
int
,
int
,
int
]:
"""
Load model checkpoint, optimizer, LR scheduler and intermedidate running states.
"""
# Update booster params states.
booster
.
load_model
(
model
=
model
,
checkpoint
=
os
.
path
.
join
(
load_dir
,
"modeling"
))
booster
.
load_optimizer
(
optimizer
=
optimizer
,
checkpoint
=
os
.
path
.
join
(
load_dir
,
"optimizer"
))
booster
.
load_lr_scheduler
(
lr_scheduler
=
lr_scheduler
,
checkpoint
=
os
.
path
.
join
(
load_dir
,
"lr_scheduler"
))
running_states
=
load_json
(
file_path
=
os
.
path
.
join
(
load_dir
,
"running_states.json"
))
return
(
running_states
[
"epoch"
],
running_states
[
"step"
],
running_states
[
"sample_start_index"
],
)
applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py
0 → 100644
View file @
74aa7d96
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from
types
import
MethodType
from
typing
import
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
transformers.models.llama.modeling_llama
import
(
LlamaRMSNorm
,
LlamaAttention
,
LlamaModel
,
LlamaForCausalLM
,
apply_rotary_pos_emb
,
repeat_kv
,
)
from
colossalai.logging
import
get_dist_logger
from
einops
import
rearrange
from
flash_attn.bert_padding
import
pad_input
,
unpad_input
from
flash_attn.flash_attn_interface
import
(
flash_attn_func
,
flash_attn_varlen_kvpacked_func
,
)
from
flash_attn.ops.rms_norm
import
rms_norm
logger
=
get_dist_logger
()
def
_prepare_decoder_attention_mask
(
self
:
LlamaModel
,
attention_mask
:
torch
.
BoolTensor
,
input_shape
:
torch
.
Size
,
inputs_embeds
:
torch
.
Tensor
,
past_key_values_length
:
int
,
)
->
Optional
[
torch
.
Tensor
]:
"""
Decoder attetion mask
"""
if
past_key_values_length
>
0
and
attention_mask
is
not
None
:
attention_mask
=
torch
.
cat
(
tensors
=
(
torch
.
full
(
size
=
(
input_shape
[
0
],
past_key_values_length
),
fill_value
=
True
,
dtype
=
attention_mask
.
dtype
,
device
=
attention_mask
.
device
,
),
attention_mask
,
),
dim
=-
1
,
)
# (bsz, past_key_values_length + q_len)
if
attention_mask
is
not
None
and
torch
.
all
(
attention_mask
):
return
None
# Faster
return
attention_mask
def
attention_forward
(
self
:
LlamaAttention
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
"""
if
output_attentions
:
logger
.
warning
(
"Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
"return `None` instead."
)
bsz
,
q_len
,
_
=
hidden_states
.
size
()
if
self
.
config
.
pretraining_tp
>
1
:
q_slicing
,
kv_slicing
=
(
dim
//
self
.
config
.
pretraining_tp
for
dim
in
(
self
.
num_heads
*
self
.
head_dim
,
self
.
num_key_value_heads
*
self
.
head_dim
,
)
)
# `Tuple[int, int]`
q_slices
,
k_slices
,
v_slices
=
(
proj
.
weight
.
split
(
slicing
,
dim
=
0
)
for
proj
,
slicing
in
(
(
self
.
q_proj
,
q_slicing
),
(
self
.
k_proj
,
kv_slicing
),
(
self
.
v_proj
,
kv_slicing
),
)
)
# Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
q
,
k
,
v
=
(
torch
.
cat
(
[
F
.
linear
(
hidden_states
,
slices
[
i
])
for
i
in
range
(
self
.
config
.
pretraining_tp
)],
dim
=-
1
,
)
for
slices
in
(
q_slices
,
k_slices
,
v_slices
)
)
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
# (bsz, q_len, num_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim)
else
:
q
,
k
,
v
=
(
proj
(
hidden_states
)
for
proj
in
(
self
.
q_proj
,
self
.
k_proj
,
self
.
v_proj
))
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
# (bsz, q_len, num_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim)
# (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim);
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim);
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim)
q
,
k
,
v
=
(
states
.
view
(
bsz
,
q_len
,
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
for
states
,
num_heads
in
(
(
q
,
self
.
num_heads
),
(
k
,
self
.
num_key_value_heads
),
(
v
,
self
.
num_key_value_heads
),
)
)
kv_len
=
k
.
shape
[
-
2
]
# initially, `kv_len` == `q_len`
past_kv_len
=
0
if
past_key_value
is
not
None
:
# if `past_key_value` is not None, `kv_len` > `q_len`.
past_kv_len
=
past_key_value
[
0
].
shape
[
-
2
]
kv_len
+=
past_kv_len
# two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim)
cos
,
sin
=
self
.
rotary_emb
(
v
,
seq_len
=
kv_len
)
# (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
q
,
k
=
apply_rotary_pos_emb
(
q
=
q
,
k
=
k
,
cos
=
cos
,
sin
=
sin
,
position_ids
=
position_ids
)
if
past_key_value
is
not
None
:
# reuse k, v, self_attention
k
=
torch
.
cat
([
past_key_value
[
0
],
k
],
dim
=
2
)
v
=
torch
.
cat
([
past_key_value
[
1
],
v
],
dim
=
2
)
past_key_value
=
(
k
,
v
)
if
use_cache
else
None
# repeat k/v heads if n_kv_heads < n_heads
k
=
repeat_kv
(
hidden_states
=
k
,
n_rep
=
self
.
num_key_value_groups
)
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
v
=
repeat_kv
(
hidden_states
=
v
,
n_rep
=
self
.
num_key_value_groups
)
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
key_padding_mask
=
attention_mask
# (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim)
q
,
k
,
v
=
(
states
.
transpose
(
1
,
2
)
for
states
in
(
q
,
k
,
v
))
if
past_kv_len
>
0
:
q
=
torch
.
cat
(
tensors
=
(
torch
.
full
(
size
=
(
bsz
,
past_kv_len
,
self
.
num_heads
,
self
.
head_dim
),
fill_value
=
0.0
,
dtype
=
q
.
dtype
,
device
=
q
.
device
,
),
q
,
),
dim
=
1
,
)
# (bsz, past_kv_len + q_len, num_heads, head_dim)
if
key_padding_mask
is
None
:
# (bsz, past_kv_len + q_len, num_heads, head_dim)
output
=
flash_attn_func
(
q
=
q
,
k
=
k
,
v
=
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
True
)
# (bsz, )
output
=
rearrange
(
output
,
pattern
=
"... h d -> ... (h d)"
)
# (bsz, past_kv_len + q_len, num_heads * head_dim)
else
:
q
,
indices
,
cu_q_lens
,
max_q_len
=
unpad_input
(
hidden_states
=
q
,
attention_mask
=
key_padding_mask
)
kv
,
_
,
cu_kv_lens
,
max_kv_len
=
unpad_input
(
hidden_states
=
torch
.
stack
(
tensors
=
(
k
,
v
),
dim
=
2
),
attention_mask
=
key_padding_mask
,
)
output_unpad
=
flash_attn_varlen_kvpacked_func
(
q
=
q
,
kv
=
kv
,
cu_seqlens_q
=
cu_q_lens
,
cu_seqlens_k
=
cu_kv_lens
,
max_seqlen_q
=
max_q_len
,
max_seqlen_k
=
max_kv_len
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
True
,
)
output
=
pad_input
(
hidden_states
=
rearrange
(
output_unpad
,
pattern
=
"nnz h d -> nnz (h d)"
),
indices
=
indices
,
batch
=
bsz
,
seqlen
=
past_kv_len
+
q_len
,
)
# (bsz, past_kv_len + q_len, num_heads * head_dim)
if
past_kv_len
>
0
:
# Strip off the zero query outputs.
output
=
output
[:,
past_kv_len
:,
...]
# (bsz, q_len, num_heads * head_dim)
output
=
self
.
o_proj
(
output
)
# (bsz, q_len, hidden_size)
return
output
,
None
,
past_key_value
def
rms_norm_forward
(
self
:
LlamaRMSNorm
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Formard function for RMS Norm
"""
return
rms_norm
(
x
=
hidden_states
,
weight
=
self
.
weight
,
epsilon
=
self
.
variance_epsilon
)
def
replace_with_flash_attention
(
model
:
LlamaForCausalLM
)
->
None
:
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
LlamaAttention
):
module
.
forward
=
MethodType
(
attention_forward
,
module
)
if
isinstance
(
module
,
LlamaModel
):
module
.
_prepare_decoder_attention_mask
=
MethodType
(
_prepare_decoder_attention_mask
,
module
)
if
isinstance
(
module
,
LlamaRMSNorm
):
module
.
forward
=
MethodType
(
rms_norm_forward
,
module
)
applications/Colossal-LLaMA-2/colossal_llama2/utils/froze.py
0 → 100644
View file @
74aa7d96
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from
transformers.models.llama
import
LlamaForCausalLM
def
freeze_non_embeds_parameters
(
model
:
LlamaForCausalLM
)
->
None
:
"""Freeze all parameters except embeddings."""
for
name
,
params
in
model
.
named_parameters
():
if
"embed_tokens"
not
in
name
and
"lm_head"
not
in
name
:
params
.
requires_grad
=
False
else
:
params
.
requires_grad
=
True
def
unfreeze_parameters
(
model
:
LlamaForCausalLM
)
->
None
:
for
name
,
params
in
model
.
named_parameters
():
params
.
requires_grad
=
False
applications/Colossal-LLaMA-2/docs/example.md
0 → 100644
View file @
74aa7d96
This diff is collapsed.
Click to expand it.
applications/Colossal-LLaMA-2/hostfile.example
0 → 100644
View file @
74aa7d96
hostname1
hostname2
\ No newline at end of file
applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py
0 → 100644
View file @
74aa7d96
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Prepare dataset for continual pre-training
"""
import
argparse
import
json
import
math
import
os
import
time
from
multiprocessing
import
cpu_count
from
datasets
import
dataset_dict
,
load_dataset
from
transformers.models.llama.tokenization_llama
import
LlamaTokenizer
from
colossalai.logging
import
get_dist_logger
from
colossal_llama2.dataset.spliced_and_tokenized_dataset
import
(
supervised_tokenize
,
ClosedToConstantLengthSplicedDataset
,
)
logger
=
get_dist_logger
()
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--data_input_dirs"
,
type
=
str
,
required
=
True
,
default
=
None
,
help
=
"Comma(i.e., ',') separated list of all data directories containing `.jsonl` data files."
,
)
parser
.
add_argument
(
"--tokenizer_dir"
,
type
=
str
,
required
=
True
,
default
=
None
,
help
=
"A directory containing the tokenizer"
)
parser
.
add_argument
(
"--data_cache_dir"
,
type
=
str
,
default
=
"cache"
,
help
=
"Data cache directory"
)
parser
.
add_argument
(
"--data_jsonl_output_dir"
,
type
=
str
,
default
=
"jsonl_output"
,
help
=
"Output directory of spliced dataset with jsonl format"
,
)
parser
.
add_argument
(
"--data_arrow_output_dir"
,
type
=
str
,
default
=
"arrow_output"
,
help
=
"Output directory of spliced dataset with arrow format"
,
)
parser
.
add_argument
(
"--max_length"
,
type
=
int
,
default
=
4096
,
help
=
"Max length of each spliced tokenized sequence"
)
parser
.
add_argument
(
"--num_spliced_dataset_bins"
,
type
=
int
,
default
=
10
,
help
=
"Number of spliced dataset bins"
)
args
=
parser
.
parse_args
()
if
args
.
num_spliced_dataset_bins
>=
100000
:
raise
ValueError
(
"Too many spliced divisions, must be smaller than 100000"
)
assert
not
os
.
path
.
exists
(
args
.
data_cache_dir
),
f
"Find existed data cache dir
{
args
.
data_cache_dir
}
"
assert
not
os
.
path
.
exists
(
args
.
data_jsonl_output_dir
),
f
"Find existed jsonl data output dir
{
args
.
data_jsonl_output_dir
}
"
assert
not
os
.
path
.
exists
(
args
.
data_arrow_output_dir
),
f
"Find existed arrow data output dir
{
args
.
data_arrow_output_dir
}
"
os
.
makedirs
(
args
.
data_jsonl_output_dir
)
os
.
makedirs
(
args
.
data_arrow_output_dir
)
# Prepare to all input datasets
input_data_paths
=
[]
input_data_dirs
=
args
.
data_input_dirs
.
split
(
","
)
for
ds_dir
in
input_data_dirs
:
ds_dir
=
os
.
path
.
abspath
(
ds_dir
)
assert
os
.
path
.
exists
(
ds_dir
),
f
"Not find data dir
{
ds_dir
}
"
ds_files
=
[
name
for
name
in
os
.
listdir
(
ds_dir
)
if
name
.
endswith
(
".jsonl"
)]
ds_paths
=
[
os
.
path
.
join
(
ds_dir
,
name
)
for
name
in
ds_files
]
input_data_paths
.
extend
(
ds_paths
)
# Prepare to data splitting.
train_splits
=
[]
split_interval
=
math
.
ceil
(
100
/
args
.
num_spliced_dataset_bins
)
for
i
in
range
(
0
,
100
,
split_interval
):
start
=
i
end
=
i
+
split_interval
if
end
>
100
:
end
=
100
train_splits
.
append
(
f
"train[
{
start
}
%:
{
end
}
%]"
)
# Prepare to the tokenizer.
tokenizer
=
LlamaTokenizer
.
from_pretrained
(
args
.
tokenizer_dir
)
tokenizer
.
add_bos_token
=
False
tokenizer
.
add_eos_token
=
False
if
tokenizer
.
pad_token
is
None
:
tokenizer
.
pad_token
=
tokenizer
.
unk_token
list_dataset
=
load_dataset
(
path
=
"json"
,
data_files
=
input_data_paths
,
cache_dir
=
os
.
path
.
join
(
args
.
data_cache_dir
,
"raw"
),
keep_in_memory
=
False
,
split
=
train_splits
,
num_proc
=
cpu_count
(),
)
for
index
,
dataset
in
enumerate
(
list_dataset
):
assert
isinstance
(
dataset
,
dataset_dict
.
Dataset
)
logger
.
info
(
f
"Start to process part-
{
index
}
/
{
len
(
list_dataset
)
}
of all original datasets."
)
dataset
=
dataset
.
map
(
function
=
supervised_tokenize
,
fn_kwargs
=
{
"tokenizer"
:
tokenizer
,
"max_length"
:
args
.
max_length
},
keep_in_memory
=
False
,
num_proc
=
min
(
len
(
dataset
),
cpu_count
()),
)
dataset
=
dataset
.
remove_columns
(
column_names
=
[
"source"
,
"target"
,
"category"
])
dataset
=
dataset
.
sort
(
column_names
=
(
"seq_category"
,
"seq_length"
),
reverse
=
False
,
keep_in_memory
=
False
)
dataset
=
dataset
.
remove_columns
(
column_names
=
[
"seq_category"
,
"seq_length"
])
spliced_dataset
=
ClosedToConstantLengthSplicedDataset
(
dataset
=
dataset
,
tokenizer
=
tokenizer
,
max_length
=
args
.
max_length
,
error_strict
=
False
)
# Save each jsonl spliced dataset.
output_index
=
"0"
*
(
5
-
len
(
str
(
index
)))
+
str
(
index
)
output_name
=
f
"part-
{
output_index
}
"
output_jsonl_path
=
os
.
path
.
join
(
args
.
data_jsonl_output_dir
,
output_name
+
".jsonl"
)
st
=
time
.
time
()
with
open
(
file
=
output_jsonl_path
,
mode
=
"w"
,
encoding
=
"utf-8"
)
as
fp_writer
:
spliced_count
=
0
for
spliced_data_point
in
spliced_dataset
:
if
spliced_count
%
500
==
0
:
logger
.
info
(
f
"processing
{
spliced_count
}
spliced data points for
{
fp_writer
.
name
}
"
)
spliced_count
+=
1
fp_writer
.
write
(
json
.
dumps
(
spliced_data_point
,
ensure_ascii
=
False
)
+
"
\n
"
)
logger
.
info
(
f
"Current file
{
fp_writer
.
name
}
; "
f
"Data size:
{
len
(
spliced_dataset
)
}
; "
f
"Spliced data size:
{
spliced_dataset
.
current_size
}
; "
f
"Splicing compression rate:
{
round
(
spliced_dataset
.
current_size
/
len
(
spliced_dataset
),
6
)
}
; "
f
"Time cost:
{
round
((
time
.
time
()
-
st
)
/
60
,
6
)
}
minutes."
)
# Save each arrow spliced dataset
output_arrow_path
=
os
.
path
.
join
(
args
.
data_arrow_output_dir
,
output_name
)
logger
.
info
(
f
"Start to save
{
output_arrow_path
}
"
)
spliced_dataset
=
load_dataset
(
path
=
"json"
,
data_files
=
[
output_jsonl_path
],
cache_dir
=
os
.
path
.
join
(
args
.
data_cache_dir
,
"spliced_and_tokenized"
),
keep_in_memory
=
False
,
num_proc
=
cpu_count
(),
split
=
"train"
,
)
spliced_dataset
.
save_to_disk
(
dataset_path
=
output_arrow_path
,
num_proc
=
min
(
len
(
spliced_dataset
),
cpu_count
()))
if
__name__
==
'__main__'
:
main
()
applications/Colossal-LLaMA-2/requirements.txt
0 → 100644
View file @
74aa7d96
torch<2.0.0, >=1.12.1
packaging==23.1
colossalai==0.3.2
autoflake==2.2.1
black==23.9.1
transformers
tensorboard==2.14.0
six==1.16.0
datasets
ninja==1.11.1
flash-attn>=2.0.0,<=2.0.5
tqdm
sentencepiece==0.1.99
protobuf<=3.20.0
applications/Colossal-LLaMA-2/train.example.sh
0 → 100644
View file @
74aa7d96
#!/bin/bash
# NCCL IB environment variables
export
NCCL_IB_HCA
=
mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1
export
NCCL_IB_DISABLE
=
0
export
NCCL_SOCKET_IFNAME
=
eth0
export
NCCL_IB_GID_INDEX
=
3
export
NCCL_IB_TIMEOUT
=
23
export
NCCL_IB_RETRY_CNT
=
7
export
OMP_NUM_THREADS
=
8
PROJECT_NAME
=
""
PARENT_SAVE_DIR
=
""
PARENT_TENSORBOARD_DIR
=
""
PARENT_CONFIG_FILE
=
""
PRETRAINED_MODEL_PATH
=
""
declare
-a
dataset
=(
"PATH TO THE DATASET"
)
TIMESTAMP
=
$(
date
+%Y-%m-%d-%H-%M-%S
)
FULL_PROJECT_NAME
=
"
${
PROJECT_NAME
}
-
${
TIMESTAMP
}
"
SAVE_DIR
=
"
${
PARENT_SAVE_DIR
}${
FULL_PROJECT_NAME
}
"
TENSORBOARD_DIR
=
"
${
PARENT_TENSORBOARD_DIR
}${
FULL_PROJECT_NAME
}
"
CONFIG_FILE
=
"
${
PARENT_CONFIG_FILE
}${
FULL_PROJECT_NAME
}
.json"
colossalai run
--nproc_per_node
8
--hostfile
hostfile
--master_port
30013 train.py
\
--pretrained
$PRETRAINED_MODEL_PATH
\
--dataset
${
dataset
[@]
}
\
--plugin
"zero2"
\
--save_interval
400
\
--save_dir
$SAVE_DIR
\
--tensorboard_dir
$TENSORBOARD_DIR
\
--config_file
$CONFIG_FILE
\
--num_epochs
1
\
--micro_batch_size
8
\
--lr
1e-4
\
--mixed_precision
"bf16"
\
--grad_clip
1.0
\
--weight_decay
0.01
\
--warmup_steps
100
\
--use_grad_checkpoint
\
--use_flash_attn
\
applications/Colossal-LLaMA-2/train.py
0 → 100644
View file @
74aa7d96
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
"""
import
json
import
argparse
import
os
import
resource
from
contextlib
import
nullcontext
from
tqdm
import
tqdm
import
torch
import
torch.distributed
as
dist
from
torch.utils.tensorboard
import
SummaryWriter
from
transformers
import
LlamaTokenizer
,
LlamaForCausalLM
,
LlamaConfig
import
colossalai
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin
import
(
GeminiPlugin
,
LowLevelZeroPlugin
,
HybridParallelPlugin
,
)
from
colossalai.cluster
import
DistCoordinator
from
colossalai.lazy
import
LazyInitContext
from
colossalai.nn.lr_scheduler
import
CosineAnnealingWarmupLR
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.utils
import
get_current_device
from
colossal_llama2.dataset.loader
import
(
load_tokenized_dataset
,
setup_distributed_dataloader
,
DataCollatorForSupervisedDataset
,
StatefulDistributedSampler
,
)
from
colossal_llama2.utils.flash_attention_patch
import
replace_with_flash_attention
from
colossal_llama2.utils.ckpt_io
import
load_checkpoint
,
save_checkpoint
from
colossal_llama2.utils.froze
import
freeze_non_embeds_parameters
def
get_model_numel
(
model
:
torch
.
nn
.
Module
)
->
int
:
return
sum
(
p
.
numel
()
for
p
in
model
.
parameters
())
def
format_numel_str
(
numel
:
int
)
->
str
:
B
=
1024
**
3
M
=
1024
**
2
K
=
1024
if
numel
>=
B
:
return
f
"
{
numel
/
B
:.
2
f
}
B"
elif
numel
>=
M
:
return
f
"
{
numel
/
M
:.
2
f
}
M"
elif
numel
>=
K
:
return
f
"
{
numel
/
K
:.
2
f
}
K"
else
:
return
f
"
{
numel
}
"
def
all_reduce_mean
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
dist
.
all_reduce
(
tensor
=
tensor
,
op
=
dist
.
ReduceOp
.
SUM
)
tensor
.
div_
(
dist
.
get_world_size
())
return
tensor
def
main
()
->
None
:
# ==============================
# Parse Arguments
# ==============================
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--pretrained"
,
type
=
str
,
default
=
None
,
help
=
"Address of the pre-trained modeling"
,
)
parser
.
add_argument
(
"--dataset"
,
nargs
=
"+"
,
default
=
[])
parser
.
add_argument
(
"--plugin"
,
type
=
str
,
default
=
"gemini"
,
choices
=
[
"gemini"
,
"gemini_auto"
,
"zero2"
,
"zero2_cpu"
,
"3d"
],
help
=
"Choose which plugin to use"
,
)
parser
.
add_argument
(
"--load_checkpoint"
,
type
=
str
,
default
=
None
,
help
=
"Load checkpoint"
)
parser
.
add_argument
(
"--save_interval"
,
type
=
int
,
default
=
1000
,
help
=
"Save interval"
)
parser
.
add_argument
(
"--save_dir"
,
type
=
str
,
default
=
"checkpoint_dir"
,
help
=
"Checkpoint directory"
)
parser
.
add_argument
(
"--tensorboard_dir"
,
type
=
str
,
default
=
"logs_dir"
,
help
=
"Tensorboard directory"
)
parser
.
add_argument
(
"--config_file"
,
type
=
str
,
default
=
"config_file"
,
help
=
"Config file"
)
parser
.
add_argument
(
"--num_epochs"
,
type
=
int
,
default
=
1
,
help
=
"Number of training epochs"
)
parser
.
add_argument
(
"--micro_batch_size"
,
type
=
int
,
default
=
2
,
help
=
"Batch size of each process"
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
3e-4
,
help
=
"Learning rate"
)
parser
.
add_argument
(
"--max_length"
,
type
=
int
,
default
=
4096
,
help
=
"Model max length"
)
parser
.
add_argument
(
"--mixed_precision"
,
type
=
str
,
default
=
"fp16"
,
choices
=
[
"fp16"
,
"bf16"
],
help
=
"Mixed precision"
,
)
parser
.
add_argument
(
"--grad_clip"
,
type
=
float
,
default
=
1.0
,
help
=
"Gradient clipping value"
)
parser
.
add_argument
(
"--weight_decay"
,
type
=
float
,
default
=
0.1
,
help
=
"Weight decay"
)
parser
.
add_argument
(
"--warmup_steps"
,
type
=
int
,
default
=
None
,
help
=
"Warmup steps"
)
parser
.
add_argument
(
"--use_grad_checkpoint"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use gradient checkpointing"
,
)
parser
.
add_argument
(
"--use_flash_attn"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use flash-attention"
,
)
parser
.
add_argument
(
"--freeze_non_embeds_params"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Freeze non embeddings parameters"
,
)
parser
.
add_argument
(
"--tp"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--zero"
,
type
=
int
,
default
=
1
)
args
=
parser
.
parse_args
()
with
open
(
args
.
config_file
,
"w"
)
as
f
:
json
.
dump
(
args
.
__dict__
,
f
,
indent
=
4
)
# ==============================
# Initialize Distributed Training
# ==============================
colossalai
.
launch_from_torch
({})
coordinator
=
DistCoordinator
()
# ==============================
# Initialize Tensorboard
# ==============================
if
coordinator
.
is_master
():
os
.
makedirs
(
args
.
tensorboard_dir
,
exist_ok
=
True
)
writer
=
SummaryWriter
(
args
.
tensorboard_dir
)
# ==============================
# Initialize Booster
# ==============================
if
args
.
plugin
==
"gemini"
:
plugin
=
GeminiPlugin
(
precision
=
args
.
mixed_precision
,
initial_scale
=
2
**
16
,
max_norm
=
args
.
grad_clip
,
)
elif
args
.
plugin
==
"gemini_auto"
:
plugin
=
GeminiPlugin
(
precision
=
args
.
mixed_precision
,
placement_policy
=
"auto"
,
initial_scale
=
2
**
16
,
max_norm
=
args
.
grad_clip
,
)
elif
args
.
plugin
==
"zero2"
:
plugin
=
LowLevelZeroPlugin
(
stage
=
2
,
precision
=
args
.
mixed_precision
,
initial_scale
=
2
**
16
,
max_norm
=
args
.
grad_clip
,
)
elif
args
.
plugin
==
"zero2_cpu"
:
plugin
=
LowLevelZeroPlugin
(
stage
=
2
,
precision
=
args
.
mixed_precision
,
initial_scale
=
2
**
16
,
cpu_offload
=
True
,
max_norm
=
args
.
grad_clip
,
)
elif
args
.
plugin
==
"3d"
:
plugin
=
HybridParallelPlugin
(
tp_size
=
args
.
tp
,
pp_size
=
1
,
zero_stage
=
args
.
zero
,
max_norm
=
args
.
grad_clip
,
precision
=
args
.
mixed_precision
,
)
else
:
raise
ValueError
(
f
"Unknown plugin
{
args
.
plugin
}
"
)
booster
=
Booster
(
plugin
=
plugin
)
# ======================================================
# Initialize Tokenizer, Dataset, Collator and Dataloader
# ======================================================
tokenizer
=
LlamaTokenizer
.
from_pretrained
(
args
.
pretrained
)
tokenizer
.
pad_token
=
tokenizer
.
unk_token
tokenizer
.
add_bos_token
=
False
tokenizer
.
add_eos_token
=
False
coordinator
.
print_on_master
(
f
"Configuration file will be saved at:
{
args
.
config_file
}
"
)
coordinator
.
print_on_master
(
f
"Tensorboard logs will be saved at:
{
args
.
tensorboard_dir
}
"
)
coordinator
.
print_on_master
(
f
"Model checkpoint will be saved at:
{
args
.
save_dir
}
"
)
coordinator
.
print_on_master
(
f
"Load dataset:
{
args
.
dataset
}
"
)
dataset
=
load_tokenized_dataset
(
dataset_paths
=
args
.
dataset
,
mode
=
"train"
)
data_collator
=
DataCollatorForSupervisedDataset
(
tokenizer
=
tokenizer
,
max_length
=
args
.
max_length
)
dataloader
=
setup_distributed_dataloader
(
dataset
=
dataset
,
batch_size
=
args
.
micro_batch_size
,
shuffle
=
True
,
drop_last
=
True
,
collate_fn
=
data_collator
,
)
coordinator
.
print_on_master
(
f
"Max CUDA memory after data loader:
{
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
init_ctx
=
(
LazyInitContext
(
default_device
=
get_current_device
())
if
isinstance
(
plugin
,
(
GeminiPlugin
,))
else
nullcontext
()
)
with
init_ctx
:
model
=
LlamaForCausalLM
(
LlamaConfig
.
from_pretrained
(
args
.
pretrained
))
# Freeze part of parameters.
if
args
.
freeze_non_embeds_params
:
freeze_non_embeds_parameters
(
model
=
model
)
if
args
.
use_grad_checkpoint
:
model
.
gradient_checkpointing_enable
()
coordinator
.
print_on_master
(
msg
=
"Gradient checkpointing enabled successfully"
)
if
args
.
use_flash_attn
:
replace_with_flash_attention
(
model
=
model
)
coordinator
.
print_on_master
(
msg
=
"Flash-attention enabled successfully"
)
model_numel
=
get_model_numel
(
model
)
coordinator
.
print_on_master
(
f
"Model params:
{
format_numel_str
(
model_numel
)
}
"
)
optimizer
=
HybridAdam
(
model_params
=
filter
(
lambda
p
:
p
.
requires_grad
,
model
.
parameters
())
if
args
.
freeze_non_embeds_params
else
model
.
parameters
(),
lr
=
args
.
lr
,
betas
=
(
0.9
,
0.95
),
weight_decay
=
args
.
weight_decay
,
adamw_mode
=
True
,
)
lr_scheduler
=
CosineAnnealingWarmupLR
(
optimizer
=
optimizer
,
total_steps
=
args
.
num_epochs
*
len
(
dataloader
),
warmup_steps
=
args
.
warmup_steps
if
args
.
warmup_steps
is
not
None
else
int
(
args
.
num_epochs
*
len
(
dataloader
)
*
0.025
),
eta_min
=
0.1
*
args
.
lr
,
)
# Flash attention will be disabled because it does NOT support fp32.
default_dtype
=
torch
.
float16
if
args
.
mixed_precision
==
"fp16"
else
torch
.
bfloat16
torch
.
set_default_dtype
(
default_dtype
)
model
,
optimizer
,
_
,
dataloader
,
lr_scheduler
=
booster
.
boost
(
model
=
model
,
optimizer
=
optimizer
,
lr_scheduler
=
lr_scheduler
,
dataloader
=
dataloader
,
)
torch
.
set_default_dtype
(
torch
.
float
)
if
args
.
load_checkpoint
is
None
:
coordinator
.
print_on_master
(
f
"Load pretrained model checkpoint from
{
args
.
pretrained
}
"
)
booster
.
load_model
(
model
,
args
.
pretrained
,
strict
=
False
)
coordinator
.
print_on_master
(
f
"Booster init max CUDA memory:
{
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
coordinator
.
print_on_master
(
f
"Booster init max CPU memory:
{
resource
.
getrusage
(
resource
.
RUSAGE_SELF
).
ru_maxrss
/
1024
:.
2
f
}
MB"
)
start_epoch
=
0
start_step
=
0
sampler_start_idx
=
0
if
args
.
load_checkpoint
is
not
None
:
if
"modeling"
in
args
.
load_checkpoint
:
coordinator
.
print_on_master
(
f
"Continued pretrain from checkpoint
{
args
.
load_checkpoint
}
"
)
booster
.
load_model
(
model
,
args
.
load_checkpoint
)
else
:
coordinator
.
print_on_master
(
f
"Load model checkpoint from
{
args
.
load_checkpoint
}
"
)
start_epoch
,
start_step
,
sampler_start_idx
=
load_checkpoint
(
load_dir
=
args
.
load_checkpoint
,
booster
=
booster
,
model
=
model
,
optimizer
=
optimizer
,
lr_scheduler
=
lr_scheduler
,
)
coordinator
.
print_on_master
(
f
"Loaded checkpoint
{
args
.
load_checkpoint
}
at epoch
{
start_epoch
}
step
{
start_step
}
"
)
coordinator
.
print_on_master
(
f
"Loaded sample at index
{
sampler_start_idx
}
"
)
coordinator
.
print_on_master
(
f
"Checkpoint loaded max CUDA memory:
{
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
coordinator
.
print_on_master
(
f
"Checkpoint loaded CUDA memory:
{
torch
.
cuda
.
memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
coordinator
.
print_on_master
(
f
"Checkpoint loaded max CPU memory:
{
resource
.
getrusage
(
resource
.
RUSAGE_SELF
).
ru_maxrss
/
1024
:.
2
f
}
MB"
)
num_steps_per_epoch
=
len
(
dataloader
)
# If resume training, set the sampler start index to the correct value
assert
isinstance
(
dataloader
.
sampler
,
StatefulDistributedSampler
)
dataloader
.
sampler
.
set_start_index
(
start_index
=
sampler_start_idx
)
for
epoch
in
range
(
start_epoch
,
args
.
num_epochs
):
dataloader
.
sampler
.
set_epoch
(
epoch
=
epoch
)
with
tqdm
(
iterable
=
enumerate
(
dataloader
,
start
=
start_step
),
desc
=
f
"Epoch
{
epoch
}
"
,
disable
=
not
coordinator
.
is_master
(),
total
=
num_steps_per_epoch
,
initial
=
start_step
,
)
as
pbar
:
for
step
,
batch
in
pbar
:
batch
=
{
k
:
v
.
to
(
get_current_device
())
for
k
,
v
in
batch
.
items
()
if
isinstance
(
v
,
torch
.
Tensor
)}
batch_output
=
model
(
**
batch
)
loss
=
batch_output
.
loss
booster
.
backward
(
loss
=
loss
,
optimizer
=
optimizer
)
optimizer
.
step
()
lr_scheduler
.
step
()
optimizer
.
zero_grad
()
all_reduce_mean
(
tensor
=
loss
)
pbar
.
set_postfix
({
"Loss"
:
f
"
{
loss
.
item
():.
4
f
}
"
})
if
coordinator
.
is_master
():
global_step
=
epoch
*
num_steps_per_epoch
+
step
writer
.
add_scalar
(
tag
=
"Loss"
,
scalar_value
=
loss
.
item
(),
global_step
=
global_step
)
writer
.
add_scalar
(
tag
=
"Learning Rate"
,
scalar_value
=
lr_scheduler
.
get_last_lr
()[
0
],
global_step
=
global_step
,
)
# Save modeling.
if
(
args
.
save_interval
>
0
and
(
step
+
1
)
%
args
.
save_interval
==
0
)
or
(
step
+
1
)
==
len
(
dataloader
):
coordinator
.
print_on_master
(
"
\n
Start saving model checkpoint with running states"
)
save_checkpoint
(
save_dir
=
args
.
save_dir
,
booster
=
booster
,
model
=
model
,
optimizer
=
optimizer
,
lr_scheduler
=
lr_scheduler
,
epoch
=
epoch
,
step
=
step
+
1
,
batch_size
=
args
.
micro_batch_size
,
coordinator
=
coordinator
,
)
coordinator
.
print_on_master
(
f
"Saved checkpoint at epoch
{
epoch
}
step
{
step
+
1
}
at folder
{
args
.
save_dir
}
"
)
# Delete CUDA cache.
# del batch, batch_labels, batch_output, loss
torch
.
cuda
.
empty_cache
()
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader
.
sampler
.
set_start_index
(
start_index
=
0
)
start_step
=
0
# Final save.
coordinator
.
print_on_master
(
"Start saving final model checkpoint"
)
booster
.
save_model
(
model
,
os
.
path
.
join
(
args
.
save_dir
,
"modeling"
),
shard
=
True
)
coordinator
.
print_on_master
(
f
"Saved final model checkpoint at epoch
{
epoch
}
at folder
{
args
.
save_dir
}
"
)
coordinator
.
print_on_master
(
f
"Max CUDA memory usage:
{
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
if
__name__
==
"__main__"
:
main
()
applications/Colossal-LLaMA-2/version.txt
0 → 100644
View file @
74aa7d96
0.0.1
\ No newline at end of file
applications/README.md
View file @
74aa7d96
...
@@ -4,8 +4,9 @@ This directory contains the applications that are powered by Colossal-AI.
...
@@ -4,8 +4,9 @@ This directory contains the applications that are powered by Colossal-AI.
The list of applications include:
The list of applications include:
-
[
X] [Chatbot
](
./Chat/README.md
)
-
[
X] [Colossal-LLaMA-2
](
./Colossal-LLaMA-2/
)
: Continual Pre-training of LLaMA-2.
-
[
X] [FastFold
](
https://github.com/hpcaitech/FastFold
)
: Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters
-
[
X] [Chatbot
](
./Chat/README.md
)
: Replication of ChatGPT with RLHF.
-
[
X] [FastFold
](
https://github.com/hpcaitech/FastFold
)
: Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters.
> Please note that the `Chatbot` application is migrated from the original `ChatGPT` folder.
> Please note that the `Chatbot` application is migrated from the original `ChatGPT` folder.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment