Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
aaeb520c
Unverified
Commit
aaeb520c
authored
Sep 04, 2023
by
yingliu-hpc
Committed by
GitHub
Sep 04, 2023
Browse files
Merge pull request #4542 from hpcaitech/chatglm
[coati] Add chatglm in coati
parents
8d7b0229
9f852f24
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
2165 additions
and
42 deletions
+2165
-42
.github/workflows/run_chatgpt_examples.yml
.github/workflows/run_chatgpt_examples.yml
+1
-2
.github/workflows/run_chatgpt_unit_tests.yml
.github/workflows/run_chatgpt_unit_tests.yml
+1
-2
applications/Chat/coati/dataset/sft_dataset.py
applications/Chat/coati/dataset/sft_dataset.py
+63
-12
applications/Chat/coati/models/chatglm/__init__.py
applications/Chat/coati/models/chatglm/__init__.py
+3
-0
applications/Chat/coati/models/chatglm/chatglm_actor.py
applications/Chat/coati/models/chatglm/chatglm_actor.py
+34
-0
applications/Chat/coati/models/chatglm/chatglm_tokenizer.py
applications/Chat/coati/models/chatglm/chatglm_tokenizer.py
+446
-0
applications/Chat/coati/models/chatglm/configuration_chatglm.py
...ations/Chat/coati/models/chatglm/configuration_chatglm.py
+107
-0
applications/Chat/coati/models/chatglm/modeling_chatglm.py
applications/Chat/coati/models/chatglm/modeling_chatglm.py
+1439
-0
applications/Chat/coati/trainer/sft.py
applications/Chat/coati/trainer/sft.py
+7
-3
applications/Chat/examples/requirements.txt
applications/Chat/examples/requirements.txt
+1
-0
applications/Chat/examples/train_sft.py
applications/Chat/examples/train_sft.py
+9
-3
applications/Chat/requirements-test.txt
applications/Chat/requirements-test.txt
+1
-0
applications/Chat/requirements.txt
applications/Chat/requirements.txt
+1
-1
applications/Chat/tests/test_dataset.py
applications/Chat/tests/test_dataset.py
+26
-5
applications/Chat/tests/test_models.py
applications/Chat/tests/test_models.py
+26
-14
No files found.
.github/workflows/run_chatgpt_examples.yml
View file @
aaeb520c
...
@@ -28,9 +28,8 @@ jobs:
...
@@ -28,9 +28,8 @@ jobs:
-
name
:
Checkout ColossalAI
-
name
:
Checkout ColossalAI
uses
:
actions/checkout@v2
uses
:
actions/checkout@v2
-
name
:
Install
ColossalAI and
ChatGPT
-
name
:
Install ChatGPT
run
:
|
run
:
|
pip install -e .
cd applications/Chat
cd applications/Chat
pip install -v .
pip install -v .
pip install -r examples/requirements.txt
pip install -r examples/requirements.txt
...
...
.github/workflows/run_chatgpt_unit_tests.yml
View file @
aaeb520c
...
@@ -30,9 +30,8 @@ jobs:
...
@@ -30,9 +30,8 @@ jobs:
-
name
:
Checkout ColossalAI
-
name
:
Checkout ColossalAI
uses
:
actions/checkout@v2
uses
:
actions/checkout@v2
-
name
:
Install
ColossalAI and
ChatGPT
-
name
:
Install ChatGPT
run
:
|
run
:
|
pip install -e .
cd applications/Chat
cd applications/Chat
pip install -v .
pip install -v .
pip install -r requirements-test.txt
pip install -r requirements-test.txt
...
...
applications/Chat/coati/dataset/sft_dataset.py
View file @
aaeb520c
...
@@ -19,7 +19,7 @@ import torch
...
@@ -19,7 +19,7 @@ import torch
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
from
coati.models.chatglm.chatglm_tokenizer
import
ChatGLMTokenizer
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
.utils
import
is_rank_0
,
jload
from
.utils
import
is_rank_0
,
jload
...
@@ -71,6 +71,42 @@ def _preprocess(sources: Sequence[str],
...
@@ -71,6 +71,42 @@ def _preprocess(sources: Sequence[str],
return
sequences_token
[
"input_ids"
],
labels
,
sequences_token
[
"attention_mask"
]
return
sequences_token
[
"input_ids"
],
labels
,
sequences_token
[
"attention_mask"
]
def
_preprocess_chatglm
(
sources
:
Sequence
[
str
],
targets
:
Sequence
[
str
],
tokenizer
:
PreTrainedTokenizer
,
max_length
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Preprocess the data by tokenizing.
None for attention mask, ChatGLM will calculate attention mask according to input ids
"""
labels
=
[]
input_ids
=
[]
for
source
,
target
in
zip
(
sources
,
targets
):
source_id
=
tokenizer
.
encode
(
text
=
source
,
add_special_tokens
=
False
)
target_id
=
tokenizer
.
encode
(
text
=
target
,
add_special_tokens
=
False
)
input_id
=
tokenizer
.
build_inputs_with_special_tokens
(
source_id
,
target_id
)
# truncate
sp_token_list
=
[
tokenizer
.
gmask_token_id
,
tokenizer
.
bos_token_id
]
truncate_length
=
max
(
0
,
len
(
input_id
)
-
max_length
)
input_id
=
input_id
[
truncate_length
:
]
if
truncate_length
==
len
(
source_id
)
+
1
:
input_id
=
sp_token_list
+
input_id
[
1
:
]
elif
truncate_length
>
len
(
source_id
)
+
1
:
input_id
=
sp_token_list
+
input_id
[
2
:
]
context_length
=
input_id
.
index
(
tokenizer
.
bos_token_id
)
mask_position
=
context_length
-
1
label
=
[
IGNORE_INDEX
]
*
context_length
+
input_id
[
mask_position
+
1
:]
pad_len
=
max_length
-
len
(
input_id
)
input_id
=
input_id
+
[
tokenizer
.
pad_token_id
]
*
pad_len
input_ids
.
append
(
input_id
)
labels
.
append
(
label
+
[
IGNORE_INDEX
]
*
pad_len
)
return
torch
.
tensor
(
input_ids
),
torch
.
tensor
(
labels
),
None
class
SFTDataset
(
Dataset
):
class
SFTDataset
(
Dataset
):
"""
"""
Dataset for sft model
Dataset for sft model
...
@@ -94,18 +130,25 @@ class SFTDataset(Dataset):
...
@@ -94,18 +130,25 @@ class SFTDataset(Dataset):
data
[
"completion"
]
+
tokenizer
.
eos_token
data
[
"completion"
]
+
tokenizer
.
eos_token
for
data
in
tqdm
(
dataset
,
disable
=
not
is_rank_0
())
for
data
in
tqdm
(
dataset
,
disable
=
not
is_rank_0
())
]
]
if
isinstance
(
tokenizer
,
ChatGLMTokenizer
):
self
.
input_ids
,
self
.
labels
,
self
.
attention_mask
=
\
self
.
input_ids
,
self
.
labels
,
self
.
attention_mask
=
\
_preprocess
(
sources
,
targets
,
tokenizer
,
max_length
)
_preprocess_chatglm
(
sources
,
targets
,
tokenizer
,
max_length
)
else
:
self
.
input_ids
,
self
.
labels
,
self
.
attention_mask
=
\
_preprocess
(
sources
,
targets
,
tokenizer
,
max_length
)
def
__len__
(
self
):
def
__len__
(
self
):
length
=
self
.
input_ids
.
shape
[
0
]
length
=
self
.
input_ids
.
shape
[
0
]
return
length
return
length
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
return
dict
(
input_ids
=
self
.
input_ids
[
idx
],
if
self
.
attention_mask
is
not
None
:
labels
=
self
.
labels
[
idx
],
return
dict
(
input_ids
=
self
.
input_ids
[
idx
],
attention_mask
=
self
.
attention_mask
[
idx
])
labels
=
self
.
labels
[
idx
],
attention_mask
=
self
.
attention_mask
[
idx
])
else
:
return
dict
(
input_ids
=
self
.
input_ids
[
idx
],
labels
=
self
.
labels
[
idx
])
class
SupervisedDataset
(
Dataset
):
class
SupervisedDataset
(
Dataset
):
...
@@ -137,14 +180,22 @@ class SupervisedDataset(Dataset):
...
@@ -137,14 +180,22 @@ class SupervisedDataset(Dataset):
]
]
logger
.
info
(
"Tokenizing inputs... This may take some time..."
)
logger
.
info
(
"Tokenizing inputs... This may take some time..."
)
self
.
input_ids
,
self
.
labels
,
self
.
attention_mask
=
\
if
isinstance
(
tokenizer
,
ChatGLMTokenizer
):
_preprocess
(
sources
,
targets
,
tokenizer
,
max_length
)
self
.
input_ids
,
self
.
labels
,
self
.
attention_mask
=
\
_preprocess_chatglm
(
sources
,
targets
,
tokenizer
,
max_length
)
else
:
self
.
input_ids
,
self
.
labels
,
self
.
attention_mask
=
\
_preprocess
(
sources
,
targets
,
tokenizer
,
max_length
)
def
__len__
(
self
):
def
__len__
(
self
):
length
=
self
.
input_ids
.
shape
[
0
]
length
=
self
.
input_ids
.
shape
[
0
]
return
length
return
length
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
return
dict
(
input_ids
=
self
.
input_ids
[
idx
],
if
self
.
attention_mask
is
not
None
:
labels
=
self
.
labels
[
idx
],
return
dict
(
input_ids
=
self
.
input_ids
[
idx
],
attention_mask
=
self
.
attention_mask
[
idx
])
labels
=
self
.
labels
[
idx
],
attention_mask
=
self
.
attention_mask
[
idx
])
else
:
return
dict
(
input_ids
=
self
.
input_ids
[
idx
],
labels
=
self
.
labels
[
idx
])
applications/Chat/coati/models/chatglm/__init__.py
0 → 100644
View file @
aaeb520c
from
.chatglm_actor
import
ChatGLMActor
__all__
=
[
'ChatGLMActor'
]
\ No newline at end of file
applications/Chat/coati/models/chatglm/chatglm_actor.py
0 → 100644
View file @
aaeb520c
from
typing
import
Optional
import
torch
from
.configuration_chatglm
import
ChatGLMConfig
from
.modeling_chatglm
import
ChatGLMForConditionalGeneration
from
..base
import
Actor
class
ChatGLMActor
(
Actor
):
"""
ChatGLM Actor model.
Args:
pretrained (str): Pretrained model name or path.
config (ChatGLMConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
do not support lora for now.
"""
def
__init__
(
self
,
pretrained
:
str
=
None
,
config
:
Optional
[
ChatGLMConfig
]
=
None
,
checkpoint
:
bool
=
False
)
->
None
:
if
pretrained
is
not
None
:
model
=
ChatGLMForConditionalGeneration
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
model
=
ChatGLMForConditionalGeneration
(
config
)
else
:
model
=
ChatGLMForConditionalGeneration
(
ChatGLMConfig
())
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
super
().
__init__
(
model
,
lora_rank
=
0
,
lora_train_bias
=
'none'
)
applications/Chat/coati/models/chatglm/chatglm_tokenizer.py
0 → 100644
View file @
aaeb520c
"""
This code is copied from https://huggingface.co/THUDM/chatglm-6b/blob/main/tokenization_chatglm.py
"""
"""Tokenization classes for ChatGLM."""
from
typing
import
List
,
Optional
,
Union
import
os
from
transformers.tokenization_utils
import
PreTrainedTokenizer
from
transformers.utils
import
logging
,
PaddingStrategy
from
transformers.tokenization_utils_base
import
EncodedInput
,
BatchEncoding
from
typing
import
Dict
import
sentencepiece
as
spm
import
numpy
as
np
logger
=
logging
.
get_logger
(
__name__
)
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
"THUDM/chatglm-6b"
:
2048
,
}
class
TextTokenizer
:
def
__init__
(
self
,
model_path
):
self
.
sp
=
spm
.
SentencePieceProcessor
()
self
.
sp
.
Load
(
model_path
)
self
.
num_tokens
=
self
.
sp
.
vocab_size
()
def
encode
(
self
,
text
):
return
self
.
sp
.
EncodeAsIds
(
text
)
def
decode
(
self
,
ids
:
List
[
int
]):
return
self
.
sp
.
DecodeIds
(
ids
)
def
tokenize
(
self
,
text
):
return
self
.
sp
.
EncodeAsPieces
(
text
)
def
convert_tokens_to_string
(
self
,
tokens
):
return
self
.
sp
.
DecodePieces
(
tokens
)
def
convert_tokens_to_ids
(
self
,
tokens
):
return
[
self
.
sp
.
PieceToId
(
token
)
for
token
in
tokens
]
def
convert_token_to_id
(
self
,
token
):
return
self
.
sp
.
PieceToId
(
token
)
def
convert_id_to_token
(
self
,
idx
):
return
self
.
sp
.
IdToPiece
(
idx
)
def
__len__
(
self
):
return
self
.
num_tokens
class
SPTokenizer
:
def
__init__
(
self
,
vocab_file
,
num_image_tokens
=
20000
,
max_blank_length
=
80
,
byte_fallback
=
True
,
):
assert
vocab_file
is
not
None
self
.
vocab_file
=
vocab_file
self
.
num_image_tokens
=
num_image_tokens
self
.
special_tokens
=
[
"[MASK]"
,
"[gMASK]"
,
"[sMASK]"
,
"<unused_0>"
,
"<sop>"
,
"<eop>"
,
"<ENC>"
,
"<dBLOCK>"
]
self
.
max_blank_length
=
max_blank_length
self
.
byte_fallback
=
byte_fallback
self
.
text_tokenizer
=
TextTokenizer
(
vocab_file
)
def
_get_text_tokenizer
(
self
):
return
self
.
text_tokenizer
@
staticmethod
def
get_blank_token
(
length
:
int
):
assert
length
>=
2
return
f
"<|blank_
{
length
}
|>"
@
staticmethod
def
get_tab_token
():
return
f
"<|tab|>"
@
property
def
num_text_tokens
(
self
):
return
self
.
text_tokenizer
.
num_tokens
@
property
def
num_tokens
(
self
):
return
self
.
num_image_tokens
+
self
.
num_text_tokens
@
staticmethod
def
_encode_whitespaces
(
text
:
str
,
max_len
:
int
=
80
):
text
=
text
.
replace
(
"
\t
"
,
SPTokenizer
.
get_tab_token
())
for
i
in
range
(
max_len
,
1
,
-
1
):
text
=
text
.
replace
(
" "
*
i
,
SPTokenizer
.
get_blank_token
(
i
))
return
text
def
_preprocess
(
self
,
text
:
str
,
linebreak
=
True
,
whitespaces
=
True
):
if
linebreak
:
text
=
text
.
replace
(
"
\n
"
,
"<n>"
)
if
whitespaces
:
text
=
self
.
_encode_whitespaces
(
text
,
max_len
=
self
.
max_blank_length
)
return
text
def
encode
(
self
,
text
:
str
,
linebreak
=
True
,
whitespaces
=
True
,
add_dummy_prefix
=
True
)
->
List
[
int
]:
"""
@param text: Text to encode.
@param linebreak: Whether to encode newline (
\n
) in text.
@param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
@param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
@param add_dummy_prefix: Whether to add dummy blank space in the beginning.
"""
text
=
self
.
_preprocess
(
text
,
linebreak
,
whitespaces
)
if
not
add_dummy_prefix
:
text
=
"<n>"
+
text
tmp
=
self
.
_get_text_tokenizer
().
encode
(
text
)
tokens
=
[
x
+
self
.
num_image_tokens
for
x
in
tmp
]
return
tokens
if
add_dummy_prefix
else
tokens
[
2
:]
def
postprocess
(
self
,
text
):
text
=
text
.
replace
(
"<n>"
,
"
\n
"
)
text
=
text
.
replace
(
SPTokenizer
.
get_tab_token
(),
"
\t
"
)
for
i
in
range
(
2
,
self
.
max_blank_length
+
1
):
text
=
text
.
replace
(
self
.
get_blank_token
(
i
),
" "
*
i
)
return
text
def
decode
(
self
,
text_ids
:
List
[
int
])
->
str
:
ids
=
[
int
(
_id
)
-
self
.
num_image_tokens
for
_id
in
text_ids
]
ids
=
[
_id
for
_id
in
ids
if
_id
>=
0
]
text
=
self
.
_get_text_tokenizer
().
decode
(
ids
)
text
=
self
.
postprocess
(
text
)
return
text
def
decode_tokens
(
self
,
tokens
:
List
[
str
])
->
str
:
text
=
self
.
_get_text_tokenizer
().
convert_tokens_to_string
(
tokens
)
text
=
self
.
postprocess
(
text
)
return
text
def
tokenize
(
self
,
text
:
str
,
linebreak
=
True
,
whitespaces
=
True
,
add_dummy_prefix
=
True
)
->
List
[
str
]:
"""
@param text: Text to encode.
@param linebreak: Whether to encode newline (
\n
) in text.
@param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
@param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
@param add_dummy_prefix: Whether to add dummy blank space in the beginning.
"""
text
=
self
.
_preprocess
(
text
,
linebreak
,
whitespaces
)
if
not
add_dummy_prefix
:
text
=
"<n>"
+
text
tokens
=
self
.
_get_text_tokenizer
().
tokenize
(
text
)
return
tokens
if
add_dummy_prefix
else
tokens
[
2
:]
def
__getitem__
(
self
,
x
:
Union
[
int
,
str
]):
if
isinstance
(
x
,
int
):
if
x
<
self
.
num_image_tokens
:
return
"<image_{}>"
.
format
(
x
)
else
:
return
self
.
text_tokenizer
.
convert_id_to_token
(
x
-
self
.
num_image_tokens
)
elif
isinstance
(
x
,
str
):
if
x
.
startswith
(
"<image_"
)
and
x
.
endswith
(
">"
)
and
x
[
7
:
-
1
].
isdigit
():
return
int
(
x
[
7
:
-
1
])
else
:
return
self
.
text_tokenizer
.
convert_token_to_id
(
x
)
+
self
.
num_image_tokens
else
:
raise
ValueError
(
"The key should be str or int."
)
class
ChatGLMTokenizer
(
PreTrainedTokenizer
):
"""
Construct a ChatGLM tokenizer. Based on byte-level Byte-Pair-Encoding.
Args:
vocab_file (`str`):
Path to the vocabulary file.
"""
vocab_files_names
=
{
"vocab_file"
:
"ice_text.model"
}
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names
=
[
"input_ids"
,
"attention_mask"
,
"position_ids"
]
def
__init__
(
self
,
vocab_file
,
do_lower_case
=
False
,
remove_space
=
False
,
bos_token
=
'<sop>'
,
eos_token
=
'<eop>'
,
end_token
=
'</s>'
,
mask_token
=
'[MASK]'
,
gmask_token
=
'[gMASK]'
,
padding_side
=
"left"
,
pad_token
=
"<pad>"
,
unk_token
=
"<unk>"
,
num_image_tokens
=
20000
,
**
kwargs
)
->
None
:
super
().
__init__
(
do_lower_case
=
do_lower_case
,
remove_space
=
remove_space
,
padding_side
=
padding_side
,
bos_token
=
bos_token
,
eos_token
=
eos_token
,
end_token
=
end_token
,
mask_token
=
mask_token
,
gmask_token
=
gmask_token
,
pad_token
=
pad_token
,
unk_token
=
unk_token
,
num_image_tokens
=
num_image_tokens
,
**
kwargs
)
self
.
do_lower_case
=
do_lower_case
self
.
remove_space
=
remove_space
self
.
vocab_file
=
vocab_file
self
.
bos_token
=
bos_token
self
.
eos_token
=
eos_token
self
.
end_token
=
end_token
self
.
mask_token
=
mask_token
self
.
gmask_token
=
gmask_token
self
.
sp_tokenizer
=
SPTokenizer
(
vocab_file
,
num_image_tokens
=
num_image_tokens
)
""" Initialisation """
@
property
def
gmask_token_id
(
self
)
->
Optional
[
int
]:
if
self
.
gmask_token
is
None
:
return
None
return
self
.
convert_tokens_to_ids
(
self
.
gmask_token
)
@
property
def
end_token_id
(
self
)
->
Optional
[
int
]:
"""
`Optional[int]`: Id of the end of context token in the vocabulary. Returns `None` if the token has not been
set.
"""
if
self
.
end_token
is
None
:
return
None
return
self
.
convert_tokens_to_ids
(
self
.
end_token
)
@
property
def
vocab_size
(
self
):
""" Returns vocab size """
return
self
.
sp_tokenizer
.
num_tokens
def
get_vocab
(
self
):
""" Returns vocab as a dict """
vocab
=
{
self
.
_convert_id_to_token
(
i
):
i
for
i
in
range
(
self
.
vocab_size
)}
vocab
.
update
(
self
.
added_tokens_encoder
)
return
vocab
def
preprocess_text
(
self
,
inputs
):
if
self
.
remove_space
:
outputs
=
" "
.
join
(
inputs
.
strip
().
split
())
else
:
outputs
=
inputs
if
self
.
do_lower_case
:
outputs
=
outputs
.
lower
()
return
outputs
def
_tokenize
(
self
,
text
,
**
kwargs
):
""" Returns a tokenized string. """
text
=
self
.
preprocess_text
(
text
)
seq
=
self
.
sp_tokenizer
.
tokenize
(
text
)
return
seq
def
convert_tokens_to_string
(
self
,
tokens
:
List
[
str
])
->
str
:
return
self
.
sp_tokenizer
.
decode_tokens
(
tokens
)
def
_decode
(
self
,
token_ids
:
Union
[
int
,
List
[
int
]],
**
kwargs
)
->
str
:
if
isinstance
(
token_ids
,
int
):
token_ids
=
[
token_ids
]
if
len
(
token_ids
)
==
0
:
return
""
if
self
.
pad_token_id
in
token_ids
:
# remove pad
token_ids
=
list
(
filter
((
self
.
pad_token_id
).
__ne__
,
token_ids
))
return
super
().
_decode
(
token_ids
,
**
kwargs
)
def
_convert_token_to_id
(
self
,
token
):
""" Converts a token (str) in an id using the vocab. """
return
self
.
sp_tokenizer
[
token
]
def
_convert_id_to_token
(
self
,
index
):
"""Converts an index (integer) in a token (str) using the vocab."""
return
self
.
sp_tokenizer
[
index
]
def
save_vocabulary
(
self
,
save_directory
,
filename_prefix
=
None
):
"""
Save the vocabulary and special tokens file to a directory.
Args:
save_directory (`str`):
The directory in which to save the vocabulary.
filename_prefix (`str`, *optional*):
An optional prefix to add to the named of the saved files.
Returns:
`Tuple(str)`: Paths to the files saved.
"""
if
os
.
path
.
isdir
(
save_directory
):
vocab_file
=
os
.
path
.
join
(
save_directory
,
self
.
vocab_files_names
[
"vocab_file"
]
)
else
:
vocab_file
=
save_directory
with
open
(
self
.
vocab_file
,
'rb'
)
as
fin
:
proto_str
=
fin
.
read
()
with
open
(
vocab_file
,
"wb"
)
as
writer
:
writer
.
write
(
proto_str
)
return
(
vocab_file
,)
def
build_inputs_with_special_tokens
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A BERT sequence has the following format:
- single sequence: `[CLS] X [SEP]`
- pair of sequences: `[CLS] A [SEP] B [SEP]`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
gmask_id
=
self
.
sp_tokenizer
[
self
.
gmask_token
]
eos_id
=
self
.
sp_tokenizer
[
self
.
eos_token
]
token_ids_0
=
token_ids_0
+
[
gmask_id
,
self
.
sp_tokenizer
[
self
.
bos_token
]]
if
token_ids_1
is
not
None
:
token_ids_0
=
token_ids_0
+
token_ids_1
return
token_ids_0
def
_pad
(
self
,
encoded_inputs
:
Union
[
Dict
[
str
,
EncodedInput
],
BatchEncoding
],
max_length
:
Optional
[
int
]
=
None
,
padding_strategy
:
PaddingStrategy
=
PaddingStrategy
.
DO_NOT_PAD
,
pad_to_multiple_of
:
Optional
[
int
]
=
None
,
return_attention_mask
:
Optional
[
bool
]
=
None
,
)
->
dict
:
"""
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
Args:
encoded_inputs:
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
max_length: maximum length of the returned list and optionally padding length (see below).
Will truncate by taking into account the special tokens.
padding_strategy: PaddingStrategy to use for padding.
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
- PaddingStrategy.DO_NOT_PAD: Do not pad
The tokenizer padding sides are defined in self.padding_side:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
`>= 7.5` (Volta).
return_attention_mask:
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
# Load from model defaults
bos_token_id
=
self
.
sp_tokenizer
[
self
.
bos_token
]
mask_token_id
=
self
.
sp_tokenizer
[
self
.
mask_token
]
gmask_token_id
=
self
.
sp_tokenizer
[
self
.
gmask_token
]
assert
self
.
padding_side
==
"left"
required_input
=
encoded_inputs
[
self
.
model_input_names
[
0
]]
seq_length
=
len
(
required_input
)
if
padding_strategy
==
PaddingStrategy
.
LONGEST
:
max_length
=
len
(
required_input
)
if
max_length
is
not
None
and
pad_to_multiple_of
is
not
None
and
(
max_length
%
pad_to_multiple_of
!=
0
):
max_length
=
((
max_length
//
pad_to_multiple_of
)
+
1
)
*
pad_to_multiple_of
needs_to_be_padded
=
padding_strategy
!=
PaddingStrategy
.
DO_NOT_PAD
and
len
(
required_input
)
!=
max_length
# Initialize attention mask if not present.
if
max_length
is
not
None
:
if
"attention_mask"
not
in
encoded_inputs
:
if
bos_token_id
in
required_input
:
context_length
=
required_input
.
index
(
bos_token_id
)
else
:
context_length
=
seq_length
attention_mask
=
np
.
ones
((
1
,
seq_length
,
seq_length
))
attention_mask
=
np
.
tril
(
attention_mask
)
attention_mask
[:,
:,
:
context_length
]
=
1
attention_mask
=
np
.
bool_
(
attention_mask
<
0.5
)
encoded_inputs
[
"attention_mask"
]
=
attention_mask
if
"position_ids"
not
in
encoded_inputs
:
if
bos_token_id
in
required_input
:
context_length
=
required_input
.
index
(
bos_token_id
)
else
:
context_length
=
seq_length
position_ids
=
np
.
arange
(
seq_length
,
dtype
=
np
.
int64
)
mask_token
=
mask_token_id
if
mask_token_id
in
required_input
else
gmask_token_id
if
mask_token
in
required_input
:
mask_position
=
required_input
.
index
(
mask_token
)
position_ids
[
context_length
:]
=
mask_position
block_position_ids
=
np
.
concatenate
(
[
np
.
zeros
(
context_length
,
dtype
=
np
.
int64
),
np
.
arange
(
1
,
seq_length
-
context_length
+
1
,
dtype
=
np
.
int64
)])
encoded_inputs
[
"position_ids"
]
=
np
.
stack
([
position_ids
,
block_position_ids
],
axis
=
0
)
if
needs_to_be_padded
:
difference
=
max_length
-
len
(
required_input
)
if
"attention_mask"
in
encoded_inputs
:
encoded_inputs
[
"attention_mask"
]
=
np
.
pad
(
encoded_inputs
[
"attention_mask"
],
pad_width
=
[(
0
,
0
),
(
difference
,
0
),
(
difference
,
0
)],
mode
=
'constant'
,
constant_values
=
True
)
if
"token_type_ids"
in
encoded_inputs
:
encoded_inputs
[
"token_type_ids"
]
=
[
self
.
pad_token_type_id
]
*
difference
+
encoded_inputs
[
"token_type_ids"
]
if
"special_tokens_mask"
in
encoded_inputs
:
encoded_inputs
[
"special_tokens_mask"
]
=
[
1
]
*
difference
+
encoded_inputs
[
"special_tokens_mask"
]
if
"position_ids"
in
encoded_inputs
:
encoded_inputs
[
"position_ids"
]
=
np
.
pad
(
encoded_inputs
[
"position_ids"
],
pad_width
=
[(
0
,
0
),
(
difference
,
0
)])
encoded_inputs
[
self
.
model_input_names
[
0
]]
=
[
self
.
pad_token_id
]
*
difference
+
required_input
return
encoded_inputs
\ No newline at end of file
applications/Chat/coati/models/chatglm/configuration_chatglm.py
0 → 100644
View file @
aaeb520c
"""
This code is copied from https://huggingface.co/THUDM/chatglm-6b/resolve/main/configuration_chatglm.py
"""
""" ChatGLM model configuration """
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
class
ChatGLMConfig
(
PretrainedConfig
):
r
"""
This is the configuration class to store the configuration of a [`~ChatGLMModel`].
It is used to instantiate an ChatGLM model according to the specified arguments, defining the model
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
the ChatGLM-6B [THUDM/ChatGLM-6B](https://huggingface.co/THUDM/chatglm-6b) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used
to control the model outputs. Read the documentation from [`PretrainedConfig`]
for more information.
Args:
vocab_size (`int`, *optional*, defaults to 150528):
Vocabulary size of the ChatGLM-6B model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`~ChatGLMModel`] or
[`~TFChatGLMModel`].
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 28):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
inner_hidden_size (`int`, *optional*, defaults to 16384):
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
max_sequence_length (`int`, *optional*, defaults to 512):
The maximum sequence length that this model might ever be used with.
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
layernorm_epsilon (`float`, *optional*, defaults to 1e-5):
The epsilon used by the layer normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether the model should return the last key/values attentions (not used by all models).
Example:
```python
>>> from configuration_chatglm import ChatGLMConfig
>>> from modeling_chatglm import ChatGLMModel
>>> # Initializing a ChatGLM-6B THUDM/ChatGLM-6B style configuration
>>> configuration = ChatGLMConfig()
>>> # Initializing a model from the THUDM/ChatGLM-6B style configuration
>>> model = ChatGLMModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type
=
"chatglm"
def
__init__
(
self
,
vocab_size
=
130528
,
hidden_size
=
4096
,
num_layers
=
28
,
num_attention_heads
=
32
,
layernorm_epsilon
=
1e-5
,
use_cache
=
True
,
bos_token_id
=
130004
,
eos_token_id
=
130005
,
mask_token_id
=
130000
,
gmask_token_id
=
130001
,
pad_token_id
=
3
,
max_sequence_length
=
2048
,
inner_hidden_size
=
16384
,
position_encoding_2d
=
True
,
quantization_bit
=
0
,
pre_seq_len
=
None
,
prefix_projection
=
False
,
**
kwargs
):
self
.
num_layers
=
num_layers
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
num_attention_heads
=
num_attention_heads
self
.
max_sequence_length
=
max_sequence_length
self
.
layernorm_epsilon
=
layernorm_epsilon
self
.
inner_hidden_size
=
inner_hidden_size
self
.
use_cache
=
use_cache
self
.
bos_token_id
=
bos_token_id
self
.
eos_token_id
=
eos_token_id
self
.
pad_token_id
=
pad_token_id
self
.
mask_token_id
=
mask_token_id
self
.
gmask_token_id
=
gmask_token_id
self
.
position_encoding_2d
=
position_encoding_2d
self
.
quantization_bit
=
quantization_bit
self
.
pre_seq_len
=
pre_seq_len
self
.
prefix_projection
=
prefix_projection
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
**
kwargs
)
\ No newline at end of file
applications/Chat/coati/models/chatglm/modeling_chatglm.py
0 → 100644
View file @
aaeb520c
"""
This code is copied from https://huggingface.co/THUDM/chatglm-6b/resolve/main/modeling_chatglm.py
"""
""" PyTorch ChatGLM model. """
import
math
import
copy
import
os
import
warnings
import
re
import
sys
import
torch
import
torch.utils.checkpoint
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch.nn
import
CrossEntropyLoss
,
LayerNorm
from
torch.nn.utils
import
skip_init
from
typing
import
Optional
,
Tuple
,
Union
,
List
,
Callable
,
Dict
,
Any
from
transformers.utils
import
(
add_code_sample_docstrings
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
)
from
transformers.modeling_outputs
import
(
BaseModelOutputWithPast
,
CausalLMOutputWithPast
,
BaseModelOutputWithPastAndCrossAttentions
,
)
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.utils
import
logging
from
transformers.generation.logits_process
import
LogitsProcessor
from
transformers.generation.utils
import
LogitsProcessorList
,
StoppingCriteriaList
,
GenerationConfig
,
ModelOutput
from
.configuration_chatglm
import
ChatGLMConfig
# flags required to enable jit fusion kernels
if
sys
.
platform
!=
'darwin'
:
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
torch
.
_C
.
_jit_set_profiling_executor
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
logger
=
logging
.
get_logger
(
__name__
)
_CHECKPOINT_FOR_DOC
=
"THUDM/ChatGLM-6B"
_CONFIG_FOR_DOC
=
"ChatGLM6BConfig"
CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST
=
[
"THUDM/chatglm-6b"
,
# See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm
]
class
InvalidScoreLogitsProcessor
(
LogitsProcessor
):
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
if
torch
.
isnan
(
scores
).
any
()
or
torch
.
isinf
(
scores
).
any
():
scores
.
zero_
()
scores
[...,
5
]
=
5e4
return
scores
def
load_tf_weights_in_chatglm_6b
(
model
,
config
,
tf_checkpoint_path
):
"""Load tf checkpoints in a pytorch model."""
try
:
import
re
import
numpy
as
np
import
tensorflow
as
tf
except
ImportError
:
logger
.
error
(
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
tf_path
=
os
.
path
.
abspath
(
tf_checkpoint_path
)
logger
.
info
(
f
"Converting TensorFlow checkpoint from
{
tf_path
}
"
)
# Load weights from TF model
init_vars
=
tf
.
train
.
list_variables
(
tf_path
)
names
=
[]
arrays
=
[]
for
name
,
shape
in
init_vars
:
logger
.
info
(
f
"Loading TF weight
{
name
}
with shape
{
shape
}
"
)
array
=
tf
.
train
.
load_variable
(
tf_path
,
name
)
names
.
append
(
name
)
arrays
.
append
(
array
)
for
name
,
array
in
zip
(
names
,
arrays
):
name
=
name
.
split
(
"/"
)
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if
any
(
n
in
[
"adam_v"
,
"adam_m"
,
"AdamWeightDecayOptimizer"
,
"AdamWeightDecayOptimizer_1"
,
"global_step"
]
for
n
in
name
):
logger
.
info
(
f
"Skipping
{
'/'
.
join
(
name
)
}
"
)
continue
pointer
=
model
for
m_name
in
name
:
if
re
.
fullmatch
(
r
"[A-Za-z]+_\d+"
,
m_name
):
scope_names
=
re
.
split
(
r
"_(\d+)"
,
m_name
)
else
:
scope_names
=
[
m_name
]
if
scope_names
[
0
]
==
"kernel"
or
scope_names
[
0
]
==
"gamma"
:
pointer
=
getattr
(
pointer
,
"weight"
)
elif
scope_names
[
0
]
==
"output_bias"
or
scope_names
[
0
]
==
"beta"
:
pointer
=
getattr
(
pointer
,
"bias"
)
elif
scope_names
[
0
]
==
"output_weights"
:
pointer
=
getattr
(
pointer
,
"weight"
)
elif
scope_names
[
0
]
==
"squad"
:
pointer
=
getattr
(
pointer
,
"classifier"
)
else
:
try
:
pointer
=
getattr
(
pointer
,
scope_names
[
0
])
except
AttributeError
:
logger
.
info
(
f
"Skipping
{
'/'
.
join
(
name
)
}
"
)
continue
if
len
(
scope_names
)
>=
2
:
num
=
int
(
scope_names
[
1
])
pointer
=
pointer
[
num
]
if
m_name
[
-
11
:]
==
"_embeddings"
:
pointer
=
getattr
(
pointer
,
"weight"
)
elif
m_name
==
"kernel"
:
array
=
np
.
transpose
(
array
)
try
:
assert
(
pointer
.
shape
==
array
.
shape
),
f
"Pointer shape
{
pointer
.
shape
}
and array shape
{
array
.
shape
}
mismatched"
except
AssertionError
as
e
:
e
.
args
+=
(
pointer
.
shape
,
array
.
shape
)
raise
logger
.
info
(
f
"Initialize PyTorch weight
{
name
}
"
)
pointer
.
data
=
torch
.
from_numpy
(
array
)
return
model
class
PrefixEncoder
(
torch
.
nn
.
Module
):
"""
The torch.nn model to encode the prefix
Input shape: (batch-size, prefix-length)
Output shape: (batch-size, prefix-length, 2*layers*hidden)
"""
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
prefix_projection
=
config
.
prefix_projection
if
self
.
prefix_projection
:
# Use a two-layer MLP to encode the prefix
self
.
embedding
=
torch
.
nn
.
Embedding
(
config
.
pre_seq_len
,
config
.
hidden_size
)
self
.
trans
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
),
torch
.
nn
.
Tanh
(),
torch
.
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_layers
*
config
.
hidden_size
*
2
)
)
else
:
self
.
embedding
=
torch
.
nn
.
Embedding
(
config
.
pre_seq_len
,
config
.
num_layers
*
config
.
hidden_size
*
2
)
def
forward
(
self
,
prefix
:
torch
.
Tensor
):
if
self
.
prefix_projection
:
prefix_tokens
=
self
.
embedding
(
prefix
)
past_key_values
=
self
.
trans
(
prefix_tokens
)
else
:
past_key_values
=
self
.
embedding
(
prefix
)
return
past_key_values
@
torch
.
jit
.
script
def
gelu_impl
(
x
):
"""OpenAI's gelu implementation."""
return
0.5
*
x
*
(
1.0
+
torch
.
tanh
(
0.7978845608028654
*
x
*
(
1.0
+
0.044715
*
x
*
x
)))
def
gelu
(
x
):
return
gelu_impl
(
x
)
class
RotaryEmbedding
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
base
=
10000
,
precision
=
torch
.
half
,
learnable
=
False
):
super
().
__init__
()
inv_freq
=
1.
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
()
/
dim
))
inv_freq
=
inv_freq
.
half
()
self
.
learnable
=
learnable
if
learnable
:
self
.
inv_freq
=
torch
.
nn
.
Parameter
(
inv_freq
)
self
.
max_seq_len_cached
=
None
else
:
self
.
register_buffer
(
'inv_freq'
,
inv_freq
)
self
.
max_seq_len_cached
=
None
self
.
cos_cached
=
None
self
.
sin_cached
=
None
self
.
precision
=
precision
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
pass
def
forward
(
self
,
x
,
seq_dim
=
1
,
seq_len
=
None
):
if
seq_len
is
None
:
seq_len
=
x
.
shape
[
seq_dim
]
if
self
.
max_seq_len_cached
is
None
or
(
seq_len
>
self
.
max_seq_len_cached
):
self
.
max_seq_len_cached
=
None
if
self
.
learnable
else
seq_len
t
=
torch
.
arange
(
seq_len
,
device
=
x
.
device
,
dtype
=
self
.
inv_freq
.
dtype
)
freqs
=
torch
.
einsum
(
'i,j->ij'
,
t
,
self
.
inv_freq
)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
).
to
(
x
.
device
)
if
self
.
precision
==
torch
.
bfloat16
:
emb
=
emb
.
float
()
# [sx, 1 (b * np), hn]
cos_cached
=
emb
.
cos
()[:,
None
,
:]
sin_cached
=
emb
.
sin
()[:,
None
,
:]
if
self
.
precision
==
torch
.
bfloat16
:
cos_cached
=
cos_cached
.
bfloat16
()
sin_cached
=
sin_cached
.
bfloat16
()
if
self
.
learnable
:
return
cos_cached
,
sin_cached
self
.
cos_cached
,
self
.
sin_cached
=
cos_cached
,
sin_cached
return
self
.
cos_cached
[:
seq_len
,
...],
self
.
sin_cached
[:
seq_len
,
...]
def
_apply
(
self
,
fn
):
if
self
.
cos_cached
is
not
None
:
self
.
cos_cached
=
fn
(
self
.
cos_cached
)
if
self
.
sin_cached
is
not
None
:
self
.
sin_cached
=
fn
(
self
.
sin_cached
)
return
super
().
_apply
(
fn
)
def
rotate_half
(
x
):
x1
,
x2
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
],
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=
x1
.
ndim
-
1
)
# dim=-1 triggers a bug in earlier torch versions
@
torch
.
jit
.
script
def
apply_rotary_pos_emb_index
(
q
,
k
,
cos
,
sin
,
position_id
):
# position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
cos
,
sin
=
F
.
embedding
(
position_id
,
cos
.
squeeze
(
1
)).
unsqueeze
(
2
),
\
F
.
embedding
(
position_id
,
sin
.
squeeze
(
1
)).
unsqueeze
(
2
)
q
,
k
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
),
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q
,
k
def
attention_fn
(
self
,
query_layer
,
key_layer
,
value_layer
,
attention_mask
,
hidden_size_per_partition
,
layer_id
,
layer_past
=
None
,
scaling_attention_score
=
True
,
use_cache
=
False
,
):
if
layer_past
is
not
None
:
past_key
,
past_value
=
layer_past
[
0
],
layer_past
[
1
]
key_layer
=
torch
.
cat
((
past_key
,
key_layer
),
dim
=
0
)
value_layer
=
torch
.
cat
((
past_value
,
value_layer
),
dim
=
0
)
# seqlen, batch, num_attention_heads, hidden_size_per_attention_head
seq_len
,
b
,
nh
,
hidden_size
=
key_layer
.
shape
if
use_cache
:
present
=
(
key_layer
,
value_layer
)
else
:
present
=
None
query_key_layer_scaling_coeff
=
float
(
layer_id
+
1
)
if
scaling_attention_score
:
query_layer
=
query_layer
/
(
math
.
sqrt
(
hidden_size
)
*
query_key_layer_scaling_coeff
)
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size
=
(
query_layer
.
size
(
1
),
query_layer
.
size
(
2
),
query_layer
.
size
(
0
),
key_layer
.
size
(
0
))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer
=
query_layer
.
view
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer
=
key_layer
.
view
(
output_size
[
3
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
matmul_result
=
torch
.
zeros
(
1
,
1
,
1
,
dtype
=
query_layer
.
dtype
,
device
=
query_layer
.
device
,
)
matmul_result
=
torch
.
baddbmm
(
matmul_result
,
query_layer
.
transpose
(
0
,
1
),
# [b * np, sq, hn]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
# [b * np, hn, sk]
beta
=
0.0
,
alpha
=
1.0
,
)
# change view to [b, np, sq, sk]
attention_scores
=
matmul_result
.
view
(
*
output_size
)
if
self
.
scale_mask_softmax
:
self
.
scale_mask_softmax
.
scale
=
query_key_layer_scaling_coeff
attention_probs
=
self
.
scale_mask_softmax
(
attention_scores
,
attention_mask
.
contiguous
())
else
:
if
not
(
attention_mask
==
0
).
all
():
# if auto-regressive, skip
attention_scores
.
masked_fill_
(
attention_mask
,
-
10000.0
)
dtype
=
attention_scores
.
dtype
attention_scores
=
attention_scores
.
float
()
attention_scores
=
attention_scores
*
query_key_layer_scaling_coeff
attention_probs
=
F
.
softmax
(
attention_scores
,
dim
=-
1
)
attention_probs
=
attention_probs
.
type
(
dtype
)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size
=
(
value_layer
.
size
(
1
),
value_layer
.
size
(
2
),
query_layer
.
size
(
0
),
value_layer
.
size
(
3
))
# change view [sk, b * np, hn]
value_layer
=
value_layer
.
view
(
value_layer
.
size
(
0
),
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# change view [b * np, sq, sk]
attention_probs
=
attention_probs
.
view
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
-
1
)
# matmul: [b * np, sq, hn]
context_layer
=
torch
.
bmm
(
attention_probs
,
value_layer
.
transpose
(
0
,
1
))
# change view [b, np, sq, hn]
context_layer
=
context_layer
.
view
(
*
output_size
)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer
=
context_layer
.
permute
(
2
,
0
,
1
,
3
).
contiguous
()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
hidden_size_per_partition
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
outputs
=
(
context_layer
,
present
,
attention_probs
)
return
outputs
def
default_init
(
cls
,
*
args
,
**
kwargs
):
return
cls
(
*
args
,
**
kwargs
)
class
SelfAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
num_attention_heads
,
layer_id
,
hidden_size_per_attention_head
=
None
,
bias
=
True
,
params_dtype
=
torch
.
float
,
position_encoding_2d
=
True
,
empty_init
=
True
):
if
empty_init
:
init_method
=
skip_init
else
:
init_method
=
default_init
super
(
SelfAttention
,
self
).
__init__
()
self
.
layer_id
=
layer_id
self
.
hidden_size
=
hidden_size
self
.
hidden_size_per_partition
=
hidden_size
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_heads_per_partition
=
num_attention_heads
self
.
position_encoding_2d
=
position_encoding_2d
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
hidden_size
//
(
self
.
num_attention_heads
*
2
)
if
position_encoding_2d
else
self
.
hidden_size
//
self
.
num_attention_heads
,
base
=
10000
,
precision
=
torch
.
half
,
learnable
=
False
,
)
self
.
scale_mask_softmax
=
None
if
hidden_size_per_attention_head
is
None
:
self
.
hidden_size_per_attention_head
=
hidden_size
//
num_attention_heads
else
:
self
.
hidden_size_per_attention_head
=
hidden_size_per_attention_head
self
.
inner_hidden_size
=
num_attention_heads
*
self
.
hidden_size_per_attention_head
# Strided linear layer.
self
.
query_key_value
=
init_method
(
torch
.
nn
.
Linear
,
hidden_size
,
3
*
self
.
inner_hidden_size
,
bias
=
bias
,
dtype
=
params_dtype
,
)
self
.
dense
=
init_method
(
torch
.
nn
.
Linear
,
self
.
inner_hidden_size
,
hidden_size
,
bias
=
bias
,
dtype
=
params_dtype
,
)
@
staticmethod
def
attention_mask_func
(
attention_scores
,
attention_mask
):
attention_scores
.
masked_fill_
(
attention_mask
,
-
10000.0
)
return
attention_scores
def
split_tensor_along_last_dim
(
self
,
tensor
,
num_partitions
,
contiguous_split_chunks
=
False
):
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
"""
# Get the size and dimension.
last_dim
=
tensor
.
dim
()
-
1
last_dim_size
=
tensor
.
size
()[
last_dim
]
//
num_partitions
# Split.
tensor_list
=
torch
.
split
(
tensor
,
last_dim_size
,
dim
=
last_dim
)
# Note: torch.split does not create contiguous tensors by default.
if
contiguous_split_chunks
:
return
tuple
(
chunk
.
contiguous
()
for
chunk
in
tensor_list
)
return
tensor_list
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
,
attention_mask
:
torch
.
Tensor
,
layer_id
,
layer_past
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
use_cache
:
bool
=
False
,
output_attentions
:
bool
=
False
,
):
"""
hidden_states: [seq_len, batch, hidden_size]
attention_mask: [(1, 1), seq_len, seq_len]
"""
# [seq_len, batch, 3 * hidden_size]
mixed_raw_layer
=
self
.
query_key_value
(
hidden_states
)
# [seq_len, batch, 3 * hidden_size] --> [seq_len, batch, num_attention_heads, 3 * hidden_size_per_attention_head]
new_tensor_shape
=
mixed_raw_layer
.
size
()[:
-
1
]
+
(
self
.
num_attention_heads_per_partition
,
3
*
self
.
hidden_size_per_attention_head
,
)
mixed_raw_layer
=
mixed_raw_layer
.
view
(
*
new_tensor_shape
)
# [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]
(
query_layer
,
key_layer
,
value_layer
)
=
self
.
split_tensor_along_last_dim
(
mixed_raw_layer
,
3
)
if
self
.
position_encoding_2d
:
q1
,
q2
=
query_layer
.
chunk
(
2
,
dim
=
(
query_layer
.
ndim
-
1
))
k1
,
k2
=
key_layer
.
chunk
(
2
,
dim
=
(
key_layer
.
ndim
-
1
))
cos
,
sin
=
self
.
rotary_emb
(
q1
,
seq_len
=
position_ids
.
max
()
+
1
)
position_ids
,
block_position_ids
=
position_ids
[:,
0
,
:].
transpose
(
0
,
1
).
contiguous
(),
\
position_ids
[:,
1
,
:].
transpose
(
0
,
1
).
contiguous
()
q1
,
k1
=
apply_rotary_pos_emb_index
(
q1
,
k1
,
cos
,
sin
,
position_ids
)
q2
,
k2
=
apply_rotary_pos_emb_index
(
q2
,
k2
,
cos
,
sin
,
block_position_ids
)
query_layer
=
torch
.
concat
([
q1
,
q2
],
dim
=
(
q1
.
ndim
-
1
))
key_layer
=
torch
.
concat
([
k1
,
k2
],
dim
=
(
k1
.
ndim
-
1
))
else
:
position_ids
=
position_ids
.
transpose
(
0
,
1
)
cos
,
sin
=
self
.
rotary_emb
(
value_layer
,
seq_len
=
position_ids
.
max
()
+
1
)
# [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]
query_layer
,
key_layer
=
apply_rotary_pos_emb_index
(
query_layer
,
key_layer
,
cos
,
sin
,
position_ids
)
# [seq_len, batch, hidden_size]
context_layer
,
present
,
attention_probs
=
attention_fn
(
self
=
self
,
query_layer
=
query_layer
,
key_layer
=
key_layer
,
value_layer
=
value_layer
,
attention_mask
=
attention_mask
,
hidden_size_per_partition
=
self
.
hidden_size_per_partition
,
layer_id
=
layer_id
,
layer_past
=
layer_past
,
use_cache
=
use_cache
)
output
=
self
.
dense
(
context_layer
)
outputs
=
(
output
,
present
)
if
output_attentions
:
outputs
+=
(
attention_probs
,)
return
outputs
# output, present, attention_probs
class
GEGLU
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
activation_fn
=
F
.
gelu
def
forward
(
self
,
x
):
# dim=-1 breaks in jit for pt<1.10
x1
,
x2
=
x
.
chunk
(
2
,
dim
=
(
x
.
ndim
-
1
))
return
x1
*
self
.
activation_fn
(
x2
)
class
GLU
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
inner_hidden_size
=
None
,
layer_id
=
None
,
bias
=
True
,
activation_func
=
gelu
,
params_dtype
=
torch
.
float
,
empty_init
=
True
):
super
(
GLU
,
self
).
__init__
()
if
empty_init
:
init_method
=
skip_init
else
:
init_method
=
default_init
self
.
layer_id
=
layer_id
self
.
activation_func
=
activation_func
# Project to 4h.
self
.
hidden_size
=
hidden_size
if
inner_hidden_size
is
None
:
inner_hidden_size
=
4
*
hidden_size
self
.
inner_hidden_size
=
inner_hidden_size
self
.
dense_h_to_4h
=
init_method
(
torch
.
nn
.
Linear
,
self
.
hidden_size
,
self
.
inner_hidden_size
,
bias
=
bias
,
dtype
=
params_dtype
,
)
# Project back to h.
self
.
dense_4h_to_h
=
init_method
(
torch
.
nn
.
Linear
,
self
.
inner_hidden_size
,
self
.
hidden_size
,
bias
=
bias
,
dtype
=
params_dtype
,
)
def
forward
(
self
,
hidden_states
):
"""
hidden_states: [seq_len, batch, hidden_size]
"""
# [seq_len, batch, inner_hidden_size]
intermediate_parallel
=
self
.
dense_h_to_4h
(
hidden_states
)
intermediate_parallel
=
self
.
activation_func
(
intermediate_parallel
)
output
=
self
.
dense_4h_to_h
(
intermediate_parallel
)
return
output
class
GLMBlock
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
num_attention_heads
,
layernorm_epsilon
,
layer_id
,
inner_hidden_size
=
None
,
hidden_size_per_attention_head
=
None
,
layernorm
=
LayerNorm
,
use_bias
=
True
,
params_dtype
=
torch
.
float
,
num_layers
=
28
,
position_encoding_2d
=
True
,
empty_init
=
True
):
super
(
GLMBlock
,
self
).
__init__
()
# Set output layer initialization if not provided.
self
.
layer_id
=
layer_id
# Layernorm on the input data.
self
.
input_layernorm
=
layernorm
(
hidden_size
,
eps
=
layernorm_epsilon
)
self
.
position_encoding_2d
=
position_encoding_2d
# Self attention.
self
.
attention
=
SelfAttention
(
hidden_size
,
num_attention_heads
,
layer_id
,
hidden_size_per_attention_head
=
hidden_size_per_attention_head
,
bias
=
use_bias
,
params_dtype
=
params_dtype
,
position_encoding_2d
=
self
.
position_encoding_2d
,
empty_init
=
empty_init
)
# Layernorm on the input data.
self
.
post_attention_layernorm
=
layernorm
(
hidden_size
,
eps
=
layernorm_epsilon
)
self
.
num_layers
=
num_layers
# GLU
self
.
mlp
=
GLU
(
hidden_size
,
inner_hidden_size
=
inner_hidden_size
,
bias
=
use_bias
,
layer_id
=
layer_id
,
params_dtype
=
params_dtype
,
empty_init
=
empty_init
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
,
attention_mask
:
torch
.
Tensor
,
layer_id
,
layer_past
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
use_cache
:
bool
=
False
,
output_attentions
:
bool
=
False
,
):
"""
hidden_states: [seq_len, batch, hidden_size]
attention_mask: [(1, 1), seq_len, seq_len]
"""
# Layer norm at the begining of the transformer layer.
# [seq_len, batch, hidden_size]
attention_input
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
attention_outputs
=
self
.
attention
(
attention_input
,
position_ids
,
attention_mask
=
attention_mask
,
layer_id
=
layer_id
,
layer_past
=
layer_past
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
)
attention_output
=
attention_outputs
[
0
]
outputs
=
attention_outputs
[
1
:]
# Residual connection.
alpha
=
(
2
*
self
.
num_layers
)
**
0.5
hidden_states
=
attention_input
*
alpha
+
attention_output
mlp_input
=
self
.
post_attention_layernorm
(
hidden_states
)
# MLP.
mlp_output
=
self
.
mlp
(
mlp_input
)
# Second residual connection.
output
=
mlp_input
*
alpha
+
mlp_output
if
use_cache
:
outputs
=
(
output
,)
+
outputs
else
:
outputs
=
(
output
,)
+
outputs
[
1
:]
return
outputs
# hidden_states, present, attentions
class
ChatGLMPreTrainedModel
(
PreTrainedModel
):
"""
An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
is_parallelizable
=
False
supports_gradient_checkpointing
=
True
config_class
=
ChatGLMConfig
base_model_prefix
=
"transformer"
_no_split_modules
=
[
"GLMBlock"
]
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
().
__init__
(
*
inputs
,
**
kwargs
)
def
_init_weights
(
self
,
module
:
nn
.
Module
):
"""Initialize the weights."""
return
def
get_masks
(
self
,
input_ids
,
device
):
batch_size
,
seq_length
=
input_ids
.
shape
context_lengths
=
[
seq
.
tolist
().
index
(
self
.
config
.
bos_token_id
)
for
seq
in
input_ids
]
attention_mask
=
torch
.
ones
((
batch_size
,
seq_length
,
seq_length
),
device
=
device
)
attention_mask
.
tril_
()
for
i
,
context_length
in
enumerate
(
context_lengths
):
attention_mask
[
i
,
:,
:
context_length
]
=
1
attention_mask
.
unsqueeze_
(
1
)
attention_mask
=
(
attention_mask
<
0.5
).
bool
()
return
attention_mask
def
get_position_ids
(
self
,
input_ids
,
mask_positions
,
device
,
use_gmasks
=
None
):
batch_size
,
seq_length
=
input_ids
.
shape
if
use_gmasks
is
None
:
use_gmasks
=
[
False
]
*
batch_size
context_lengths
=
[
seq
.
tolist
().
index
(
self
.
config
.
bos_token_id
)
for
seq
in
input_ids
]
if
self
.
position_encoding_2d
:
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
device
).
unsqueeze
(
0
).
repeat
(
batch_size
,
1
)
for
i
,
context_length
in
enumerate
(
context_lengths
):
position_ids
[
i
,
context_length
:]
=
mask_positions
[
i
]
block_position_ids
=
[
torch
.
cat
((
torch
.
zeros
(
context_length
,
dtype
=
torch
.
long
,
device
=
device
),
torch
.
arange
(
seq_length
-
context_length
,
dtype
=
torch
.
long
,
device
=
device
)
+
1
))
for
context_length
in
context_lengths
]
block_position_ids
=
torch
.
stack
(
block_position_ids
,
dim
=
0
)
position_ids
=
torch
.
stack
((
position_ids
,
block_position_ids
),
dim
=
1
)
else
:
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
device
).
unsqueeze
(
0
).
repeat
(
batch_size
,
1
)
for
i
,
context_length
in
enumerate
(
context_lengths
):
if
not
use_gmasks
[
i
]:
position_ids
[
i
,
context_length
:]
=
mask_positions
[
i
]
return
position_ids
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
ChatGLMModel
):
module
.
gradient_checkpointing
=
value
CHATGLM_6B_START_DOCSTRING
=
r
"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
usage and behavior.
Parameters:
config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
CHATGLM_6B_INPUTS_DOCSTRING
=
r
"""
Args:
input_ids (`torch.LongTensor` of shape `({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`ChatGLM6BTokenizer`].
See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`:
- 0 corresponds to a *sentence A* token,
- 1 corresponds to a *sentence B* token.
[What are token type IDs?](../glossary#token-type-ids)
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range `[0, config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert *input_ids* indices into associated vectors
than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@
add_start_docstrings
(
"The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top."
,
CHATGLM_6B_START_DOCSTRING
,
)
class
ChatGLMModel
(
ChatGLMPreTrainedModel
):
"""
The model can behave as an encoder (with only self-attention) as well
as a decoder, in which case a layer of cross-attention is added between
the self-attention layers, following the architecture described in [Attention is
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani,
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
To behave as an decoder the model needs to be initialized with the
`is_decoder` argument of the configuration set to `True`.
To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder`
argument and `add_cross_attention` set to `True`; an
`encoder_hidden_states` is then expected as an input to the forward pass.
"""
def
__init__
(
self
,
config
:
ChatGLMConfig
,
empty_init
=
True
):
super
().
__init__
(
config
)
if
empty_init
:
init_method
=
skip_init
else
:
init_method
=
default_init
# recording parameters
self
.
max_sequence_length
=
config
.
max_sequence_length
self
.
hidden_size
=
config
.
hidden_size
self
.
params_dtype
=
torch
.
half
self
.
num_attention_heads
=
config
.
num_attention_heads
self
.
vocab_size
=
config
.
vocab_size
self
.
num_layers
=
config
.
num_layers
self
.
layernorm_epsilon
=
config
.
layernorm_epsilon
self
.
inner_hidden_size
=
config
.
inner_hidden_size
self
.
hidden_size_per_attention_head
=
self
.
hidden_size
//
self
.
num_attention_heads
self
.
position_encoding_2d
=
config
.
position_encoding_2d
self
.
pre_seq_len
=
config
.
pre_seq_len
self
.
prefix_projection
=
config
.
prefix_projection
self
.
word_embeddings
=
init_method
(
torch
.
nn
.
Embedding
,
num_embeddings
=
self
.
vocab_size
,
embedding_dim
=
self
.
hidden_size
,
dtype
=
self
.
params_dtype
)
self
.
gradient_checkpointing
=
False
def
get_layer
(
layer_id
):
return
GLMBlock
(
self
.
hidden_size
,
self
.
num_attention_heads
,
self
.
layernorm_epsilon
,
layer_id
,
inner_hidden_size
=
self
.
inner_hidden_size
,
hidden_size_per_attention_head
=
self
.
hidden_size_per_attention_head
,
layernorm
=
LayerNorm
,
use_bias
=
True
,
params_dtype
=
self
.
params_dtype
,
position_encoding_2d
=
self
.
position_encoding_2d
,
empty_init
=
empty_init
)
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
get_layer
(
layer_id
)
for
layer_id
in
range
(
self
.
num_layers
)]
)
# Final layer norm before output.
self
.
final_layernorm
=
LayerNorm
(
self
.
hidden_size
,
eps
=
self
.
layernorm_epsilon
)
if
self
.
pre_seq_len
is
not
None
:
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
self
.
prefix_tokens
=
torch
.
arange
(
self
.
pre_seq_len
).
long
()
self
.
prefix_encoder
=
PrefixEncoder
(
config
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
0.1
)
# total_params = sum(p.numel() for p in self.parameters())
# trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
# print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params))
def
get_input_embeddings
(
self
):
return
self
.
word_embeddings
def
set_input_embeddings
(
self
,
new_embeddings
:
torch
.
Tensor
):
self
.
word_embeddings
=
new_embeddings
def
get_prompt
(
self
,
batch_size
,
device
,
dtype
=
torch
.
half
):
prefix_tokens
=
self
.
prefix_tokens
.
unsqueeze
(
0
).
expand
(
batch_size
,
-
1
).
to
(
device
)
past_key_values
=
self
.
prefix_encoder
(
prefix_tokens
).
type
(
dtype
)
past_key_values
=
past_key_values
.
view
(
batch_size
,
self
.
pre_seq_len
,
self
.
num_layers
*
2
,
self
.
num_attention_heads
,
self
.
hidden_size
//
self
.
num_attention_heads
)
# seq_len, b, nh, hidden_size
past_key_values
=
self
.
dropout
(
past_key_values
)
past_key_values
=
past_key_values
.
permute
([
2
,
1
,
0
,
3
,
4
]).
split
(
2
)
# past_key_values = [(v[0], v[1]) for v in past_key_values]
return
past_key_values
@
add_start_docstrings_to_model_forward
(
CHATGLM_6B_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_code_sample_docstrings
(
checkpoint
=
_CHECKPOINT_FOR_DOC
,
output_type
=
BaseModelOutputWithPastAndCrossAttentions
,
config_class
=
_CONFIG_FOR_DOC
,
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
[
torch
.
Tensor
,
...],
BaseModelOutputWithPast
]:
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
use_cache
=
use_cache
if
use_cache
is
not
None
else
self
.
config
.
use_cache
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
input_ids
is
not
None
:
batch_size
,
seq_length
=
input_ids
.
shape
[:
2
]
elif
inputs_embeds
is
not
None
:
batch_size
,
seq_length
=
inputs_embeds
.
shape
[:
2
]
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
if
past_key_values
is
None
:
if
self
.
pre_seq_len
is
not
None
:
past_key_values
=
self
.
get_prompt
(
batch_size
=
input_ids
.
shape
[
0
],
device
=
input_ids
.
device
,
dtype
=
inputs_embeds
.
dtype
)
else
:
past_key_values
=
tuple
([
None
]
*
len
(
self
.
layers
))
if
attention_mask
is
None
:
attention_mask
=
self
.
get_masks
(
input_ids
,
device
=
input_ids
.
device
)
if
position_ids
is
None
:
MASK
,
gMASK
=
self
.
config
.
mask_token_id
,
self
.
config
.
gmask_token_id
seqs
=
input_ids
.
tolist
()
mask_positions
,
use_gmasks
=
[],
[]
for
seq
in
seqs
:
mask_token
=
gMASK
if
gMASK
in
seq
else
MASK
use_gmask
=
mask_token
==
gMASK
mask_positions
.
append
(
seq
.
index
(
mask_token
))
use_gmasks
.
append
(
use_gmask
)
position_ids
=
self
.
get_position_ids
(
input_ids
,
mask_positions
=
mask_positions
,
device
=
input_ids
.
device
,
use_gmasks
=
use_gmasks
)
if
self
.
pre_seq_len
is
not
None
and
attention_mask
is
not
None
:
prefix_attention_mask
=
torch
.
ones
(
batch_size
,
1
,
input_ids
.
size
(
-
1
),
self
.
pre_seq_len
).
to
(
attention_mask
.
device
)
prefix_attention_mask
=
(
prefix_attention_mask
<
0.5
).
bool
()
attention_mask
=
torch
.
cat
((
prefix_attention_mask
,
attention_mask
),
dim
=
3
)
# [seq_len, batch, hidden_size]
hidden_states
=
inputs_embeds
.
transpose
(
0
,
1
)
presents
=
()
if
use_cache
else
None
all_self_attentions
=
()
if
output_attentions
else
None
all_hidden_states
=
()
if
output_hidden_states
else
None
if
attention_mask
is
None
:
attention_mask
=
torch
.
zeros
(
1
,
1
,
device
=
input_ids
.
device
).
bool
()
else
:
attention_mask
=
attention_mask
.
to
(
hidden_states
.
device
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
if
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
layer_past
=
past_key_values
[
i
]
if
self
.
gradient_checkpointing
and
self
.
training
:
layer_ret
=
torch
.
utils
.
checkpoint
.
checkpoint
(
layer
,
hidden_states
,
position_ids
,
attention_mask
,
torch
.
tensor
(
i
),
layer_past
,
use_cache
,
output_attentions
)
else
:
layer_ret
=
layer
(
hidden_states
,
position_ids
=
position_ids
,
attention_mask
=
attention_mask
,
layer_id
=
torch
.
tensor
(
i
),
layer_past
=
layer_past
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
)
hidden_states
=
layer_ret
[
0
]
if
use_cache
:
presents
=
presents
+
(
layer_ret
[
1
],)
if
output_attentions
:
all_self_attentions
=
all_self_attentions
+
(
layer_ret
[
2
if
use_cache
else
1
],)
# Final layer norm.
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
if
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
presents
,
all_hidden_states
,
all_self_attentions
]
if
v
is
not
None
)
return
BaseModelOutputWithPast
(
last_hidden_state
=
hidden_states
,
past_key_values
=
presents
,
hidden_states
=
all_hidden_states
,
attentions
=
all_self_attentions
,
)
class
ChatGLMForConditionalGeneration
(
ChatGLMPreTrainedModel
):
def
__init__
(
self
,
config
:
ChatGLMConfig
,
empty_init
=
True
):
super
().
__init__
(
config
)
if
empty_init
:
init_method
=
skip_init
else
:
init_method
=
default_init
# self.hidden_size = config.hidden_size
# self.params_dtype = torch.half
# self.vocab_size = config.vocab_size
self
.
max_sequence_length
=
config
.
max_sequence_length
self
.
position_encoding_2d
=
config
.
position_encoding_2d
self
.
transformer
=
ChatGLMModel
(
config
,
empty_init
=
empty_init
)
self
.
lm_head
=
init_method
(
nn
.
Linear
,
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
,
dtype
=
torch
.
half
)
self
.
config
=
config
self
.
quantized
=
False
if
self
.
config
.
quantization_bit
:
self
.
quantize
(
self
.
config
.
quantization_bit
,
empty_init
=
True
)
def
get_output_embeddings
(
self
):
return
self
.
lm_head
def
set_output_embeddings
(
self
,
new_embeddings
):
self
.
lm_head
=
new_embeddings
def
_update_model_kwargs_for_generation
(
self
,
outputs
:
ModelOutput
,
model_kwargs
:
Dict
[
str
,
Any
],
is_encoder_decoder
:
bool
=
False
,
standardize_cache_format
:
bool
=
False
,
)
->
Dict
[
str
,
Any
]:
# update past_key_values
model_kwargs
[
"past_key_values"
]
=
self
.
_extract_past_from_model_output
(
outputs
,
standardize_cache_format
=
standardize_cache_format
)
# update attention mask
if
"attention_mask"
in
model_kwargs
:
attention_mask
=
model_kwargs
[
"attention_mask"
]
if
attention_mask
is
not
None
and
attention_mask
.
dtype
==
torch
.
bool
:
attention_mask
=
torch
.
cat
(
[
attention_mask
,
attention_mask
.
new_ones
((
*
attention_mask
.
shape
[:
3
],
1
))],
dim
=
3
)
new_attention_mask
=
attention_mask
[:,
:,
-
1
:].
clone
()
new_attention_mask
[...,
-
1
]
=
False
model_kwargs
[
"attention_mask"
]
=
torch
.
cat
(
[
attention_mask
,
new_attention_mask
],
dim
=
2
)
# update position ids
if
"position_ids"
in
model_kwargs
:
position_ids
=
model_kwargs
[
"position_ids"
]
new_position_id
=
position_ids
[...,
-
1
:].
clone
()
new_position_id
[:,
1
,
:]
+=
1
model_kwargs
[
"position_ids"
]
=
torch
.
cat
(
[
position_ids
,
new_position_id
],
dim
=-
1
)
return
model_kwargs
def
prepare_inputs_for_generation
(
self
,
input_ids
:
torch
.
LongTensor
,
past
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
)
->
dict
:
batch_size
,
seq_length
=
input_ids
.
shape
MASK
,
gMASK
=
self
.
config
.
mask_token_id
,
self
.
config
.
gmask_token_id
seqs
=
input_ids
.
tolist
()
mask_positions
,
use_gmasks
=
[],
[]
for
seq
in
seqs
:
mask_token
=
gMASK
if
gMASK
in
seq
else
MASK
use_gmask
=
mask_token
==
gMASK
mask_positions
.
append
(
seq
.
index
(
mask_token
))
use_gmasks
.
append
(
use_gmask
)
# only last token for input_ids if past is not None
if
past
is
not
None
or
past_key_values
is
not
None
:
last_token
=
input_ids
[:,
-
1
].
unsqueeze
(
-
1
)
if
attention_mask
is
not
None
and
attention_mask
.
dtype
==
torch
.
bool
:
attention_mask
=
attention_mask
[:,
:,
-
1
:]
else
:
attention_mask
=
None
if
position_ids
is
not
None
:
position_ids
=
position_ids
[...,
-
1
:]
else
:
context_lengths
=
[
seq
.
index
(
self
.
config
.
bos_token_id
)
for
seq
in
seqs
]
if
self
.
position_encoding_2d
:
position_ids
=
torch
.
tensor
(
[[
mask_position
,
seq_length
-
context_length
]
for
mask_position
,
context_length
in
zip
(
mask_positions
,
context_lengths
)],
dtype
=
torch
.
long
,
device
=
input_ids
.
device
).
unsqueeze
(
-
1
)
else
:
position_ids
=
torch
.
tensor
([
mask_position
for
mask_position
in
mask_positions
],
dtype
=
torch
.
long
,
device
=
input_ids
.
device
).
unsqueeze
(
-
1
)
if
past
is
None
:
past
=
past_key_values
return
{
"input_ids"
:
last_token
,
"past_key_values"
:
past
,
"position_ids"
:
position_ids
,
"attention_mask"
:
attention_mask
}
else
:
if
attention_mask
is
not
None
and
attention_mask
.
dtype
!=
torch
.
bool
:
logger
.
warning_once
(
f
"The dtype of attention mask (
{
attention_mask
.
dtype
}
) is not bool"
)
attention_mask
=
None
if
attention_mask
is
None
:
attention_mask
=
self
.
get_masks
(
input_ids
,
device
=
input_ids
.
device
)
if
position_ids
is
None
:
position_ids
=
self
.
get_position_ids
(
input_ids
,
device
=
input_ids
.
device
,
mask_positions
=
mask_positions
,
use_gmasks
=
use_gmasks
)
return
{
"input_ids"
:
input_ids
,
"past_key_values"
:
past
,
"position_ids"
:
position_ids
,
"attention_mask"
:
attention_mask
}
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
labels
:
Optional
[
torch
.
Tensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
):
use_cache
=
use_cache
if
use_cache
is
not
None
else
self
.
config
.
use_cache
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
transformer_outputs
=
self
.
transformer
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs_embeds
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
hidden_states
=
transformer_outputs
[
0
]
lm_logits
=
self
.
lm_head
(
hidden_states
).
permute
(
1
,
0
,
2
).
contiguous
()
loss
=
None
if
labels
is
not
None
:
lm_logits
=
lm_logits
.
to
(
torch
.
float32
)
# Shift so that tokens < n predict n
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
labels
[...,
1
:].
contiguous
()
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
100
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
lm_logits
=
lm_logits
.
to
(
hidden_states
.
dtype
)
loss
=
loss
.
to
(
hidden_states
.
dtype
)
if
not
return_dict
:
output
=
(
lm_logits
,)
+
transformer_outputs
[
1
:]
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
return
CausalLMOutputWithPast
(
loss
=
loss
,
logits
=
lm_logits
,
past_key_values
=
transformer_outputs
.
past_key_values
,
hidden_states
=
transformer_outputs
.
hidden_states
,
attentions
=
transformer_outputs
.
attentions
,
)
@
staticmethod
def
_reorder_cache
(
past
:
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...],
beam_idx
:
torch
.
LongTensor
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
Output shares the same memory storage as `past`.
"""
return
tuple
(
(
layer_past
[
0
].
index_select
(
1
,
beam_idx
.
to
(
layer_past
[
0
].
device
)),
layer_past
[
1
].
index_select
(
1
,
beam_idx
.
to
(
layer_past
[
1
].
device
)),
)
for
layer_past
in
past
)
def
process_response
(
self
,
response
):
response
=
response
.
strip
()
response
=
response
.
replace
(
"[[训练时间]]"
,
"2023年"
)
punkts
=
[
[
","
,
","
],
[
"!"
,
"!"
],
[
":"
,
":"
],
[
";"
,
";"
],
[
"\?"
,
"?"
],
]
for
item
in
punkts
:
response
=
re
.
sub
(
r
"([\u4e00-\u9fff])%s"
%
item
[
0
],
r
"\1%s"
%
item
[
1
],
response
)
response
=
re
.
sub
(
r
"%s([\u4e00-\u9fff])"
%
item
[
0
],
r
"%s\1"
%
item
[
1
],
response
)
return
response
@
torch
.
no_grad
()
def
chat
(
self
,
tokenizer
,
query
:
str
,
history
:
List
[
Tuple
[
str
,
str
]]
=
None
,
max_length
:
int
=
2048
,
num_beams
=
1
,
do_sample
=
True
,
top_p
=
0.7
,
temperature
=
0.95
,
logits_processor
=
None
,
**
kwargs
):
if
history
is
None
:
history
=
[]
if
logits_processor
is
None
:
logits_processor
=
LogitsProcessorList
()
logits_processor
.
append
(
InvalidScoreLogitsProcessor
())
gen_kwargs
=
{
"max_length"
:
max_length
,
"num_beams"
:
num_beams
,
"do_sample"
:
do_sample
,
"top_p"
:
top_p
,
"temperature"
:
temperature
,
"logits_processor"
:
logits_processor
,
**
kwargs
}
if
not
history
:
prompt
=
query
else
:
prompt
=
""
for
i
,
(
old_query
,
response
)
in
enumerate
(
history
):
prompt
+=
"[Round {}]
\n
问:{}
\n
答:{}
\n
"
.
format
(
i
,
old_query
,
response
)
prompt
+=
"[Round {}]
\n
问:{}
\n
答:"
.
format
(
len
(
history
),
query
)
inputs
=
tokenizer
([
prompt
],
return_tensors
=
"pt"
)
inputs
=
inputs
.
to
(
self
.
device
)
outputs
=
self
.
generate
(
**
inputs
,
**
gen_kwargs
)
outputs
=
outputs
.
tolist
()[
0
][
len
(
inputs
[
"input_ids"
][
0
]):]
response
=
tokenizer
.
decode
(
outputs
)
response
=
self
.
process_response
(
response
)
history
=
history
+
[(
query
,
response
)]
return
response
,
history
@
torch
.
no_grad
()
def
stream_chat
(
self
,
tokenizer
,
query
:
str
,
history
:
List
[
Tuple
[
str
,
str
]]
=
None
,
max_length
:
int
=
2048
,
do_sample
=
True
,
top_p
=
0.7
,
temperature
=
0.95
,
logits_processor
=
None
,
**
kwargs
):
if
history
is
None
:
history
=
[]
if
logits_processor
is
None
:
logits_processor
=
LogitsProcessorList
()
logits_processor
.
append
(
InvalidScoreLogitsProcessor
())
gen_kwargs
=
{
"max_length"
:
max_length
,
"do_sample"
:
do_sample
,
"top_p"
:
top_p
,
"temperature"
:
temperature
,
"logits_processor"
:
logits_processor
,
**
kwargs
}
if
not
history
:
prompt
=
query
else
:
prompt
=
""
for
i
,
(
old_query
,
response
)
in
enumerate
(
history
):
prompt
+=
"[Round {}]
\n
问:{}
\n
答:{}
\n
"
.
format
(
i
,
old_query
,
response
)
prompt
+=
"[Round {}]
\n
问:{}
\n
答:"
.
format
(
len
(
history
),
query
)
inputs
=
tokenizer
([
prompt
],
return_tensors
=
"pt"
)
inputs
=
inputs
.
to
(
self
.
device
)
for
outputs
in
self
.
stream_generate
(
**
inputs
,
**
gen_kwargs
):
outputs
=
outputs
.
tolist
()[
0
][
len
(
inputs
[
"input_ids"
][
0
]):]
response
=
tokenizer
.
decode
(
outputs
)
response
=
self
.
process_response
(
response
)
new_history
=
history
+
[(
query
,
response
)]
yield
response
,
new_history
@
torch
.
no_grad
()
def
stream_generate
(
self
,
input_ids
,
generation_config
:
Optional
[
GenerationConfig
]
=
None
,
logits_processor
:
Optional
[
LogitsProcessorList
]
=
None
,
stopping_criteria
:
Optional
[
StoppingCriteriaList
]
=
None
,
prefix_allowed_tokens_fn
:
Optional
[
Callable
[[
int
,
torch
.
Tensor
],
List
[
int
]]]
=
None
,
**
kwargs
,
):
batch_size
,
input_ids_seq_length
=
input_ids
.
shape
[
0
],
input_ids
.
shape
[
-
1
]
if
generation_config
is
None
:
generation_config
=
self
.
generation_config
generation_config
=
copy
.
deepcopy
(
generation_config
)
model_kwargs
=
generation_config
.
update
(
**
kwargs
)
bos_token_id
,
eos_token_id
=
generation_config
.
bos_token_id
,
generation_config
.
eos_token_id
if
isinstance
(
eos_token_id
,
int
):
eos_token_id
=
[
eos_token_id
]
has_default_max_length
=
kwargs
.
get
(
"max_length"
)
is
None
and
generation_config
.
max_length
is
not
None
if
has_default_max_length
and
generation_config
.
max_new_tokens
is
None
:
warnings
.
warn
(
f
"Using `max_length`'s default (
{
generation_config
.
max_length
}
) to control the generation length. "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation."
,
UserWarning
,
)
elif
generation_config
.
max_new_tokens
is
not
None
:
generation_config
.
max_length
=
generation_config
.
max_new_tokens
+
input_ids_seq_length
if
not
has_default_max_length
:
logger
.
warn
(
f
"Both `max_new_tokens` (=
{
generation_config
.
max_new_tokens
}
) and `max_length`(="
f
"
{
generation_config
.
max_length
}
) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
,
UserWarning
,
)
if
input_ids_seq_length
>=
generation_config
.
max_length
:
input_ids_string
=
"decoder_input_ids"
if
self
.
config
.
is_encoder_decoder
else
"input_ids"
logger
.
warning
(
f
"Input length of
{
input_ids_string
}
is
{
input_ids_seq_length
}
, but `max_length` is set to"
f
"
{
generation_config
.
max_length
}
. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
# 2. Set generation parameters if not already defined
logits_processor
=
logits_processor
if
logits_processor
is
not
None
else
LogitsProcessorList
()
stopping_criteria
=
stopping_criteria
if
stopping_criteria
is
not
None
else
StoppingCriteriaList
()
logits_processor
=
self
.
_get_logits_processor
(
generation_config
=
generation_config
,
input_ids_seq_length
=
input_ids_seq_length
,
encoder_input_ids
=
input_ids
,
prefix_allowed_tokens_fn
=
prefix_allowed_tokens_fn
,
logits_processor
=
logits_processor
,
)
stopping_criteria
=
self
.
_get_stopping_criteria
(
generation_config
=
generation_config
,
stopping_criteria
=
stopping_criteria
)
logits_warper
=
self
.
_get_logits_warper
(
generation_config
)
unfinished_sequences
=
input_ids
.
new
(
input_ids
.
shape
[
0
]).
fill_
(
1
)
scores
=
None
while
True
:
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
**
model_kwargs
)
# forward pass to get next token
outputs
=
self
(
**
model_inputs
,
return_dict
=
True
,
output_attentions
=
False
,
output_hidden_states
=
False
,
)
next_token_logits
=
outputs
.
logits
[:,
-
1
,
:]
# pre-process distribution
next_token_scores
=
logits_processor
(
input_ids
,
next_token_logits
)
next_token_scores
=
logits_warper
(
input_ids
,
next_token_scores
)
# sample
probs
=
nn
.
functional
.
softmax
(
next_token_scores
,
dim
=-
1
)
if
generation_config
.
do_sample
:
next_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
1
).
squeeze
(
1
)
else
:
next_tokens
=
torch
.
argmax
(
probs
,
dim
=-
1
)
# update generated ids, model inputs, and length for next step
input_ids
=
torch
.
cat
([
input_ids
,
next_tokens
[:,
None
]],
dim
=-
1
)
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
)
unfinished_sequences
=
unfinished_sequences
.
mul
((
sum
(
next_tokens
!=
i
for
i
in
eos_token_id
)).
long
())
# stop when each sentence is finished, or if we exceed the maximum length
if
unfinished_sequences
.
max
()
==
0
or
stopping_criteria
(
input_ids
,
scores
):
break
yield
input_ids
def
quantize
(
self
,
bits
:
int
,
empty_init
=
False
,
**
kwargs
):
if
bits
==
0
:
return
from
.quantization
import
quantize
if
self
.
quantized
:
logger
.
info
(
"Already quantized."
)
return
self
self
.
quantized
=
True
self
.
config
.
quantization_bit
=
bits
self
.
transformer
=
quantize
(
self
.
transformer
,
bits
,
empty_init
=
empty_init
,
**
kwargs
)
return
self
applications/Chat/coati/trainer/sft.py
View file @
aaeb520c
...
@@ -52,9 +52,13 @@ class SFTTrainer(SLTrainer):
...
@@ -52,9 +52,13 @@ class SFTTrainer(SLTrainer):
for
batch_id
,
batch
in
enumerate
(
self
.
train_dataloader
):
for
batch_id
,
batch
in
enumerate
(
self
.
train_dataloader
):
batch
=
to_device
(
batch
,
torch
.
cuda
.
current_device
())
batch
=
to_device
(
batch
,
torch
.
cuda
.
current_device
())
outputs
=
self
.
model
(
batch
[
"input_ids"
],
if
"attention_mask"
in
batch
:
attention_mask
=
batch
[
"attention_mask"
],
outputs
=
self
.
model
(
batch
[
"input_ids"
],
labels
=
batch
[
"labels"
])
attention_mask
=
batch
[
"attention_mask"
],
labels
=
batch
[
"labels"
])
else
:
outputs
=
self
.
model
(
batch
[
"input_ids"
],
labels
=
batch
[
"labels"
])
loss
=
outputs
.
loss
loss
=
outputs
.
loss
loss
=
loss
/
self
.
accumulation_steps
loss
=
loss
/
self
.
accumulation_steps
...
...
applications/Chat/examples/requirements.txt
View file @
aaeb520c
pandas>=1.4.1
pandas>=1.4.1
sentencepiece
sentencepiece
colossalai==0.3.1
\ No newline at end of file
applications/Chat/examples/train_sft.py
View file @
aaeb520c
...
@@ -9,13 +9,15 @@ from coati.models.bloom import BLOOMActor
...
@@ -9,13 +9,15 @@ from coati.models.bloom import BLOOMActor
from
coati.models.gpt
import
GPTActor
from
coati.models.gpt
import
GPTActor
from
coati.models.llama
import
LlamaActor
from
coati.models.llama
import
LlamaActor
from
coati.models.opt
import
OPTActor
from
coati.models.opt
import
OPTActor
from
coati.models.chatglm
import
ChatGLMActor
from
coati.trainer
import
SFTTrainer
from
coati.trainer
import
SFTTrainer
from
coati.trainer.strategies
import
DDPStrategy
,
GeminiStrategy
,
LowLevelZeroStrategy
from
coati.trainer.strategies
import
DDPStrategy
,
GeminiStrategy
,
LowLevelZeroStrategy
from
datasets
import
load_dataset
from
datasets
import
load_dataset
from
torch.optim
import
Adam
from
torch.optim
import
Adam
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
transformers
import
AutoTokenizer
,
BloomTokenizerFast
,
LlamaTokenizer
from
transformers
import
AutoTokenizer
,
BloomTokenizerFast
,
LlamaTokenizer
,
AutoModel
from
coati.models.chatglm.chatglm_tokenizer
import
ChatGLMTokenizer
from
transformers.models.gpt2.tokenization_gpt2
import
GPT2Tokenizer
from
transformers.models.gpt2.tokenization_gpt2
import
GPT2Tokenizer
from
transformers.trainer
import
get_scheduler
from
transformers.trainer
import
get_scheduler
...
@@ -58,6 +60,8 @@ def train(args):
...
@@ -58,6 +60,8 @@ def train(args):
model
=
LlamaActor
(
pretrained
=
args
.
pretrain
,
model
=
LlamaActor
(
pretrained
=
args
.
pretrain
,
lora_rank
=
args
.
lora_rank
,
lora_rank
=
args
.
lora_rank
,
checkpoint
=
args
.
grad_checkpoint
)
checkpoint
=
args
.
grad_checkpoint
)
elif
args
.
model
==
'chatglm'
:
model
=
ChatGLMActor
(
pretrained
=
args
.
pretrain
)
else
:
else
:
raise
ValueError
(
f
'Unsupported model "
{
args
.
model
}
"'
)
raise
ValueError
(
f
'Unsupported model "
{
args
.
model
}
"'
)
...
@@ -81,6 +85,9 @@ def train(args):
...
@@ -81,6 +85,9 @@ def train(args):
"hf-internal-testing/llama-tokenizer"
if
args
.
tokenizer
is
None
else
args
.
tokenizer
)
"hf-internal-testing/llama-tokenizer"
if
args
.
tokenizer
is
None
else
args
.
tokenizer
)
tokenizer
.
eos_token
=
'<\s>'
tokenizer
.
eos_token
=
'<\s>'
tokenizer
.
pad_token
=
tokenizer
.
unk_token
tokenizer
.
pad_token
=
tokenizer
.
unk_token
elif
args
.
model
==
'chatglm'
:
tokenizer
=
ChatGLMTokenizer
.
from_pretrained
(
"THUDM/chatglm-6b"
if
args
.
tokenizer
is
None
else
args
.
tokenizer
,
trust_remote_code
=
True
)
else
:
else
:
raise
ValueError
(
f
'Unsupported model "
{
args
.
model
}
"'
)
raise
ValueError
(
f
'Unsupported model "
{
args
.
model
}
"'
)
...
@@ -99,7 +106,6 @@ def train(args):
...
@@ -99,7 +106,6 @@ def train(args):
optim
=
HybridAdam
(
model
.
parameters
(),
lr
=
args
.
lr
,
clipping_norm
=
1.0
)
optim
=
HybridAdam
(
model
.
parameters
(),
lr
=
args
.
lr
,
clipping_norm
=
1.0
)
else
:
else
:
optim
=
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
optim
=
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
logger
=
get_dist_logger
()
logger
=
get_dist_logger
()
# configure dataset
# configure dataset
...
@@ -185,7 +191,7 @@ if __name__ == '__main__':
...
@@ -185,7 +191,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--strategy'
,
parser
.
add_argument
(
'--strategy'
,
choices
=
[
'ddp'
,
'colossalai_gemini'
,
'colossalai_zero2'
,
'colossalai_zero2_cpu'
],
choices
=
[
'ddp'
,
'colossalai_gemini'
,
'colossalai_zero2'
,
'colossalai_zero2_cpu'
],
default
=
'colossalai_zero2'
)
default
=
'colossalai_zero2'
)
parser
.
add_argument
(
'--model'
,
choices
=
[
'gpt2'
,
'bloom'
,
'opt'
,
'llama'
],
default
=
'bloom'
)
parser
.
add_argument
(
'--model'
,
choices
=
[
'gpt2'
,
'bloom'
,
'opt'
,
'llama'
,
'chatglm'
],
default
=
'bloom'
)
parser
.
add_argument
(
'--tokenizer'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--tokenizer'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--pretrain'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--pretrain'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
default
=
None
)
...
...
applications/Chat/requirements-test.txt
View file @
aaeb520c
pytest
pytest
colossalai==0.3.1
\ No newline at end of file
applications/Chat/requirements.txt
View file @
aaeb520c
...
@@ -2,7 +2,7 @@ transformers>=4.20.1
...
@@ -2,7 +2,7 @@ transformers>=4.20.1
tqdm
tqdm
datasets
datasets
loralib
loralib
colossalai
>
=0.
2.4
colossalai
=
=0.
3.1
torch<2.0.0, >=1.12.1
torch<2.0.0, >=1.12.1
langchain
langchain
tokenizers
tokenizers
...
...
applications/Chat/tests/test_dataset.py
View file @
aaeb520c
...
@@ -11,7 +11,7 @@ from coati.dataset.sft_dataset import IGNORE_INDEX, SFTDataset, SupervisedDatase
...
@@ -11,7 +11,7 @@ from coati.dataset.sft_dataset import IGNORE_INDEX, SFTDataset, SupervisedDatase
from
datasets
import
load_dataset
from
datasets
import
load_dataset
from
transformers
import
AutoTokenizer
,
BloomTokenizerFast
,
LlamaTokenizer
,
PreTrainedTokenizer
from
transformers
import
AutoTokenizer
,
BloomTokenizerFast
,
LlamaTokenizer
,
PreTrainedTokenizer
from
transformers.models.gpt2.tokenization_gpt2
import
GPT2Tokenizer
from
transformers.models.gpt2.tokenization_gpt2
import
GPT2Tokenizer
from
coati.models.chatglm.chatglm_tokenizer
import
ChatGLMTokenizer
SFT_DATASET
=
[
SFT_DATASET
=
[
{
{
"instruction"
:
"Provide a list of the top 10 most popular mobile games in Asia"
,
"instruction"
:
"Provide a list of the top 10 most popular mobile games in Asia"
,
...
@@ -66,6 +66,8 @@ def make_tokenizer(model: str):
...
@@ -66,6 +66,8 @@ def make_tokenizer(model: str):
elif
model
==
"llama"
:
elif
model
==
"llama"
:
tokenizer
=
LlamaTokenizer
.
from_pretrained
(
"hf-internal-testing/llama-tokenizer"
)
tokenizer
=
LlamaTokenizer
.
from_pretrained
(
"hf-internal-testing/llama-tokenizer"
)
tokenizer
.
pad_token
=
tokenizer
.
unk_token
tokenizer
.
pad_token
=
tokenizer
.
unk_token
elif
model
==
"chatglm"
:
tokenizer
=
ChatGLMTokenizer
.
from_pretrained
(
"THUDM/chatglm-6b"
,
trust_remote_code
=
True
)
else
:
else
:
raise
ValueError
(
f
"Unsupported model '
{
model
}
'"
)
raise
ValueError
(
f
"Unsupported model '
{
model
}
'"
)
return
tokenizer
return
tokenizer
...
@@ -81,13 +83,19 @@ def check_content(input_ids_stripped: torch.Tensor,
...
@@ -81,13 +83,19 @@ def check_content(input_ids_stripped: torch.Tensor,
elif
model
==
"llama"
:
elif
model
==
"llama"
:
assert
input_ids_stripped
[
0
]
==
tokenizer
.
bos_token_id
assert
input_ids_stripped
[
0
]
==
tokenizer
.
bos_token_id
input_ids_stripped
=
input_ids_stripped
[
1
:]
input_ids_stripped
=
input_ids_stripped
[
1
:]
elif
model
==
"chatglm"
:
assert
input_ids_stripped
[
0
]
==
tokenizer
.
bos_token_id
assert
input_ids_stripped
[
-
1
]
==
tokenizer
.
eos_token_id
input_ids_stripped
=
input_ids_stripped
[
1
:
-
1
]
assert
torch
.
all
(
input_ids_stripped
!=
tokenizer
.
pad_token_id
)
assert
torch
.
all
(
input_ids_stripped
!=
tokenizer
.
pad_token_id
)
assert
torch
.
all
(
input_ids_stripped
!=
tokenizer
.
bos_token_id
)
assert
torch
.
all
(
input_ids_stripped
!=
tokenizer
.
bos_token_id
)
assert
torch
.
all
(
input_ids_stripped
!=
tokenizer
.
eos_token_id
)
assert
torch
.
all
(
input_ids_stripped
!=
tokenizer
.
eos_token_id
)
assert
input_ids_stripped
!=
tokenizer
.
sep_token_id
assert
input_ids_stripped
!=
tokenizer
.
sep_token_id
assert
input_ids_stripped
!=
tokenizer
.
cls_token_id
assert
input_ids_stripped
!=
tokenizer
.
cls_token_id
assert
input_ids_stripped
!=
tokenizer
.
mask_token_id
if
model
==
"chatglm"
:
assert
torch
.
all
(
input_ids_stripped
!=
tokenizer
.
mask_token_id
)
else
:
assert
input_ids_stripped
!=
tokenizer
.
mask_token_id
@
pytest
.
mark
.
cpu
@
pytest
.
mark
.
cpu
...
@@ -189,7 +197,7 @@ def test_reward_dataset(model: str,
...
@@ -189,7 +197,7 @@ def test_reward_dataset(model: str,
@
pytest
.
mark
.
cpu
@
pytest
.
mark
.
cpu
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"gpt2"
,
"bloom"
,
"opt"
,
"llama"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"gpt2"
,
"bloom"
,
"opt"
,
"llama"
,
"chatglm"
])
@
pytest
.
mark
.
parametrize
(
"dataset_path"
,
[
"yizhongw/self_instruct"
,
None
])
@
pytest
.
mark
.
parametrize
(
"dataset_path"
,
[
"yizhongw/self_instruct"
,
None
])
@
pytest
.
mark
.
parametrize
(
"max_dataset_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"max_dataset_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"max_length"
,
[
32
,
1024
])
@
pytest
.
mark
.
parametrize
(
"max_length"
,
[
32
,
1024
])
...
@@ -213,6 +221,19 @@ def test_sft_dataset(model: str,
...
@@ -213,6 +221,19 @@ def test_sft_dataset(model: str,
max_length
=
max_length
)
max_length
=
max_length
)
assert
len
(
sft_dataset
)
==
min
(
max_dataset_size
,
len
(
SFT_DATASET
))
assert
len
(
sft_dataset
)
==
min
(
max_dataset_size
,
len
(
SFT_DATASET
))
if
isinstance
(
tokenizer
,
ChatGLMTokenizer
):
for
i
in
range
(
max_dataset_size
):
assert
isinstance
(
sft_dataset
[
i
],
dict
)
assert
list
(
sft_dataset
[
i
].
keys
())
==
[
"input_ids"
,
"labels"
]
input_ids
=
sft_dataset
[
i
][
"input_ids"
]
labels
=
sft_dataset
[
i
][
"labels"
]
assert
input_ids
.
shape
==
labels
.
shape
==
torch
.
Size
([
max_length
])
ignore_mask
=
labels
==
IGNORE_INDEX
assert
input_ids
.
masked_select
(
torch
.
logical_not
(
ignore_mask
))[
0
]
==
tokenizer
.
bos_token_id
check_content
(
input_ids
.
masked_select
(
torch
.
logical_not
(
ignore_mask
)),
tokenizer
,
model
)
return
for
i
in
range
(
max_dataset_size
):
for
i
in
range
(
max_dataset_size
):
assert
isinstance
(
sft_dataset
[
i
],
dict
)
assert
isinstance
(
sft_dataset
[
i
],
dict
)
assert
list
(
sft_dataset
[
i
].
keys
())
==
[
"input_ids"
,
"labels"
,
"attention_mask"
]
assert
list
(
sft_dataset
[
i
].
keys
())
==
[
"input_ids"
,
"labels"
,
"attention_mask"
]
...
@@ -245,4 +266,4 @@ if __name__ == "__main__":
...
@@ -245,4 +266,4 @@ if __name__ == "__main__":
test_prompt_dataset
(
model
=
"opt"
,
test_prompt_dataset
(
model
=
"opt"
,
max_datasets_size
=
2
,
max_datasets_size
=
2
,
max_length
=
128
)
max_length
=
128
)
\ No newline at end of file
applications/Chat/tests/test_models.py
View file @
aaeb520c
...
@@ -9,11 +9,12 @@ from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
...
@@ -9,11 +9,12 @@ from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from
coati.models.generation
import
generate
from
coati.models.generation
import
generate
from
coati.models.gpt
import
GPTRM
,
GPTActor
,
GPTCritic
from
coati.models.gpt
import
GPTRM
,
GPTActor
,
GPTCritic
from
coati.models.llama
import
LlamaActor
,
LlamaCritic
,
LlamaRM
from
coati.models.llama
import
LlamaActor
,
LlamaCritic
,
LlamaRM
from
coati.models.chatglm
import
ChatGLMActor
from
coati.models.lora
import
LoraLinear
,
convert_to_lora_module
from
coati.models.lora
import
LoraLinear
,
convert_to_lora_module
from
coati.models.loss
import
GPTLMLoss
,
LogExpLoss
,
LogSigLoss
,
PolicyLoss
,
ValueLoss
from
coati.models.loss
import
GPTLMLoss
,
LogExpLoss
,
LogSigLoss
,
PolicyLoss
,
ValueLoss
from
coati.models.opt
import
OPTRM
,
OPTActor
,
OPTCritic
from
coati.models.opt
import
OPTRM
,
OPTActor
,
OPTCritic
from
coati.models.utils
import
calc_action_log_probs
,
compute_reward
,
masked_mean
from
coati.models.utils
import
calc_action_log_probs
,
compute_reward
,
masked_mean
from
coati.models.chatglm.chatglm_tokenizer
import
ChatGLMTokenizer
@
pytest
.
mark
.
gpu
@
pytest
.
mark
.
gpu
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
...
@@ -23,7 +24,8 @@ from coati.models.utils import calc_action_log_probs, compute_reward, masked_mea
...
@@ -23,7 +24,8 @@ from coati.models.utils import calc_action_log_probs, compute_reward, masked_mea
lambda
:
GPTActor
(),
lambda
:
GPTActor
(),
# HACK: skip llama due to long execution time
# HACK: skip llama due to long execution time
# lambda: LlamaActor(),
# lambda: LlamaActor(),
lambda
:
OPTActor
()
lambda
:
OPTActor
(),
# lambda: ChatGLMActor(),
])
])
@
pytest
.
mark
.
parametrize
(
"generate_kwargs"
,
[{
@
pytest
.
mark
.
parametrize
(
"generate_kwargs"
,
[{
"max_length"
:
64
,
"max_length"
:
64
,
...
@@ -129,12 +131,12 @@ def test_lora(lora_rank: int,
...
@@ -129,12 +131,12 @@ def test_lora(lora_rank: int,
# HACK: skip llama due to long execution time
# HACK: skip llama due to long execution time
# lambda: (LlamaActor(), LlamaCritic(), LlamaRM()),
# lambda: (LlamaActor(), LlamaCritic(), LlamaRM()),
lambda
:
(
OPTActor
(),
OPTCritic
(),
OPTRM
()),
lambda
:
(
OPTActor
(),
OPTCritic
(),
OPTRM
()),
lambda
:
(
ChatGLMActor
(),
None
,
None
),
])
])
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
test_models
(
models_maker
:
Callable
[[],
Tuple
[
Actor
,
Critic
,
RewardModel
]],
def
test_models
(
models_maker
:
Callable
[[],
Tuple
[
Actor
,
Critic
,
RewardModel
]],
batch_size
:
int
,
batch_size
:
int
,
seq_len
:
int
):
seq_len
:
int
):
actor_input
=
{
actor_input
=
{
"input_ids"
:
torch
.
randint
(
0
,
100
,
(
batch_size
,
seq_len
)),
"input_ids"
:
torch
.
randint
(
0
,
100
,
(
batch_size
,
seq_len
)),
"attention_mask"
:
torch
.
randint
(
0
,
2
,
(
batch_size
,
seq_len
))
"attention_mask"
:
torch
.
randint
(
0
,
2
,
(
batch_size
,
seq_len
))
...
@@ -150,20 +152,30 @@ def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]],
...
@@ -150,20 +152,30 @@ def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]],
}
}
actor
,
critic
,
rm
=
models_maker
()
actor
,
critic
,
rm
=
models_maker
()
if
isinstance
(
actor
,
ChatGLMActor
):
actor
=
actor
.
float
()
tokenizer
=
ChatGLMTokenizer
.
from_pretrained
(
"THUDM/chatglm-6b"
,
trust_remote_code
=
True
)
chatglm_special_token
=
torch
.
tensor
([
tokenizer
.
gmask_token_id
,
tokenizer
.
bos_token_id
]).
repeat
(
batch_size
,
1
)
actor_input
=
{
"input_ids"
:
torch
.
cat
((
torch
.
randint
(
0
,
100
,
(
batch_size
,
seq_len
//
2
)),
chatglm_special_token
,
torch
.
randint
(
0
,
100
,
(
batch_size
,
seq_len
//
2
-
2
))),
dim
=
1
),
"attention_mask"
:
torch
.
randint
(
0
,
2
,
(
batch_size
,
1
,
seq_len
,
seq_len
))
}
assert
isinstance
(
actor
,
Actor
)
assert
isinstance
(
actor
,
Actor
)
base_actor_model
=
get_base_model
(
actor
)
base_actor_model
=
get_base_model
(
actor
)
assert
isinstance
(
critic
,
Critic
)
base_critic_model
=
get_base_model
(
critic
)
assert
isinstance
(
rm
,
RewardModel
)
base_rm_model
=
get_base_model
(
rm
)
actor_output
=
actor
(
**
actor_input
)
actor_output
=
actor
(
**
actor_input
)
critic_output
=
critic
(
**
critic_input
)
rm_output
=
rm
(
**
rm_input
)
assert
actor_output
.
logits
.
shape
[:
2
]
==
(
batch_size
,
seq_len
)
assert
actor_output
.
logits
.
shape
[:
2
]
==
(
batch_size
,
seq_len
)
assert
critic_output
.
shape
==
(
batch_size
,
)
assert
rm_output
.
shape
==
(
batch_size
,
)
if
critic
:
assert
isinstance
(
critic
,
Critic
)
base_critic_model
=
get_base_model
(
critic
)
critic_output
=
critic
(
**
critic_input
)
assert
critic_output
.
shape
==
(
batch_size
,
)
if
rm
:
assert
isinstance
(
rm
,
RewardModel
)
base_rm_model
=
get_base_model
(
rm
)
rm_output
=
rm
(
**
rm_input
)
assert
rm_output
.
shape
==
(
batch_size
,
)
@
pytest
.
mark
.
cpu
@
pytest
.
mark
.
cpu
...
@@ -232,4 +244,4 @@ if __name__ == "__main__":
...
@@ -232,4 +244,4 @@ if __name__ == "__main__":
batch_size
=
8
,
batch_size
=
8
,
seq_len
=
128
)
seq_len
=
128
)
test_loss
(
batch_size
=
8
,
seq_len
=
128
,
num_labels
=
100
)
test_loss
(
batch_size
=
8
,
seq_len
=
128
,
num_labels
=
100
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment