Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ChatGLM-6B_pytorch
Commits
7a88fd0a
"vscode:/vscode.git/clone" did not exist on "21e22b9e96c16e3cef9ba3d2aa16df45a20f7a2b"
Commit
7a88fd0a
authored
Aug 14, 2023
by
zhaoying1
Browse files
Update modeling_chatglm.py
parent
0427b1ea
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1477 additions
and
0 deletions
+1477
-0
chatglm-6b testing/modeling_chatglm.py
chatglm-6b testing/modeling_chatglm.py
+1477
-0
No files found.
chatglm-6b testing/modeling_chatglm.py
View file @
7a88fd0a
""" PyTorch ChatGLM model. """
import
math
import
copy
import
os
import
warnings
import
re
import
sys
import
torch
import
torch.utils.checkpoint
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch.nn
import
CrossEntropyLoss
,
LayerNorm
from
torch.nn.utils
import
skip_init
from
typing
import
Optional
,
Tuple
,
Union
,
List
,
Callable
,
Dict
,
Any
from
time
import
perf_counter
import
numpy
as
np
from
transformers.utils
import
(
add_code_sample_docstrings
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
)
from
transformers.modeling_outputs
import
(
BaseModelOutputWithPast
,
CausalLMOutputWithPast
,
BaseModelOutputWithPastAndCrossAttentions
,
)
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.utils
import
logging
from
transformers.generation.logits_process
import
LogitsProcessor
from
transformers.generation.utils
import
LogitsProcessorList
,
StoppingCriteriaList
,
GenerationConfig
,
ModelOutput
from
.configuration_chatglm
import
ChatGLMConfig
# flags required to enable jit fusion kernels
if
sys
.
platform
!=
'darwin'
:
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
torch
.
_C
.
_jit_set_profiling_executor
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
logger
=
logging
.
get_logger
(
__name__
)
_CHECKPOINT_FOR_DOC
=
"THUDM/ChatGLM-6B"
_CONFIG_FOR_DOC
=
"ChatGLM6BConfig"
CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST
=
[
"THUDM/chatglm-6b"
,
# See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm
]
class
InvalidScoreLogitsProcessor
(
LogitsProcessor
):
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
if
torch
.
isnan
(
scores
).
any
()
or
torch
.
isinf
(
scores
).
any
():
scores
.
zero_
()
scores
[...,
5
]
=
5e4
return
scores
def
load_tf_weights_in_chatglm_6b
(
model
,
config
,
tf_checkpoint_path
):
"""Load tf checkpoints in a pytorch model."""
try
:
import
re
import
numpy
as
np
import
tensorflow
as
tf
except
ImportError
:
logger
.
error
(
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
tf_path
=
os
.
path
.
abspath
(
tf_checkpoint_path
)
logger
.
info
(
f
"Converting TensorFlow checkpoint from
{
tf_path
}
"
)
# Load weights from TF model
init_vars
=
tf
.
train
.
list_variables
(
tf_path
)
names
=
[]
arrays
=
[]
for
name
,
shape
in
init_vars
:
logger
.
info
(
f
"Loading TF weight
{
name
}
with shape
{
shape
}
"
)
array
=
tf
.
train
.
load_variable
(
tf_path
,
name
)
names
.
append
(
name
)
arrays
.
append
(
array
)
for
name
,
array
in
zip
(
names
,
arrays
):
name
=
name
.
split
(
"/"
)
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if
any
(
n
in
[
"adam_v"
,
"adam_m"
,
"AdamWeightDecayOptimizer"
,
"AdamWeightDecayOptimizer_1"
,
"global_step"
]
for
n
in
name
):
logger
.
info
(
f
"Skipping
{
'/'
.
join
(
name
)
}
"
)
continue
pointer
=
model
for
m_name
in
name
:
if
re
.
fullmatch
(
r
"[A-Za-z]+_\d+"
,
m_name
):
scope_names
=
re
.
split
(
r
"_(\d+)"
,
m_name
)
else
:
scope_names
=
[
m_name
]
if
scope_names
[
0
]
==
"kernel"
or
scope_names
[
0
]
==
"gamma"
:
pointer
=
getattr
(
pointer
,
"weight"
)
elif
scope_names
[
0
]
==
"output_bias"
or
scope_names
[
0
]
==
"beta"
:
pointer
=
getattr
(
pointer
,
"bias"
)
elif
scope_names
[
0
]
==
"output_weights"
:
pointer
=
getattr
(
pointer
,
"weight"
)
elif
scope_names
[
0
]
==
"squad"
:
pointer
=
getattr
(
pointer
,
"classifier"
)
else
:
try
:
pointer
=
getattr
(
pointer
,
scope_names
[
0
])
except
AttributeError
:
logger
.
info
(
f
"Skipping
{
'/'
.
join
(
name
)
}
"
)
continue
if
len
(
scope_names
)
>=
2
:
num
=
int
(
scope_names
[
1
])
pointer
=
pointer
[
num
]
if
m_name
[
-
11
:]
==
"_embeddings"
:
pointer
=
getattr
(
pointer
,
"weight"
)
elif
m_name
==
"kernel"
:
array
=
np
.
transpose
(
array
)
try
:
assert
(
pointer
.
shape
==
array
.
shape
),
f
"Pointer shape
{
pointer
.
shape
}
and array shape
{
array
.
shape
}
mismatched"
except
AssertionError
as
e
:
e
.
args
+=
(
pointer
.
shape
,
array
.
shape
)
raise
logger
.
info
(
f
"Initialize PyTorch weight
{
name
}
"
)
pointer
.
data
=
torch
.
from_numpy
(
array
)
return
model
class
PrefixEncoder
(
torch
.
nn
.
Module
):
"""
The torch.nn model to encode the prefix
Input shape: (batch-size, prefix-length)
Output shape: (batch-size, prefix-length, 2*layers*hidden)
"""
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
prefix_projection
=
config
.
prefix_projection
if
self
.
prefix_projection
:
# Use a two-layer MLP to encode the prefix
self
.
embedding
=
torch
.
nn
.
Embedding
(
config
.
pre_seq_len
,
config
.
hidden_size
)
self
.
trans
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
),
torch
.
nn
.
Tanh
(),
torch
.
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_layers
*
config
.
hidden_size
*
2
)
)
else
:
self
.
embedding
=
torch
.
nn
.
Embedding
(
config
.
pre_seq_len
,
config
.
num_layers
*
config
.
hidden_size
*
2
)
def
forward
(
self
,
prefix
:
torch
.
Tensor
):
if
self
.
prefix_projection
:
prefix_tokens
=
self
.
embedding
(
prefix
)
past_key_values
=
self
.
trans
(
prefix_tokens
)
else
:
past_key_values
=
self
.
embedding
(
prefix
)
return
past_key_values
@
torch
.
jit
.
script
def
gelu_impl
(
x
):
"""OpenAI's gelu implementation."""
return
0.5
*
x
*
(
1.0
+
torch
.
tanh
(
0.7978845608028654
*
x
*
(
1.0
+
0.044715
*
x
*
x
)))
def
gelu
(
x
):
return
gelu_impl
(
x
)
class
RotaryEmbedding
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
base
=
10000
,
precision
=
torch
.
half
,
learnable
=
False
):
super
().
__init__
()
inv_freq
=
1.
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
()
/
dim
))
inv_freq
=
inv_freq
.
half
()
self
.
learnable
=
learnable
if
learnable
:
self
.
inv_freq
=
torch
.
nn
.
Parameter
(
inv_freq
)
self
.
max_seq_len_cached
=
None
else
:
self
.
register_buffer
(
'inv_freq'
,
inv_freq
)
self
.
max_seq_len_cached
=
None
self
.
cos_cached
=
None
self
.
sin_cached
=
None
self
.
precision
=
precision
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
pass
def
forward
(
self
,
x
,
seq_dim
=
1
,
seq_len
=
None
):
if
seq_len
is
None
:
seq_len
=
x
.
shape
[
seq_dim
]
if
self
.
max_seq_len_cached
is
None
or
(
seq_len
>
self
.
max_seq_len_cached
):
self
.
max_seq_len_cached
=
None
if
self
.
learnable
else
seq_len
t
=
torch
.
arange
(
seq_len
,
device
=
x
.
device
,
dtype
=
self
.
inv_freq
.
dtype
)
freqs
=
torch
.
einsum
(
'i,j->ij'
,
t
,
self
.
inv_freq
)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
).
to
(
x
.
device
)
if
self
.
precision
==
torch
.
bfloat16
:
emb
=
emb
.
float
()
# [sx, 1 (b * np), hn]
cos_cached
=
emb
.
cos
()[:,
None
,
:]
sin_cached
=
emb
.
sin
()[:,
None
,
:]
if
self
.
precision
==
torch
.
bfloat16
:
cos_cached
=
cos_cached
.
bfloat16
()
sin_cached
=
sin_cached
.
bfloat16
()
if
self
.
learnable
:
return
cos_cached
,
sin_cached
self
.
cos_cached
,
self
.
sin_cached
=
cos_cached
,
sin_cached
return
self
.
cos_cached
[:
seq_len
,
...],
self
.
sin_cached
[:
seq_len
,
...]
def
_apply
(
self
,
fn
):
if
self
.
cos_cached
is
not
None
:
self
.
cos_cached
=
fn
(
self
.
cos_cached
)
if
self
.
sin_cached
is
not
None
:
self
.
sin_cached
=
fn
(
self
.
sin_cached
)
return
super
().
_apply
(
fn
)
def
rotate_half
(
x
):
x1
,
x2
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
],
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=
x1
.
ndim
-
1
)
# dim=-1 triggers a bug in earlier torch versions
@
torch
.
jit
.
script
def
apply_rotary_pos_emb_index
(
q
,
k
,
cos
,
sin
,
position_id
):
# position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
cos
,
sin
=
F
.
embedding
(
position_id
,
cos
.
squeeze
(
1
)).
unsqueeze
(
2
),
\
F
.
embedding
(
position_id
,
sin
.
squeeze
(
1
)).
unsqueeze
(
2
)
q
,
k
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
),
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q
,
k
def
attention_fn
(
self
,
query_layer
,
key_layer
,
value_layer
,
attention_mask
,
hidden_size_per_partition
,
layer_id
,
layer_past
=
None
,
scaling_attention_score
=
True
,
use_cache
=
False
,
):
if
layer_past
is
not
None
:
past_key
,
past_value
=
layer_past
[
0
],
layer_past
[
1
]
key_layer
=
torch
.
cat
((
past_key
,
key_layer
),
dim
=
0
)
value_layer
=
torch
.
cat
((
past_value
,
value_layer
),
dim
=
0
)
# seqlen, batch, num_attention_heads, hidden_size_per_attention_head
seq_len
,
b
,
nh
,
hidden_size
=
key_layer
.
shape
if
use_cache
:
present
=
(
key_layer
,
value_layer
)
else
:
present
=
None
query_key_layer_scaling_coeff
=
float
(
layer_id
+
1
)
if
scaling_attention_score
:
query_layer
=
query_layer
/
(
math
.
sqrt
(
hidden_size
)
*
query_key_layer_scaling_coeff
)
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size
=
(
query_layer
.
size
(
1
),
query_layer
.
size
(
2
),
query_layer
.
size
(
0
),
key_layer
.
size
(
0
))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer
=
query_layer
.
view
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer
=
key_layer
.
view
(
output_size
[
3
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
matmul_result
=
torch
.
zeros
(
1
,
1
,
1
,
dtype
=
query_layer
.
dtype
,
device
=
query_layer
.
device
,
)
matmul_result
=
torch
.
baddbmm
(
matmul_result
,
query_layer
.
transpose
(
0
,
1
),
# [b * np, sq, hn]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
# [b * np, hn, sk]
beta
=
0.0
,
alpha
=
1.0
,
)
# change view to [b, np, sq, sk]
attention_scores
=
matmul_result
.
view
(
*
output_size
)
if
self
.
scale_mask_softmax
:
self
.
scale_mask_softmax
.
scale
=
query_key_layer_scaling_coeff
attention_probs
=
self
.
scale_mask_softmax
(
attention_scores
,
attention_mask
.
contiguous
())
else
:
if
not
(
attention_mask
==
0
).
all
():
# if auto-regressive, skip
attention_scores
.
masked_fill_
(
attention_mask
,
-
10000.0
)
dtype
=
attention_scores
.
dtype
attention_scores
=
attention_scores
.
float
()
attention_scores
=
attention_scores
*
query_key_layer_scaling_coeff
attention_probs
=
F
.
softmax
(
attention_scores
,
dim
=-
1
)
attention_probs
=
attention_probs
.
type
(
dtype
)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size
=
(
value_layer
.
size
(
1
),
value_layer
.
size
(
2
),
query_layer
.
size
(
0
),
value_layer
.
size
(
3
))
# change view [sk, b * np, hn]
value_layer
=
value_layer
.
view
(
value_layer
.
size
(
0
),
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# change view [b * np, sq, sk]
attention_probs
=
attention_probs
.
view
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
-
1
)
# matmul: [b * np, sq, hn]
context_layer
=
torch
.
bmm
(
attention_probs
,
value_layer
.
transpose
(
0
,
1
))
# change view [b, np, sq, hn]
context_layer
=
context_layer
.
view
(
*
output_size
)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer
=
context_layer
.
permute
(
2
,
0
,
1
,
3
).
contiguous
()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
hidden_size_per_partition
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
outputs
=
(
context_layer
,
present
,
attention_probs
)
return
outputs
def
default_init
(
cls
,
*
args
,
**
kwargs
):
return
cls
(
*
args
,
**
kwargs
)
class
SelfAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
num_attention_heads
,
layer_id
,
hidden_size_per_attention_head
=
None
,
bias
=
True
,
params_dtype
=
torch
.
float
,
position_encoding_2d
=
True
,
empty_init
=
True
):
if
empty_init
:
init_method
=
skip_init
else
:
init_method
=
default_init
super
(
SelfAttention
,
self
).
__init__
()
self
.
layer_id
=
layer_id
self
.
hidden_size
=
hidden_size
self
.
hidden_size_per_partition
=
hidden_size
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_heads_per_partition
=
num_attention_heads
self
.
position_encoding_2d
=
position_encoding_2d
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
hidden_size
//
(
self
.
num_attention_heads
*
2
)
if
position_encoding_2d
else
self
.
hidden_size
//
self
.
num_attention_heads
,
base
=
10000
,
precision
=
torch
.
half
,
learnable
=
False
,
)
self
.
scale_mask_softmax
=
None
if
hidden_size_per_attention_head
is
None
:
self
.
hidden_size_per_attention_head
=
hidden_size
//
num_attention_heads
else
:
self
.
hidden_size_per_attention_head
=
hidden_size_per_attention_head
self
.
inner_hidden_size
=
num_attention_heads
*
self
.
hidden_size_per_attention_head
# Strided linear layer.
self
.
query_key_value
=
init_method
(
torch
.
nn
.
Linear
,
hidden_size
,
3
*
self
.
inner_hidden_size
,
bias
=
bias
,
dtype
=
params_dtype
,
)
self
.
dense
=
init_method
(
torch
.
nn
.
Linear
,
self
.
inner_hidden_size
,
hidden_size
,
bias
=
bias
,
dtype
=
params_dtype
,
)
@
staticmethod
def
attention_mask_func
(
attention_scores
,
attention_mask
):
attention_scores
.
masked_fill_
(
attention_mask
,
-
10000.0
)
return
attention_scores
def
split_tensor_along_last_dim
(
self
,
tensor
,
num_partitions
,
contiguous_split_chunks
=
False
):
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
"""
# Get the size and dimension.
last_dim
=
tensor
.
dim
()
-
1
last_dim_size
=
tensor
.
size
()[
last_dim
]
//
num_partitions
# Split.
tensor_list
=
torch
.
split
(
tensor
,
last_dim_size
,
dim
=
last_dim
)
# Note: torch.split does not create contiguous tensors by default.
if
contiguous_split_chunks
:
return
tuple
(
chunk
.
contiguous
()
for
chunk
in
tensor_list
)
return
tensor_list
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
,
attention_mask
:
torch
.
Tensor
,
layer_id
,
layer_past
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
use_cache
:
bool
=
False
,
output_attentions
:
bool
=
False
,
):
"""
hidden_states: [seq_len, batch, hidden_size]
attention_mask: [(1, 1), seq_len, seq_len]
"""
# [seq_len, batch, 3 * hidden_size]
mixed_raw_layer
=
self
.
query_key_value
(
hidden_states
)
# [seq_len, batch, 3 * hidden_size] --> [seq_len, batch, num_attention_heads, 3 * hidden_size_per_attention_head]
new_tensor_shape
=
mixed_raw_layer
.
size
()[:
-
1
]
+
(
self
.
num_attention_heads_per_partition
,
3
*
self
.
hidden_size_per_attention_head
,
)
mixed_raw_layer
=
mixed_raw_layer
.
view
(
*
new_tensor_shape
)
# [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]
(
query_layer
,
key_layer
,
value_layer
)
=
self
.
split_tensor_along_last_dim
(
mixed_raw_layer
,
3
)
if
self
.
position_encoding_2d
:
q1
,
q2
=
query_layer
.
chunk
(
2
,
dim
=
(
query_layer
.
ndim
-
1
))
k1
,
k2
=
key_layer
.
chunk
(
2
,
dim
=
(
key_layer
.
ndim
-
1
))
cos
,
sin
=
self
.
rotary_emb
(
q1
,
seq_len
=
position_ids
.
max
()
+
1
)
position_ids
,
block_position_ids
=
position_ids
[:,
0
,
:].
transpose
(
0
,
1
).
contiguous
(),
\
position_ids
[:,
1
,
:].
transpose
(
0
,
1
).
contiguous
()
q1
,
k1
=
apply_rotary_pos_emb_index
(
q1
,
k1
,
cos
,
sin
,
position_ids
)
q2
,
k2
=
apply_rotary_pos_emb_index
(
q2
,
k2
,
cos
,
sin
,
block_position_ids
)
query_layer
=
torch
.
concat
([
q1
,
q2
],
dim
=
(
q1
.
ndim
-
1
))
key_layer
=
torch
.
concat
([
k1
,
k2
],
dim
=
(
k1
.
ndim
-
1
))
else
:
position_ids
=
position_ids
.
transpose
(
0
,
1
)
cos
,
sin
=
self
.
rotary_emb
(
value_layer
,
seq_len
=
position_ids
.
max
()
+
1
)
# [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]
query_layer
,
key_layer
=
apply_rotary_pos_emb_index
(
query_layer
,
key_layer
,
cos
,
sin
,
position_ids
)
# [seq_len, batch, hidden_size]
context_layer
,
present
,
attention_probs
=
attention_fn
(
self
=
self
,
query_layer
=
query_layer
,
key_layer
=
key_layer
,
value_layer
=
value_layer
,
attention_mask
=
attention_mask
,
hidden_size_per_partition
=
self
.
hidden_size_per_partition
,
layer_id
=
layer_id
,
layer_past
=
layer_past
,
use_cache
=
use_cache
)
output
=
self
.
dense
(
context_layer
)
outputs
=
(
output
,
present
)
if
output_attentions
:
outputs
+=
(
attention_probs
,)
return
outputs
# output, present, attention_probs
class
GEGLU
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
activation_fn
=
F
.
gelu
def
forward
(
self
,
x
):
# dim=-1 breaks in jit for pt<1.10
x1
,
x2
=
x
.
chunk
(
2
,
dim
=
(
x
.
ndim
-
1
))
return
x1
*
self
.
activation_fn
(
x2
)
class
GLU
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
inner_hidden_size
=
None
,
layer_id
=
None
,
bias
=
True
,
activation_func
=
gelu
,
params_dtype
=
torch
.
float
,
empty_init
=
True
):
super
(
GLU
,
self
).
__init__
()
if
empty_init
:
init_method
=
skip_init
else
:
init_method
=
default_init
self
.
layer_id
=
layer_id
self
.
activation_func
=
activation_func
# Project to 4h.
self
.
hidden_size
=
hidden_size
if
inner_hidden_size
is
None
:
inner_hidden_size
=
4
*
hidden_size
self
.
inner_hidden_size
=
inner_hidden_size
self
.
dense_h_to_4h
=
init_method
(
torch
.
nn
.
Linear
,
self
.
hidden_size
,
self
.
inner_hidden_size
,
bias
=
bias
,
dtype
=
params_dtype
,
)
# Project back to h.
self
.
dense_4h_to_h
=
init_method
(
torch
.
nn
.
Linear
,
self
.
inner_hidden_size
,
self
.
hidden_size
,
bias
=
bias
,
dtype
=
params_dtype
,
)
def
forward
(
self
,
hidden_states
):
"""
hidden_states: [seq_len, batch, hidden_size]
"""
# [seq_len, batch, inner_hidden_size]
intermediate_parallel
=
self
.
dense_h_to_4h
(
hidden_states
)
intermediate_parallel
=
self
.
activation_func
(
intermediate_parallel
)
output
=
self
.
dense_4h_to_h
(
intermediate_parallel
)
return
output
class
GLMBlock
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
num_attention_heads
,
layernorm_epsilon
,
layer_id
,
inner_hidden_size
=
None
,
hidden_size_per_attention_head
=
None
,
layernorm
=
LayerNorm
,
use_bias
=
True
,
params_dtype
=
torch
.
float
,
num_layers
=
28
,
position_encoding_2d
=
True
,
empty_init
=
True
):
super
(
GLMBlock
,
self
).
__init__
()
# Set output layer initialization if not provided.
self
.
layer_id
=
layer_id
# Layernorm on the input data.
self
.
input_layernorm
=
layernorm
(
hidden_size
,
eps
=
layernorm_epsilon
)
self
.
position_encoding_2d
=
position_encoding_2d
# Self attention.
self
.
attention
=
SelfAttention
(
hidden_size
,
num_attention_heads
,
layer_id
,
hidden_size_per_attention_head
=
hidden_size_per_attention_head
,
bias
=
use_bias
,
params_dtype
=
params_dtype
,
position_encoding_2d
=
self
.
position_encoding_2d
,
empty_init
=
empty_init
)
# Layernorm on the input data.
self
.
post_attention_layernorm
=
layernorm
(
hidden_size
,
eps
=
layernorm_epsilon
)
self
.
num_layers
=
num_layers
# GLU
self
.
mlp
=
GLU
(
hidden_size
,
inner_hidden_size
=
inner_hidden_size
,
bias
=
use_bias
,
layer_id
=
layer_id
,
params_dtype
=
params_dtype
,
empty_init
=
empty_init
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
,
attention_mask
:
torch
.
Tensor
,
layer_id
,
layer_past
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
use_cache
:
bool
=
False
,
output_attentions
:
bool
=
False
,
):
"""
hidden_states: [seq_len, batch, hidden_size]
attention_mask: [(1, 1), seq_len, seq_len]
"""
# Layer norm at the begining of the transformer layer.
# [seq_len, batch, hidden_size]
attention_input
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
attention_outputs
=
self
.
attention
(
attention_input
,
position_ids
,
attention_mask
=
attention_mask
,
layer_id
=
layer_id
,
layer_past
=
layer_past
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
)
attention_output
=
attention_outputs
[
0
]
outputs
=
attention_outputs
[
1
:]
# Residual connection.
alpha
=
(
2
*
self
.
num_layers
)
**
0.5
hidden_states
=
attention_input
*
alpha
+
attention_output
mlp_input
=
self
.
post_attention_layernorm
(
hidden_states
)
# MLP.
mlp_output
=
self
.
mlp
(
mlp_input
)
# Second residual connection.
output
=
mlp_input
*
alpha
+
mlp_output
if
use_cache
:
outputs
=
(
output
,)
+
outputs
else
:
outputs
=
(
output
,)
+
outputs
[
1
:]
return
outputs
# hidden_states, present, attentions
class
ChatGLMPreTrainedModel
(
PreTrainedModel
):
"""
An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
is_parallelizable
=
False
supports_gradient_checkpointing
=
True
config_class
=
ChatGLMConfig
base_model_prefix
=
"transformer"
_no_split_modules
=
[
"GLMBlock"
]
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
().
__init__
(
*
inputs
,
**
kwargs
)
def
_init_weights
(
self
,
module
:
nn
.
Module
):
"""Initialize the weights."""
return
def
get_masks
(
self
,
input_ids
,
device
):
batch_size
,
seq_length
=
input_ids
.
shape
context_lengths
=
[
seq
.
tolist
().
index
(
self
.
config
.
bos_token_id
)
for
seq
in
input_ids
]
attention_mask
=
torch
.
ones
((
batch_size
,
seq_length
,
seq_length
),
device
=
device
)
attention_mask
.
tril_
()
for
i
,
context_length
in
enumerate
(
context_lengths
):
attention_mask
[
i
,
:,
:
context_length
]
=
1
attention_mask
.
unsqueeze_
(
1
)
attention_mask
=
(
attention_mask
<
0.5
).
bool
()
return
attention_mask
def
get_position_ids
(
self
,
input_ids
,
mask_positions
,
device
,
use_gmasks
=
None
):
batch_size
,
seq_length
=
input_ids
.
shape
if
use_gmasks
is
None
:
use_gmasks
=
[
False
]
*
batch_size
context_lengths
=
[
seq
.
tolist
().
index
(
self
.
config
.
bos_token_id
)
for
seq
in
input_ids
]
if
self
.
position_encoding_2d
:
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
device
).
unsqueeze
(
0
).
repeat
(
batch_size
,
1
)
for
i
,
context_length
in
enumerate
(
context_lengths
):
position_ids
[
i
,
context_length
:]
=
mask_positions
[
i
]
block_position_ids
=
[
torch
.
cat
((
torch
.
zeros
(
context_length
,
dtype
=
torch
.
long
,
device
=
device
),
torch
.
arange
(
seq_length
-
context_length
,
dtype
=
torch
.
long
,
device
=
device
)
+
1
))
for
context_length
in
context_lengths
]
block_position_ids
=
torch
.
stack
(
block_position_ids
,
dim
=
0
)
position_ids
=
torch
.
stack
((
position_ids
,
block_position_ids
),
dim
=
1
)
else
:
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
device
).
unsqueeze
(
0
).
repeat
(
batch_size
,
1
)
for
i
,
context_length
in
enumerate
(
context_lengths
):
if
not
use_gmasks
[
i
]:
position_ids
[
i
,
context_length
:]
=
mask_positions
[
i
]
return
position_ids
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
ChatGLMModel
):
module
.
gradient_checkpointing
=
value
CHATGLM_6B_START_DOCSTRING
=
r
"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
usage and behavior.
Parameters:
config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
CHATGLM_6B_INPUTS_DOCSTRING
=
r
"""
Args:
input_ids (`torch.LongTensor` of shape `({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`ChatGLM6BTokenizer`].
See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`:
- 0 corresponds to a *sentence A* token,
- 1 corresponds to a *sentence B* token.
[What are token type IDs?](../glossary#token-type-ids)
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range `[0, config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert *input_ids* indices into associated vectors
than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@
add_start_docstrings
(
"The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top."
,
CHATGLM_6B_START_DOCSTRING
,
)
class
ChatGLMModel
(
ChatGLMPreTrainedModel
):
"""
The model can behave as an encoder (with only self-attention) as well
as a decoder, in which case a layer of cross-attention is added between
the self-attention layers, following the architecture described in [Attention is
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani,
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
To behave as an decoder the model needs to be initialized with the
`is_decoder` argument of the configuration set to `True`.
To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder`
argument and `add_cross_attention` set to `True`; an
`encoder_hidden_states` is then expected as an input to the forward pass.
"""
def
__init__
(
self
,
config
:
ChatGLMConfig
,
empty_init
=
True
):
super
().
__init__
(
config
)
if
empty_init
:
init_method
=
skip_init
else
:
init_method
=
default_init
# recording parameters
self
.
max_sequence_length
=
config
.
max_sequence_length
self
.
hidden_size
=
config
.
hidden_size
self
.
params_dtype
=
torch
.
half
self
.
num_attention_heads
=
config
.
num_attention_heads
self
.
vocab_size
=
config
.
vocab_size
self
.
num_layers
=
config
.
num_layers
self
.
layernorm_epsilon
=
config
.
layernorm_epsilon
self
.
inner_hidden_size
=
config
.
inner_hidden_size
self
.
hidden_size_per_attention_head
=
self
.
hidden_size
//
self
.
num_attention_heads
self
.
position_encoding_2d
=
config
.
position_encoding_2d
self
.
pre_seq_len
=
config
.
pre_seq_len
self
.
prefix_projection
=
config
.
prefix_projection
self
.
word_embeddings
=
init_method
(
torch
.
nn
.
Embedding
,
num_embeddings
=
self
.
vocab_size
,
embedding_dim
=
self
.
hidden_size
,
dtype
=
self
.
params_dtype
)
self
.
gradient_checkpointing
=
False
def
get_layer
(
layer_id
):
return
GLMBlock
(
self
.
hidden_size
,
self
.
num_attention_heads
,
self
.
layernorm_epsilon
,
layer_id
,
inner_hidden_size
=
self
.
inner_hidden_size
,
hidden_size_per_attention_head
=
self
.
hidden_size_per_attention_head
,
layernorm
=
LayerNorm
,
use_bias
=
True
,
params_dtype
=
self
.
params_dtype
,
position_encoding_2d
=
self
.
position_encoding_2d
,
empty_init
=
empty_init
)
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
get_layer
(
layer_id
)
for
layer_id
in
range
(
self
.
num_layers
)]
)
# Final layer norm before output.
self
.
final_layernorm
=
LayerNorm
(
self
.
hidden_size
,
eps
=
self
.
layernorm_epsilon
)
if
self
.
pre_seq_len
is
not
None
:
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
self
.
prefix_tokens
=
torch
.
arange
(
self
.
pre_seq_len
).
long
()
self
.
prefix_encoder
=
PrefixEncoder
(
config
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
0.1
)
# total_params = sum(p.numel() for p in self.parameters())
# trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
# print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params))
def
get_input_embeddings
(
self
):
return
self
.
word_embeddings
def
set_input_embeddings
(
self
,
new_embeddings
:
torch
.
Tensor
):
self
.
word_embeddings
=
new_embeddings
def
get_prompt
(
self
,
batch_size
,
device
,
dtype
=
torch
.
half
):
prefix_tokens
=
self
.
prefix_tokens
.
unsqueeze
(
0
).
expand
(
batch_size
,
-
1
).
to
(
device
)
past_key_values
=
self
.
prefix_encoder
(
prefix_tokens
).
type
(
dtype
)
past_key_values
=
past_key_values
.
view
(
batch_size
,
self
.
pre_seq_len
,
self
.
num_layers
*
2
,
self
.
num_attention_heads
,
self
.
hidden_size
//
self
.
num_attention_heads
)
# seq_len, b, nh, hidden_size
past_key_values
=
self
.
dropout
(
past_key_values
)
past_key_values
=
past_key_values
.
permute
([
2
,
1
,
0
,
3
,
4
]).
split
(
2
)
# past_key_values = [(v[0], v[1]) for v in past_key_values]
return
past_key_values
@
add_start_docstrings_to_model_forward
(
CHATGLM_6B_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_code_sample_docstrings
(
checkpoint
=
_CHECKPOINT_FOR_DOC
,
output_type
=
BaseModelOutputWithPastAndCrossAttentions
,
config_class
=
_CONFIG_FOR_DOC
,
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
[
torch
.
Tensor
,
...],
BaseModelOutputWithPast
]:
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
use_cache
=
use_cache
if
use_cache
is
not
None
else
self
.
config
.
use_cache
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
input_ids
is
not
None
:
batch_size
,
seq_length
=
input_ids
.
shape
[:
2
]
elif
inputs_embeds
is
not
None
:
batch_size
,
seq_length
=
inputs_embeds
.
shape
[:
2
]
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
if
past_key_values
is
None
:
if
self
.
pre_seq_len
is
not
None
:
past_key_values
=
self
.
get_prompt
(
batch_size
=
input_ids
.
shape
[
0
],
device
=
input_ids
.
device
,
dtype
=
inputs_embeds
.
dtype
)
else
:
past_key_values
=
tuple
([
None
]
*
len
(
self
.
layers
))
if
attention_mask
is
None
:
attention_mask
=
self
.
get_masks
(
input_ids
,
device
=
input_ids
.
device
)
if
position_ids
is
None
:
MASK
,
gMASK
=
self
.
config
.
mask_token_id
,
self
.
config
.
gmask_token_id
seqs
=
input_ids
.
tolist
()
mask_positions
,
use_gmasks
=
[],
[]
for
seq
in
seqs
:
mask_token
=
gMASK
if
gMASK
in
seq
else
MASK
use_gmask
=
mask_token
==
gMASK
mask_positions
.
append
(
seq
.
index
(
mask_token
))
use_gmasks
.
append
(
use_gmask
)
position_ids
=
self
.
get_position_ids
(
input_ids
,
mask_positions
=
mask_positions
,
device
=
input_ids
.
device
,
use_gmasks
=
use_gmasks
)
if
self
.
pre_seq_len
is
not
None
and
attention_mask
is
not
None
:
prefix_attention_mask
=
torch
.
ones
(
batch_size
,
1
,
input_ids
.
size
(
-
1
),
self
.
pre_seq_len
).
to
(
attention_mask
.
device
)
prefix_attention_mask
=
(
prefix_attention_mask
<
0.5
).
bool
()
attention_mask
=
torch
.
cat
((
prefix_attention_mask
,
attention_mask
),
dim
=
3
)
# [seq_len, batch, hidden_size]
hidden_states
=
inputs_embeds
.
transpose
(
0
,
1
)
presents
=
()
if
use_cache
else
None
all_self_attentions
=
()
if
output_attentions
else
None
all_hidden_states
=
()
if
output_hidden_states
else
None
if
attention_mask
is
None
:
attention_mask
=
torch
.
zeros
(
1
,
1
,
device
=
input_ids
.
device
).
bool
()
else
:
attention_mask
=
attention_mask
.
to
(
hidden_states
.
device
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
if
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
layer_past
=
past_key_values
[
i
]
if
self
.
gradient_checkpointing
and
self
.
training
:
layer_ret
=
torch
.
utils
.
checkpoint
.
checkpoint
(
layer
,
hidden_states
,
position_ids
,
attention_mask
,
torch
.
tensor
(
i
),
layer_past
,
use_cache
,
output_attentions
)
else
:
layer_ret
=
layer
(
hidden_states
,
position_ids
=
position_ids
,
attention_mask
=
attention_mask
,
layer_id
=
torch
.
tensor
(
i
),
layer_past
=
layer_past
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
)
hidden_states
=
layer_ret
[
0
]
if
use_cache
:
presents
=
presents
+
(
layer_ret
[
1
],)
if
output_attentions
:
all_self_attentions
=
all_self_attentions
+
(
layer_ret
[
2
if
use_cache
else
1
],)
# Final layer norm.
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
if
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
presents
,
all_hidden_states
,
all_self_attentions
]
if
v
is
not
None
)
return
BaseModelOutputWithPast
(
last_hidden_state
=
hidden_states
,
past_key_values
=
presents
,
hidden_states
=
all_hidden_states
,
attentions
=
all_self_attentions
,
)
class
ChatGLMForConditionalGeneration
(
ChatGLMPreTrainedModel
):
def
__init__
(
self
,
config
:
ChatGLMConfig
,
empty_init
=
True
):
super
().
__init__
(
config
)
if
empty_init
:
init_method
=
skip_init
else
:
init_method
=
default_init
# self.hidden_size = config.hidden_size
# self.params_dtype = torch.half
# self.vocab_size = config.vocab_size
self
.
max_sequence_length
=
config
.
max_sequence_length
self
.
position_encoding_2d
=
config
.
position_encoding_2d
self
.
transformer
=
ChatGLMModel
(
config
,
empty_init
=
empty_init
)
self
.
lm_head
=
init_method
(
nn
.
Linear
,
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
,
dtype
=
torch
.
half
)
self
.
config
=
config
self
.
quantized
=
False
if
self
.
config
.
quantization_bit
:
self
.
quantize
(
self
.
config
.
quantization_bit
,
empty_init
=
True
)
def
get_output_embeddings
(
self
):
return
self
.
lm_head
def
set_output_embeddings
(
self
,
new_embeddings
):
self
.
lm_head
=
new_embeddings
def
_update_model_kwargs_for_generation
(
self
,
outputs
:
ModelOutput
,
model_kwargs
:
Dict
[
str
,
Any
],
is_encoder_decoder
:
bool
=
False
,
standardize_cache_format
:
bool
=
False
,
)
->
Dict
[
str
,
Any
]:
# update past_key_values
model_kwargs
[
"past_key_values"
]
=
self
.
_extract_past_from_model_output
(
outputs
,
standardize_cache_format
=
standardize_cache_format
)
# update attention mask
if
"attention_mask"
in
model_kwargs
:
attention_mask
=
model_kwargs
[
"attention_mask"
]
if
attention_mask
is
not
None
and
attention_mask
.
dtype
==
torch
.
bool
:
attention_mask
=
torch
.
cat
(
[
attention_mask
,
attention_mask
.
new_ones
((
*
attention_mask
.
shape
[:
3
],
1
))],
dim
=
3
)
new_attention_mask
=
attention_mask
[:,
:,
-
1
:].
clone
()
new_attention_mask
[...,
-
1
]
=
False
model_kwargs
[
"attention_mask"
]
=
torch
.
cat
(
[
attention_mask
,
new_attention_mask
],
dim
=
2
)
# update position ids
if
"position_ids"
in
model_kwargs
:
position_ids
=
model_kwargs
[
"position_ids"
]
new_position_id
=
position_ids
[...,
-
1
:].
clone
()
new_position_id
[:,
1
,
:]
+=
1
model_kwargs
[
"position_ids"
]
=
torch
.
cat
(
[
position_ids
,
new_position_id
],
dim
=-
1
)
return
model_kwargs
def
prepare_inputs_for_generation
(
self
,
input_ids
:
torch
.
LongTensor
,
past
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
)
->
dict
:
batch_size
,
seq_length
=
input_ids
.
shape
MASK
,
gMASK
=
self
.
config
.
mask_token_id
,
self
.
config
.
gmask_token_id
seqs
=
input_ids
.
tolist
()
mask_positions
,
use_gmasks
=
[],
[]
for
seq
in
seqs
:
mask_token
=
gMASK
if
gMASK
in
seq
else
MASK
use_gmask
=
mask_token
==
gMASK
mask_positions
.
append
(
seq
.
index
(
mask_token
))
use_gmasks
.
append
(
use_gmask
)
# only last token for input_ids if past is not None
if
past
is
not
None
or
past_key_values
is
not
None
:
last_token
=
input_ids
[:,
-
1
].
unsqueeze
(
-
1
)
if
attention_mask
is
not
None
and
attention_mask
.
dtype
==
torch
.
bool
:
attention_mask
=
attention_mask
[:,
:,
-
1
:]
else
:
attention_mask
=
None
if
position_ids
is
not
None
:
position_ids
=
position_ids
[...,
-
1
:]
else
:
context_lengths
=
[
seq
.
index
(
self
.
config
.
bos_token_id
)
for
seq
in
seqs
]
if
self
.
position_encoding_2d
:
position_ids
=
torch
.
tensor
(
[[
mask_position
,
seq_length
-
context_length
]
for
mask_position
,
context_length
in
zip
(
mask_positions
,
context_lengths
)],
dtype
=
torch
.
long
,
device
=
input_ids
.
device
).
unsqueeze
(
-
1
)
else
:
position_ids
=
torch
.
tensor
([
mask_position
for
mask_position
in
mask_positions
],
dtype
=
torch
.
long
,
device
=
input_ids
.
device
).
unsqueeze
(
-
1
)
if
past
is
None
:
past
=
past_key_values
return
{
"input_ids"
:
last_token
,
"past_key_values"
:
past
,
"position_ids"
:
position_ids
,
"attention_mask"
:
attention_mask
}
else
:
if
attention_mask
is
not
None
and
attention_mask
.
dtype
!=
torch
.
bool
:
logger
.
warning_once
(
f
"The dtype of attention mask (
{
attention_mask
.
dtype
}
) is not bool"
)
attention_mask
=
None
if
attention_mask
is
None
:
attention_mask
=
self
.
get_masks
(
input_ids
,
device
=
input_ids
.
device
)
if
position_ids
is
None
:
position_ids
=
self
.
get_position_ids
(
input_ids
,
device
=
input_ids
.
device
,
mask_positions
=
mask_positions
,
use_gmasks
=
use_gmasks
)
return
{
"input_ids"
:
input_ids
,
"past_key_values"
:
past
,
"position_ids"
:
position_ids
,
"attention_mask"
:
attention_mask
}
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
labels
:
Optional
[
torch
.
Tensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
):
use_cache
=
use_cache
if
use_cache
is
not
None
else
self
.
config
.
use_cache
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
transformer_outputs
=
self
.
transformer
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs_embeds
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
hidden_states
=
transformer_outputs
[
0
]
lm_logits
=
self
.
lm_head
(
hidden_states
).
permute
(
1
,
0
,
2
).
contiguous
()
loss
=
None
if
labels
is
not
None
:
lm_logits
=
lm_logits
.
to
(
torch
.
float32
)
# Shift so that tokens < n predict n
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
labels
[...,
1
:].
contiguous
()
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
100
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
lm_logits
=
lm_logits
.
to
(
hidden_states
.
dtype
)
loss
=
loss
.
to
(
hidden_states
.
dtype
)
if
not
return_dict
:
output
=
(
lm_logits
,)
+
transformer_outputs
[
1
:]
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
return
CausalLMOutputWithPast
(
loss
=
loss
,
logits
=
lm_logits
,
past_key_values
=
transformer_outputs
.
past_key_values
,
hidden_states
=
transformer_outputs
.
hidden_states
,
attentions
=
transformer_outputs
.
attentions
,
)
@
staticmethod
def
_reorder_cache
(
past
:
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...],
beam_idx
:
torch
.
LongTensor
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
Output shares the same memory storage as `past`.
"""
return
tuple
(
(
layer_past
[
0
].
index_select
(
1
,
beam_idx
.
to
(
layer_past
[
0
].
device
)),
layer_past
[
1
].
index_select
(
1
,
beam_idx
.
to
(
layer_past
[
1
].
device
)),
)
for
layer_past
in
past
)
def
process_response
(
self
,
response
):
response
=
response
.
strip
()
response
=
response
.
replace
(
"[[训练时间]]"
,
"2023年"
)
punkts
=
[
[
","
,
","
],
[
"!"
,
"!"
],
[
":"
,
":"
],
[
";"
,
";"
],
[
"\?"
,
"?"
],
]
for
item
in
punkts
:
response
=
re
.
sub
(
r
"([\u4e00-\u9fff])%s"
%
item
[
0
],
r
"\1%s"
%
item
[
1
],
response
)
response
=
re
.
sub
(
r
"%s([\u4e00-\u9fff])"
%
item
[
0
],
r
"%s\1"
%
item
[
1
],
response
)
return
response
@
torch
.
no_grad
()
def
chat
(
self
,
tokenizer
,
query
:
str
,
history
:
List
[
Tuple
[
str
,
str
]]
=
None
,
max_length
:
int
=
2048
,
num_beams
=
1
,
do_sample
=
True
,
top_p
=
0.7
,
temperature
=
0.95
,
logits_processor
=
None
,
**
kwargs
):
if
history
is
None
:
history
=
[]
if
logits_processor
is
None
:
logits_processor
=
LogitsProcessorList
()
logits_processor
.
append
(
InvalidScoreLogitsProcessor
())
gen_kwargs
=
{
"max_length"
:
max_length
,
"num_beams"
:
num_beams
,
"do_sample"
:
do_sample
,
"top_p"
:
top_p
,
"temperature"
:
temperature
,
"logits_processor"
:
logits_processor
,
**
kwargs
}
if
not
history
:
prompt
=
query
else
:
prompt
=
""
for
i
,
(
old_query
,
response
)
in
enumerate
(
history
):
prompt
+=
"[Round {}]
\n
问:{}
\n
答:{}
\n
"
.
format
(
i
,
old_query
,
response
)
prompt
+=
"[Round {}]
\n
问:{}
\n
答:"
.
format
(
len
(
history
),
query
)
inputs
=
tokenizer
([
prompt
],
return_tensors
=
"pt"
)
inputs
=
inputs
.
to
(
self
.
device
)
outputs
=
self
.
generate
(
**
inputs
,
**
gen_kwargs
)
outputs
=
outputs
.
tolist
()[
0
][
len
(
inputs
[
"input_ids"
][
0
]):]
response
=
tokenizer
.
decode
(
outputs
)
response
=
self
.
process_response
(
response
)
history
=
history
+
[(
query
,
response
)]
return
response
,
history
@
torch
.
no_grad
()
def
measure_latency
(
self
,
tokenizer
,
query
:
str
,
history
:
List
[
Tuple
[
str
,
str
]]
=
None
,
max_length
:
int
=
2048
,
num_beams
=
1
,
do_sample
=
True
,
top_p
=
0.7
,
temperature
=
0.95
,
logits_processor
=
None
,
**
kwargs
):
if
history
is
None
:
history
=
[]
if
logits_processor
is
None
:
logits_processor
=
LogitsProcessorList
()
logits_processor
.
append
(
InvalidScoreLogitsProcessor
())
gen_kwargs
=
{
"max_length"
:
max_length
,
"num_beams"
:
num_beams
,
"do_sample"
:
do_sample
,
"top_p"
:
top_p
,
"temperature"
:
temperature
,
"logits_processor"
:
logits_processor
,
**
kwargs
}
if
not
history
:
prompt
=
query
else
:
prompt
=
""
for
i
,
(
old_query
,
response
)
in
enumerate
(
history
):
prompt
+=
"[Round {}]
\n
问:{}
\n
答:{}
\n
"
.
format
(
i
,
old_query
,
response
)
prompt
+=
"[Round {}]
\n
问:{}
\n
答:"
.
format
(
len
(
history
),
query
)
inputs
=
tokenizer
([
prompt
],
return_tensors
=
"pt"
)
inputs
=
inputs
.
to
(
self
.
device
)
print
(
inputs
)
latencies
=
[]
for
_
in
range
(
2
):
_
=
self
.
generate
(
**
inputs
,
**
gen_kwargs
)
for
_
in
range
(
10
):
start_time
=
perf_counter
()
outputs
=
self
.
generate
(
**
inputs
,
**
gen_kwargs
)
latency
=
perf_counter
()
-
start_time
latencies
.
append
(
latency
)
outputs
=
outputs
.
tolist
()[
0
][
len
(
inputs
[
"input_ids"
][
0
]):]
response
=
tokenizer
.
decode
(
outputs
)
response
=
self
.
process_response
(
response
)
print
(
response
)
print
(
len
(
response
))
time_avg_ms
=
1000
*
np
.
mean
(
latencies
)
# 延时均值
time_std_ms
=
1000
*
np
.
std
(
latencies
)
# 延时方差
time_p95_ms
=
1000
*
np
.
percentile
(
latencies
,
95
)
# 延时的95分位数
print
(
f
"P95延时 (ms) -
{
time_p95_ms
}
; 平均延时 (ms) -
{
time_avg_ms
:.
2
f
}
+\-
{
time_std_ms
:.
2
f
}
;"
)
print
(
f
"字平均延时 (ms) -
{
time_avg_ms
/
len
(
response
)
}
; "
)
return
response
,
history
@
torch
.
no_grad
()
def
stream_chat
(
self
,
tokenizer
,
query
:
str
,
history
:
List
[
Tuple
[
str
,
str
]]
=
None
,
max_length
:
int
=
2048
,
do_sample
=
True
,
top_p
=
0.7
,
temperature
=
0.95
,
logits_processor
=
None
,
**
kwargs
):
if
history
is
None
:
history
=
[]
if
logits_processor
is
None
:
logits_processor
=
LogitsProcessorList
()
logits_processor
.
append
(
InvalidScoreLogitsProcessor
())
gen_kwargs
=
{
"max_length"
:
max_length
,
"do_sample"
:
do_sample
,
"top_p"
:
top_p
,
"temperature"
:
temperature
,
"logits_processor"
:
logits_processor
,
**
kwargs
}
if
not
history
:
prompt
=
query
else
:
prompt
=
""
for
i
,
(
old_query
,
response
)
in
enumerate
(
history
):
prompt
+=
"[Round {}]
\n
问:{}
\n
答:{}
\n
"
.
format
(
i
,
old_query
,
response
)
prompt
+=
"[Round {}]
\n
问:{}
\n
答:"
.
format
(
len
(
history
),
query
)
inputs
=
tokenizer
([
prompt
],
return_tensors
=
"pt"
)
inputs
=
inputs
.
to
(
self
.
device
)
for
outputs
in
self
.
stream_generate
(
**
inputs
,
**
gen_kwargs
):
outputs
=
outputs
.
tolist
()[
0
][
len
(
inputs
[
"input_ids"
][
0
]):]
response
=
tokenizer
.
decode
(
outputs
)
response
=
self
.
process_response
(
response
)
new_history
=
history
+
[(
query
,
response
)]
yield
response
,
new_history
@
torch
.
no_grad
()
def
stream_generate
(
self
,
input_ids
,
generation_config
:
Optional
[
GenerationConfig
]
=
None
,
logits_processor
:
Optional
[
LogitsProcessorList
]
=
None
,
stopping_criteria
:
Optional
[
StoppingCriteriaList
]
=
None
,
prefix_allowed_tokens_fn
:
Optional
[
Callable
[[
int
,
torch
.
Tensor
],
List
[
int
]]]
=
None
,
**
kwargs
,
):
batch_size
,
input_ids_seq_length
=
input_ids
.
shape
[
0
],
input_ids
.
shape
[
-
1
]
if
generation_config
is
None
:
generation_config
=
self
.
generation_config
generation_config
=
copy
.
deepcopy
(
generation_config
)
model_kwargs
=
generation_config
.
update
(
**
kwargs
)
bos_token_id
,
eos_token_id
=
generation_config
.
bos_token_id
,
generation_config
.
eos_token_id
if
isinstance
(
eos_token_id
,
int
):
eos_token_id
=
[
eos_token_id
]
has_default_max_length
=
kwargs
.
get
(
"max_length"
)
is
None
and
generation_config
.
max_length
is
not
None
if
has_default_max_length
and
generation_config
.
max_new_tokens
is
None
:
warnings
.
warn
(
f
"Using `max_length`'s default (
{
generation_config
.
max_length
}
) to control the generation length. "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation."
,
UserWarning
,
)
elif
generation_config
.
max_new_tokens
is
not
None
:
generation_config
.
max_length
=
generation_config
.
max_new_tokens
+
input_ids_seq_length
if
not
has_default_max_length
:
logger
.
warn
(
f
"Both `max_new_tokens` (=
{
generation_config
.
max_new_tokens
}
) and `max_length`(="
f
"
{
generation_config
.
max_length
}
) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
,
UserWarning
,
)
if
input_ids_seq_length
>=
generation_config
.
max_length
:
input_ids_string
=
"decoder_input_ids"
if
self
.
config
.
is_encoder_decoder
else
"input_ids"
logger
.
warning
(
f
"Input length of
{
input_ids_string
}
is
{
input_ids_seq_length
}
, but `max_length` is set to"
f
"
{
generation_config
.
max_length
}
. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
# 2. Set generation parameters if not already defined
logits_processor
=
logits_processor
if
logits_processor
is
not
None
else
LogitsProcessorList
()
stopping_criteria
=
stopping_criteria
if
stopping_criteria
is
not
None
else
StoppingCriteriaList
()
logits_processor
=
self
.
_get_logits_processor
(
generation_config
=
generation_config
,
input_ids_seq_length
=
input_ids_seq_length
,
encoder_input_ids
=
input_ids
,
prefix_allowed_tokens_fn
=
prefix_allowed_tokens_fn
,
logits_processor
=
logits_processor
,
)
stopping_criteria
=
self
.
_get_stopping_criteria
(
generation_config
=
generation_config
,
stopping_criteria
=
stopping_criteria
)
logits_warper
=
self
.
_get_logits_warper
(
generation_config
)
unfinished_sequences
=
input_ids
.
new
(
input_ids
.
shape
[
0
]).
fill_
(
1
)
scores
=
None
while
True
:
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
**
model_kwargs
)
# forward pass to get next token
outputs
=
self
(
**
model_inputs
,
return_dict
=
True
,
output_attentions
=
False
,
output_hidden_states
=
False
,
)
next_token_logits
=
outputs
.
logits
[:,
-
1
,
:]
# pre-process distribution
next_token_scores
=
logits_processor
(
input_ids
,
next_token_logits
)
next_token_scores
=
logits_warper
(
input_ids
,
next_token_scores
)
# sample
probs
=
nn
.
functional
.
softmax
(
next_token_scores
,
dim
=-
1
)
if
generation_config
.
do_sample
:
next_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
1
).
squeeze
(
1
)
else
:
next_tokens
=
torch
.
argmax
(
probs
,
dim
=-
1
)
# update generated ids, model inputs, and length for next step
input_ids
=
torch
.
cat
([
input_ids
,
next_tokens
[:,
None
]],
dim
=-
1
)
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
)
unfinished_sequences
=
unfinished_sequences
.
mul
((
sum
(
next_tokens
!=
i
for
i
in
eos_token_id
)).
long
())
# stop when each sentence is finished, or if we exceed the maximum length
if
unfinished_sequences
.
max
()
==
0
or
stopping_criteria
(
input_ids
,
scores
):
break
yield
input_ids
def
quantize
(
self
,
bits
:
int
,
empty_init
=
False
,
**
kwargs
):
if
bits
==
0
:
return
from
.quantization
import
quantize
if
self
.
quantized
:
logger
.
info
(
"Already quantized."
)
return
self
self
.
quantized
=
True
self
.
config
.
quantization_bit
=
bits
self
.
transformer
=
quantize
(
self
.
transformer
,
bits
,
empty_init
=
empty_init
,
**
kwargs
)
return
self
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment