Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7bc5a8e3
Commit
7bc5a8e3
authored
May 05, 2023
by
zhuwenwen
Browse files
Merge branch 'main' of
https://github.com/hpcaitech/ColossalAI
parents
e6748d82
0f785cb1
Changes
428
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
894 additions
and
0 deletions
+894
-0
applications/Chat/coati/dataset/__init__.py
applications/Chat/coati/dataset/__init__.py
+9
-0
applications/Chat/coati/dataset/prompt_dataset.py
applications/Chat/coati/dataset/prompt_dataset.py
+51
-0
applications/Chat/coati/dataset/reward_dataset.py
applications/Chat/coati/dataset/reward_dataset.py
+112
-0
applications/Chat/coati/dataset/sft_dataset.py
applications/Chat/coati/dataset/sft_dataset.py
+166
-0
applications/Chat/coati/dataset/utils.py
applications/Chat/coati/dataset/utils.py
+22
-0
applications/Chat/coati/experience_maker/__init__.py
applications/Chat/coati/experience_maker/__init__.py
+4
-0
applications/Chat/coati/experience_maker/base.py
applications/Chat/coati/experience_maker/base.py
+77
-0
applications/Chat/coati/experience_maker/naive.py
applications/Chat/coati/experience_maker/naive.py
+35
-0
applications/Chat/coati/kernels/__init__.py
applications/Chat/coati/kernels/__init__.py
+6
-0
applications/Chat/coati/kernels/opt_attn.py
applications/Chat/coati/kernels/opt_attn.py
+87
-0
applications/Chat/coati/kernels/wrapper.py
applications/Chat/coati/kernels/wrapper.py
+18
-0
applications/Chat/coati/models/__init__.py
applications/Chat/coati/models/__init__.py
+8
-0
applications/Chat/coati/models/base/__init__.py
applications/Chat/coati/models/base/__init__.py
+24
-0
applications/Chat/coati/models/base/actor.py
applications/Chat/coati/models/base/actor.py
+65
-0
applications/Chat/coati/models/base/critic.py
applications/Chat/coati/models/base/critic.py
+54
-0
applications/Chat/coati/models/base/reward_model.py
applications/Chat/coati/models/base/reward_model.py
+41
-0
applications/Chat/coati/models/bloom/__init__.py
applications/Chat/coati/models/bloom/__init__.py
+5
-0
applications/Chat/coati/models/bloom/bloom_actor.py
applications/Chat/coati/models/bloom/bloom_actor.py
+35
-0
applications/Chat/coati/models/bloom/bloom_critic.py
applications/Chat/coati/models/bloom/bloom_critic.py
+38
-0
applications/Chat/coati/models/bloom/bloom_rm.py
applications/Chat/coati/models/bloom/bloom_rm.py
+37
-0
No files found.
Too many changes to show.
To preserve performance only
428 of 428+
files are displayed.
Plain diff
Email patch
applications/Chat/coati/dataset/__init__.py
0 → 100644
View file @
7bc5a8e3
from
.prompt_dataset
import
PromptDataset
from
.reward_dataset
import
HhRlhfDataset
,
RmStaticDataset
from
.sft_dataset
import
DataCollatorForSupervisedDataset
,
SFTDataset
,
SupervisedDataset
from
.utils
import
is_rank_0
__all__
=
[
'RmStaticDataset'
,
'HhRlhfDataset'
,
'is_rank_0'
,
'SFTDataset'
,
'SupervisedDataset'
,
'DataCollatorForSupervisedDataset'
,
'PromptDataset'
]
applications/Chat/coati/dataset/prompt_dataset.py
0 → 100644
View file @
7bc5a8e3
import
copy
import
random
from
collections
import
defaultdict
from
dataclasses
import
dataclass
,
field
from
typing
import
Callable
,
Dict
,
Sequence
import
torch
import
torch.distributed
as
dist
import
transformers
from
torch.utils.data
import
Dataset
from
tqdm
import
tqdm
from
colossalai.logging
import
get_dist_logger
from
.utils
import
is_rank_0
,
jload
logger
=
get_dist_logger
()
class
PromptDataset
(
Dataset
):
"""Dataset for supervised fine-tuning."""
def
__init__
(
self
,
data_path
:
str
,
tokenizer
:
transformers
.
PreTrainedTokenizer
,
max_datasets_size
:
int
=
None
,
max_length
:
int
=
96
):
super
(
PromptDataset
,
self
).
__init__
()
self
.
keyed_prompt
=
defaultdict
(
list
)
logger
.
info
(
"Loading data..."
)
list_data_dict
=
jload
(
data_path
)
logger
.
info
(
f
"Loaded
{
len
(
list_data_dict
)
}
examples."
)
if
max_datasets_size
is
not
None
:
logger
.
info
(
f
"Limiting dataset to
{
max_datasets_size
}
examples."
)
list_data_dict
=
list_data_dict
[:
max_datasets_size
]
for
data_dict
in
list_data_dict
:
token
=
tokenizer
(
data_dict
[
"instruction"
],
return_tensors
=
'pt'
,
max_length
=
max_length
,
padding
=
'max_length'
,
truncation
=
True
)
for
k
,
tensor
in
token
.
items
():
self
.
keyed_prompt
[
k
].
extend
(
tensor
.
to
(
torch
.
cuda
.
current_device
()).
unbind
())
def
__len__
(
self
):
return
len
(
self
.
keyed_prompt
)
def
__getitem__
(
self
,
i
)
->
Dict
[
str
,
torch
.
Tensor
]:
return
{
k
:
v
[
i
]
for
k
,
v
in
self
.
keyed_prompt
.
items
()}
applications/Chat/coati/dataset/reward_dataset.py
0 → 100644
View file @
7bc5a8e3
from
typing
import
Callable
from
torch.utils.data
import
Dataset
from
tqdm
import
tqdm
from
.utils
import
is_rank_0
# Dahaos/rm-static
class
RmStaticDataset
(
Dataset
):
"""
Dataset for reward model
Args:
dataset: dataset for reward model
tokenizer: tokenizer for reward model
max_length: max length of input
special_token: special token at the end of sentence
"""
def
__init__
(
self
,
dataset
,
tokenizer
:
Callable
,
max_length
:
int
,
special_token
=
None
)
->
None
:
super
().
__init__
()
self
.
chosen
=
[]
self
.
reject
=
[]
if
special_token
is
None
:
self
.
end_token
=
tokenizer
.
eos_token
else
:
self
.
end_token
=
special_token
for
data
in
tqdm
(
dataset
,
disable
=
not
is_rank_0
()):
prompt
=
data
[
'prompt'
]
chosen
=
prompt
+
data
[
'chosen'
]
+
self
.
end_token
chosen_token
=
tokenizer
(
chosen
,
max_length
=
max_length
,
padding
=
"max_length"
,
truncation
=
True
,
return_tensors
=
"pt"
)
self
.
chosen
.
append
({
"input_ids"
:
chosen_token
[
'input_ids'
],
"attention_mask"
:
chosen_token
[
'attention_mask'
]
})
reject
=
prompt
+
data
[
'rejected'
]
+
self
.
end_token
reject_token
=
tokenizer
(
reject
,
max_length
=
max_length
,
padding
=
"max_length"
,
truncation
=
True
,
return_tensors
=
"pt"
)
self
.
reject
.
append
({
"input_ids"
:
reject_token
[
'input_ids'
],
"attention_mask"
:
reject_token
[
'attention_mask'
]
})
def
__len__
(
self
):
length
=
len
(
self
.
chosen
)
return
length
def
__getitem__
(
self
,
idx
):
return
self
.
chosen
[
idx
][
"input_ids"
],
self
.
chosen
[
idx
][
"attention_mask"
],
self
.
reject
[
idx
][
"input_ids"
],
self
.
reject
[
idx
][
"attention_mask"
]
# Anthropic/hh-rlhf
class
HhRlhfDataset
(
Dataset
):
"""
Dataset for reward model
Args:
dataset: dataset for reward model
tokenizer: tokenizer for reward model
max_length: max length of input
special_token: special token at the end of sentence
"""
def
__init__
(
self
,
dataset
,
tokenizer
:
Callable
,
max_length
:
int
,
special_token
=
None
)
->
None
:
super
().
__init__
()
self
.
chosen
=
[]
self
.
reject
=
[]
if
special_token
is
None
:
self
.
end_token
=
tokenizer
.
eos_token
else
:
self
.
end_token
=
special_token
for
data
in
tqdm
(
dataset
,
disable
=
not
is_rank_0
()):
chosen
=
data
[
'chosen'
]
+
self
.
end_token
chosen_token
=
tokenizer
(
chosen
,
max_length
=
max_length
,
padding
=
"max_length"
,
truncation
=
True
,
return_tensors
=
"pt"
)
self
.
chosen
.
append
({
"input_ids"
:
chosen_token
[
'input_ids'
],
"attention_mask"
:
chosen_token
[
'attention_mask'
]
})
reject
=
data
[
'rejected'
]
+
self
.
end_token
reject_token
=
tokenizer
(
reject
,
max_length
=
max_length
,
padding
=
"max_length"
,
truncation
=
True
,
return_tensors
=
"pt"
)
self
.
reject
.
append
({
"input_ids"
:
reject_token
[
'input_ids'
],
"attention_mask"
:
reject_token
[
'attention_mask'
]
})
def
__len__
(
self
):
length
=
len
(
self
.
chosen
)
return
length
def
__getitem__
(
self
,
idx
):
return
self
.
chosen
[
idx
][
"input_ids"
],
self
.
chosen
[
idx
][
"attention_mask"
],
self
.
reject
[
idx
][
"input_ids"
],
self
.
reject
[
idx
][
"attention_mask"
]
applications/Chat/coati/dataset/sft_dataset.py
0 → 100644
View file @
7bc5a8e3
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
import
random
from
dataclasses
import
dataclass
,
field
from
typing
import
Callable
,
Dict
,
Sequence
import
torch
import
torch.distributed
as
dist
import
transformers
from
torch.utils.data
import
Dataset
from
tqdm
import
tqdm
from
colossalai.logging
import
get_dist_logger
from
.utils
import
is_rank_0
,
jload
logger
=
get_dist_logger
()
IGNORE_INDEX
=
-
100
PROMPT_DICT
=
{
"prompt_input"
:
(
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.
\n\n
"
"### Instruction:
\n
{instruction}
\n\n
### Input:
\n
{input}
\n\n
### Response:"
),
"prompt_no_input"
:
(
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.
\n\n
"
"### Instruction:
\n
{instruction}
\n\n
### Response:"
),
}
class
SFTDataset
(
Dataset
):
"""
Dataset for sft model
Args:
dataset: dataset for supervised model
tokenizer: tokenizer for supervised model
max_length: max length of input
"""
def
__init__
(
self
,
dataset
,
tokenizer
:
Callable
,
max_length
:
int
=
512
)
->
None
:
super
().
__init__
()
self
.
input_ids
=
[]
for
data
in
tqdm
(
dataset
,
disable
=
not
is_rank_0
()):
prompt
=
data
[
'prompt'
]
+
data
[
'completion'
]
+
tokenizer
.
eos_token
prompt_token
=
tokenizer
(
prompt
,
max_length
=
max_length
,
padding
=
"max_length"
,
truncation
=
True
,
return_tensors
=
"pt"
)
self
.
input_ids
.
append
(
prompt_token
[
'input_ids'
][
0
])
self
.
labels
=
copy
.
deepcopy
(
self
.
input_ids
)
def
__len__
(
self
):
length
=
len
(
self
.
input_ids
)
return
length
def
__getitem__
(
self
,
idx
):
return
dict
(
input_ids
=
self
.
input_ids
[
idx
],
labels
=
self
.
labels
[
idx
])
def
_tokenize_fn
(
strings
:
Sequence
[
str
],
tokenizer
:
transformers
.
PreTrainedTokenizer
,
max_length
:
int
)
->
Dict
:
"""Tokenize a list of strings."""
tokenized_list
=
[
tokenizer
(
text
,
return_tensors
=
"pt"
,
padding
=
"longest"
,
max_length
=
max_length
,
truncation
=
True
,
)
for
text
in
strings
]
input_ids
=
labels
=
[
tokenized
.
input_ids
[
0
]
for
tokenized
in
tokenized_list
]
input_ids_lens
=
labels_lens
=
[
tokenized
.
input_ids
.
ne
(
tokenizer
.
pad_token_id
).
sum
().
item
()
for
tokenized
in
tokenized_list
]
return
dict
(
input_ids
=
input_ids
,
labels
=
labels
,
input_ids_lens
=
input_ids_lens
,
labels_lens
=
labels_lens
,
)
def
preprocess
(
sources
:
Sequence
[
str
],
targets
:
Sequence
[
str
],
tokenizer
:
transformers
.
PreTrainedTokenizer
,
max_length
:
int
,
)
->
Dict
:
"""Preprocess the data by tokenizing."""
examples
=
[
s
+
t
for
s
,
t
in
zip
(
sources
,
targets
)]
examples_tokenized
,
sources_tokenized
=
[
_tokenize_fn
(
strings
,
tokenizer
,
max_length
)
for
strings
in
(
examples
,
sources
)]
input_ids
=
examples_tokenized
[
"input_ids"
]
labels
=
copy
.
deepcopy
(
input_ids
)
for
label
,
source_len
in
zip
(
labels
,
sources_tokenized
[
"input_ids_lens"
]):
label
[:
source_len
]
=
IGNORE_INDEX
return
dict
(
input_ids
=
input_ids
,
labels
=
labels
)
class
SupervisedDataset
(
Dataset
):
"""Dataset for supervised fine-tuning."""
def
__init__
(
self
,
data_path
:
str
,
tokenizer
:
transformers
.
PreTrainedTokenizer
,
max_datasets_size
:
int
=
None
,
max_length
:
int
=
512
):
super
(
SupervisedDataset
,
self
).
__init__
()
logger
.
info
(
"Loading data..."
)
list_data_dict
=
jload
(
data_path
)
logger
.
info
(
f
"Loaded
{
len
(
list_data_dict
)
}
examples."
)
if
max_datasets_size
is
not
None
:
logger
.
info
(
f
"Limiting dataset to
{
max_datasets_size
}
examples."
)
list_data_dict
=
list_data_dict
[:
max_datasets_size
]
logger
.
info
(
"Formatting inputs..."
)
prompt_input
,
prompt_no_input
=
PROMPT_DICT
[
"prompt_input"
],
PROMPT_DICT
[
"prompt_no_input"
]
sources
=
[
prompt_input
.
format_map
(
example
)
if
example
.
get
(
"input"
,
""
)
!=
""
else
prompt_no_input
.
format_map
(
example
)
for
example
in
list_data_dict
]
targets
=
[
f
"
{
example
[
'output'
]
}{
tokenizer
.
eos_token
}
"
for
example
in
list_data_dict
]
logger
.
info
(
"Tokenizing inputs... This may take some time..."
)
data_dict
=
preprocess
(
sources
,
targets
,
tokenizer
,
max_length
)
self
.
input_ids
=
data_dict
[
"input_ids"
]
self
.
labels
=
data_dict
[
"labels"
]
def
__len__
(
self
):
return
len
(
self
.
input_ids
)
def
__getitem__
(
self
,
i
)
->
Dict
[
str
,
torch
.
Tensor
]:
return
dict
(
input_ids
=
self
.
input_ids
[
i
],
labels
=
self
.
labels
[
i
])
@
dataclass
class
DataCollatorForSupervisedDataset
(
object
):
"""Collate examples for supervised fine-tuning."""
tokenizer
:
transformers
.
PreTrainedTokenizer
def
__call__
(
self
,
instances
:
Sequence
[
Dict
])
->
Dict
[
str
,
torch
.
Tensor
]:
input_ids
,
labels
=
tuple
([
instance
[
key
]
for
instance
in
instances
]
for
key
in
(
"input_ids"
,
"labels"
))
input_ids
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
input_ids
,
batch_first
=
True
,
padding_value
=
self
.
tokenizer
.
pad_token_id
)
labels
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
labels
,
batch_first
=
True
,
padding_value
=
IGNORE_INDEX
)
return
dict
(
input_ids
=
input_ids
,
labels
=
labels
,
attention_mask
=
input_ids
.
ne
(
self
.
tokenizer
.
pad_token_id
),
)
applications/Chat/coati/dataset/utils.py
0 → 100644
View file @
7bc5a8e3
import
io
import
json
import
torch.distributed
as
dist
def
is_rank_0
()
->
bool
:
return
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
def
_make_r_io_base
(
f
,
mode
:
str
):
if
not
isinstance
(
f
,
io
.
IOBase
):
f
=
open
(
f
,
mode
=
mode
)
return
f
def
jload
(
f
,
mode
=
"r"
):
"""Load a .json file into a dictionary."""
f
=
_make_r_io_base
(
f
,
mode
)
jdict
=
json
.
load
(
f
)
f
.
close
()
return
jdict
applications/Chat/coati/experience_maker/__init__.py
0 → 100644
View file @
7bc5a8e3
from
.base
import
Experience
,
ExperienceMaker
from
.naive
import
NaiveExperienceMaker
__all__
=
[
'Experience'
,
'ExperienceMaker'
,
'NaiveExperienceMaker'
]
applications/Chat/coati/experience_maker/base.py
0 → 100644
View file @
7bc5a8e3
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
coati.models.base
import
Actor
@
dataclass
class
Experience
:
"""Experience is a batch of data.
These data should have the the sequence length and number of actions.
Left padding for sequences is applied.
Shapes of each tensor:
sequences: (B, S)
action_log_probs: (B, A)
values: (B)
reward: (B)
advantages: (B)
attention_mask: (B, S)
action_mask: (B, A)
"A" is the number of actions.
"""
sequences
:
torch
.
Tensor
action_log_probs
:
torch
.
Tensor
values
:
torch
.
Tensor
reward
:
torch
.
Tensor
advantages
:
torch
.
Tensor
attention_mask
:
Optional
[
torch
.
LongTensor
]
action_mask
:
Optional
[
torch
.
BoolTensor
]
@
torch
.
no_grad
()
def
to_device
(
self
,
device
:
torch
.
device
)
->
None
:
self
.
sequences
=
self
.
sequences
.
to
(
device
)
self
.
action_log_probs
=
self
.
action_log_probs
.
to
(
device
)
self
.
values
=
self
.
values
.
to
(
device
)
self
.
reward
=
self
.
reward
.
to
(
device
)
self
.
advantages
=
self
.
advantages
.
to
(
device
)
if
self
.
attention_mask
is
not
None
:
self
.
attention_mask
=
self
.
attention_mask
.
to
(
device
)
if
self
.
action_mask
is
not
None
:
self
.
action_mask
=
self
.
action_mask
.
to
(
device
)
def
pin_memory
(
self
):
self
.
sequences
=
self
.
sequences
.
pin_memory
()
self
.
action_log_probs
=
self
.
action_log_probs
.
pin_memory
()
self
.
values
=
self
.
values
.
pin_memory
()
self
.
reward
=
self
.
reward
.
pin_memory
()
self
.
advantages
=
self
.
advantages
.
pin_memory
()
if
self
.
attention_mask
is
not
None
:
self
.
attention_mask
=
self
.
attention_mask
.
pin_memory
()
if
self
.
action_mask
is
not
None
:
self
.
action_mask
=
self
.
action_mask
.
pin_memory
()
return
self
class
ExperienceMaker
(
ABC
):
def
__init__
(
self
,
actor
:
Actor
,
critic
:
nn
.
Module
,
reward_model
:
nn
.
Module
,
initial_model
:
Actor
,
kl_coef
:
float
=
0.1
)
->
None
:
super
().
__init__
()
self
.
actor
=
actor
self
.
critic
=
critic
self
.
reward_model
=
reward_model
self
.
initial_model
=
initial_model
self
.
kl_coef
=
kl_coef
@
abstractmethod
def
make_experience
(
self
,
input_ids
:
torch
.
Tensor
,
**
generate_kwargs
)
->
Experience
:
pass
applications/Chat/coati/experience_maker/naive.py
0 → 100644
View file @
7bc5a8e3
import
torch
from
coati.models.utils
import
compute_reward
,
normalize
from
.base
import
Experience
,
ExperienceMaker
class
NaiveExperienceMaker
(
ExperienceMaker
):
"""
Naive experience maker.
"""
@
torch
.
no_grad
()
def
make_experience
(
self
,
input_ids
:
torch
.
Tensor
,
**
generate_kwargs
)
->
Experience
:
self
.
actor
.
eval
()
self
.
critic
.
eval
()
self
.
initial_model
.
eval
()
self
.
reward_model
.
eval
()
sequences
,
attention_mask
,
action_mask
=
self
.
actor
.
generate
(
input_ids
,
return_action_mask
=
True
,
**
generate_kwargs
)
num_actions
=
action_mask
.
size
(
1
)
action_log_probs
=
self
.
actor
(
sequences
,
num_actions
,
attention_mask
)
base_action_log_probs
=
self
.
initial_model
(
sequences
,
num_actions
,
attention_mask
)
value
=
self
.
critic
(
sequences
,
action_mask
,
attention_mask
)
r
=
self
.
reward_model
(
sequences
,
attention_mask
)
reward
=
compute_reward
(
r
,
self
.
kl_coef
,
action_log_probs
,
base_action_log_probs
,
action_mask
=
action_mask
)
advantage
=
reward
-
value
# TODO(ver217): maybe normalize adv
if
advantage
.
ndim
==
1
:
advantage
=
advantage
.
unsqueeze
(
-
1
)
return
Experience
(
sequences
,
action_log_probs
,
value
,
reward
,
advantage
,
attention_mask
,
action_mask
)
applications/Chat/coati/kernels/__init__.py
0 → 100644
View file @
7bc5a8e3
from
.wrapper
import
convert_to_xformer_model
,
recover_from_xformer_model
__all__
=
[
'convert_to_xformer_model'
,
'recover_from_xformer_model'
,
]
applications/Chat/coati/kernels/opt_attn.py
0 → 100644
View file @
7bc5a8e3
from
typing
import
Optional
,
Tuple
import
torch
import
xformers.ops
as
xops
from
torch
import
Tensor
from
transformers.models.opt.modeling_opt
import
OPTAttention
# This is modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py
class
XOPTAttention
(
OPTAttention
):
# def _shape(self, tensor: Tensor, seq_len: int, bsz: int):
# return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous()
def
forward
(
self
,
hidden_states
:
Tensor
,
key_value_states
:
Optional
[
Tensor
]
=
None
,
past_key_value
:
Optional
[
Tensor
]
=
None
,
attention_mask
:
Optional
[
Tensor
]
=
None
,
layer_head_mask
:
Optional
[
Tensor
]
=
None
,
output_attentions
:
bool
=
False
,
)
->
Tuple
[
Tensor
,
Optional
[
Tensor
],
Optional
[
Tuple
[
Tensor
]]]:
if
not
self
.
training
:
return
super
().
forward
(
hidden_states
,
key_value_states
,
past_key_value
,
attention_mask
,
layer_head_mask
,
output_attentions
)
"""Input shape: Batch x Time x Channel"""
assert
layer_head_mask
is
None
,
'Xformers attention does not support layer_head_mask'
assert
not
output_attentions
,
'Xformers attention does not support output_attentions'
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention
=
key_value_states
is
not
None
bsz
,
tgt_len
,
_
=
hidden_states
.
size
()
# get query proj
query_states
=
self
.
q_proj
(
hidden_states
)
# get key, value proj
if
is_cross_attention
and
past_key_value
is
not
None
:
# reuse k,v, cross_attentions
key_states
=
past_key_value
[
0
]
value_states
=
past_key_value
[
1
]
elif
is_cross_attention
:
# cross_attentions
key_states
=
self
.
_shape
(
self
.
k_proj
(
key_value_states
),
-
1
,
bsz
)
value_states
=
self
.
_shape
(
self
.
v_proj
(
key_value_states
),
-
1
,
bsz
)
elif
past_key_value
is
not
None
:
# reuse k, v, self_attention
key_states
=
self
.
_shape
(
self
.
k_proj
(
hidden_states
),
-
1
,
bsz
)
value_states
=
self
.
_shape
(
self
.
v_proj
(
hidden_states
),
-
1
,
bsz
)
key_states
=
torch
.
cat
([
past_key_value
[
0
],
key_states
],
dim
=
2
)
value_states
=
torch
.
cat
([
past_key_value
[
1
],
value_states
],
dim
=
2
)
else
:
# self_attention
key_states
=
self
.
_shape
(
self
.
k_proj
(
hidden_states
),
-
1
,
bsz
)
value_states
=
self
.
_shape
(
self
.
v_proj
(
hidden_states
),
-
1
,
bsz
)
if
self
.
is_decoder
:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value
=
(
key_states
,
value_states
)
query_states
=
self
.
_shape
(
query_states
,
tgt_len
,
bsz
).
transpose
(
1
,
2
)
key_states
=
key_states
.
transpose
(
1
,
2
)
value_states
=
value_states
.
transpose
(
1
,
2
)
attn_output
=
xops
.
memory_efficient_attention
(
query_states
,
key_states
,
value_states
,
attn_bias
=
xops
.
LowerTriangularMask
(),
p
=
self
.
dropout
if
self
.
training
else
0.0
,
scale
=
self
.
scaling
)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output
=
attn_output
.
reshape
(
bsz
,
tgt_len
,
self
.
embed_dim
)
attn_output
=
self
.
out_proj
(
attn_output
)
attn_weights_reshaped
=
None
return
attn_output
,
attn_weights_reshaped
,
past_key_value
applications/Chat/coati/kernels/wrapper.py
0 → 100644
View file @
7bc5a8e3
import
torch.nn
as
nn
from
transformers.models.opt.modeling_opt
import
OPTAttention
from
.opt_attn
import
XOPTAttention
def
convert_to_xformer_model
(
model
:
nn
.
Module
)
->
nn
.
Module
:
for
module
in
model
.
modules
():
if
isinstance
(
module
,
OPTAttention
):
module
.
__class__
=
XOPTAttention
return
model
def
recover_from_xformer_model
(
model
:
nn
.
Module
)
->
nn
.
Module
:
for
module
in
model
.
modules
():
if
isinstance
(
module
,
XOPTAttention
):
module
.
__class__
=
OPTAttention
return
model
applications/Chat/coati/models/__init__.py
0 → 100644
View file @
7bc5a8e3
from
.base
import
Actor
,
Critic
,
RewardModel
from
.lora
import
LoRAModule
,
convert_to_lora_module
from
.loss
import
LogExpLoss
,
LogSigLoss
,
PolicyLoss
,
PPOPtxActorLoss
,
ValueLoss
__all__
=
[
'Actor'
,
'Critic'
,
'RewardModel'
,
'PolicyLoss'
,
'ValueLoss'
,
'PPOPtxActorLoss'
,
'LogSigLoss'
,
'LogExpLoss'
,
'LoRAModule'
,
'convert_to_lora_module'
]
applications/Chat/coati/models/base/__init__.py
0 → 100644
View file @
7bc5a8e3
import
torch.nn
as
nn
from
.actor
import
Actor
from
.critic
import
Critic
from
.reward_model
import
RewardModel
def
get_base_model
(
model
:
nn
.
Module
)
->
nn
.
Module
:
"""Get the base model of our wrapper classes.
For Actor, it's base model is ``actor.model`` and it's usually a ``transformers.PreTrainedModel``.
For Critic and RewardModel, it's base model is itself.
Args:
model (nn.Module): model to get base model from
Returns:
nn.Module: the base model
"""
if
isinstance
(
model
,
Actor
):
return
model
.
get_base_model
()
return
model
__all__
=
[
'Actor'
,
'Critic'
,
'RewardModel'
,
'get_base_model'
]
applications/Chat/coati/models/base/actor.py
0 → 100644
View file @
7bc5a8e3
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
..generation
import
generate
from
..lora
import
LoRAModule
from
..utils
import
log_probs_from_logits
class
Actor
(
LoRAModule
):
"""
Actor model base class.
Args:
model (nn.Module): Actor Model.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
model
:
nn
.
Module
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
super
().
__init__
(
lora_rank
=
lora_rank
,
lora_train_bias
=
lora_train_bias
)
self
.
model
=
model
self
.
convert_to_lora
()
@
torch
.
no_grad
()
def
generate
(
self
,
input_ids
:
torch
.
Tensor
,
return_action_mask
:
bool
=
True
,
**
kwargs
)
->
Union
[
Tuple
[
torch
.
LongTensor
,
torch
.
LongTensor
],
Tuple
[
torch
.
LongTensor
,
torch
.
LongTensor
,
torch
.
BoolTensor
]]:
sequences
=
generate
(
self
.
model
,
input_ids
,
**
kwargs
)
attention_mask
=
None
pad_token_id
=
kwargs
.
get
(
'pad_token_id'
,
None
)
if
pad_token_id
is
not
None
:
attention_mask
=
sequences
.
not_equal
(
pad_token_id
).
to
(
dtype
=
torch
.
long
,
device
=
sequences
.
device
)
if
not
return_action_mask
:
return
sequences
,
attention_mask
,
None
input_len
=
input_ids
.
size
(
1
)
eos_token_id
=
kwargs
.
get
(
'eos_token_id'
,
None
)
if
eos_token_id
is
None
:
action_mask
=
torch
.
ones_like
(
sequences
,
dtype
=
torch
.
bool
)
else
:
# left padding may be applied, only mask action
action_mask
=
(
sequences
[:,
input_len
:]
==
eos_token_id
).
cumsum
(
dim
=-
1
)
==
0
action_mask
=
F
.
pad
(
action_mask
,
(
1
+
input_len
,
-
1
),
value
=
True
)
# include eos token and input
action_mask
[:,
:
input_len
]
=
False
action_mask
=
action_mask
[:,
1
:]
return
sequences
,
attention_mask
,
action_mask
[:,
-
(
sequences
.
size
(
1
)
-
input_len
):]
def
forward
(
self
,
sequences
:
torch
.
LongTensor
,
num_actions
:
int
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""Returns action log probs
"""
output
=
self
.
model
(
sequences
,
attention_mask
=
attention_mask
)
logits
=
output
[
'logits'
]
log_probs
=
log_probs_from_logits
(
logits
[:,
:
-
1
,
:],
sequences
[:,
1
:])
return
log_probs
[:,
-
num_actions
:]
def
get_base_model
(
self
):
return
self
.
model
applications/Chat/coati/models/base/critic.py
0 → 100644
View file @
7bc5a8e3
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
..lora
import
LoRAModule
from
..utils
import
masked_mean
class
Critic
(
LoRAModule
):
"""
Critic model base class.
Args:
model (nn.Module): Critic model.
value_head (nn.Module): Value head to get value.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
model
:
nn
.
Module
,
value_head
:
nn
.
Module
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
,
use_action_mask
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
lora_rank
=
lora_rank
,
lora_train_bias
=
lora_train_bias
)
self
.
model
=
model
self
.
value_head
=
value_head
self
.
use_action_mask
=
use_action_mask
self
.
convert_to_lora
()
def
forward
(
self
,
sequences
:
torch
.
LongTensor
,
action_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
outputs
=
self
.
model
(
sequences
,
attention_mask
=
attention_mask
)
last_hidden_states
=
outputs
[
'last_hidden_state'
]
values
=
self
.
value_head
(
last_hidden_states
).
squeeze
(
-
1
)
if
action_mask
is
not
None
and
self
.
use_action_mask
:
num_actions
=
action_mask
.
size
(
1
)
prompt_mask
=
attention_mask
[:,
:
-
num_actions
]
values
=
values
[:,
:
-
num_actions
]
value
=
masked_mean
(
values
,
prompt_mask
,
dim
=
1
)
return
value
values
=
values
[:,
:
-
1
]
value
=
values
.
mean
(
dim
=
1
)
return
value
applications/Chat/coati/models/base/reward_model.py
0 → 100644
View file @
7bc5a8e3
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
..lora
import
LoRAModule
class
RewardModel
(
LoRAModule
):
"""
Reward model base class.
Args:
model (nn.Module): Reward model.
value_head (nn.Module): Value head to get reward score.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
model
:
nn
.
Module
,
value_head
:
Optional
[
nn
.
Module
]
=
None
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
super
().
__init__
(
lora_rank
=
lora_rank
,
lora_train_bias
=
lora_train_bias
)
self
.
model
=
model
self
.
convert_to_lora
()
if
value_head
is
not
None
:
if
value_head
.
out_features
!=
1
:
raise
ValueError
(
"The value head of reward model's output dim should be 1!"
)
self
.
value_head
=
value_head
else
:
self
.
value_head
=
nn
.
Linear
(
model
.
config
.
n_embd
,
1
)
def
forward
(
self
,
sequences
:
torch
.
LongTensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
outputs
=
self
.
model
(
sequences
,
attention_mask
=
attention_mask
)
last_hidden_states
=
outputs
[
'last_hidden_state'
]
values
=
self
.
value_head
(
last_hidden_states
)[:,
:
-
1
]
value
=
values
.
mean
(
dim
=
1
).
squeeze
(
1
)
# ensure shape is (B)
return
value
applications/Chat/coati/models/bloom/__init__.py
0 → 100644
View file @
7bc5a8e3
from
.bloom_actor
import
BLOOMActor
from
.bloom_critic
import
BLOOMCritic
from
.bloom_rm
import
BLOOMRM
__all__
=
[
'BLOOMActor'
,
'BLOOMCritic'
,
'BLOOMRM'
]
applications/Chat/coati/models/bloom/bloom_actor.py
0 → 100644
View file @
7bc5a8e3
from
typing
import
Optional
import
torch
from
transformers
import
BloomConfig
,
BloomForCausalLM
,
BloomModel
from
..base
import
Actor
class
BLOOMActor
(
Actor
):
"""
BLOOM Actor model.
Args:
pretrained (str): Pretrained model name or path.
config (BloomConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
pretrained
:
str
=
None
,
config
:
Optional
[
BloomConfig
]
=
None
,
checkpoint
:
bool
=
False
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
if
pretrained
is
not
None
:
model
=
BloomForCausalLM
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
model
=
BloomForCausalLM
(
config
)
else
:
model
=
BloomForCausalLM
(
BloomConfig
())
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
super
().
__init__
(
model
,
lora_rank
,
lora_train_bias
)
applications/Chat/coati/models/bloom/bloom_critic.py
0 → 100644
View file @
7bc5a8e3
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
transformers
import
BloomConfig
,
BloomForCausalLM
,
BloomModel
from
..base
import
Critic
class
BLOOMCritic
(
Critic
):
"""
BLOOM Critic model.
Args:
pretrained (str): Pretrained model name or path.
config (BloomConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
pretrained
:
str
=
None
,
config
:
Optional
[
BloomConfig
]
=
None
,
checkpoint
:
bool
=
False
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
,
**
kwargs
)
->
None
:
if
pretrained
is
not
None
:
model
=
BloomModel
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
model
=
BloomModel
(
config
)
else
:
model
=
BloomModel
(
BloomConfig
())
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
value_head
=
nn
.
Linear
(
model
.
config
.
hidden_size
,
1
)
super
().
__init__
(
model
,
value_head
,
lora_rank
,
lora_train_bias
,
**
kwargs
)
applications/Chat/coati/models/bloom/bloom_rm.py
0 → 100644
View file @
7bc5a8e3
from
typing
import
Optional
import
torch.nn
as
nn
from
transformers
import
BloomConfig
,
BloomForCausalLM
,
BloomModel
from
..base
import
RewardModel
class
BLOOMRM
(
RewardModel
):
"""
BLOOM Reward model.
Args:
pretrained (str): Pretrained model name or path.
config (BloomConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
pretrained
:
str
=
None
,
config
:
Optional
[
BloomConfig
]
=
None
,
checkpoint
:
bool
=
False
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
if
pretrained
is
not
None
:
model
=
BloomModel
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
model
=
BloomModel
(
config
)
else
:
model
=
BloomModel
(
BloomConfig
())
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
value_head
=
nn
.
Linear
(
model
.
config
.
hidden_size
,
1
)
value_head
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
1
/
(
model
.
config
.
hidden_size
+
1
))
super
().
__init__
(
model
,
value_head
,
lora_rank
,
lora_train_bias
)
Prev
1
2
3
4
5
6
7
8
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment