Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7377be7a
Commit
7377be7a
authored
Jul 07, 2023
by
klhhhhh
Committed by
Hongxin Liu
Aug 15, 2023
Browse files
import chatglm
parent
c4928698
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
1327 additions
and
0 deletions
+1327
-0
=2.0
=2.0
+134
-0
tests/kit/model_zoo/transformers/chatglm2-6b/modeling_chatglm.py
...it/model_zoo/transformers/chatglm2-6b/modeling_chatglm.py
+1193
-0
No files found.
=2.0
0 → 100644
View file @
7377be7a
Defaulting to user installation because normal site-packages is not writeable
Collecting protobuf
Using cached protobuf-4.23.4-cp37-abi3-manylinux2014_x86_64.whl (304 kB)
Requirement already satisfied: transformers==4.30.2 in /home/lclk/.local/lib/python3.9/site-packages (4.30.2)
Collecting cpm_kernels
Using cached cpm_kernels-1.0.11-py3-none-any.whl (416 kB)
Requirement already satisfied: torch in /home/lclk/.local/lib/python3.9/site-packages (2.0.0+cu118)
Collecting gradio
Using cached gradio-3.36.0-py3-none-any.whl (19.8 MB)
Collecting mdtex2html
Using cached mdtex2html-1.2.0-py3-none-any.whl (13 kB)
Collecting sentencepiece
Using cached sentencepiece-0.1.99-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
Collecting accelerate
Using cached accelerate-0.20.3-py3-none-any.whl (227 kB)
Requirement already satisfied: pyyaml>=5.1 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (6.0)
Requirement already satisfied: regex!=2019.12.17 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (2023.6.3)
Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (0.15.1)
Requirement already satisfied: packaging>=20.0 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (23.1)
Requirement already satisfied: requests in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from transformers==4.30.2) (2.25.1)
Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (0.13.3)
Requirement already satisfied: safetensors>=0.3.1 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (0.3.1)
Requirement already satisfied: filelock in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (3.12.0)
Requirement already satisfied: numpy>=1.17 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (1.24.3)
Requirement already satisfied: tqdm>=4.27 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (4.65.0)
Requirement already satisfied: fsspec in /home/lclk/.local/lib/python3.9/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers==4.30.2) (2023.6.0)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/lclk/.local/lib/python3.9/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers==4.30.2) (4.6.3)
Requirement already satisfied: networkx in /home/lclk/.local/lib/python3.9/site-packages (from torch) (3.1)
Requirement already satisfied: sympy in /home/lclk/.local/lib/python3.9/site-packages (from torch) (1.12)
Requirement already satisfied: triton==2.0.0 in /home/lclk/.local/lib/python3.9/site-packages (from torch) (2.0.0)
Requirement already satisfied: jinja2 in /home/lclk/.local/lib/python3.9/site-packages (from torch) (3.1.2)
Requirement already satisfied: lit in /home/lclk/.local/lib/python3.9/site-packages (from triton==2.0.0->torch) (16.0.5.post0)
Requirement already satisfied: cmake in /home/lclk/.local/lib/python3.9/site-packages (from triton==2.0.0->torch) (3.26.3)
Collecting aiofiles
Using cached aiofiles-23.1.0-py3-none-any.whl (14 kB)
Collecting ffmpy
Using cached ffmpy-0.3.0.tar.gz (4.8 kB)
Requirement already satisfied: pillow in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (9.5.0)
Collecting pydub
Using cached pydub-0.25.1-py2.py3-none-any.whl (32 kB)
Requirement already satisfied: pandas in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (2.0.2)
Collecting python-multipart
Using cached python_multipart-0.0.6-py3-none-any.whl (45 kB)
Collecting semantic-version
Using cached semantic_version-2.10.0-py2.py3-none-any.whl (15 kB)
Collecting pydantic
Using cached pydantic-2.0.2-py3-none-any.whl (359 kB)
Collecting uvicorn>=0.14.0
Using cached uvicorn-0.22.0-py3-none-any.whl (58 kB)
Collecting mdit-py-plugins<=0.3.3
Using cached mdit_py_plugins-0.3.3-py3-none-any.whl (50 kB)
Requirement already satisfied: pygments>=2.12.0 in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (2.15.1)
Collecting httpx
Using cached httpx-0.24.1-py3-none-any.whl (75 kB)
Collecting orjson
Using cached orjson-3.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (136 kB)
Collecting fastapi
Using cached fastapi-0.99.1-py3-none-any.whl (58 kB)
Collecting altair>=4.2.0
Using cached altair-5.0.1-py3-none-any.whl (471 kB)
Collecting gradio-client>=0.2.7
Using cached gradio_client-0.2.7-py3-none-any.whl (288 kB)
Requirement already satisfied: aiohttp in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (3.8.4)
Requirement already satisfied: matplotlib in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (3.7.1)
Collecting websockets>=10.0
Using cached websockets-11.0.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (129 kB)
Requirement already satisfied: markdown-it-py[linkify]>=2.0.0 in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (2.2.0)
Requirement already satisfied: markupsafe in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (2.1.3)
Collecting toolz
Using cached toolz-0.12.0-py3-none-any.whl (55 kB)
Collecting jsonschema>=3.0
Using cached jsonschema-4.18.0-py3-none-any.whl (81 kB)
Collecting rpds-py>=0.7.1
Downloading rpds_py-0.8.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
Collecting referencing>=0.28.4
Using cached referencing-0.29.1-py3-none-any.whl (25 kB)
Collecting jsonschema-specifications>=2023.03.6
Using cached jsonschema_specifications-2023.6.1-py3-none-any.whl (17 kB)
Requirement already satisfied: attrs>=22.2.0 in /home/lclk/.local/lib/python3.9/site-packages (from jsonschema>=3.0->altair>=4.2.0->gradio) (23.1.0)
Requirement already satisfied: mdurl~=0.1 in /home/lclk/.local/lib/python3.9/site-packages (from markdown-it-py[linkify]>=2.0.0->gradio) (0.1.2)
Collecting linkify-it-py<3,>=1
Downloading linkify_it_py-2.0.2-py3-none-any.whl (19 kB)
Collecting uc-micro-py
Downloading uc_micro_py-1.0.2-py3-none-any.whl (6.2 kB)
Requirement already satisfied: pytz>=2020.1 in /home/lclk/.local/lib/python3.9/site-packages (from pandas->gradio) (2023.3)
Requirement already satisfied: tzdata>=2022.1 in /home/lclk/.local/lib/python3.9/site-packages (from pandas->gradio) (2023.3)
Requirement already satisfied: python-dateutil>=2.8.2 in /home/lclk/.local/lib/python3.9/site-packages (from pandas->gradio) (2.8.2)
Requirement already satisfied: six>=1.5 in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from python-dateutil>=2.8.2->pandas->gradio) (1.16.0)
Requirement already satisfied: click>=7.0 in /home/lclk/.local/lib/python3.9/site-packages (from uvicorn>=0.14.0->gradio) (8.1.3)
Collecting h11>=0.8
Downloading h11-0.14.0-py3-none-any.whl (58 kB)
Collecting latex2mathml
Downloading latex2mathml-3.76.0-py3-none-any.whl (73 kB)
Collecting markdown
Downloading Markdown-3.4.3-py3-none-any.whl (93 kB)
Requirement already satisfied: psutil in /home/lclk/.local/lib/python3.9/site-packages (from accelerate) (5.9.5)
Requirement already satisfied: multidict<7.0,>=4.5 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (6.0.4)
Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (4.0.2)
Requirement already satisfied: aiosignal>=1.1.2 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (1.3.1)
Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (3.1.0)
Requirement already satisfied: frozenlist>=1.1.1 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (1.3.3)
Requirement already satisfied: yarl<2.0,>=1.0 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (1.9.2)
Requirement already satisfied: idna>=2.0 in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from yarl<2.0,>=1.0->aiohttp->gradio) (2.10)
Collecting pydantic
Downloading pydantic-1.10.11-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB)
Collecting starlette<0.28.0,>=0.27.0
Downloading starlette-0.27.0-py3-none-any.whl (66 kB)
Collecting anyio<5,>=3.4.0
Downloading anyio-3.7.1-py3-none-any.whl (80 kB)
Collecting sniffio>=1.1
Downloading sniffio-1.3.0-py3-none-any.whl (10 kB)
Requirement already satisfied: exceptiongroup in /home/lclk/.local/lib/python3.9/site-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->gradio) (1.1.1)
Collecting httpcore<0.18.0,>=0.15.0
Downloading httpcore-0.17.3-py3-none-any.whl (74 kB)
Requirement already satisfied: certifi in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from httpx->gradio) (2021.5.30)
Requirement already satisfied: importlib-metadata>=4.4 in /home/lclk/.local/lib/python3.9/site-packages (from markdown->mdtex2html) (6.7.0)
Requirement already satisfied: zipp>=0.5 in /home/lclk/.local/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown->mdtex2html) (3.15.0)
Requirement already satisfied: contourpy>=1.0.1 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (1.1.0)
Requirement already satisfied: fonttools>=4.22.0 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (4.40.0)
Requirement already satisfied: pyparsing>=2.3.1 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (3.1.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (1.4.4)
Requirement already satisfied: importlib-resources>=3.2.0 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (5.12.0)
Requirement already satisfied: cycler>=0.10 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (0.11.0)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from requests->transformers==4.30.2) (1.26.6)
Requirement already satisfied: chardet<5,>=3.0.2 in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from requests->transformers==4.30.2) (4.0.0)
Requirement already satisfied: mpmath>=0.19 in /home/lclk/.local/lib/python3.9/site-packages (from sympy->torch) (1.3.0)
Building wheels for collected packages: ffmpy
Building wheel for ffmpy (setup.py): started
Building wheel for ffmpy (setup.py): finished with status 'done'
Created wheel for ffmpy: filename=ffmpy-0.3.0-py3-none-any.whl size=4709 sha256=071cebb58ca6c6947fbc669e1d94509d6f53d1ed45d9d7fb9f060d1a342cfc18
Stored in directory: /home/lclk/.cache/pip/wheels/91/e2/96/f676aa08bfd789328c6576cd0f1fde4a3d686703bb0c247697
Successfully built ffmpy
Installing collected packages: sniffio, rpds-py, referencing, h11, anyio, uc-micro-py, jsonschema-specifications, httpcore, websockets, toolz, starlette, pydantic, linkify-it-py, jsonschema, httpx, uvicorn, semantic-version, python-multipart, pydub, orjson, mdit-py-plugins, markdown, latex2mathml, gradio-client, ffmpy, fastapi, altair, aiofiles, sentencepiece, protobuf, mdtex2html, gradio, cpm-kernels, accelerate
Successfully installed accelerate-0.20.3 aiofiles-23.1.0 altair-5.0.1 anyio-3.7.1 cpm-kernels-1.0.11 fastapi-0.99.1 ffmpy-0.3.0 gradio-3.36.0 gradio-client-0.2.7 h11-0.14.0 httpcore-0.17.3 httpx-0.24.1 jsonschema-4.18.0 jsonschema-specifications-2023.6.1 latex2mathml-3.76.0 linkify-it-py-2.0.2 markdown-3.4.3 mdit-py-plugins-0.3.3 mdtex2html-1.2.0 orjson-3.9.1 protobuf-4.23.4 pydantic-1.10.11 pydub-0.25.1 python-multipart-0.0.6 referencing-0.29.1 rpds-py-0.8.8 semantic-version-2.10.0 sentencepiece-0.1.99 sniffio-1.3.0 starlette-0.27.0 toolz-0.12.0 uc-micro-py-1.0.2 uvicorn-0.22.0 websockets-11.0.3
tests/kit/model_zoo/transformers/chatglm2-6b/modeling_chatglm.py
0 → 100644
View file @
7377be7a
""" PyTorch ChatGLM model. """
import
math
import
copy
import
warnings
import
re
import
sys
import
torch
import
torch.utils.checkpoint
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch.nn
import
CrossEntropyLoss
,
LayerNorm
from
torch.nn.utils
import
skip_init
from
typing
import
Optional
,
Tuple
,
Union
,
List
,
Callable
,
Dict
,
Any
from
transformers.modeling_outputs
import
(
BaseModelOutputWithPast
,
CausalLMOutputWithPast
,
)
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.utils
import
logging
from
transformers.generation.logits_process
import
LogitsProcessor
from
transformers.generation.utils
import
LogitsProcessorList
,
StoppingCriteriaList
,
GenerationConfig
,
ModelOutput
from
.configuration_chatglm
import
ChatGLMConfig
# flags required to enable jit fusion kernels
if
sys
.
platform
!=
'darwin'
:
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
torch
.
_C
.
_jit_set_profiling_executor
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
logger
=
logging
.
get_logger
(
__name__
)
_CHECKPOINT_FOR_DOC
=
"THUDM/ChatGLM2-6B"
_CONFIG_FOR_DOC
=
"ChatGLM6BConfig"
CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST
=
[
"THUDM/chatglm2-6b"
,
# See all ChatGLM models at https://huggingface.co/models?filter=chatglm
]
def
default_init
(
cls
,
*
args
,
**
kwargs
):
return
cls
(
*
args
,
**
kwargs
)
class
InvalidScoreLogitsProcessor
(
LogitsProcessor
):
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
if
torch
.
isnan
(
scores
).
any
()
or
torch
.
isinf
(
scores
).
any
():
scores
.
zero_
()
scores
[...,
5
]
=
5e4
return
scores
class
PrefixEncoder
(
torch
.
nn
.
Module
):
"""
The torch.nn model to encode the prefix
Input shape: (batch-size, prefix-length)
Output shape: (batch-size, prefix-length, 2*layers*hidden)
"""
def
__init__
(
self
,
config
:
ChatGLMConfig
):
super
().
__init__
()
self
.
prefix_projection
=
config
.
prefix_projection
if
self
.
prefix_projection
:
# Use a two-layer MLP to encode the prefix
kv_size
=
config
.
num_layers
*
config
.
kv_channels
*
config
.
multi_query_group_num
*
2
self
.
embedding
=
torch
.
nn
.
Embedding
(
config
.
pre_seq_len
,
kv_size
)
self
.
trans
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
kv_size
,
config
.
hidden_size
),
torch
.
nn
.
Tanh
(),
torch
.
nn
.
Linear
(
config
.
hidden_size
,
kv_size
)
)
else
:
self
.
embedding
=
torch
.
nn
.
Embedding
(
config
.
pre_seq_len
,
config
.
num_layers
*
config
.
kv_channels
*
config
.
multi_query_group_num
*
2
)
def
forward
(
self
,
prefix
:
torch
.
Tensor
):
if
self
.
prefix_projection
:
prefix_tokens
=
self
.
embedding
(
prefix
)
past_key_values
=
self
.
trans
(
prefix_tokens
)
else
:
past_key_values
=
self
.
embedding
(
prefix
)
return
past_key_values
def
split_tensor_along_last_dim
(
tensor
:
torch
.
Tensor
,
num_partitions
:
int
,
contiguous_split_chunks
:
bool
=
False
,
)
->
List
[
torch
.
Tensor
]:
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
# Get the size and dimension.
last_dim
=
tensor
.
dim
()
-
1
last_dim_size
=
tensor
.
size
()[
last_dim
]
//
num_partitions
# Split.
tensor_list
=
torch
.
split
(
tensor
,
last_dim_size
,
dim
=
last_dim
)
# Note: torch.split does not create contiguous tensors by default.
if
contiguous_split_chunks
:
return
tuple
(
chunk
.
contiguous
()
for
chunk
in
tensor_list
)
return
tensor_list
class
RotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
original_impl
=
False
,
device
=
None
,
dtype
=
None
):
super
().
__init__
()
inv_freq
=
1.0
/
(
10000
**
(
torch
.
arange
(
0
,
dim
,
2
,
device
=
device
).
to
(
dtype
=
dtype
)
/
dim
))
self
.
register_buffer
(
"inv_freq"
,
inv_freq
)
self
.
dim
=
dim
self
.
original_impl
=
original_impl
def
forward_impl
(
self
,
seq_len
:
int
,
n_elem
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
base
:
int
=
10000
):
"""Enhanced Transformer with Rotary Position Embedding.
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
transformers/rope/__init__.py. MIT License:
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
"""
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
theta
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
n_elem
,
2
,
dtype
=
dtype
,
device
=
device
)
/
n_elem
))
# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
,
device
=
device
)
# Calculate the product of position index and $\theta_i$
idx_theta
=
torch
.
outer
(
seq_idx
,
theta
).
float
()
cache
=
torch
.
stack
([
torch
.
cos
(
idx_theta
),
torch
.
sin
(
idx_theta
)],
dim
=-
1
)
# this is to mimic the behaviour of complex32, else we will get different results
if
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
,
torch
.
int8
):
cache
=
cache
.
bfloat16
()
if
dtype
==
torch
.
bfloat16
else
cache
.
half
()
return
cache
def
forward
(
self
,
max_seq_len
,
offset
=
0
):
return
self
.
forward_impl
(
max_seq_len
,
self
.
dim
,
dtype
=
self
.
inv_freq
.
dtype
,
device
=
self
.
inv_freq
.
device
)
@
torch
.
jit
.
script
def
apply_rotary_pos_emb
(
x
:
torch
.
Tensor
,
rope_cache
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# x: [sq, b, np, hn]
sq
,
b
,
np
,
hn
=
x
.
size
(
0
),
x
.
size
(
1
),
x
.
size
(
2
),
x
.
size
(
3
)
rot_dim
=
rope_cache
.
shape
[
-
2
]
*
2
x
,
x_pass
=
x
[...,
:
rot_dim
],
x
[...,
rot_dim
:]
# truncate to support variable sizes
rope_cache
=
rope_cache
[:
sq
]
xshaped
=
x
.
reshape
(
sq
,
-
1
,
np
,
rot_dim
//
2
,
2
)
rope_cache
=
rope_cache
.
view
(
sq
,
-
1
,
1
,
xshaped
.
size
(
3
),
2
)
x_out2
=
torch
.
stack
(
[
xshaped
[...,
0
]
*
rope_cache
[...,
0
]
-
xshaped
[...,
1
]
*
rope_cache
[...,
1
],
xshaped
[...,
1
]
*
rope_cache
[...,
0
]
+
xshaped
[...,
0
]
*
rope_cache
[...,
1
],
],
-
1
,
)
x_out2
=
x_out2
.
flatten
(
3
)
return
torch
.
cat
((
x_out2
,
x_pass
),
dim
=-
1
)
class
RMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
device
=
None
,
dtype
=
None
,
**
kwargs
):
super
().
__init__
()
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
normalized_shape
,
device
=
device
,
dtype
=
dtype
))
self
.
eps
=
eps
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
):
input_dtype
=
hidden_states
.
dtype
variance
=
hidden_states
.
to
(
torch
.
float32
).
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
eps
)
return
(
self
.
weight
*
hidden_states
).
to
(
input_dtype
)
class
CoreAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
:
ChatGLMConfig
,
layer_number
):
super
(
CoreAttention
,
self
).
__init__
()
self
.
apply_query_key_layer_scaling
=
config
.
apply_query_key_layer_scaling
self
.
attention_softmax_in_fp32
=
config
.
attention_softmax_in_fp32
if
self
.
apply_query_key_layer_scaling
:
self
.
attention_softmax_in_fp32
=
True
self
.
layer_number
=
max
(
1
,
layer_number
)
projection_size
=
config
.
kv_channels
*
config
.
num_attention_heads
# Per attention head and per partition values.
self
.
hidden_size_per_partition
=
projection_size
self
.
hidden_size_per_attention_head
=
projection_size
//
config
.
num_attention_heads
self
.
num_attention_heads_per_partition
=
config
.
num_attention_heads
coeff
=
None
self
.
norm_factor
=
math
.
sqrt
(
self
.
hidden_size_per_attention_head
)
if
self
.
apply_query_key_layer_scaling
:
coeff
=
self
.
layer_number
self
.
norm_factor
*=
coeff
self
.
coeff
=
coeff
self
.
attention_dropout
=
torch
.
nn
.
Dropout
(
config
.
attention_dropout
)
def
forward
(
self
,
query_layer
,
key_layer
,
value_layer
,
attention_mask
):
pytorch_major_version
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
if
pytorch_major_version
>=
2
:
query_layer
,
key_layer
,
value_layer
=
[
k
.
permute
(
1
,
2
,
0
,
3
)
for
k
in
[
query_layer
,
key_layer
,
value_layer
]]
if
attention_mask
is
None
and
query_layer
.
shape
[
2
]
==
key_layer
.
shape
[
2
]:
context_layer
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
query_layer
,
key_layer
,
value_layer
,
is_causal
=
True
)
else
:
if
attention_mask
is
not
None
:
attention_mask
=
~
attention_mask
context_layer
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
context_layer
=
context_layer
.
permute
(
2
,
0
,
1
,
3
)
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
self
.
hidden_size_per_partition
,)
context_layer
=
context_layer
.
reshape
(
*
new_context_layer_shape
)
else
:
# Raw attention scores
# [b, np, sq, sk]
output_size
=
(
query_layer
.
size
(
1
),
query_layer
.
size
(
2
),
query_layer
.
size
(
0
),
key_layer
.
size
(
0
))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer
=
query_layer
.
view
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer
=
key_layer
.
view
(
output_size
[
3
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer
=
torch
.
empty
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
3
],
dtype
=
query_layer
.
dtype
,
device
=
query_layer
.
device
)
# Raw attention scores. [b * np, sq, sk]
matmul_result
=
torch
.
baddbmm
(
matmul_input_buffer
,
query_layer
.
transpose
(
0
,
1
),
# [b * np, sq, hn]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
# [b * np, hn, sk]
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
),
)
# change view to [b, np, sq, sk]
attention_scores
=
matmul_result
.
view
(
*
output_size
)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
if
self
.
attention_softmax_in_fp32
:
attention_scores
=
attention_scores
.
float
()
if
self
.
coeff
is
not
None
:
attention_scores
=
attention_scores
*
self
.
coeff
if
attention_mask
is
None
and
attention_scores
.
shape
[
2
]
==
attention_scores
.
shape
[
3
]:
attention_mask
=
torch
.
ones
(
output_size
[
0
],
1
,
output_size
[
2
],
output_size
[
3
],
device
=
attention_scores
.
device
,
dtype
=
torch
.
bool
)
attention_mask
.
tril_
()
attention_mask
=
~
attention_mask
if
attention_mask
is
not
None
:
attention_scores
=
attention_scores
.
masked_fill
(
attention_mask
,
float
(
"-inf"
))
attention_probs
=
F
.
softmax
(
attention_scores
,
dim
=-
1
)
attention_probs
=
attention_probs
.
type_as
(
value_layer
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size
=
(
value_layer
.
size
(
1
),
value_layer
.
size
(
2
),
query_layer
.
size
(
0
),
value_layer
.
size
(
3
))
# change view [sk, b * np, hn]
value_layer
=
value_layer
.
view
(
value_layer
.
size
(
0
),
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# change view [b * np, sq, sk]
attention_probs
=
attention_probs
.
view
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
-
1
)
# matmul: [b * np, sq, hn]
context_layer
=
torch
.
bmm
(
attention_probs
,
value_layer
.
transpose
(
0
,
1
))
# change view [b, np, sq, hn]
context_layer
=
context_layer
.
view
(
*
output_size
)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer
=
context_layer
.
permute
(
2
,
0
,
1
,
3
).
contiguous
()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
self
.
hidden_size_per_partition
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
return
context_layer
class
SelfAttention
(
torch
.
nn
.
Module
):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def
__init__
(
self
,
config
:
ChatGLMConfig
,
layer_number
,
device
=
None
):
super
(
SelfAttention
,
self
).
__init__
()
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
projection_size
=
config
.
kv_channels
*
config
.
num_attention_heads
# Per attention head and per partition values.
self
.
hidden_size_per_attention_head
=
self
.
projection_size
//
config
.
num_attention_heads
self
.
num_attention_heads_per_partition
=
config
.
num_attention_heads
self
.
multi_query_attention
=
config
.
multi_query_attention
self
.
qkv_hidden_size
=
3
*
self
.
projection_size
if
self
.
multi_query_attention
:
self
.
num_multi_query_groups_per_partition
=
config
.
multi_query_group_num
self
.
qkv_hidden_size
=
(
self
.
projection_size
+
2
*
self
.
hidden_size_per_attention_head
*
config
.
multi_query_group_num
)
self
.
query_key_value
=
nn
.
Linear
(
config
.
hidden_size
,
self
.
qkv_hidden_size
,
bias
=
config
.
add_bias_linear
or
config
.
add_qkv_bias
,
device
=
device
,
**
_config_to_kwargs
(
config
)
)
self
.
core_attention
=
CoreAttention
(
config
,
self
.
layer_number
)
# Output.
self
.
dense
=
nn
.
Linear
(
self
.
projection_size
,
config
.
hidden_size
,
bias
=
config
.
add_bias_linear
,
device
=
device
,
**
_config_to_kwargs
(
config
)
)
def
_allocate_memory
(
self
,
inference_max_sequence_len
,
batch_size
,
device
=
None
,
dtype
=
None
):
if
self
.
multi_query_attention
:
num_attention_heads
=
self
.
num_multi_query_groups_per_partition
else
:
num_attention_heads
=
self
.
num_attention_heads_per_partition
return
torch
.
empty
(
inference_max_sequence_len
,
batch_size
,
num_attention_heads
,
self
.
hidden_size_per_attention_head
,
dtype
=
dtype
,
device
=
device
,
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
rotary_pos_emb
,
kv_cache
=
None
,
use_cache
=
True
):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
# =====================
# Query, Key, and Value
# =====================
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer
=
self
.
query_key_value
(
hidden_states
)
if
self
.
multi_query_attention
:
(
query_layer
,
key_layer
,
value_layer
)
=
mixed_x_layer
.
split
(
[
self
.
num_attention_heads_per_partition
*
self
.
hidden_size_per_attention_head
,
self
.
num_multi_query_groups_per_partition
*
self
.
hidden_size_per_attention_head
,
self
.
num_multi_query_groups_per_partition
*
self
.
hidden_size_per_attention_head
,
],
dim
=-
1
,
)
query_layer
=
query_layer
.
view
(
query_layer
.
size
()[:
-
1
]
+
(
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
)
)
key_layer
=
key_layer
.
view
(
key_layer
.
size
()[:
-
1
]
+
(
self
.
num_multi_query_groups_per_partition
,
self
.
hidden_size_per_attention_head
)
)
value_layer
=
value_layer
.
view
(
value_layer
.
size
()[:
-
1
]
+
(
self
.
num_multi_query_groups_per_partition
,
self
.
hidden_size_per_attention_head
)
)
else
:
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
3
*
self
.
hidden_size_per_attention_head
)
mixed_x_layer
=
mixed_x_layer
.
view
(
*
new_tensor_shape
)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(
query_layer
,
key_layer
,
value_layer
)
=
split_tensor_along_last_dim
(
mixed_x_layer
,
3
)
# apply relative positional encoding (rotary embedding)
if
rotary_pos_emb
is
not
None
:
query_layer
=
apply_rotary_pos_emb
(
query_layer
,
rotary_pos_emb
)
key_layer
=
apply_rotary_pos_emb
(
key_layer
,
rotary_pos_emb
)
# adjust key and value for inference
if
kv_cache
is
not
None
:
cache_k
,
cache_v
=
kv_cache
key_layer
=
torch
.
cat
((
cache_k
,
key_layer
),
dim
=
0
)
value_layer
=
torch
.
cat
((
cache_v
,
value_layer
),
dim
=
0
)
if
use_cache
:
kv_cache
=
(
key_layer
,
value_layer
)
else
:
kv_cache
=
None
if
self
.
multi_query_attention
:
key_layer
=
key_layer
.
unsqueeze
(
-
2
)
key_layer
=
key_layer
.
expand
(
-
1
,
-
1
,
-
1
,
self
.
num_attention_heads_per_partition
//
self
.
num_multi_query_groups_per_partition
,
-
1
)
key_layer
=
key_layer
.
contiguous
().
view
(
key_layer
.
size
()[:
2
]
+
(
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
)
)
value_layer
=
value_layer
.
unsqueeze
(
-
2
)
value_layer
=
value_layer
.
expand
(
-
1
,
-
1
,
-
1
,
self
.
num_attention_heads_per_partition
//
self
.
num_multi_query_groups_per_partition
,
-
1
)
value_layer
=
value_layer
.
contiguous
().
view
(
value_layer
.
size
()[:
2
]
+
(
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
)
)
# ==================================
# core attention computation
# ==================================
context_layer
=
self
.
core_attention
(
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
# =================
# Output. [sq, b, h]
# =================
output
=
self
.
dense
(
context_layer
)
return
output
,
kv_cache
def
_config_to_kwargs
(
args
):
common_kwargs
=
{
"dtype"
:
args
.
torch_dtype
,
}
return
common_kwargs
class
MLP
(
torch
.
nn
.
Module
):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def
__init__
(
self
,
config
:
ChatGLMConfig
,
device
=
None
):
super
(
MLP
,
self
).
__init__
()
self
.
add_bias
=
config
.
add_bias_linear
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
self
.
dense_h_to_4h
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
ffn_hidden_size
*
2
,
bias
=
self
.
add_bias
,
device
=
device
,
**
_config_to_kwargs
(
config
)
)
def
swiglu
(
x
):
x
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
return
F
.
silu
(
x
[
0
])
*
x
[
1
]
self
.
activation_func
=
swiglu
# Project back to h.
self
.
dense_4h_to_h
=
nn
.
Linear
(
config
.
ffn_hidden_size
,
config
.
hidden_size
,
bias
=
self
.
add_bias
,
device
=
device
,
**
_config_to_kwargs
(
config
)
)
def
forward
(
self
,
hidden_states
):
# [s, b, 4hp]
intermediate_parallel
=
self
.
dense_h_to_4h
(
hidden_states
)
intermediate_parallel
=
self
.
activation_func
(
intermediate_parallel
)
# [s, b, h]
output
=
self
.
dense_4h_to_h
(
intermediate_parallel
)
return
output
class
GLMBlock
(
torch
.
nn
.
Module
):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def
__init__
(
self
,
config
:
ChatGLMConfig
,
layer_number
,
device
=
None
):
super
(
GLMBlock
,
self
).
__init__
()
self
.
layer_number
=
layer_number
self
.
apply_residual_connection_post_layernorm
=
config
.
apply_residual_connection_post_layernorm
self
.
fp32_residual_connection
=
config
.
fp32_residual_connection
LayerNormFunc
=
RMSNorm
if
config
.
rmsnorm
else
LayerNorm
# Layernorm on the input data.
self
.
input_layernorm
=
LayerNormFunc
(
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
,
device
=
device
,
dtype
=
config
.
torch_dtype
)
# Self attention.
self
.
self_attention
=
SelfAttention
(
config
,
layer_number
,
device
=
device
)
self
.
hidden_dropout
=
config
.
hidden_dropout
# Layernorm on the attention output
self
.
post_attention_layernorm
=
LayerNormFunc
(
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
,
device
=
device
,
dtype
=
config
.
torch_dtype
)
# MLP
self
.
mlp
=
MLP
(
config
,
device
=
device
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
rotary_pos_emb
,
kv_cache
=
None
,
use_cache
=
True
,
):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
attention_output
,
kv_cache
=
self
.
self_attention
(
layernorm_output
,
attention_mask
,
rotary_pos_emb
,
kv_cache
=
kv_cache
,
use_cache
=
use_cache
)
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
hidden_states
layernorm_input
=
torch
.
nn
.
functional
.
dropout
(
attention_output
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
layernorm_input
=
residual
+
layernorm_input
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
# MLP.
mlp_output
=
self
.
mlp
(
layernorm_output
)
# Second residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
output
=
torch
.
nn
.
functional
.
dropout
(
mlp_output
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
output
=
residual
+
output
return
output
,
kv_cache
class
GLMTransformer
(
torch
.
nn
.
Module
):
"""Transformer class."""
def
__init__
(
self
,
config
:
ChatGLMConfig
,
device
=
None
):
super
(
GLMTransformer
,
self
).
__init__
()
self
.
fp32_residual_connection
=
config
.
fp32_residual_connection
self
.
post_layer_norm
=
config
.
post_layer_norm
# Number of layers.
self
.
num_layers
=
config
.
num_layers
# Transformer layers.
def
build_layer
(
layer_number
):
return
GLMBlock
(
config
,
layer_number
,
device
=
device
)
self
.
layers
=
torch
.
nn
.
ModuleList
([
build_layer
(
i
+
1
)
for
i
in
range
(
self
.
num_layers
)])
if
self
.
post_layer_norm
:
LayerNormFunc
=
RMSNorm
if
config
.
rmsnorm
else
LayerNorm
# Final layer norm before output.
self
.
final_layernorm
=
LayerNormFunc
(
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
,
device
=
device
,
dtype
=
config
.
torch_dtype
)
self
.
gradient_checkpointing
=
False
def
_get_layer
(
self
,
layer_number
):
return
self
.
layers
[
layer_number
]
def
forward
(
self
,
hidden_states
,
attention_mask
,
rotary_pos_emb
,
kv_caches
=
None
,
use_cache
:
Optional
[
bool
]
=
True
,
output_hidden_states
:
Optional
[
bool
]
=
False
,
):
if
not
kv_caches
:
kv_caches
=
[
None
for
_
in
range
(
self
.
num_layers
)]
presents
=
()
if
use_cache
else
None
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
all_self_attentions
=
None
all_hidden_states
=
()
if
output_hidden_states
else
None
for
index
in
range
(
self
.
num_layers
):
if
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
layer
=
self
.
_get_layer
(
index
)
if
self
.
gradient_checkpointing
and
self
.
training
:
layer_ret
=
torch
.
utils
.
checkpoint
.
checkpoint
(
layer
,
hidden_states
,
attention_mask
,
rotary_pos_emb
,
kv_caches
[
index
],
use_cache
)
else
:
layer_ret
=
layer
(
hidden_states
,
attention_mask
,
rotary_pos_emb
,
kv_cache
=
kv_caches
[
index
],
use_cache
=
use_cache
)
hidden_states
,
kv_cache
=
layer_ret
if
use_cache
:
presents
=
presents
+
(
kv_cache
,)
if
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
# Final layer norm.
if
self
.
post_layer_norm
:
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
return
hidden_states
,
presents
,
all_hidden_states
,
all_self_attentions
class
ChatGLMPreTrainedModel
(
PreTrainedModel
):
"""
An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
is_parallelizable
=
False
supports_gradient_checkpointing
=
True
config_class
=
ChatGLMConfig
base_model_prefix
=
"transformer"
_no_split_modules
=
[
"GLMBlock"
]
def
_init_weights
(
self
,
module
:
nn
.
Module
):
"""Initialize the weights."""
return
def
get_masks
(
self
,
input_ids
,
past_key_values
,
padding_mask
=
None
):
batch_size
,
seq_length
=
input_ids
.
shape
full_attention_mask
=
torch
.
ones
(
batch_size
,
seq_length
,
seq_length
,
device
=
input_ids
.
device
)
full_attention_mask
.
tril_
()
past_length
=
0
if
past_key_values
:
past_length
=
past_key_values
[
0
][
0
].
shape
[
0
]
if
past_length
:
full_attention_mask
=
torch
.
cat
((
torch
.
ones
(
batch_size
,
seq_length
,
past_length
,
device
=
input_ids
.
device
),
full_attention_mask
),
dim
=-
1
)
if
padding_mask
is
not
None
:
full_attention_mask
=
full_attention_mask
*
padding_mask
.
unsqueeze
(
1
)
if
not
past_length
and
padding_mask
is
not
None
:
full_attention_mask
-=
padding_mask
.
unsqueeze
(
-
1
)
-
1
full_attention_mask
=
(
full_attention_mask
<
0.5
).
bool
()
full_attention_mask
.
unsqueeze_
(
1
)
return
full_attention_mask
def
get_position_ids
(
self
,
input_ids
,
device
):
batch_size
,
seq_length
=
input_ids
.
shape
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
device
).
unsqueeze
(
0
).
repeat
(
batch_size
,
1
)
return
position_ids
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
GLMTransformer
):
module
.
gradient_checkpointing
=
value
class
Embedding
(
torch
.
nn
.
Module
):
"""Language model embeddings."""
def
__init__
(
self
,
config
:
ChatGLMConfig
,
device
=
None
):
super
(
Embedding
,
self
).
__init__
()
self
.
hidden_size
=
config
.
hidden_size
# Word embeddings (parallel).
self
.
word_embeddings
=
nn
.
Embedding
(
config
.
padded_vocab_size
,
self
.
hidden_size
,
dtype
=
config
.
torch_dtype
,
device
=
device
)
self
.
fp32_residual_connection
=
config
.
fp32_residual_connection
def
forward
(
self
,
input_ids
):
# Embeddings.
words_embeddings
=
self
.
word_embeddings
(
input_ids
)
embeddings
=
words_embeddings
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings
=
embeddings
.
transpose
(
0
,
1
).
contiguous
()
# If the input flag for fp32 residual connection is set, convert for float.
if
self
.
fp32_residual_connection
:
embeddings
=
embeddings
.
float
()
return
embeddings
class
ChatGLMModel
(
ChatGLMPreTrainedModel
):
def
__init__
(
self
,
config
:
ChatGLMConfig
,
device
=
None
,
empty_init
=
True
):
super
().
__init__
(
config
)
if
empty_init
:
init_method
=
skip_init
else
:
init_method
=
default_init
init_kwargs
=
{}
if
device
is
not
None
:
init_kwargs
[
"device"
]
=
device
self
.
embedding
=
init_method
(
Embedding
,
config
,
**
init_kwargs
)
self
.
num_layers
=
config
.
num_layers
self
.
multi_query_group_num
=
config
.
multi_query_group_num
self
.
kv_channels
=
config
.
kv_channels
# Rotary positional embeddings
self
.
seq_length
=
config
.
seq_length
rotary_dim
=
(
config
.
hidden_size
//
config
.
num_attention_heads
if
config
.
kv_channels
is
None
else
config
.
kv_channels
)
self
.
rotary_pos_emb
=
RotaryEmbedding
(
rotary_dim
//
2
,
original_impl
=
config
.
original_rope
,
device
=
device
,
dtype
=
config
.
torch_dtype
)
self
.
encoder
=
init_method
(
GLMTransformer
,
config
,
**
init_kwargs
)
self
.
output_layer
=
init_method
(
nn
.
Linear
,
config
.
hidden_size
,
config
.
padded_vocab_size
,
bias
=
False
,
dtype
=
config
.
torch_dtype
,
**
init_kwargs
)
self
.
pre_seq_len
=
config
.
pre_seq_len
self
.
prefix_projection
=
config
.
prefix_projection
if
self
.
pre_seq_len
is
not
None
:
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
self
.
prefix_tokens
=
torch
.
arange
(
self
.
pre_seq_len
).
long
()
self
.
prefix_encoder
=
PrefixEncoder
(
config
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
0.1
)
def
get_input_embeddings
(
self
):
return
self
.
embedding
.
word_embeddings
def
get_prompt
(
self
,
batch_size
,
device
,
dtype
=
torch
.
half
):
prefix_tokens
=
self
.
prefix_tokens
.
unsqueeze
(
0
).
expand
(
batch_size
,
-
1
).
to
(
device
)
past_key_values
=
self
.
prefix_encoder
(
prefix_tokens
).
type
(
dtype
)
past_key_values
=
past_key_values
.
view
(
batch_size
,
self
.
pre_seq_len
,
self
.
num_layers
*
2
,
self
.
multi_query_group_num
,
self
.
kv_channels
)
# seq_len, b, nh, hidden_size
past_key_values
=
self
.
dropout
(
past_key_values
)
past_key_values
=
past_key_values
.
permute
([
2
,
1
,
0
,
3
,
4
]).
split
(
2
)
return
past_key_values
def
forward
(
self
,
input_ids
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
BoolTensor
]
=
None
,
full_attention_mask
:
Optional
[
torch
.
BoolTensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
):
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
use_cache
=
use_cache
if
use_cache
is
not
None
else
self
.
config
.
use_cache
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
batch_size
,
seq_length
=
input_ids
.
shape
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embedding
(
input_ids
)
if
self
.
pre_seq_len
is
not
None
:
if
past_key_values
is
None
:
past_key_values
=
self
.
get_prompt
(
batch_size
=
batch_size
,
device
=
input_ids
.
device
,
dtype
=
inputs_embeds
.
dtype
)
if
attention_mask
is
not
None
:
attention_mask
=
torch
.
cat
([
attention_mask
.
new_ones
((
batch_size
,
self
.
pre_seq_len
)),
attention_mask
],
dim
=-
1
)
if
full_attention_mask
is
None
:
if
(
attention_mask
is
not
None
and
not
attention_mask
.
all
())
or
(
past_key_values
and
seq_length
!=
1
):
full_attention_mask
=
self
.
get_masks
(
input_ids
,
past_key_values
,
padding_mask
=
attention_mask
)
# Rotary positional embeddings
rotary_pos_emb
=
self
.
rotary_pos_emb
(
self
.
seq_length
)
if
position_ids
is
not
None
:
rotary_pos_emb
=
rotary_pos_emb
[
position_ids
]
else
:
rotary_pos_emb
=
rotary_pos_emb
[
None
,
:
seq_length
]
rotary_pos_emb
=
rotary_pos_emb
.
transpose
(
0
,
1
).
contiguous
()
# Run encoder.
hidden_states
,
presents
,
all_hidden_states
,
all_self_attentions
=
self
.
encoder
(
inputs_embeds
,
full_attention_mask
,
rotary_pos_emb
=
rotary_pos_emb
,
kv_caches
=
past_key_values
,
use_cache
=
use_cache
,
output_hidden_states
=
output_hidden_states
)
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
presents
,
all_hidden_states
,
all_self_attentions
]
if
v
is
not
None
)
return
BaseModelOutputWithPast
(
last_hidden_state
=
hidden_states
,
past_key_values
=
presents
,
hidden_states
=
all_hidden_states
,
attentions
=
all_self_attentions
,
)
def
quantize
(
self
,
weight_bit_width
:
int
):
from
.quantization
import
quantize
quantize
(
self
.
encoder
,
weight_bit_width
)
return
self
class
ChatGLMForConditionalGeneration
(
ChatGLMPreTrainedModel
):
def
__init__
(
self
,
config
:
ChatGLMConfig
,
empty_init
=
True
,
device
=
None
):
super
().
__init__
(
config
)
self
.
max_sequence_length
=
config
.
max_length
self
.
transformer
=
ChatGLMModel
(
config
,
empty_init
=
empty_init
,
device
=
device
)
self
.
config
=
config
self
.
quantized
=
False
if
self
.
config
.
quantization_bit
:
self
.
quantize
(
self
.
config
.
quantization_bit
,
empty_init
=
True
)
def
_update_model_kwargs_for_generation
(
self
,
outputs
:
ModelOutput
,
model_kwargs
:
Dict
[
str
,
Any
],
is_encoder_decoder
:
bool
=
False
,
standardize_cache_format
:
bool
=
False
,
)
->
Dict
[
str
,
Any
]:
# update past_key_values
model_kwargs
[
"past_key_values"
]
=
self
.
_extract_past_from_model_output
(
outputs
,
standardize_cache_format
=
standardize_cache_format
)
# update attention mask
if
"attention_mask"
in
model_kwargs
:
attention_mask
=
model_kwargs
[
"attention_mask"
]
model_kwargs
[
"attention_mask"
]
=
torch
.
cat
(
[
attention_mask
,
attention_mask
.
new_ones
((
attention_mask
.
shape
[
0
],
1
))],
dim
=-
1
)
# update position ids
if
"position_ids"
in
model_kwargs
:
position_ids
=
model_kwargs
[
"position_ids"
]
new_position_id
=
position_ids
[...,
-
1
:].
clone
()
new_position_id
+=
1
model_kwargs
[
"position_ids"
]
=
torch
.
cat
(
[
position_ids
,
new_position_id
],
dim
=-
1
)
model_kwargs
[
"is_first_forward"
]
=
False
return
model_kwargs
def
prepare_inputs_for_generation
(
self
,
input_ids
:
torch
.
LongTensor
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
is_first_forward
:
bool
=
True
,
**
kwargs
)
->
dict
:
# only last token for input_ids if past is not None
if
position_ids
is
None
:
position_ids
=
self
.
get_position_ids
(
input_ids
,
device
=
input_ids
.
device
)
if
not
is_first_forward
:
position_ids
=
position_ids
[...,
-
1
:]
input_ids
=
input_ids
[:,
-
1
:]
return
{
"input_ids"
:
input_ids
,
"past_key_values"
:
past_key_values
,
"position_ids"
:
position_ids
,
"attention_mask"
:
attention_mask
,
"return_last_logit"
:
True
}
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
labels
:
Optional
[
torch
.
Tensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
return_last_logit
:
Optional
[
bool
]
=
False
,
):
use_cache
=
use_cache
if
use_cache
is
not
None
else
self
.
config
.
use_cache
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
transformer_outputs
=
self
.
transformer
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs_embeds
,
use_cache
=
use_cache
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
hidden_states
=
transformer_outputs
[
0
]
if
return_last_logit
:
hidden_states
=
hidden_states
[
-
1
:]
lm_logits
=
self
.
transformer
.
output_layer
(
hidden_states
)
lm_logits
=
lm_logits
.
transpose
(
0
,
1
).
contiguous
()
loss
=
None
if
labels
is
not
None
:
lm_logits
=
lm_logits
.
to
(
torch
.
float32
)
# Shift so that tokens < n predict n
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
labels
[...,
1
:].
contiguous
()
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
100
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
lm_logits
=
lm_logits
.
to
(
hidden_states
.
dtype
)
loss
=
loss
.
to
(
hidden_states
.
dtype
)
if
not
return_dict
:
output
=
(
lm_logits
,)
+
transformer_outputs
[
1
:]
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
return
CausalLMOutputWithPast
(
loss
=
loss
,
logits
=
lm_logits
,
past_key_values
=
transformer_outputs
.
past_key_values
,
hidden_states
=
transformer_outputs
.
hidden_states
,
attentions
=
transformer_outputs
.
attentions
,
)
@
staticmethod
def
_reorder_cache
(
past
:
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...],
beam_idx
:
torch
.
LongTensor
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
Output shares the same memory storage as `past`.
"""
return
tuple
(
(
layer_past
[
0
].
index_select
(
1
,
beam_idx
.
to
(
layer_past
[
0
].
device
)),
layer_past
[
1
].
index_select
(
1
,
beam_idx
.
to
(
layer_past
[
1
].
device
)),
)
for
layer_past
in
past
)
def
process_response
(
self
,
response
):
response
=
response
.
strip
()
response
=
response
.
replace
(
"[[训练时间]]"
,
"2023年"
)
return
response
def
build_inputs
(
self
,
tokenizer
,
query
:
str
,
history
:
List
[
Tuple
[
str
,
str
]]
=
None
):
prompt
=
tokenizer
.
build_prompt
(
query
,
history
=
history
)
inputs
=
tokenizer
([
prompt
],
return_tensors
=
"pt"
)
inputs
=
inputs
.
to
(
self
.
device
)
return
inputs
def
build_stream_inputs
(
self
,
tokenizer
,
query
:
str
,
history
:
List
[
Tuple
[
str
,
str
]]
=
None
):
if
history
:
prompt
=
"
\n\n
[Round {}]
\n\n
问:{}
\n\n
答:"
.
format
(
len
(
history
)
+
1
,
query
)
input_ids
=
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
False
)
input_ids
=
input_ids
[
1
:]
inputs
=
tokenizer
.
batch_encode_plus
([(
input_ids
,
None
)],
return_tensors
=
"pt"
,
add_special_tokens
=
False
)
else
:
prompt
=
"[Round {}]
\n\n
问:{}
\n\n
答:"
.
format
(
len
(
history
)
+
1
,
query
)
inputs
=
tokenizer
([
prompt
],
return_tensors
=
"pt"
)
inputs
=
inputs
.
to
(
self
.
device
)
return
inputs
@
torch
.
no_grad
()
def
chat
(
self
,
tokenizer
,
query
:
str
,
history
:
List
[
Tuple
[
str
,
str
]]
=
None
,
max_length
:
int
=
8192
,
num_beams
=
1
,
do_sample
=
True
,
top_p
=
0.8
,
temperature
=
0.8
,
logits_processor
=
None
,
**
kwargs
):
if
history
is
None
:
history
=
[]
if
logits_processor
is
None
:
logits_processor
=
LogitsProcessorList
()
logits_processor
.
append
(
InvalidScoreLogitsProcessor
())
gen_kwargs
=
{
"max_length"
:
max_length
,
"num_beams"
:
num_beams
,
"do_sample"
:
do_sample
,
"top_p"
:
top_p
,
"temperature"
:
temperature
,
"logits_processor"
:
logits_processor
,
**
kwargs
}
inputs
=
self
.
build_inputs
(
tokenizer
,
query
,
history
=
history
)
outputs
=
self
.
generate
(
**
inputs
,
**
gen_kwargs
)
outputs
=
outputs
.
tolist
()[
0
][
len
(
inputs
[
"input_ids"
][
0
]):]
response
=
tokenizer
.
decode
(
outputs
)
response
=
self
.
process_response
(
response
)
history
=
history
+
[(
query
,
response
)]
return
response
,
history
@
torch
.
no_grad
()
def
stream_chat
(
self
,
tokenizer
,
query
:
str
,
history
:
List
[
Tuple
[
str
,
str
]]
=
None
,
past_key_values
=
None
,
max_length
:
int
=
8192
,
do_sample
=
True
,
top_p
=
0.8
,
temperature
=
0.8
,
logits_processor
=
None
,
return_past_key_values
=
False
,
**
kwargs
):
if
history
is
None
:
history
=
[]
if
logits_processor
is
None
:
logits_processor
=
LogitsProcessorList
()
logits_processor
.
append
(
InvalidScoreLogitsProcessor
())
gen_kwargs
=
{
"max_length"
:
max_length
,
"do_sample"
:
do_sample
,
"top_p"
:
top_p
,
"temperature"
:
temperature
,
"logits_processor"
:
logits_processor
,
**
kwargs
}
if
past_key_values
is
None
and
not
return_past_key_values
:
inputs
=
self
.
build_inputs
(
tokenizer
,
query
,
history
=
history
)
else
:
inputs
=
self
.
build_stream_inputs
(
tokenizer
,
query
,
history
=
history
)
if
past_key_values
is
not
None
:
past_length
=
past_key_values
[
0
][
0
].
shape
[
0
]
if
self
.
transformer
.
pre_seq_len
is
not
None
:
past_length
-=
self
.
transformer
.
pre_seq_len
inputs
.
position_ids
+=
past_length
attention_mask
=
inputs
.
attention_mask
attention_mask
=
torch
.
cat
((
attention_mask
.
new_ones
(
1
,
past_length
),
attention_mask
),
dim
=
1
)
inputs
[
'attention_mask'
]
=
attention_mask
for
outputs
in
self
.
stream_generate
(
**
inputs
,
past_key_values
=
past_key_values
,
return_past_key_values
=
return_past_key_values
,
**
gen_kwargs
):
if
return_past_key_values
:
outputs
,
past_key_values
=
outputs
outputs
=
outputs
.
tolist
()[
0
][
len
(
inputs
[
"input_ids"
][
0
]):]
response
=
tokenizer
.
decode
(
outputs
)
if
response
and
response
[
-
1
]
!=
"�"
:
response
=
self
.
process_response
(
response
)
new_history
=
history
+
[(
query
,
response
)]
if
return_past_key_values
:
yield
response
,
new_history
,
past_key_values
else
:
yield
response
,
new_history
@
torch
.
no_grad
()
def
stream_generate
(
self
,
input_ids
,
generation_config
:
Optional
[
GenerationConfig
]
=
None
,
logits_processor
:
Optional
[
LogitsProcessorList
]
=
None
,
stopping_criteria
:
Optional
[
StoppingCriteriaList
]
=
None
,
prefix_allowed_tokens_fn
:
Optional
[
Callable
[[
int
,
torch
.
Tensor
],
List
[
int
]]]
=
None
,
return_past_key_values
=
False
,
**
kwargs
,
):
batch_size
,
input_ids_seq_length
=
input_ids
.
shape
[
0
],
input_ids
.
shape
[
-
1
]
if
generation_config
is
None
:
generation_config
=
self
.
generation_config
generation_config
=
copy
.
deepcopy
(
generation_config
)
model_kwargs
=
generation_config
.
update
(
**
kwargs
)
bos_token_id
,
eos_token_id
=
generation_config
.
bos_token_id
,
generation_config
.
eos_token_id
if
isinstance
(
eos_token_id
,
int
):
eos_token_id
=
[
eos_token_id
]
has_default_max_length
=
kwargs
.
get
(
"max_length"
)
is
None
and
generation_config
.
max_length
is
not
None
if
has_default_max_length
and
generation_config
.
max_new_tokens
is
None
:
warnings
.
warn
(
f
"Using `max_length`'s default (
{
generation_config
.
max_length
}
) to control the generation length. "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation."
,
UserWarning
,
)
elif
generation_config
.
max_new_tokens
is
not
None
:
generation_config
.
max_length
=
generation_config
.
max_new_tokens
+
input_ids_seq_length
if
not
has_default_max_length
:
logger
.
warn
(
f
"Both `max_new_tokens` (=
{
generation_config
.
max_new_tokens
}
) and `max_length`(="
f
"
{
generation_config
.
max_length
}
) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
,
UserWarning
,
)
if
input_ids_seq_length
>=
generation_config
.
max_length
:
input_ids_string
=
"decoder_input_ids"
if
self
.
config
.
is_encoder_decoder
else
"input_ids"
logger
.
warning
(
f
"Input length of
{
input_ids_string
}
is
{
input_ids_seq_length
}
, but `max_length` is set to"
f
"
{
generation_config
.
max_length
}
. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
# 2. Set generation parameters if not already defined
logits_processor
=
logits_processor
if
logits_processor
is
not
None
else
LogitsProcessorList
()
stopping_criteria
=
stopping_criteria
if
stopping_criteria
is
not
None
else
StoppingCriteriaList
()
logits_processor
=
self
.
_get_logits_processor
(
generation_config
=
generation_config
,
input_ids_seq_length
=
input_ids_seq_length
,
encoder_input_ids
=
input_ids
,
prefix_allowed_tokens_fn
=
prefix_allowed_tokens_fn
,
logits_processor
=
logits_processor
,
)
stopping_criteria
=
self
.
_get_stopping_criteria
(
generation_config
=
generation_config
,
stopping_criteria
=
stopping_criteria
)
logits_warper
=
self
.
_get_logits_warper
(
generation_config
)
unfinished_sequences
=
input_ids
.
new
(
input_ids
.
shape
[
0
]).
fill_
(
1
)
scores
=
None
while
True
:
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
**
model_kwargs
)
# forward pass to get next token
outputs
=
self
(
**
model_inputs
,
return_dict
=
True
,
output_attentions
=
False
,
output_hidden_states
=
False
,
)
next_token_logits
=
outputs
.
logits
[:,
-
1
,
:]
# pre-process distribution
next_token_scores
=
logits_processor
(
input_ids
,
next_token_logits
)
next_token_scores
=
logits_warper
(
input_ids
,
next_token_scores
)
# sample
probs
=
nn
.
functional
.
softmax
(
next_token_scores
,
dim
=-
1
)
if
generation_config
.
do_sample
:
next_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
1
).
squeeze
(
1
)
else
:
next_tokens
=
torch
.
argmax
(
probs
,
dim
=-
1
)
# update generated ids, model inputs, and length for next step
input_ids
=
torch
.
cat
([
input_ids
,
next_tokens
[:,
None
]],
dim
=-
1
)
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
)
unfinished_sequences
=
unfinished_sequences
.
mul
((
sum
(
next_tokens
!=
i
for
i
in
eos_token_id
)).
long
())
if
return_past_key_values
:
yield
input_ids
,
outputs
.
past_key_values
else
:
yield
input_ids
# stop when each sentence is finished, or if we exceed the maximum length
if
unfinished_sequences
.
max
()
==
0
or
stopping_criteria
(
input_ids
,
scores
):
break
def
quantize
(
self
,
bits
:
int
,
empty_init
=
False
,
device
=
None
,
**
kwargs
):
if
bits
==
0
:
return
from
.quantization
import
quantize
if
self
.
quantized
:
logger
.
info
(
"Already quantized."
)
return
self
self
.
quantized
=
True
self
.
config
.
quantization_bit
=
bits
self
.
transformer
.
encoder
=
quantize
(
self
.
transformer
.
encoder
,
bits
,
empty_init
=
empty_init
,
device
=
device
,
**
kwargs
)
return
self
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment