Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
fe851fbc
Commit
fe851fbc
authored
Mar 24, 2024
by
zhouxiang
Browse files
0.2.6版本新增文件补充
parent
e2d98ddc
Changes
220
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
8718 additions
and
0 deletions
+8718
-0
lmdeploy/pytorch/modeling/modeling_baichuan.py
lmdeploy/pytorch/modeling/modeling_baichuan.py
+824
-0
lmdeploy/pytorch/modeling/modeling_internlm.py
lmdeploy/pytorch/modeling/modeling_internlm.py
+1171
-0
lmdeploy/pytorch/modeling/modeling_internlm2.py
lmdeploy/pytorch/modeling/modeling_internlm2.py
+1372
-0
lmdeploy/pytorch/modeling/modeling_llama.py
lmdeploy/pytorch/modeling/modeling_llama.py
+1297
-0
lmdeploy/pytorch/models/__init__.py
lmdeploy/pytorch/models/__init__.py
+5
-0
lmdeploy/pytorch/models/baichuan.py
lmdeploy/pytorch/models/baichuan.py
+418
-0
lmdeploy/pytorch/models/chatglm2.py
lmdeploy/pytorch/models/chatglm2.py
+364
-0
lmdeploy/pytorch/models/deepseek.py
lmdeploy/pytorch/models/deepseek.py
+139
-0
lmdeploy/pytorch/models/falcon.py
lmdeploy/pytorch/models/falcon.py
+373
-0
lmdeploy/pytorch/models/functional.py
lmdeploy/pytorch/models/functional.py
+363
-0
lmdeploy/pytorch/models/gemma.py
lmdeploy/pytorch/models/gemma.py
+233
-0
lmdeploy/pytorch/models/internlm.py
lmdeploy/pytorch/models/internlm.py
+136
-0
lmdeploy/pytorch/models/internlm2.py
lmdeploy/pytorch/models/internlm2.py
+231
-0
lmdeploy/pytorch/models/llama.py
lmdeploy/pytorch/models/llama.py
+456
-0
lmdeploy/pytorch/models/mistral.py
lmdeploy/pytorch/models/mistral.py
+145
-0
lmdeploy/pytorch/models/mixtral.py
lmdeploy/pytorch/models/mixtral.py
+278
-0
lmdeploy/pytorch/models/module_map.py
lmdeploy/pytorch/models/module_map.py
+186
-0
lmdeploy/pytorch/models/patch.py
lmdeploy/pytorch/models/patch.py
+290
-0
lmdeploy/pytorch/models/peft.py
lmdeploy/pytorch/models/peft.py
+282
-0
lmdeploy/pytorch/models/q_modules.py
lmdeploy/pytorch/models/q_modules.py
+155
-0
No files found.
lmdeploy/pytorch/modeling/modeling_baichuan.py
0 → 100644
View file @
fe851fbc
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.utils.checkpoint
from
torch
import
nn
from
torch.nn
import
CrossEntropyLoss
from
transformers
import
PreTrainedModel
from
transformers.activations
import
ACT2FN
from
transformers.modeling_outputs
import
(
BaseModelOutputWithPast
,
CausalLMOutputWithPast
)
from
lmdeploy.pytorch.modeling.convert_to_qmodules
import
convert_to_qmodules
from
lmdeploy.utils
import
get_logger
from
.configuration_baichuan
import
BaiChuanConfig
logger
=
get_logger
(
'lmdeploy'
)
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def
_make_causal_mask
(
input_ids_shape
:
torch
.
Size
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
past_key_values_length
:
int
=
0
):
"""Make causal mask used for bi-directional self-attention."""
bsz
,
tgt_len
=
input_ids_shape
mask
=
torch
.
full
((
tgt_len
,
tgt_len
),
torch
.
tensor
(
torch
.
finfo
(
dtype
).
min
,
device
=
device
),
device
=
device
)
mask_cond
=
torch
.
arange
(
mask
.
size
(
-
1
),
device
=
device
)
mask
.
masked_fill_
(
mask_cond
<
(
mask_cond
+
1
).
view
(
mask
.
size
(
-
1
),
1
),
0
)
mask
=
mask
.
to
(
dtype
)
if
past_key_values_length
>
0
:
mask
=
torch
.
cat
([
torch
.
zeros
(
tgt_len
,
past_key_values_length
,
dtype
=
dtype
,
device
=
device
),
mask
],
dim
=-
1
)
return
mask
[
None
,
None
,
:,
:].
expand
(
bsz
,
1
,
tgt_len
,
tgt_len
+
past_key_values_length
)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def
_expand_mask
(
mask
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
tgt_len
:
Optional
[
int
]
=
None
):
"""Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len,
src_seq_len]`."""
bsz
,
src_len
=
mask
.
size
()
tgt_len
=
tgt_len
if
tgt_len
is
not
None
else
src_len
expanded_mask
=
mask
[:,
None
,
None
,
:].
expand
(
bsz
,
1
,
tgt_len
,
src_len
).
to
(
dtype
)
inverted_mask
=
1.0
-
expanded_mask
return
inverted_mask
.
masked_fill
(
inverted_mask
.
to
(
torch
.
bool
),
torch
.
finfo
(
dtype
).
min
)
class
RMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
"""RMSNorm is equivalent to T5LayerNorm."""
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
hidden_states
):
variance
=
hidden_states
.
to
(
torch
.
float32
).
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
# convert into half-precision if necessary
if
self
.
weight
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]:
hidden_states
=
hidden_states
.
to
(
self
.
weight
.
dtype
)
return
self
.
weight
*
hidden_states
class
RotaryEmbedding
(
torch
.
nn
.
Module
):
"""RotaryEmbedding for Baichuan Model.
This module generates sine and cosine positional encodings based on
the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding".
The purpose of this class is to provide positional embeddings to the
input tensors. It utilizes a cache mechanism to store precomputed
sine and cosine values for speedup.
Args:
dim (int): The dimensionality of the embeddings.
max_position_embeddings (int, optional): The maximum number of
position embeddings. Default is 2048.
base (int, optional): The base value for the inverse frequency
calculation. Default is 10000.
device (str, optional): The device to run operations on.
If None, defaults to the device of the model.
"""
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
):
super
().
__init__
()
index
=
(
torch
.
arange
(
0
,
dim
,
2
).
float
().
to
(
device
)
/
dim
)
inv_freq
=
1.0
/
(
base
**
index
)
self
.
register_buffer
(
'inv_freq'
,
inv_freq
)
# Build here to make `torch.jit.trace` work.
self
.
max_seq_len_cached
=
max_position_embeddings
t
=
torch
.
arange
(
self
.
max_seq_len_cached
,
device
=
self
.
inv_freq
.
device
,
dtype
=
self
.
inv_freq
.
dtype
)
freqs
=
torch
.
einsum
(
'i,j->ij'
,
t
,
self
.
inv_freq
)
# Different from paper, but it uses a different permutation in order
# to obtain the same calculation
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
self
.
register_buffer
(
'cos_cached'
,
emb
.
cos
()[
None
,
None
,
:,
:],
persistent
=
False
)
self
.
register_buffer
(
'sin_cached'
,
emb
.
sin
()[
None
,
None
,
:,
:],
persistent
=
False
)
def
forward
(
self
,
x
,
seq_len
=
None
):
"""Forward propagation method for the embedding layer.
Generates positional embeddings for the given input tensor.
"""
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in
# `__init__`. Keep the logic here just in case.
if
seq_len
>
self
.
max_seq_len_cached
:
self
.
max_seq_len_cached
=
seq_len
t
=
torch
.
arange
(
self
.
max_seq_len_cached
,
device
=
x
.
device
,
dtype
=
self
.
inv_freq
.
dtype
)
freqs
=
torch
.
einsum
(
'i,j->ij'
,
t
,
self
.
inv_freq
)
# Different from paper, but it uses a different permutation in
# order to obtain the same calculation
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
).
to
(
x
.
device
)
self
.
register_buffer
(
'cos_cached'
,
emb
.
cos
()[
None
,
None
,
:,
:],
persistent
=
False
)
self
.
register_buffer
(
'sin_cached'
,
emb
.
sin
()[
None
,
None
,
:,
:],
persistent
=
False
)
return
(
self
.
cos_cached
[:,
:,
:
seq_len
,
...].
to
(
dtype
=
x
.
dtype
),
self
.
sin_cached
[:,
:,
:
seq_len
,
...].
to
(
dtype
=
x
.
dtype
),
)
def
rotate_half
(
x
):
"""Rotates half the hidden dims of the input."""
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
position_ids
):
"""Apply rotary positional embeddings to query and key tensors.
This function applies the cosine and sine positional embeddings on the
input query (q) and key (k) tensors using element-wise multiplication and
addition.
"""
# The first two dimensions of cos and sin are always 1,
# so we can `squeeze` them.
cos
=
cos
.
squeeze
(
1
).
squeeze
(
0
)
# [seq_len, dim]
sin
=
sin
.
squeeze
(
1
).
squeeze
(
0
)
# [seq_len, dim]
cos
=
cos
[
position_ids
].
unsqueeze
(
1
)
# [bs, 1, seq_len, dim]
sin
=
sin
[
position_ids
].
unsqueeze
(
1
)
# [bs, 1, seq_len, dim]
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q_embed
,
k_embed
class
MLP
(
nn
.
Module
):
"""MLP for Baichuan Model."""
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
):
super
().
__init__
()
self
.
gate_proj
=
nn
.
Linear
(
hidden_size
,
intermediate_size
,
bias
=
False
)
self
.
down_proj
=
nn
.
Linear
(
intermediate_size
,
hidden_size
,
bias
=
False
)
self
.
up_proj
=
nn
.
Linear
(
hidden_size
,
intermediate_size
,
bias
=
False
)
self
.
act_fn
=
ACT2FN
[
hidden_act
]
def
forward
(
self
,
x
):
return
self
.
down_proj
(
self
.
act_fn
(
self
.
gate_proj
(
x
))
*
self
.
up_proj
(
x
))
class
Attention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper."""
def
__init__
(
self
,
config
:
BaiChuanConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_dim
=
self
.
hidden_size
//
self
.
num_heads
self
.
max_position_embeddings
=
config
.
max_position_embeddings
if
(
self
.
head_dim
*
self
.
num_heads
)
!=
self
.
hidden_size
:
raise
ValueError
(
'hidden_size must be divisible by num_heads '
f
'(got `hidden_size`:
{
self
.
hidden_size
}
'
f
' and `num_heads`:
{
self
.
num_heads
}
).'
)
self
.
W_pack
=
nn
.
Linear
(
self
.
hidden_size
,
3
*
self
.
hidden_size
,
bias
=
False
)
self
.
o_proj
=
nn
.
Linear
(
self
.
num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
bias
=
False
)
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
head_dim
,
max_position_embeddings
=
self
.
max_position_embeddings
)
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
).
contiguous
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""Forward propagation method for the attention layer."""
bsz
,
q_len
,
_
=
hidden_states
.
size
()
proj
=
self
.
W_pack
(
hidden_states
)
proj
=
proj
.
unflatten
(
-
1
,
(
3
,
self
.
hidden_size
)).
unsqueeze
(
0
).
transpose
(
0
,
-
2
).
squeeze
(
-
2
)
query_states
=
proj
[
0
].
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
# batch_size x source_len x hidden_size
key_states
=
proj
[
1
].
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
# batch_size x target_len x head_size
value_states
=
proj
[
2
].
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
# batch_size x source_len x hidden_size
kv_seq_len
=
key_states
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
# [bsz, nh, t, hd]
if
past_key_value
is
not
None
:
# reuse k, v, self_attention
key_states
=
torch
.
cat
([
past_key_value
[
0
],
key_states
],
dim
=
2
)
value_states
=
torch
.
cat
([
past_key_value
[
1
],
value_states
],
dim
=
2
)
past_key_value
=
(
key_states
,
value_states
)
if
use_cache
else
None
attn_weights
=
torch
.
matmul
(
query_states
,
key_states
.
transpose
(
2
,
3
))
/
math
.
sqrt
(
self
.
head_dim
)
if
attn_weights
.
size
()
!=
(
bsz
,
self
.
num_heads
,
q_len
,
kv_seq_len
):
raise
ValueError
(
'Attention weights should be of size '
f
'
{
(
bsz
,
self
.
num_heads
,
q_len
,
kv_seq_len
)
}
, but is'
f
'
{
attn_weights
.
size
()
}
'
)
if
attention_mask
is
not
None
:
if
attention_mask
.
size
()
!=
(
bsz
,
1
,
q_len
,
kv_seq_len
):
raise
ValueError
(
'Attention mask should be of size '
f
'
{
(
bsz
,
1
,
q_len
,
kv_seq_len
)
}
,'
f
' but is
{
attention_mask
.
size
()
}
'
)
attn_weights
=
attn_weights
+
attention_mask
attn_weights
=
torch
.
max
(
attn_weights
,
torch
.
tensor
(
torch
.
finfo
(
attn_weights
.
dtype
).
min
))
# upcast attention to fp32
attn_weights
=
nn
.
functional
.
softmax
(
attn_weights
,
dim
=-
1
,
dtype
=
torch
.
float32
).
to
(
query_states
.
dtype
)
attn_output
=
torch
.
matmul
(
attn_weights
,
value_states
)
if
attn_output
.
size
()
!=
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
head_dim
):
raise
ValueError
(
'`attn_output` should be of size '
f
'
{
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
head_dim
)
}
, but is'
f
'
{
attn_output
.
size
()
}
'
)
attn_output
=
attn_output
.
transpose
(
1
,
2
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
hidden_size
)
attn_output
=
self
.
o_proj
(
attn_output
)
if
not
output_attentions
:
attn_weights
=
None
return
attn_output
,
attn_weights
,
past_key_value
class
DecoderLayer
(
nn
.
Module
):
"""Decoder layer for Baichuan Model."""
def
__init__
(
self
,
config
:
BaiChuanConfig
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
Attention
(
config
=
config
)
self
.
mlp
=
MLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
Optional
[
bool
]
=
False
,
use_cache
:
Optional
[
bool
]
=
False
,
)
->
Tuple
[
torch
.
FloatTensor
,
Optional
[
Tuple
[
torch
.
FloatTensor
,
torch
.
FloatTensor
]]]:
""" # noqa: E501
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
# Self Attention
hidden_states
,
self_attn_weights
,
present_key_value
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
use_cache
=
use_cache
,
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
outputs
=
(
hidden_states
,
)
if
output_attentions
:
outputs
+=
(
self_attn_weights
,
)
if
use_cache
:
outputs
+=
(
present_key_value
,
)
return
outputs
class
PreTrainedModel
(
PreTrainedModel
):
config_class
=
BaiChuanConfig
base_model_prefix
=
'model'
supports_gradient_checkpointing
=
True
_no_split_modules
=
[
'DecoderLayer'
]
_keys_to_ignore_on_load_unexpected
=
[
r
'decoder\.version'
]
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
if
isinstance
(
module
,
nn
.
Linear
):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
std
)
if
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
nn
.
Embedding
):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
std
)
if
module
.
padding_idx
is
not
None
:
module
.
weight
.
data
[
module
.
padding_idx
].
zero_
()
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
Model
):
module
.
gradient_checkpointing
=
value
class
Model
(
PreTrainedModel
):
"""Transformer decoder consisting of *config.num_hidden_layers* layers.
Each layer is a [`DecoderLayer`]
Args:
config: BaiChuanConfig
"""
def
__init__
(
self
,
config
:
BaiChuanConfig
):
super
().
__init__
(
config
)
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
hidden_size
,
self
.
padding_idx
)
self
.
layers
=
nn
.
ModuleList
(
[
DecoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
gradient_checkpointing
=
False
# Initialize weights and apply final processing
self
.
post_init
()
def
get_input_embeddings
(
self
):
return
self
.
embed_tokens
def
set_input_embeddings
(
self
,
value
):
self
.
embed_tokens
=
value
# Copied from transformers.models.bart.modeling_bart.BartDecoder.
# prepare_decoder_attention_mask
def
_prepare_decoder_attention_mask
(
self
,
attention_mask
,
input_shape
,
inputs_embeds
,
past_key_values_length
):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask
=
None
if
input_shape
[
-
1
]
>
1
:
combined_attention_mask
=
_make_causal_mask
(
input_shape
,
inputs_embeds
.
dtype
,
device
=
inputs_embeds
.
device
,
past_key_values_length
=
past_key_values_length
,
)
if
attention_mask
is
not
None
:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask
=
_expand_mask
(
attention_mask
,
inputs_embeds
.
dtype
,
tgt_len
=
input_shape
[
-
1
]).
to
(
inputs_embeds
.
device
)
combined_attention_mask
=
(
expanded_attn_mask
if
combined_attention_mask
is
None
else
expanded_attn_mask
+
combined_attention_mask
)
return
combined_attention_mask
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPast
]:
output_attentions
=
(
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
)
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
use_cache
=
(
use_cache
if
use_cache
is
not
None
else
self
.
config
.
use_cache
)
return_dict
=
(
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
)
# retrieve input_ids and inputs_embeds
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
'You cannot specify both decoder_input_ids '
'and decoder_inputs_embeds at the same time'
)
elif
input_ids
is
not
None
:
batch_size
,
seq_length
=
input_ids
.
shape
elif
inputs_embeds
is
not
None
:
batch_size
,
seq_length
,
_
=
inputs_embeds
.
shape
else
:
raise
ValueError
(
'You have to specify either decoder_input_ids '
'or decoder_inputs_embeds'
)
seq_length_with_past
=
seq_length
past_key_values_length
=
0
if
past_key_values
is
not
None
:
past_key_values_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
seq_length_with_past
=
(
seq_length_with_past
+
past_key_values_length
)
if
position_ids
is
None
:
device
=
(
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
)
position_ids
=
torch
.
arange
(
past_key_values_length
,
seq_length
+
past_key_values_length
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
view
(
-
1
,
seq_length
)
else
:
position_ids
=
position_ids
.
view
(
-
1
,
seq_length
).
long
()
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
# embed positions
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones
((
batch_size
,
seq_length_with_past
),
dtype
=
torch
.
bool
,
device
=
inputs_embeds
.
device
)
attention_mask
=
self
.
_prepare_decoder_attention_mask
(
attention_mask
,
(
batch_size
,
seq_length
),
inputs_embeds
,
past_key_values_length
)
hidden_states
=
inputs_embeds
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
'`use_cache=True` is incompatible with gradient '
'checkpointing. Setting `use_cache=False`...'
)
use_cache
=
False
# decoder layers
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attns
=
()
if
output_attentions
else
None
next_decoder_cache
=
()
if
use_cache
else
None
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,
)
past_key_value
=
past_key_values
[
idx
]
if
past_key_values
is
not
None
else
None
if
self
.
gradient_checkpointing
and
self
.
training
:
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
# None for past_key_value
return
module
(
*
inputs
,
output_attentions
,
None
)
return
custom_forward
layer_outputs
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
decoder_layer
),
hidden_states
,
attention_mask
,
position_ids
,
None
,
)
else
:
layer_outputs
=
decoder_layer
(
hidden_states
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
use_cache
=
use_cache
,
)
hidden_states
=
layer_outputs
[
0
]
if
use_cache
:
next_decoder_cache
+=
(
layer_outputs
[
2
if
output_attentions
else
1
],
)
if
output_attentions
:
all_self_attns
+=
(
layer_outputs
[
1
],
)
hidden_states
=
self
.
norm
(
hidden_states
)
# add hidden states from the last decoder layer
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,
)
next_cache
=
next_decoder_cache
if
use_cache
else
None
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
BaseModelOutputWithPast
(
last_hidden_state
=
hidden_states
,
past_key_values
=
next_cache
,
hidden_states
=
all_hidden_states
,
attentions
=
all_self_attns
,
)
class
BaiChuanForCausalLM
(
PreTrainedModel
):
"""This class extends the `PreTrainedModel` to enable causal language
modeling.
It wraps the basic Baichuan model (`Model`) and includes a linear layer as
a language model head (`lm_head`). The purpose is to predict token
probabilities, given the previous tokens in the sequence.
"""
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
model
=
Model
(
config
)
self
.
lm_head
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
# Initialize weights and apply final processing
self
.
post_init
()
convert_to_qmodules
(
self
)
def
get_input_embeddings
(
self
):
"""Get the token embedding layer."""
return
self
.
model
.
embed_tokens
def
set_input_embeddings
(
self
,
value
):
"""Set the token embedding layer."""
self
.
model
.
embed_tokens
=
value
def
get_output_embeddings
(
self
):
"""Get the output embedding layer."""
return
self
.
lm_head
def
set_output_embeddings
(
self
,
new_embeddings
):
"""Set the output embedding layer."""
self
.
lm_head
=
new_embeddings
def
set_decoder
(
self
,
decoder
):
"""Set the decoder model."""
self
.
model
=
decoder
def
get_decoder
(
self
):
"""Get the decoder model."""
return
self
.
model
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
CausalLMOutputWithPast
]:
r
""" # noqa: E501
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, ModelForCausalLM
>>> model = ModelForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you consciours? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```"""
output_attentions
=
(
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
)
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict
=
(
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
)
# decoder outputs consists of
# (dec_features, layer_state, dec_hidden, dec_attn)
outputs
=
self
.
model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs_embeds
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
hidden_states
=
outputs
[
0
]
logits
=
self
.
lm_head
(
hidden_states
)
loss
=
None
if
labels
is
not
None
:
# Shift so that tokens < n predict n
shift_logits
=
logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
labels
[...,
1
:].
contiguous
()
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
()
shift_logits
=
shift_logits
.
view
(
-
1
,
self
.
config
.
vocab_size
)
shift_labels
=
shift_labels
.
view
(
-
1
)
# Enable model parallelism
shift_labels
=
shift_labels
.
to
(
shift_logits
.
device
)
loss
=
loss_fct
(
shift_logits
,
shift_labels
)
if
not
return_dict
:
output
=
(
logits
,
)
+
outputs
[
1
:]
return
(
loss
,
)
+
output
if
loss
is
not
None
else
output
return
CausalLMOutputWithPast
(
loss
=
loss
,
logits
=
logits
,
past_key_values
=
outputs
.
past_key_values
,
hidden_states
=
outputs
.
hidden_states
,
attentions
=
outputs
.
attentions
,
)
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past_key_values
=
None
,
attention_mask
=
None
,
inputs_embeds
=
None
,
**
kwargs
):
"""Prepare inputs for generating sequences using the model.
Args:
input_ids (torch.Tensor): Input token ids.
past_key_values (list[torch.Tensor], optional): List of past key
and value states.
attention_mask (torch.Tensor, optional): Mask indicating which
tokens should be attended to.
inputs_embeds (torch.FloatTensor, optional): Optionally,
the input embeddings instead of token ids.
Returns:
dict: Dictionary containing prepared inputs for model generation.
"""
if
past_key_values
:
input_ids
=
input_ids
[:,
-
1
:]
position_ids
=
kwargs
.
get
(
'position_ids'
,
None
)
if
attention_mask
is
not
None
and
position_ids
is
None
:
# create position_ids on the fly for batch generation
position_ids
=
attention_mask
.
long
().
cumsum
(
-
1
)
-
1
position_ids
.
masked_fill_
(
attention_mask
==
0
,
1
)
if
past_key_values
:
position_ids
=
position_ids
[:,
-
1
].
unsqueeze
(
-
1
)
# if `inputs_embeds` are passed,
# we only want to use them in the 1st generation step
if
inputs_embeds
is
not
None
and
past_key_values
is
None
:
model_inputs
=
{
'inputs_embeds'
:
inputs_embeds
}
else
:
model_inputs
=
{
'input_ids'
:
input_ids
}
model_inputs
.
update
({
'position_ids'
:
position_ids
,
'past_key_values'
:
past_key_values
,
'use_cache'
:
kwargs
.
get
(
'use_cache'
),
'attention_mask'
:
attention_mask
,
})
return
model_inputs
@
staticmethod
def
_reorder_cache
(
past_key_values
,
beam_idx
):
"""Reorder cached past key-values during generation using beam search.
This function reorders the cached past key-values according to the
given indices. It's useful in beam search where the order of hypotheses
can change from one time-step to another.
"""
reordered_past
=
()
for
layer_past
in
past_key_values
:
reordered_past
+=
(
tuple
(
past_state
.
index_select
(
0
,
beam_idx
)
for
past_state
in
layer_past
),
)
return
reordered_past
lmdeploy/pytorch/modeling/modeling_internlm.py
0 → 100644
View file @
fe851fbc
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch InternLM model."""
import
math
import
queue
import
threading
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.utils.checkpoint
from
torch
import
nn
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
from
transformers.activations
import
ACT2FN
from
transformers.generation.streamers
import
BaseStreamer
from
transformers.modeling_outputs
import
(
BaseModelOutputWithPast
,
CausalLMOutputWithPast
,
SequenceClassifierOutputWithPast
)
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.utils
import
(
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
replace_return_docstrings
)
from
lmdeploy.pytorch.modeling.convert_to_qmodules
import
convert_to_qmodules
from
lmdeploy.utils
import
get_logger
from
.configuration_internlm
import
InternLMConfig
logger
=
get_logger
(
'lmdeploy'
)
_CONFIG_FOR_DOC
=
'InternLMConfig'
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def
_make_causal_mask
(
input_ids_shape
:
torch
.
Size
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
past_key_values_length
:
int
=
0
):
"""Make causal mask used for bi-directional self-attention."""
bsz
,
tgt_len
=
input_ids_shape
mask
=
torch
.
full
((
tgt_len
,
tgt_len
),
torch
.
tensor
(
torch
.
finfo
(
dtype
).
min
,
device
=
device
),
device
=
device
)
mask_cond
=
torch
.
arange
(
mask
.
size
(
-
1
),
device
=
device
)
mask
.
masked_fill_
(
mask_cond
<
(
mask_cond
+
1
).
view
(
mask
.
size
(
-
1
),
1
),
0
)
mask
=
mask
.
to
(
dtype
)
if
past_key_values_length
>
0
:
mask
=
torch
.
cat
([
torch
.
zeros
(
tgt_len
,
past_key_values_length
,
dtype
=
dtype
,
device
=
device
),
mask
],
dim
=-
1
)
return
mask
[
None
,
None
,
:,
:].
expand
(
bsz
,
1
,
tgt_len
,
tgt_len
+
past_key_values_length
)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def
_expand_mask
(
mask
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
tgt_len
:
Optional
[
int
]
=
None
):
"""Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len,
src_seq_len]`."""
bsz
,
src_len
=
mask
.
size
()
tgt_len
=
tgt_len
if
tgt_len
is
not
None
else
src_len
expanded_mask
=
mask
[:,
None
,
None
,
:].
expand
(
bsz
,
1
,
tgt_len
,
src_len
).
to
(
dtype
)
inverted_mask
=
1.0
-
expanded_mask
return
inverted_mask
.
masked_fill
(
inverted_mask
.
to
(
torch
.
bool
),
torch
.
finfo
(
dtype
).
min
)
class
InternLMRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
"""InternLMRMSNorm is equivalent to T5LayerNorm."""
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
hidden_states
):
variance
=
hidden_states
.
to
(
torch
.
float32
).
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
# convert into half-precision if necessary
if
self
.
weight
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]:
hidden_states
=
hidden_states
.
to
(
self
.
weight
.
dtype
)
return
self
.
weight
*
hidden_states
class
InternLMRotaryEmbedding
(
torch
.
nn
.
Module
):
"""RotaryEmbedding for InternLM Model.
This module generates sine and cosine positional encodings based on
the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding".
The purpose of this class is to provide positional embeddings to the
input tensors. It utilizes a cache mechanism to store precomputed
sine and cosine values for speedup.
Args:
dim (int): The dimensionality of the embeddings.
max_position_embeddings (int, optional): The maximum number of
position embeddings. Default is 2048.
base (int, optional): The base value for the inverse frequency
calculation. Default is 10000.
device (str, optional): The device to run operations on.
If None, defaults to the device of the model.
"""
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
):
super
().
__init__
()
index
=
(
torch
.
arange
(
0
,
dim
,
2
).
float
().
to
(
device
)
/
dim
)
inv_freq
=
1.0
/
(
base
**
index
)
self
.
register_buffer
(
'inv_freq'
,
inv_freq
,
persistent
=
False
)
# Build here to make `torch.jit.trace` work.
self
.
max_seq_len_cached
=
max_position_embeddings
t
=
torch
.
arange
(
self
.
max_seq_len_cached
,
device
=
self
.
inv_freq
.
device
,
dtype
=
self
.
inv_freq
.
dtype
)
freqs
=
torch
.
einsum
(
'i,j->ij'
,
t
,
self
.
inv_freq
)
# Different from paper, but it uses a different permutation in order
# to obtain the same calculation
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
self
.
register_buffer
(
'cos_cached'
,
emb
.
cos
()[
None
,
None
,
:,
:],
persistent
=
False
)
self
.
register_buffer
(
'sin_cached'
,
emb
.
sin
()[
None
,
None
,
:,
:],
persistent
=
False
)
def
forward
(
self
,
x
,
seq_len
=
None
):
"""Forward propagation method for the embedding layer.
Generates positional embeddings for the given input tensor.
"""
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in
# `__init__`. Keep the logic here just in case.
if
seq_len
>
self
.
max_seq_len_cached
:
self
.
max_seq_len_cached
=
seq_len
t
=
torch
.
arange
(
self
.
max_seq_len_cached
,
device
=
x
.
device
,
dtype
=
self
.
inv_freq
.
dtype
)
freqs
=
torch
.
einsum
(
'i,j->ij'
,
t
,
self
.
inv_freq
)
# Different from paper, but it uses a different permutation in
# order to obtain the same calculation
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
).
to
(
x
.
device
)
self
.
register_buffer
(
'cos_cached'
,
emb
.
cos
()[
None
,
None
,
:,
:],
persistent
=
False
)
self
.
register_buffer
(
'sin_cached'
,
emb
.
sin
()[
None
,
None
,
:,
:],
persistent
=
False
)
return
(
self
.
cos_cached
[:,
:,
:
seq_len
,
...].
to
(
dtype
=
x
.
dtype
),
self
.
sin_cached
[:,
:,
:
seq_len
,
...].
to
(
dtype
=
x
.
dtype
),
)
def
rotate_half
(
x
):
"""Rotates half the hidden dims of the input."""
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
position_ids
):
"""Apply rotary positional embeddings to query and key tensors.
This function applies the cosine and sine positional embeddings on the
input query (q) and key (k) tensors using element-wise multiplication and
addition.
"""
# The first two dimensions of cos and sin are always 1, so we can
# `squeeze` them.
cos
=
cos
.
squeeze
(
1
).
squeeze
(
0
)
# [seq_len, dim]
sin
=
sin
.
squeeze
(
1
).
squeeze
(
0
)
# [seq_len, dim]
cos
=
cos
[
position_ids
].
unsqueeze
(
1
)
# [bs, 1, seq_len, dim]
sin
=
sin
[
position_ids
].
unsqueeze
(
1
)
# [bs, 1, seq_len, dim]
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q_embed
,
k_embed
class
InternLMMLP
(
nn
.
Module
):
"""MLP for InternLM Model."""
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
):
super
().
__init__
()
self
.
gate_proj
=
nn
.
Linear
(
hidden_size
,
intermediate_size
,
bias
=
False
)
self
.
down_proj
=
nn
.
Linear
(
intermediate_size
,
hidden_size
,
bias
=
False
)
self
.
up_proj
=
nn
.
Linear
(
hidden_size
,
intermediate_size
,
bias
=
False
)
self
.
act_fn
=
ACT2FN
[
hidden_act
]
def
forward
(
self
,
x
):
return
self
.
down_proj
(
self
.
act_fn
(
self
.
gate_proj
(
x
))
*
self
.
up_proj
(
x
))
class
InternLMAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper."""
def
__init__
(
self
,
config
:
InternLMConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_dim
=
self
.
hidden_size
//
self
.
num_heads
self
.
max_position_embeddings
=
config
.
max_position_embeddings
if
(
self
.
head_dim
*
self
.
num_heads
)
!=
self
.
hidden_size
:
raise
ValueError
(
'hidden_size must be divisible by num_heads '
f
'(got `hidden_size`:
{
self
.
hidden_size
}
'
f
' and `num_heads`:
{
self
.
num_heads
}
).'
)
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_heads
*
self
.
head_dim
,
bias
=
config
.
bias
)
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_heads
*
self
.
head_dim
,
bias
=
config
.
bias
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_heads
*
self
.
head_dim
,
bias
=
config
.
bias
)
self
.
o_proj
=
nn
.
Linear
(
self
.
num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
bias
=
config
.
bias
)
self
.
rotary_emb
=
InternLMRotaryEmbedding
(
self
.
head_dim
,
max_position_embeddings
=
self
.
max_position_embeddings
)
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
).
contiguous
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""Forward propagation method for the attention layer."""
bsz
,
q_len
,
_
=
hidden_states
.
size
()
query_states
=
self
.
q_proj
(
hidden_states
).
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
self
.
k_proj
(
hidden_states
).
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
self
.
v_proj
(
hidden_states
).
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
kv_seq_len
=
key_states
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
# [bsz, nh, t, hd]
if
past_key_value
is
not
None
:
# reuse k, v, self_attention
key_states
=
torch
.
cat
([
past_key_value
[
0
],
key_states
],
dim
=
2
)
value_states
=
torch
.
cat
([
past_key_value
[
1
],
value_states
],
dim
=
2
)
past_key_value
=
(
key_states
,
value_states
)
if
use_cache
else
None
attn_weights
=
torch
.
matmul
(
query_states
,
key_states
.
transpose
(
2
,
3
))
/
math
.
sqrt
(
self
.
head_dim
)
if
attn_weights
.
size
()
!=
(
bsz
,
self
.
num_heads
,
q_len
,
kv_seq_len
):
raise
ValueError
(
'Attention weights should be of size '
f
'
{
(
bsz
,
self
.
num_heads
,
q_len
,
kv_seq_len
)
}
, but is'
f
'
{
attn_weights
.
size
()
}
'
)
if
attention_mask
is
not
None
:
if
attention_mask
.
size
()
!=
(
bsz
,
1
,
q_len
,
kv_seq_len
):
raise
ValueError
(
'Attention mask should be of size '
f
'
{
(
bsz
,
1
,
q_len
,
kv_seq_len
)
}
, '
f
'but is
{
attention_mask
.
size
()
}
'
)
attn_weights
=
attn_weights
+
attention_mask
attn_weights
=
torch
.
max
(
attn_weights
,
torch
.
tensor
(
torch
.
finfo
(
attn_weights
.
dtype
).
min
))
# upcast attention to fp32
attn_weights
=
nn
.
functional
.
softmax
(
attn_weights
,
dim
=-
1
,
dtype
=
torch
.
float32
).
to
(
query_states
.
dtype
)
attn_output
=
torch
.
matmul
(
attn_weights
,
value_states
)
if
attn_output
.
size
()
!=
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
head_dim
):
raise
ValueError
(
'attn_output` should be of size '
f
'`
{
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
head_dim
)
}
, but is'
f
'
{
attn_output
.
size
()
}
'
)
attn_output
=
attn_output
.
transpose
(
1
,
2
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
hidden_size
)
attn_output
=
self
.
o_proj
(
attn_output
)
if
not
output_attentions
:
attn_weights
=
None
return
attn_output
,
attn_weights
,
past_key_value
class
InternLMDecoderLayer
(
nn
.
Module
):
"""Decoder layer for InternLM Model."""
def
__init__
(
self
,
config
:
InternLMConfig
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
InternLMAttention
(
config
=
config
)
self
.
mlp
=
InternLMMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
)
self
.
input_layernorm
=
InternLMRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
InternLMRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
Optional
[
bool
]
=
False
,
use_cache
:
Optional
[
bool
]
=
False
,
)
->
Tuple
[
torch
.
FloatTensor
,
Optional
[
Tuple
[
torch
.
FloatTensor
,
torch
.
FloatTensor
]]]:
""" # noqa: E501
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
# Self Attention
hidden_states
,
self_attn_weights
,
present_key_value
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
use_cache
=
use_cache
,
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
outputs
=
(
hidden_states
,
)
if
output_attentions
:
outputs
+=
(
self_attn_weights
,
)
if
use_cache
:
outputs
+=
(
present_key_value
,
)
return
outputs
INTERNLM_START_DOCSTRING
=
r
""" # noqa: E501
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`InternLMConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@
add_start_docstrings
(
'The bare InternLM Model outputting raw hidden-states without any specific head on top.'
,
# noqa: E501
INTERNLM_START_DOCSTRING
,
)
class
InternLMPreTrainedModel
(
PreTrainedModel
):
config_class
=
InternLMConfig
base_model_prefix
=
'model'
supports_gradient_checkpointing
=
True
_no_split_modules
=
[
'InternLMDecoderLayer'
]
_keys_to_ignore_on_load_unexpected
=
[
r
'decoder\.version'
]
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
if
isinstance
(
module
,
nn
.
Linear
):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
std
)
if
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
nn
.
Embedding
):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
std
)
if
module
.
padding_idx
is
not
None
:
module
.
weight
.
data
[
module
.
padding_idx
].
zero_
()
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
InternLMModel
):
module
.
gradient_checkpointing
=
value
INTERNLM_INPUTS_DOCSTRING
=
r
""" # noqa: E501
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@
add_start_docstrings
(
'The bare InternLM Model outputting raw hidden-states without any specific head on top.'
,
# noqa: E501
INTERNLM_START_DOCSTRING
,
)
class
InternLMModel
(
InternLMPreTrainedModel
):
"""Transformer decoder consisting of *config.num_hidden_layers* layers.
Each layer is a [`InternLMDecoderLayer`]
Args:
config: InternLMConfig
"""
_auto_class
=
'AutoModel'
def
__init__
(
self
,
config
:
InternLMConfig
):
super
().
__init__
(
config
)
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
hidden_size
,
self
.
padding_idx
)
self
.
layers
=
nn
.
ModuleList
([
InternLMDecoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
InternLMRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
gradient_checkpointing
=
False
# Initialize weights and apply final processing
self
.
post_init
()
def
get_input_embeddings
(
self
):
return
self
.
embed_tokens
def
set_input_embeddings
(
self
,
value
):
self
.
embed_tokens
=
value
# Copied from transformers.models.bart.modeling_bart.BartDecoder.
# prepare_decoder_attention_mask
def
_prepare_decoder_attention_mask
(
self
,
attention_mask
,
input_shape
,
inputs_embeds
,
past_key_values_length
):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask
=
None
if
input_shape
[
-
1
]
>
1
:
combined_attention_mask
=
_make_causal_mask
(
input_shape
,
inputs_embeds
.
dtype
,
device
=
inputs_embeds
.
device
,
past_key_values_length
=
past_key_values_length
,
)
if
attention_mask
is
not
None
:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask
=
_expand_mask
(
attention_mask
,
inputs_embeds
.
dtype
,
tgt_len
=
input_shape
[
-
1
]).
to
(
inputs_embeds
.
device
)
combined_attention_mask
=
(
expanded_attn_mask
if
combined_attention_mask
is
None
else
expanded_attn_mask
+
combined_attention_mask
)
return
combined_attention_mask
@
add_start_docstrings_to_model_forward
(
INTERNLM_INPUTS_DOCSTRING
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPast
]:
output_attentions
=
output_attentions
or
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
or
self
.
config
.
output_hidden_states
)
use_cache
=
use_cache
or
self
.
config
.
use_cache
return_dict
=
(
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
)
# retrieve input_ids and inputs_embeds
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
'You cannot specify both decoder_input_ids '
'and decoder_inputs_embeds at the same time'
)
elif
input_ids
is
not
None
:
batch_size
,
seq_length
=
input_ids
.
shape
elif
inputs_embeds
is
not
None
:
batch_size
,
seq_length
,
_
=
inputs_embeds
.
shape
else
:
raise
ValueError
(
'You have to specify either decoder_input_ids '
'or decoder_inputs_embeds'
)
seq_length_with_past
=
seq_length
past_key_values_length
=
0
if
past_key_values
is
not
None
:
past_key_values_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
seq_length_with_past
=
(
seq_length_with_past
+
past_key_values_length
)
if
position_ids
is
None
:
device
=
(
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
)
position_ids
=
torch
.
arange
(
past_key_values_length
,
seq_length
+
past_key_values_length
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
view
(
-
1
,
seq_length
)
else
:
position_ids
=
position_ids
.
view
(
-
1
,
seq_length
).
long
()
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
# embed positions
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones
((
batch_size
,
seq_length_with_past
),
dtype
=
torch
.
bool
,
device
=
inputs_embeds
.
device
)
attention_mask
=
self
.
_prepare_decoder_attention_mask
(
attention_mask
,
(
batch_size
,
seq_length
),
inputs_embeds
,
past_key_values_length
)
hidden_states
=
inputs_embeds
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
'`use_cache=True` is incompatible with gradient '
'checkpointing. Setting `use_cache=False`...'
)
use_cache
=
False
# decoder layers
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attns
=
()
if
output_attentions
else
None
next_decoder_cache
=
()
if
use_cache
else
None
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,
)
past_key_value
=
past_key_values
[
idx
]
if
past_key_values
is
not
None
else
None
if
self
.
gradient_checkpointing
and
self
.
training
:
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
# None for past_key_value
return
module
(
*
inputs
,
output_attentions
,
None
)
return
custom_forward
layer_outputs
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
decoder_layer
),
hidden_states
,
attention_mask
,
position_ids
,
None
,
)
else
:
layer_outputs
=
decoder_layer
(
hidden_states
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
use_cache
=
use_cache
,
)
hidden_states
=
layer_outputs
[
0
]
if
use_cache
:
next_decoder_cache
+=
(
layer_outputs
[
2
if
output_attentions
else
1
],
)
if
output_attentions
:
all_self_attns
+=
(
layer_outputs
[
1
],
)
hidden_states
=
self
.
norm
(
hidden_states
)
# add hidden states from the last decoder layer
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,
)
next_cache
=
next_decoder_cache
if
use_cache
else
None
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
BaseModelOutputWithPast
(
last_hidden_state
=
hidden_states
,
past_key_values
=
next_cache
,
hidden_states
=
all_hidden_states
,
attentions
=
all_self_attns
,
)
class
InternLMForCausalLM
(
InternLMPreTrainedModel
):
"""This class extends the `InternLMPreTrainedModel` to enable causal
language modeling.
It wraps the basic InternLM model (`InternLMModel`) and includes a linear
layer as a language model head (`lm_head`). The purpose is to predict token
probabilities, given the previous tokens in the sequence.
"""
_auto_class
=
'AutoModelForCausalLM'
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
model
=
InternLMModel
(
config
)
self
.
lm_head
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
# Initialize weights and apply final processing
self
.
post_init
()
convert_to_qmodules
(
self
)
def
get_input_embeddings
(
self
):
"""Get the token embedding layer."""
return
self
.
model
.
embed_tokens
def
set_input_embeddings
(
self
,
value
):
"""Set the token embedding layer."""
self
.
model
.
embed_tokens
=
value
def
get_output_embeddings
(
self
):
"""Get the output embedding layer."""
return
self
.
lm_head
def
set_output_embeddings
(
self
,
new_embeddings
):
"""Set the output embedding layer."""
self
.
lm_head
=
new_embeddings
def
set_decoder
(
self
,
decoder
):
"""Set the decoder model."""
self
.
model
=
decoder
def
get_decoder
(
self
):
"""Get the decoder model."""
return
self
.
model
@
add_start_docstrings_to_model_forward
(
INTERNLM_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
CausalLMOutputWithPast
,
config_class
=
_CONFIG_FOR_DOC
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
CausalLMOutputWithPast
]:
r
""" # noqa: E501
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, InternLMForCausalLM
>>> model = InternLMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you consciours? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```"""
output_attentions
=
output_attentions
or
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict
=
(
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
)
# decoder outputs consists of
# (dec_features, layer_state, dec_hidden, dec_attn)
outputs
=
self
.
model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs_embeds
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
hidden_states
=
outputs
[
0
]
logits
=
self
.
lm_head
(
hidden_states
)
loss
=
None
if
labels
is
not
None
:
# Shift so that tokens < n predict n
shift_logits
=
logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
labels
[...,
1
:].
contiguous
()
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
()
shift_logits
=
shift_logits
.
view
(
-
1
,
self
.
config
.
vocab_size
)
shift_labels
=
shift_labels
.
view
(
-
1
)
# Enable model parallelism
shift_labels
=
shift_labels
.
to
(
shift_logits
.
device
)
loss
=
loss_fct
(
shift_logits
,
shift_labels
)
if
not
return_dict
:
output
=
(
logits
,
)
+
outputs
[
1
:]
return
(
loss
,
)
+
output
if
loss
is
not
None
else
output
return
CausalLMOutputWithPast
(
loss
=
loss
,
logits
=
logits
,
past_key_values
=
outputs
.
past_key_values
,
hidden_states
=
outputs
.
hidden_states
,
attentions
=
outputs
.
attentions
,
)
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past_key_values
=
None
,
attention_mask
=
None
,
inputs_embeds
=
None
,
**
kwargs
):
"""Prepare inputs for generating sequences using the model.
Args:
input_ids (torch.Tensor): Input token ids.
past_key_values (list[torch.Tensor], optional): List of past key
and value states.
attention_mask (torch.Tensor, optional): Mask indicating which
tokens should be attended to.
inputs_embeds (torch.FloatTensor, optional): Optionally,
the input embeddings instead of token ids.
Returns:
dict: Dictionary containing prepared inputs for model generation.
"""
if
past_key_values
:
input_ids
=
input_ids
[:,
-
1
:]
position_ids
=
kwargs
.
get
(
'position_ids'
,
None
)
if
attention_mask
is
not
None
and
position_ids
is
None
:
# create position_ids on the fly for batch generation
position_ids
=
attention_mask
.
long
().
cumsum
(
-
1
)
-
1
position_ids
.
masked_fill_
(
attention_mask
==
0
,
1
)
if
past_key_values
:
position_ids
=
position_ids
[:,
-
1
].
unsqueeze
(
-
1
)
# if `inputs_embeds` are passed,
# we only want to use them in the 1st generation step
if
inputs_embeds
is
not
None
and
past_key_values
is
None
:
model_inputs
=
{
'inputs_embeds'
:
inputs_embeds
}
else
:
model_inputs
=
{
'input_ids'
:
input_ids
}
model_inputs
.
update
({
'position_ids'
:
position_ids
,
'past_key_values'
:
past_key_values
,
'use_cache'
:
kwargs
.
get
(
'use_cache'
),
'attention_mask'
:
attention_mask
,
})
return
model_inputs
@
staticmethod
def
_reorder_cache
(
past_key_values
,
beam_idx
):
"""Reorder cached past key-values during generation using beam search.
This function reorders the cached past key-values according to the
given indices. It's useful in beam search where the order of hypotheses
can change from one time-step to another.
"""
reordered_past
=
()
for
layer_past
in
past_key_values
:
reordered_past
+=
(
tuple
(
past_state
.
index_select
(
0
,
beam_idx
)
for
past_state
in
layer_past
),
)
return
reordered_past
def
build_inputs
(
self
,
tokenizer
,
query
:
str
,
history
:
List
[
Tuple
[
str
,
str
]]
=
[]):
"""Builds the input for the model."""
prompt
=
''
for
record
in
history
:
prompt
+=
f
"""<|User|>:
{
record
[
0
]
}
<eoh>
\n
<|Bot|>:
{
record
[
1
]
}
<eoa>
\n
"""
# noqa: E501
prompt
+=
f
"""<|User|>:
{
query
}
<eoh>
\n
<|Bot|>:"""
return
tokenizer
([
prompt
],
return_tensors
=
'pt'
)
@
torch
.
no_grad
()
def
chat
(
self
,
tokenizer
,
query
:
str
,
history
:
List
[
Tuple
[
str
,
str
]]
=
[],
streamer
:
Optional
[
BaseStreamer
]
=
None
,
max_new_tokens
:
int
=
1024
,
do_sample
:
bool
=
True
,
temperature
:
float
=
0.8
,
top_p
:
float
=
0.8
,
**
kwargs
):
"""Provides a chatting functionality for the model."""
inputs
=
self
.
build_inputs
(
tokenizer
,
query
,
history
)
inputs
=
{
k
:
v
.
to
(
self
.
device
)
for
k
,
v
in
inputs
.
items
()
if
torch
.
is_tensor
(
v
)
}
outputs
=
self
.
generate
(
**
inputs
,
streamer
=
streamer
,
max_new_tokens
=
max_new_tokens
,
do_sample
=
do_sample
,
temperature
=
temperature
,
top_p
=
top_p
,
**
kwargs
)
outputs
=
outputs
[
0
].
cpu
().
tolist
()[
len
(
inputs
[
'input_ids'
][
0
]):]
response
=
tokenizer
.
decode
(
outputs
,
skip_special_tokens
=
True
)
response
=
response
.
split
(
'<eoa>'
)[
0
]
history
=
history
+
[(
query
,
response
)]
return
response
,
history
@
torch
.
no_grad
()
def
stream_chat
(
self
,
tokenizer
,
query
:
str
,
history
:
List
[
Tuple
[
str
,
str
]]
=
[],
max_new_tokens
:
int
=
1024
,
do_sample
:
bool
=
True
,
temperature
:
float
=
0.8
,
top_p
:
float
=
0.8
,
**
kwargs
):
"""Return a generator in format: (response, history) Eg.
('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')]) ('你好,有什么可以帮助您的吗?', [('你好',
'你好,有什么可以帮助您的吗?')])
"""
response_queue
=
queue
.
Queue
(
maxsize
=
20
)
class
ChatStreamer
(
BaseStreamer
):
def
__init__
(
self
,
tokenizer
)
->
None
:
super
().
__init__
()
self
.
tokenizer
=
tokenizer
self
.
queue
=
response_queue
self
.
query
=
query
self
.
history
=
history
self
.
response
=
''
self
.
received_inputs
=
False
self
.
queue
.
put
(
(
self
.
response
,
history
+
[(
self
.
query
,
self
.
response
)]))
def
put
(
self
,
value
):
if
len
(
value
.
shape
)
>
1
and
value
.
shape
[
0
]
>
1
:
raise
ValueError
(
'ChatStreamer only supports batch size 1'
)
elif
len
(
value
.
shape
)
>
1
:
value
=
value
[
0
]
if
not
self
.
received_inputs
:
# The first received value is input_ids, ignore here
self
.
received_inputs
=
True
return
token
=
self
.
tokenizer
.
decode
([
value
[
-
1
]],
skip_special_tokens
=
True
)
if
token
.
strip
()
!=
'<eoa>'
:
self
.
response
=
self
.
response
+
token
history
=
self
.
history
+
[(
self
.
query
,
self
.
response
)]
self
.
queue
.
put
((
self
.
response
,
history
))
def
end
(
self
):
self
.
queue
.
put
(
None
)
def
stream_producer
():
return
self
.
chat
(
tokenizer
=
tokenizer
,
query
=
query
,
streamer
=
ChatStreamer
(
tokenizer
=
tokenizer
),
history
=
history
,
max_new_tokens
=
max_new_tokens
,
do_sample
=
do_sample
,
temperature
=
temperature
,
top_p
=
top_p
,
**
kwargs
)
def
consumer
():
producer
=
threading
.
Thread
(
target
=
stream_producer
)
producer
.
start
()
while
True
:
res
=
response_queue
.
get
()
if
res
is
None
:
return
yield
res
return
consumer
()
@
add_start_docstrings
(
""" # noqa: E501
The InternLM Model transformer with a sequence classification head on top (linear layer).
[`InternLMForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
"""
,
INTERNLM_START_DOCSTRING
,
)
class
InternLMForSequenceClassification
(
InternLMPreTrainedModel
):
_keys_to_ignore_on_load_missing
=
[
r
'lm_head.weight'
]
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
num_labels
=
config
.
num_labels
self
.
model
=
InternLMModel
(
config
)
self
.
score
=
nn
.
Linear
(
config
.
hidden_size
,
self
.
num_labels
,
bias
=
False
)
# Initialize weights and apply final processing
self
.
post_init
()
def
get_input_embeddings
(
self
):
return
self
.
model
.
embed_tokens
def
set_input_embeddings
(
self
,
value
):
self
.
model
.
embed_tokens
=
value
@
add_start_docstrings_to_model_forward
(
INTERNLM_INPUTS_DOCSTRING
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
SequenceClassifierOutputWithPast
]:
r
""" # noqa: E501
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict
=
(
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
)
transformer_outputs
=
self
.
model
(
input_ids
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs_embeds
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
hidden_states
=
transformer_outputs
[
0
]
logits
=
self
.
score
(
hidden_states
)
if
input_ids
is
not
None
:
batch_size
=
input_ids
.
shape
[
0
]
else
:
batch_size
=
inputs_embeds
.
shape
[
0
]
if
self
.
config
.
pad_token_id
is
None
and
batch_size
!=
1
:
raise
ValueError
(
'Cannot handle batch sizes > 1 if no padding token is defined.'
)
if
self
.
config
.
pad_token_id
is
None
:
sequence_lengths
=
-
1
else
:
if
input_ids
is
not
None
:
sequence_lengths
=
(
torch
.
ne
(
input_ids
,
self
.
config
.
pad_token_id
).
sum
(
-
1
)
-
1
).
to
(
logits
.
device
)
else
:
sequence_lengths
=
-
1
pooled_logits
=
logits
[
torch
.
arange
(
batch_size
,
device
=
logits
.
device
),
sequence_lengths
]
loss
=
None
if
labels
is
not
None
:
labels
=
labels
.
to
(
logits
.
device
)
if
self
.
config
.
problem_type
is
None
:
if
self
.
num_labels
==
1
:
self
.
config
.
problem_type
=
'regression'
elif
self
.
num_labels
>
1
and
(
labels
.
dtype
==
torch
.
long
or
labels
.
dtype
==
torch
.
int
):
self
.
config
.
problem_type
=
'single_label_classification'
else
:
self
.
config
.
problem_type
=
'multi_label_classification'
if
self
.
config
.
problem_type
==
'regression'
:
loss_fct
=
MSELoss
()
if
self
.
num_labels
==
1
:
loss
=
loss_fct
(
pooled_logits
.
squeeze
(),
labels
.
squeeze
())
else
:
loss
=
loss_fct
(
pooled_logits
,
labels
)
elif
self
.
config
.
problem_type
==
'single_label_classification'
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
pooled_logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
elif
self
.
config
.
problem_type
==
'multi_label_classification'
:
loss_fct
=
BCEWithLogitsLoss
()
loss
=
loss_fct
(
pooled_logits
,
labels
)
if
not
return_dict
:
output
=
(
pooled_logits
,
)
+
transformer_outputs
[
1
:]
return
((
loss
,
)
+
output
)
if
loss
is
not
None
else
output
return
SequenceClassifierOutputWithPast
(
loss
=
loss
,
logits
=
pooled_logits
,
past_key_values
=
transformer_outputs
.
past_key_values
,
hidden_states
=
transformer_outputs
.
hidden_states
,
attentions
=
transformer_outputs
.
attentions
,
)
lmdeploy/pytorch/modeling/modeling_internlm2.py
0 → 100644
View file @
fe851fbc
# # Copyright (c) InternLM. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch InternLM2 model."""
import
math
import
queue
import
threading
import
warnings
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.utils.checkpoint
from
einops
import
rearrange
from
torch
import
nn
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
from
transformers.activations
import
ACT2FN
from
transformers.modeling_outputs
import
(
BaseModelOutputWithPast
,
CausalLMOutputWithPast
,
SequenceClassifierOutputWithPast
)
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.utils
import
(
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
replace_return_docstrings
)
from
lmdeploy.pytorch.modeling.convert_to_qmodules
import
convert_to_qmodules
from
lmdeploy.utils
import
get_logger
try
:
from
transformers.generation.streamers
import
BaseStreamer
except
:
# noqa # pylint: disable=bare-except
BaseStreamer
=
None
from
.configuration_internlm2
import
InternLM2Config
logger
=
get_logger
(
'lmdeploy'
)
_CONFIG_FOR_DOC
=
'InternLM2Config'
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def
_make_causal_mask
(
input_ids_shape
:
torch
.
Size
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
past_key_values_length
:
int
=
0
):
"""Make causal mask used for bi-directional self-attention."""
bsz
,
tgt_len
=
input_ids_shape
mask
=
torch
.
full
((
tgt_len
,
tgt_len
),
torch
.
tensor
(
torch
.
finfo
(
dtype
).
min
,
device
=
device
),
device
=
device
)
mask_cond
=
torch
.
arange
(
mask
.
size
(
-
1
),
device
=
device
)
mask
.
masked_fill_
(
mask_cond
<
(
mask_cond
+
1
).
view
(
mask
.
size
(
-
1
),
1
),
0
)
mask
=
mask
.
to
(
dtype
)
if
past_key_values_length
>
0
:
mask
=
torch
.
cat
([
torch
.
zeros
(
tgt_len
,
past_key_values_length
,
dtype
=
dtype
,
device
=
device
),
mask
],
dim
=-
1
)
return
mask
[
None
,
None
,
:,
:].
expand
(
bsz
,
1
,
tgt_len
,
tgt_len
+
past_key_values_length
)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def
_expand_mask
(
mask
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
tgt_len
:
Optional
[
int
]
=
None
):
"""Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len,
src_seq_len]`."""
bsz
,
src_len
=
mask
.
size
()
tgt_len
=
tgt_len
if
tgt_len
is
not
None
else
src_len
expanded_mask
=
mask
[:,
None
,
None
,
:].
expand
(
bsz
,
1
,
tgt_len
,
src_len
).
to
(
dtype
)
inverted_mask
=
1.0
-
expanded_mask
return
inverted_mask
.
masked_fill
(
inverted_mask
.
to
(
torch
.
bool
),
torch
.
finfo
(
dtype
).
min
)
class
InternLM2RMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
"""InternLM2RMSNorm is equivalent to T5LayerNorm."""
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
hidden_states
):
input_dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
return
self
.
weight
*
hidden_states
.
to
(
input_dtype
)
class
InternLM2RotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
):
super
().
__init__
()
self
.
dim
=
dim
self
.
max_position_embeddings
=
max_position_embeddings
self
.
base
=
base
inv_freq
=
1.0
/
(
self
.
base
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
).
float
().
to
(
device
)
/
self
.
dim
))
self
.
register_buffer
(
'inv_freq'
,
inv_freq
,
persistent
=
False
)
# Build here to make `torch.jit.trace` work.
self
.
_set_cos_sin_cache
(
seq_len
=
max_position_embeddings
,
device
=
self
.
inv_freq
.
device
,
dtype
=
torch
.
get_default_dtype
())
def
_set_cos_sin_cache
(
self
,
seq_len
,
device
,
dtype
):
self
.
max_seq_len_cached
=
seq_len
t
=
torch
.
arange
(
self
.
max_seq_len_cached
,
device
=
device
,
dtype
=
self
.
inv_freq
.
dtype
)
freqs
=
torch
.
einsum
(
'i,j->ij'
,
t
,
self
.
inv_freq
)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
self
.
register_buffer
(
'cos_cached'
,
emb
.
cos
().
to
(
dtype
),
persistent
=
False
)
self
.
register_buffer
(
'sin_cached'
,
emb
.
sin
().
to
(
dtype
),
persistent
=
False
)
def
forward
(
self
,
x
,
seq_len
=
None
):
# x: [bs, num_attention_heads, seq_len, head_size]
if
seq_len
>
self
.
max_seq_len_cached
:
self
.
_set_cos_sin_cache
(
seq_len
=
seq_len
,
device
=
x
.
device
,
dtype
=
torch
.
float32
)
return
(
self
.
cos_cached
[:
seq_len
].
to
(
dtype
=
x
.
dtype
),
self
.
sin_cached
[:
seq_len
].
to
(
dtype
=
x
.
dtype
),
)
class
InternLM2LinearScalingRotaryEmbedding
(
InternLM2RotaryEmbedding
):
"""InternLM2RotaryEmbedding extended with linear scaling.
Credits to the Reddit user /u/kaiokendev
"""
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
,
scaling_factor
=
1.0
):
self
.
scaling_factor
=
scaling_factor
super
().
__init__
(
dim
,
max_position_embeddings
,
base
,
device
)
def
_set_cos_sin_cache
(
self
,
seq_len
,
device
,
dtype
):
self
.
max_seq_len_cached
=
seq_len
t
=
torch
.
arange
(
self
.
max_seq_len_cached
,
device
=
device
,
dtype
=
self
.
inv_freq
.
dtype
)
t
=
t
/
self
.
scaling_factor
freqs
=
torch
.
einsum
(
'i,j->ij'
,
t
,
self
.
inv_freq
)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
self
.
register_buffer
(
'cos_cached'
,
emb
.
cos
().
to
(
dtype
),
persistent
=
False
)
self
.
register_buffer
(
'sin_cached'
,
emb
.
sin
().
to
(
dtype
),
persistent
=
False
)
class
InternLM2DynamicNTKScalingRotaryEmbedding
(
InternLM2RotaryEmbedding
):
"""InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
Credits to the Reddit users /u/bloc97 and /u/emozilla.
"""
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
,
scaling_factor
=
1.0
):
self
.
scaling_factor
=
scaling_factor
super
().
__init__
(
dim
,
max_position_embeddings
,
base
,
device
)
def
_set_cos_sin_cache
(
self
,
seq_len
,
device
,
dtype
):
self
.
max_seq_len_cached
=
seq_len
if
seq_len
>
self
.
max_position_embeddings
:
base
=
self
.
base
*
((
self
.
scaling_factor
*
seq_len
/
self
.
max_position_embeddings
)
-
(
self
.
scaling_factor
-
1
))
**
(
self
.
dim
/
(
self
.
dim
-
2
))
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
).
float
().
to
(
device
)
/
self
.
dim
))
self
.
register_buffer
(
'inv_freq'
,
inv_freq
,
persistent
=
False
)
t
=
torch
.
arange
(
self
.
max_seq_len_cached
,
device
=
device
,
dtype
=
self
.
inv_freq
.
dtype
)
freqs
=
torch
.
einsum
(
'i,j->ij'
,
t
,
self
.
inv_freq
)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
self
.
register_buffer
(
'cos_cached'
,
emb
.
cos
().
to
(
dtype
),
persistent
=
False
)
self
.
register_buffer
(
'sin_cached'
,
emb
.
sin
().
to
(
dtype
),
persistent
=
False
)
def
rotate_half
(
x
):
"""Rotates half the hidden dims of the input."""
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
position_ids
):
cos
=
cos
[
position_ids
].
unsqueeze
(
1
)
sin
=
sin
[
position_ids
].
unsqueeze
(
1
)
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q_embed
,
k_embed
class
InternLM2MLP
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
self
.
intermediate_size
=
config
.
intermediate_size
self
.
w1
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
intermediate_size
,
bias
=
False
)
self
.
w3
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
intermediate_size
,
bias
=
False
)
self
.
w2
=
nn
.
Linear
(
self
.
intermediate_size
,
self
.
hidden_size
,
bias
=
False
)
self
.
act_fn
=
ACT2FN
[
config
.
hidden_act
]
def
forward
(
self
,
x
):
down_proj
=
self
.
w2
(
self
.
act_fn
(
self
.
w1
(
x
))
*
self
.
w3
(
x
))
return
down_proj
def
repeat_kv
(
hidden_states
:
torch
.
Tensor
,
n_rep
:
int
)
->
torch
.
Tensor
:
"""This is the equivalent of torch.repeat_interleave(x, dim=1,
repeats=n_rep).
The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
(batch, num_attention_heads, seqlen, head_dim)
"""
batch
,
num_key_value_heads
,
slen
,
head_dim
=
hidden_states
.
shape
if
n_rep
==
1
:
return
hidden_states
hidden_states
=
hidden_states
[:,
:,
None
,
:,
:].
expand
(
batch
,
num_key_value_heads
,
n_rep
,
slen
,
head_dim
)
return
hidden_states
.
reshape
(
batch
,
num_key_value_heads
*
n_rep
,
slen
,
head_dim
)
class
InternLM2Attention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper."""
def
__init__
(
self
,
config
:
InternLM2Config
):
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_dim
=
self
.
hidden_size
//
self
.
num_heads
self
.
num_key_value_heads
=
config
.
num_key_value_heads
self
.
num_key_value_groups
=
self
.
num_heads
//
self
.
num_key_value_heads
self
.
max_position_embeddings
=
config
.
max_position_embeddings
self
.
is_causal
=
True
if
(
self
.
head_dim
*
self
.
num_heads
)
!=
self
.
hidden_size
:
raise
ValueError
(
f
'hidden_size must be divisible by num_heads (got `hidden_size`:
{
self
.
hidden_size
}
'
f
' and `num_heads`:
{
self
.
num_heads
}
).'
)
self
.
wqkv
=
nn
.
Linear
(
self
.
hidden_size
,
(
self
.
num_heads
+
2
*
self
.
num_key_value_heads
)
*
self
.
head_dim
,
bias
=
config
.
bias
,
)
self
.
wo
=
nn
.
Linear
(
self
.
num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
bias
=
config
.
bias
)
self
.
_init_rope
()
def
_init_rope
(
self
):
if
self
.
config
.
rope_scaling
is
None
:
self
.
rotary_emb
=
InternLM2RotaryEmbedding
(
self
.
head_dim
,
max_position_embeddings
=
self
.
max_position_embeddings
,
base
=
self
.
config
.
rope_theta
,
)
else
:
scaling_type
=
self
.
config
.
rope_scaling
[
'type'
]
scaling_factor
=
self
.
config
.
rope_scaling
[
'factor'
]
if
scaling_type
==
'dynamic'
:
self
.
rotary_emb
=
InternLM2DynamicNTKScalingRotaryEmbedding
(
self
.
head_dim
,
max_position_embeddings
=
self
.
max_position_embeddings
,
base
=
self
.
config
.
rope_theta
,
scaling_factor
=
scaling_factor
)
elif
scaling_type
==
'linear'
:
self
.
rotary_emb
=
InternLM2LinearScalingRotaryEmbedding
(
self
.
head_dim
,
max_position_embeddings
=
self
.
max_position_embeddings
,
base
=
self
.
config
.
rope_theta
,
scaling_factor
=
scaling_factor
)
else
:
raise
ValueError
(
"Currently we only support rotary embedding's type being 'dynamic' or 'linear'."
)
return
self
.
rotary_emb
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
).
contiguous
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
if
'padding_mask'
in
kwargs
:
warnings
.
warn
(
'Passing `padding_mask` is deprecated and will be removed in v4.37. '
'Please make sure use `attention_mask` instead.`'
)
bsz
,
q_len
,
_
=
hidden_states
.
size
()
qkv_states
=
self
.
wqkv
(
hidden_states
)
qkv_states
=
rearrange
(
qkv_states
,
'b q (h gs d) -> b q h gs d'
,
gs
=
2
+
self
.
num_key_value_groups
,
d
=
self
.
head_dim
,
)
query_states
=
qkv_states
[...,
:
self
.
num_key_value_groups
,
:]
query_states
=
rearrange
(
query_states
,
'b q h gs d -> b q (h gs) d'
)
key_states
=
qkv_states
[...,
-
2
,
:]
value_states
=
qkv_states
[...,
-
1
,
:]
query_states
=
query_states
.
transpose
(
1
,
2
)
key_states
=
key_states
.
transpose
(
1
,
2
)
value_states
=
value_states
.
transpose
(
1
,
2
)
kv_seq_len
=
key_states
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
if
past_key_value
is
not
None
:
# reuse k, v, self_attention
key_states
=
torch
.
cat
([
past_key_value
[
0
],
key_states
],
dim
=
2
)
value_states
=
torch
.
cat
([
past_key_value
[
1
],
value_states
],
dim
=
2
)
past_key_value
=
(
key_states
,
value_states
)
if
use_cache
else
None
key_states
=
repeat_kv
(
key_states
,
self
.
num_key_value_groups
)
value_states
=
repeat_kv
(
value_states
,
self
.
num_key_value_groups
)
attn_weights
=
torch
.
matmul
(
query_states
,
key_states
.
transpose
(
2
,
3
))
/
math
.
sqrt
(
self
.
head_dim
)
if
attn_weights
.
size
()
!=
(
bsz
,
self
.
num_heads
,
q_len
,
kv_seq_len
):
raise
ValueError
(
f
'Attention weights should be of size
{
(
bsz
,
self
.
num_heads
,
q_len
,
kv_seq_len
)
}
, but is'
f
'
{
attn_weights
.
size
()
}
'
)
if
attention_mask
is
not
None
:
if
attention_mask
.
size
()
!=
(
bsz
,
1
,
q_len
,
kv_seq_len
):
raise
ValueError
(
f
'Attention mask should be of size
{
(
bsz
,
1
,
q_len
,
kv_seq_len
)
}
, but is
{
attention_mask
.
size
()
}
'
)
attn_weights
=
attn_weights
+
attention_mask
# upcast attention to fp32
attn_weights
=
nn
.
functional
.
softmax
(
attn_weights
,
dim
=-
1
,
dtype
=
torch
.
float32
).
to
(
query_states
.
dtype
)
attn_output
=
torch
.
matmul
(
attn_weights
,
value_states
)
if
attn_output
.
size
()
!=
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
head_dim
):
raise
ValueError
(
f
'`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
head_dim
)
}
, but is'
f
'
{
attn_output
.
size
()
}
'
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
hidden_size
)
attn_output
=
self
.
wo
(
attn_output
)
if
not
output_attentions
:
attn_weights
=
None
return
attn_output
,
attn_weights
,
past_key_value
class
InternLM2FlashAttention2
(
InternLM2Attention
):
"""InternLM2 flash attention module.
This module inherits from `InternLM2Attention` as the weights of the module
stays untouched. The only required change would be on the forward pass
where it needs to correctly call the public API of flash attention and deal
with padding tokens in case the input contains any of them.
"""
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
# InternLM2FlashAttention2 attention does not support output_attentions
if
'padding_mask'
in
kwargs
:
warnings
.
warn
(
'Passing `padding_mask` is deprecated and will be removed in v4.37. '
'Please make sure use `attention_mask` instead.`'
)
# overwrite attention_mask with padding_mask
attention_mask
=
kwargs
.
pop
(
'padding_mask'
)
output_attentions
=
False
bsz
,
q_len
,
_
=
hidden_states
.
size
()
qkv_states
=
self
.
wqkv
(
hidden_states
)
qkv_states
=
rearrange
(
qkv_states
,
'b q (h gs d) -> b q h gs d'
,
gs
=
self
.
num_heads
+
2
*
self
.
num_key_value_heads
,
d
=
self
.
head_dim
,
q
=
q_len
,
)
query_states
=
qkv_states
[...,
:
self
.
num_key_value_groups
,
:]
query_states
=
rearrange
(
query_states
,
'b q h gs d -> b q (h gs) d'
)
key_states
=
qkv_states
[...,
-
2
,
:]
value_states
=
qkv_states
[...,
-
1
,
:]
kv_seq_len
=
key_states
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
if
past_key_value
is
not
None
:
# reuse k, v, self_attention
key_states
=
torch
.
cat
([
past_key_value
[
0
],
key_states
],
dim
=
2
)
value_states
=
torch
.
cat
([
past_key_value
[
1
],
value_states
],
dim
=
2
)
past_key_value
=
(
key_states
,
value_states
)
if
use_cache
else
None
query_states
=
query_states
.
transpose
(
1
,
2
)
key_states
=
key_states
.
transpose
(
1
,
2
)
value_states
=
value_states
.
transpose
(
1
,
2
)
dropout_rate
=
0.0
if
not
self
.
training
else
self
.
attention_dropout
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (InternLM2RMSNorm handles it correctly)
input_dtype
=
query_states
.
dtype
if
input_dtype
==
torch
.
float32
:
# Handle the case where the model is quantized
if
hasattr
(
self
.
config
,
'_pre_quantization_dtype'
):
target_dtype
=
self
.
config
.
_pre_quantization_dtype
else
:
target_dtype
=
self
.
q_proj
.
weight
.
dtype
logger
.
warning_once
(
f
'The input hidden states seems to be silently casted in float32, this might be related to'
f
' the fact you have upcasted embedding or layer norm layers in float32. We will cast back '
f
'the input in
{
target_dtype
}
.'
)
query_states
=
query_states
.
to
(
target_dtype
)
key_states
=
key_states
.
to
(
target_dtype
)
value_states
=
value_states
.
to
(
target_dtype
)
attn_output
=
self
.
_flash_attention_forward
(
query_states
,
key_states
,
value_states
,
attention_mask
,
q_len
,
dropout
=
dropout_rate
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
hidden_size
).
contiguous
()
attn_output
=
self
.
wo
(
attn_output
)
if
not
output_attentions
:
attn_weights
=
None
return
attn_output
,
attn_weights
,
past_key_value
class
InternLM2DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
InternLM2Config
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
attention
=
(
InternLM2Attention
(
config
=
config
)
if
not
getattr
(
config
,
'_flash_attn_2_enabled'
,
False
)
else
InternLM2FlashAttention2
(
config
=
config
))
self
.
feed_forward
=
InternLM2MLP
(
config
)
self
.
attention_norm
=
InternLM2RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
ffn_norm
=
InternLM2RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
Optional
[
bool
]
=
False
,
use_cache
:
Optional
[
bool
]
=
False
,
**
kwargs
,
)
->
Tuple
[
torch
.
FloatTensor
,
Optional
[
Tuple
[
torch
.
FloatTensor
,
torch
.
FloatTensor
]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
if
'padding_mask'
in
kwargs
:
warnings
.
warn
(
'Passing `padding_mask` is deprecated and will be removed in v4.37. '
'Please make sure use `attention_mask` instead.`'
)
residual
=
hidden_states
hidden_states
=
self
.
attention_norm
(
hidden_states
)
# Self Attention
hidden_states
,
self_attn_weights
,
present_key_value
=
self
.
attention
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
use_cache
=
use_cache
,
**
kwargs
,
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
ffn_norm
(
hidden_states
)
hidden_states
=
self
.
feed_forward
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
outputs
=
(
hidden_states
,
)
if
output_attentions
:
outputs
+=
(
self_attn_weights
,
)
if
use_cache
:
outputs
+=
(
present_key_value
,
)
return
outputs
InternLM2_START_DOCSTRING
=
r
"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`InternLM2Config`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@
add_start_docstrings
(
'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.'
,
InternLM2_START_DOCSTRING
,
)
class
InternLM2PreTrainedModel
(
PreTrainedModel
):
config_class
=
InternLM2Config
base_model_prefix
=
'model'
supports_gradient_checkpointing
=
True
_no_split_modules
=
[
'InternLM2DecoderLayer'
]
_skip_keys_device_placement
=
'past_key_values'
_supports_flash_attn_2
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
if
isinstance
(
module
,
nn
.
Linear
):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
std
)
if
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
nn
.
Embedding
):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
std
)
if
module
.
padding_idx
is
not
None
:
module
.
weight
.
data
[
module
.
padding_idx
].
zero_
()
InternLM2_INPUTS_DOCSTRING
=
r
"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or
when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@
add_start_docstrings
(
'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.'
,
InternLM2_START_DOCSTRING
,
)
class
InternLM2Model
(
InternLM2PreTrainedModel
):
"""Transformer decoder consisting of *config.num_hidden_layers* layers.
Each layer is a [`InternLM2DecoderLayer`]
Args:
config: InternLM2Config
"""
_auto_class
=
'AutoModel'
def
__init__
(
self
,
config
:
InternLM2Config
):
super
().
__init__
(
config
)
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
tok_embeddings
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
hidden_size
,
self
.
padding_idx
)
self
.
layers
=
nn
.
ModuleList
([
InternLM2DecoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
InternLM2RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
gradient_checkpointing
=
False
# Initialize weights and apply final processing
self
.
post_init
()
def
get_input_embeddings
(
self
):
return
self
.
tok_embeddings
def
set_input_embeddings
(
self
,
value
):
self
.
tok_embeddings
=
value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def
_prepare_decoder_attention_mask
(
self
,
attention_mask
,
input_shape
,
inputs_embeds
,
past_key_values_length
):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask
=
None
if
input_shape
[
-
1
]
>
1
:
combined_attention_mask
=
_make_causal_mask
(
input_shape
,
inputs_embeds
.
dtype
,
device
=
inputs_embeds
.
device
,
past_key_values_length
=
past_key_values_length
,
)
if
attention_mask
is
not
None
:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask
=
_expand_mask
(
attention_mask
,
inputs_embeds
.
dtype
,
tgt_len
=
input_shape
[
-
1
]).
to
(
inputs_embeds
.
device
)
combined_attention_mask
=
(
expanded_attn_mask
if
combined_attention_mask
is
None
else
expanded_attn_mask
+
combined_attention_mask
)
return
combined_attention_mask
@
add_start_docstrings_to_model_forward
(
InternLM2_INPUTS_DOCSTRING
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPast
]:
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
use_cache
=
use_cache
if
use_cache
is
not
None
else
self
.
config
.
use_cache
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
# retrieve input_ids and inputs_embeds
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
'You cannot specify both input_ids and inputs_embeds at the same time'
)
elif
input_ids
is
not
None
:
batch_size
,
seq_length
=
input_ids
.
shape
[:
2
]
elif
inputs_embeds
is
not
None
:
batch_size
,
seq_length
=
inputs_embeds
.
shape
[:
2
]
else
:
raise
ValueError
(
'You have to specify either input_ids or inputs_embeds'
)
seq_length_with_past
=
seq_length
past_key_values_length
=
0
if
past_key_values
is
not
None
:
past_key_values_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
seq_length_with_past
=
seq_length_with_past
+
past_key_values_length
if
position_ids
is
None
:
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
position_ids
=
torch
.
arange
(
past_key_values_length
,
seq_length
+
past_key_values_length
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
)
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
tok_embeddings
(
input_ids
)
# embed positions
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones
((
batch_size
,
seq_length_with_past
),
dtype
=
torch
.
bool
,
device
=
inputs_embeds
.
device
)
attention_mask
=
self
.
_prepare_decoder_attention_mask
(
attention_mask
,
(
batch_size
,
seq_length
),
inputs_embeds
,
past_key_values_length
)
# embed positions
hidden_states
=
inputs_embeds
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
'`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
)
use_cache
=
False
# decoder layers
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attns
=
()
if
output_attentions
else
None
next_decoder_cache
=
()
if
use_cache
else
None
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,
)
past_key_value
=
past_key_values
[
idx
]
if
past_key_values
is
not
None
else
None
if
self
.
gradient_checkpointing
and
self
.
training
:
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
# None for past_key_value
return
module
(
*
inputs
,
output_attentions
,
None
)
return
custom_forward
layer_outputs
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
decoder_layer
),
hidden_states
,
attention_mask
,
position_ids
,
None
,
)
else
:
layer_outputs
=
decoder_layer
(
hidden_states
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
use_cache
=
use_cache
,
)
hidden_states
=
layer_outputs
[
0
]
if
use_cache
:
next_decoder_cache
+=
(
layer_outputs
[
2
if
output_attentions
else
1
],
)
if
output_attentions
:
all_self_attns
+=
(
layer_outputs
[
1
],
)
hidden_states
=
self
.
norm
(
hidden_states
)
# add hidden states from the last decoder layer
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,
)
next_cache
=
next_decoder_cache
if
use_cache
else
None
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
BaseModelOutputWithPast
(
last_hidden_state
=
hidden_states
,
past_key_values
=
next_cache
,
hidden_states
=
all_hidden_states
,
attentions
=
all_self_attns
,
)
class
InternLM2ForCausalLM
(
InternLM2PreTrainedModel
):
_auto_class
=
'AutoModelForCausalLM'
_tied_weights_keys
=
[
'output.weight'
]
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
model
=
InternLM2Model
(
config
)
self
.
vocab_size
=
config
.
vocab_size
self
.
output
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
# Initialize weights and apply final processing
self
.
post_init
()
convert_to_qmodules
(
self
)
def
get_input_embeddings
(
self
):
return
self
.
model
.
tok_embeddings
def
set_input_embeddings
(
self
,
value
):
self
.
model
.
tok_embeddings
=
value
def
get_output_embeddings
(
self
):
return
self
.
output
def
set_output_embeddings
(
self
,
new_embeddings
):
self
.
output
=
new_embeddings
def
set_decoder
(
self
,
decoder
):
self
.
model
=
decoder
def
get_decoder
(
self
):
return
self
.
model
@
add_start_docstrings_to_model_forward
(
InternLM2_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
CausalLMOutputWithPast
,
config_class
=
_CONFIG_FOR_DOC
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
CausalLMOutputWithPast
]:
r
"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, InternLM2ForCausalLM
>>> model = InternLM2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs
=
self
.
model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs_embeds
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
hidden_states
=
outputs
[
0
]
logits
=
self
.
output
(
hidden_states
)
logits
=
logits
.
float
()
loss
=
None
if
labels
is
not
None
:
# Shift so that tokens < n predict n
shift_logits
=
logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
labels
[...,
1
:].
contiguous
()
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
()
shift_logits
=
shift_logits
.
view
(
-
1
,
self
.
config
.
vocab_size
)
shift_labels
=
shift_labels
.
view
(
-
1
)
# Enable model parallelism
shift_labels
=
shift_labels
.
to
(
shift_logits
.
device
)
loss
=
loss_fct
(
shift_logits
,
shift_labels
)
if
not
return_dict
:
output
=
(
logits
,
)
+
outputs
[
1
:]
return
(
loss
,
)
+
output
if
loss
is
not
None
else
output
return
CausalLMOutputWithPast
(
loss
=
loss
,
logits
=
logits
,
past_key_values
=
outputs
.
past_key_values
,
hidden_states
=
outputs
.
hidden_states
,
attentions
=
outputs
.
attentions
,
)
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past_key_values
=
None
,
attention_mask
=
None
,
inputs_embeds
=
None
,
**
kwargs
):
if
past_key_values
is
not
None
:
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
# Some generation methods already pass only the last input ID
if
input_ids
.
shape
[
1
]
>
past_length
:
remove_prefix_length
=
past_length
else
:
# Default to old behavior: keep only final ID
remove_prefix_length
=
input_ids
.
shape
[
1
]
-
1
input_ids
=
input_ids
[:,
remove_prefix_length
:]
position_ids
=
kwargs
.
get
(
'position_ids'
,
None
)
if
attention_mask
is
not
None
and
position_ids
is
None
:
# create position_ids on the fly for batch generation
position_ids
=
attention_mask
.
long
().
cumsum
(
-
1
)
-
1
position_ids
.
masked_fill_
(
attention_mask
==
0
,
1
)
if
past_key_values
:
position_ids
=
position_ids
[:,
-
input_ids
.
shape
[
1
]:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if
inputs_embeds
is
not
None
and
past_key_values
is
None
:
model_inputs
=
{
'inputs_embeds'
:
inputs_embeds
}
else
:
model_inputs
=
{
'input_ids'
:
input_ids
}
model_inputs
.
update
({
'position_ids'
:
position_ids
,
'past_key_values'
:
past_key_values
,
'use_cache'
:
kwargs
.
get
(
'use_cache'
),
'attention_mask'
:
attention_mask
,
})
return
model_inputs
@
staticmethod
def
_reorder_cache
(
past_key_values
,
beam_idx
):
reordered_past
=
()
for
layer_past
in
past_key_values
:
reordered_past
+=
(
tuple
(
past_state
.
index_select
(
0
,
beam_idx
.
to
(
past_state
.
device
))
for
past_state
in
layer_past
),
)
return
reordered_past
def
build_inputs
(
self
,
tokenizer
,
query
:
str
,
history
:
List
[
Tuple
[
str
,
str
]]
=
[]):
prompt
=
''
for
record
in
history
:
prompt
+=
f
"""<|User|>:
{
record
[
0
]
}
<eoh>
\n
<|Bot|>:
{
record
[
1
]
}
<eoa>
\n
"""
prompt
+=
f
"""<|User|>:
{
query
}
<eoh>
\n
<|Bot|>:"""
return
tokenizer
([
prompt
],
return_tensors
=
'pt'
)
@
torch
.
no_grad
()
def
chat
(
self
,
tokenizer
,
query
:
str
,
history
:
List
[
Tuple
[
str
,
str
]]
=
[],
streamer
:
Optional
[
BaseStreamer
]
=
None
,
max_new_tokens
:
int
=
1024
,
do_sample
:
bool
=
True
,
temperature
:
float
=
0.8
,
top_p
:
float
=
0.8
,
**
kwargs
,
):
inputs
=
self
.
build_inputs
(
tokenizer
,
query
,
history
)
inputs
=
{
k
:
v
.
to
(
self
.
device
)
for
k
,
v
in
inputs
.
items
()
if
torch
.
is_tensor
(
v
)
}
outputs
=
self
.
generate
(
**
inputs
,
streamer
=
streamer
,
max_new_tokens
=
max_new_tokens
,
do_sample
=
do_sample
,
temperature
=
temperature
,
top_p
=
top_p
,
**
kwargs
,
)
outputs
=
outputs
[
0
].
cpu
().
tolist
()[
len
(
inputs
[
'input_ids'
][
0
]):]
response
=
tokenizer
.
decode
(
outputs
,
skip_special_tokens
=
True
)
response
=
response
.
split
(
'<eoa>'
)[
0
]
history
=
history
+
[(
query
,
response
)]
return
response
,
history
@
torch
.
no_grad
()
def
stream_chat
(
self
,
tokenizer
,
query
:
str
,
history
:
List
[
Tuple
[
str
,
str
]]
=
[],
max_new_tokens
:
int
=
1024
,
do_sample
:
bool
=
True
,
temperature
:
float
=
0.8
,
top_p
:
float
=
0.8
,
**
kwargs
,
):
"""Return a generator in format: (response, history) Eg.
('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')]) ('你好,有什么可以帮助您的吗?', [('你好',
'你好,有什么可以帮助您的吗?')])
"""
if
BaseStreamer
is
None
:
raise
ModuleNotFoundError
(
'The version of `transformers` is too low. Please make sure '
'that you have installed `transformers>=4.28.0`.'
)
response_queue
=
queue
.
Queue
(
maxsize
=
20
)
class
ChatStreamer
(
BaseStreamer
):
def
__init__
(
self
,
tokenizer
)
->
None
:
super
().
__init__
()
self
.
tokenizer
=
tokenizer
self
.
queue
=
response_queue
self
.
query
=
query
self
.
history
=
history
self
.
response
=
''
self
.
received_inputs
=
False
self
.
queue
.
put
(
(
self
.
response
,
history
+
[(
self
.
query
,
self
.
response
)]))
def
put
(
self
,
value
):
if
len
(
value
.
shape
)
>
1
and
value
.
shape
[
0
]
>
1
:
raise
ValueError
(
'ChatStreamer only supports batch size 1'
)
elif
len
(
value
.
shape
)
>
1
:
value
=
value
[
0
]
if
not
self
.
received_inputs
:
# The first received value is input_ids, ignore here
self
.
received_inputs
=
True
return
token
=
self
.
tokenizer
.
decode
([
value
[
-
1
]],
skip_special_tokens
=
True
)
if
token
.
strip
()
!=
'<eoa>'
:
self
.
response
=
self
.
response
+
token
history
=
self
.
history
+
[(
self
.
query
,
self
.
response
)]
self
.
queue
.
put
((
self
.
response
,
history
))
def
end
(
self
):
self
.
queue
.
put
(
None
)
def
stream_producer
():
return
self
.
chat
(
tokenizer
=
tokenizer
,
query
=
query
,
streamer
=
ChatStreamer
(
tokenizer
=
tokenizer
),
history
=
history
,
max_new_tokens
=
max_new_tokens
,
do_sample
=
do_sample
,
temperature
=
temperature
,
top_p
=
top_p
,
**
kwargs
,
)
def
consumer
():
producer
=
threading
.
Thread
(
target
=
stream_producer
)
producer
.
start
()
while
True
:
res
=
response_queue
.
get
()
if
res
is
None
:
return
yield
res
return
consumer
()
@
add_start_docstrings
(
"""
The InternLM2 Model transformer with a sequence classification head on top (linear layer).
[`InternLM2ForSequenceClassification`] uses the last token in order to do the classification,
as other causal models (e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
"""
,
InternLM2_START_DOCSTRING
,
)
class
InternLM2ForSequenceClassification
(
InternLM2PreTrainedModel
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
num_labels
=
config
.
num_labels
self
.
model
=
InternLM2Model
(
config
)
self
.
score
=
nn
.
Linear
(
config
.
hidden_size
,
self
.
num_labels
,
bias
=
False
)
# Initialize weights and apply final processing
self
.
post_init
()
def
get_input_embeddings
(
self
):
return
self
.
model
.
tok_embeddings
def
set_input_embeddings
(
self
,
value
):
self
.
model
.
tok_embeddings
=
value
@
add_start_docstrings_to_model_forward
(
InternLM2_INPUTS_DOCSTRING
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
SequenceClassifierOutputWithPast
]:
r
"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
transformer_outputs
=
self
.
model
(
input_ids
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs_embeds
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
hidden_states
=
transformer_outputs
[
0
]
logits
=
self
.
score
(
hidden_states
)
if
input_ids
is
not
None
:
batch_size
=
input_ids
.
shape
[
0
]
else
:
batch_size
=
inputs_embeds
.
shape
[
0
]
if
self
.
config
.
pad_token_id
is
None
and
batch_size
!=
1
:
raise
ValueError
(
'Cannot handle batch sizes > 1 if no padding token is defined.'
)
if
self
.
config
.
pad_token_id
is
None
:
sequence_lengths
=
-
1
else
:
if
input_ids
is
not
None
:
sequence_lengths
=
(
torch
.
eq
(
input_ids
,
self
.
config
.
pad_token_id
).
int
().
argmax
(
-
1
)
-
1
).
to
(
logits
.
device
)
else
:
sequence_lengths
=
-
1
pooled_logits
=
logits
[
torch
.
arange
(
batch_size
,
device
=
logits
.
device
),
sequence_lengths
]
loss
=
None
if
labels
is
not
None
:
labels
=
labels
.
to
(
logits
.
device
)
if
self
.
config
.
problem_type
is
None
:
if
self
.
num_labels
==
1
:
self
.
config
.
problem_type
=
'regression'
elif
self
.
num_labels
>
1
and
(
labels
.
dtype
==
torch
.
long
or
labels
.
dtype
==
torch
.
int
):
self
.
config
.
problem_type
=
'single_label_classification'
else
:
self
.
config
.
problem_type
=
'multi_label_classification'
if
self
.
config
.
problem_type
==
'regression'
:
loss_fct
=
MSELoss
()
if
self
.
num_labels
==
1
:
loss
=
loss_fct
(
pooled_logits
.
squeeze
(),
labels
.
squeeze
())
else
:
loss
=
loss_fct
(
pooled_logits
,
labels
)
elif
self
.
config
.
problem_type
==
'single_label_classification'
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
pooled_logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
elif
self
.
config
.
problem_type
==
'multi_label_classification'
:
loss_fct
=
BCEWithLogitsLoss
()
loss
=
loss_fct
(
pooled_logits
,
labels
)
if
not
return_dict
:
output
=
(
pooled_logits
,
)
+
transformer_outputs
[
1
:]
return
((
loss
,
)
+
output
)
if
loss
is
not
None
else
output
return
SequenceClassifierOutputWithPast
(
loss
=
loss
,
logits
=
pooled_logits
,
past_key_values
=
transformer_outputs
.
past_key_values
,
hidden_states
=
transformer_outputs
.
hidden_states
,
attentions
=
transformer_outputs
.
attentions
,
)
lmdeploy/pytorch/modeling/modeling_llama.py
0 → 100644
View file @
fe851fbc
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch LLaMA model."""
import
math
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.nn.functional
as
F
import
torch.utils.checkpoint
from
torch
import
nn
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
from
transformers.activations
import
ACT2FN
from
transformers.modeling_outputs
import
(
BaseModelOutputWithPast
,
CausalLMOutputWithPast
,
SequenceClassifierOutputWithPast
)
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.models.llama.configuration_llama
import
LlamaConfig
from
transformers.utils
import
(
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
replace_return_docstrings
)
from
lmdeploy.pytorch.modeling.convert_to_qmodules
import
convert_to_qmodules
from
lmdeploy.utils
import
get_logger
logger
=
get_logger
(
'lmdeploy'
)
_CONFIG_FOR_DOC
=
'LlamaConfig'
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def
_make_causal_mask
(
input_ids_shape
:
torch
.
Size
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
past_key_values_length
:
int
=
0
):
"""Make causal mask used for bi-directional self-attention."""
bsz
,
tgt_len
=
input_ids_shape
mask
=
torch
.
full
((
tgt_len
,
tgt_len
),
torch
.
finfo
(
dtype
).
min
,
device
=
device
)
mask_cond
=
torch
.
arange
(
mask
.
size
(
-
1
),
device
=
device
)
mask
.
masked_fill_
(
mask_cond
<
(
mask_cond
+
1
).
view
(
mask
.
size
(
-
1
),
1
),
0
)
mask
=
mask
.
to
(
dtype
)
if
past_key_values_length
>
0
:
mask
=
torch
.
cat
([
torch
.
zeros
(
tgt_len
,
past_key_values_length
,
dtype
=
dtype
,
device
=
device
),
mask
],
dim
=-
1
)
return
mask
[
None
,
None
,
:,
:].
expand
(
bsz
,
1
,
tgt_len
,
tgt_len
+
past_key_values_length
)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def
_expand_mask
(
mask
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
tgt_len
:
Optional
[
int
]
=
None
):
"""Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len,
src_seq_len]`."""
bsz
,
src_len
=
mask
.
size
()
tgt_len
=
tgt_len
if
tgt_len
is
not
None
else
src_len
expanded_mask
=
mask
[:,
None
,
None
,
:].
expand
(
bsz
,
1
,
tgt_len
,
src_len
).
to
(
dtype
)
inverted_mask
=
1.0
-
expanded_mask
return
inverted_mask
.
masked_fill
(
inverted_mask
.
to
(
torch
.
bool
),
torch
.
finfo
(
dtype
).
min
)
class
LlamaRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
"""LlamaRMSNorm is equivalent to T5LayerNorm."""
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
hidden_states
):
input_dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
return
self
.
weight
*
hidden_states
.
to
(
input_dtype
)
class
LlamaRotaryEmbedding
(
torch
.
nn
.
Module
):
"""RotaryEmbedding for Llama Model.
This module generates sine and cosine positional encodings based on
the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding".
The purpose of this class is to provide positional embeddings to the
input tensors. It utilizes a cache mechanism to store precomputed
sine and cosine values for speedup.
Args:
dim (int): The dimensionality of the embeddings.
max_position_embeddings (int, optional): The maximum number of
position embeddings. Default is 2048.
base (int, optional): The base value for the inverse frequency
calculation. Default is 10000.
device (str, optional): The device to run operations on.
If None, defaults to the device of the model.
"""
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
):
super
().
__init__
()
self
.
dim
=
dim
self
.
max_position_embeddings
=
max_position_embeddings
self
.
base
=
base
inv_freq
=
1.0
/
(
self
.
base
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
).
float
().
to
(
device
)
/
self
.
dim
))
self
.
register_buffer
(
'inv_freq'
,
inv_freq
,
persistent
=
False
)
# Build here to make `torch.jit.trace` work.
self
.
_set_cos_sin_cache
(
seq_len
=
max_position_embeddings
,
device
=
self
.
inv_freq
.
device
,
dtype
=
torch
.
get_default_dtype
())
def
_set_cos_sin_cache
(
self
,
seq_len
,
device
,
dtype
):
"""Sets the cached sine and cosine values for the specified sequence
length.
Args:
seq_len (int): The sequence length for which to set the cache.
device (str): The device to use for computation.
dtype (torch.dtype): The data type to be used for tensors.
"""
self
.
max_seq_len_cached
=
seq_len
t
=
torch
.
arange
(
self
.
max_seq_len_cached
,
device
=
device
,
dtype
=
self
.
inv_freq
.
dtype
)
freqs
=
torch
.
einsum
(
'i,j->ij'
,
t
,
self
.
inv_freq
)
# Different from paper, but it uses a different permutation in order
# to obtain the same calculation
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
self
.
register_buffer
(
'cos_cached'
,
emb
.
cos
()[
None
,
None
,
:,
:].
to
(
dtype
),
persistent
=
False
)
self
.
register_buffer
(
'sin_cached'
,
emb
.
sin
()[
None
,
None
,
:,
:].
to
(
dtype
),
persistent
=
False
)
def
forward
(
self
,
x
,
seq_len
=
None
):
"""Forward propagation method for the embedding layer. Generates
positional embeddings for the given input tensor.
If the sequence length is larger than the cache, it resets the cache.
Args:
x (torch.Tensor): Input tensor of shape
[batch_size, num_attention_heads, seq_len, head_size].
seq_len (int, optional): Sequence length. If None, it is obtained
from `x`.
Returns:
tuple: Tuple containing cosine and sine positional embeddings.
"""
# x: [bs, num_attention_heads, seq_len, head_size]
if
seq_len
>
self
.
max_seq_len_cached
:
self
.
_set_cos_sin_cache
(
seq_len
=
seq_len
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
return
(
self
.
cos_cached
[:,
:,
:
seq_len
,
...].
to
(
dtype
=
x
.
dtype
),
self
.
sin_cached
[:,
:,
:
seq_len
,
...].
to
(
dtype
=
x
.
dtype
),
)
class
LlamaLinearScalingRotaryEmbedding
(
LlamaRotaryEmbedding
):
"""This class extends the `LlamaRotaryEmbedding` with linear scaling.
It provides a mechanism for adjusting the scale of the positional
embeddings by dividing the tensor generated by the range of sequence length
with a scaling factor. This is useful when dealing with sequences of
varying lengths.
Credits to Reddit User /u/kaiokendev for this extension.
Args:
dim (int): The dimensionality of the embeddings.
max_position_embeddings (int, optional): The maximum number of
position embeddings. Default is 2048.
base (int, optional): The base value for the inverse frequency
calculation. Default is 10000.
device (str, optional): The device to run operations on. If None,
defaults to the device of the model.
scaling_factor (float, optional): Scaling factor used in adjusting
the scale of positional embeddings. Default is 1.0.
"""
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
,
scaling_factor
=
1.0
):
self
.
scaling_factor
=
scaling_factor
super
().
__init__
(
dim
,
max_position_embeddings
,
base
,
device
)
def
_set_cos_sin_cache
(
self
,
seq_len
,
device
,
dtype
):
"""Sets the cached sine and cosine values for the specified sequence
length.
Args:
seq_len (int): The sequence length for which to set the cache.
device (str): The device to use for computation.
dtype (torch.dtype): The data type to use for tensors.
"""
self
.
max_seq_len_cached
=
seq_len
t
=
torch
.
arange
(
self
.
max_seq_len_cached
,
device
=
device
,
dtype
=
self
.
inv_freq
.
dtype
)
t
=
t
/
self
.
scaling_factor
freqs
=
torch
.
einsum
(
'i,j->ij'
,
t
,
self
.
inv_freq
)
# Different from paper, but it uses a different permutation in order
# to obtain the same calculation
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
self
.
register_buffer
(
'cos_cached'
,
emb
.
cos
()[
None
,
None
,
:,
:].
to
(
dtype
),
persistent
=
False
)
self
.
register_buffer
(
'sin_cached'
,
emb
.
sin
()[
None
,
None
,
:,
:].
to
(
dtype
),
persistent
=
False
)
class
LlamaDynamicNTKScalingRotaryEmbedding
(
LlamaRotaryEmbedding
):
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling.
Credits to the Reddit users /u/bloc97 and /u/emozilla
"""
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
,
scaling_factor
=
1.0
):
self
.
scaling_factor
=
scaling_factor
super
().
__init__
(
dim
,
max_position_embeddings
,
base
,
device
)
def
_set_cos_sin_cache
(
self
,
seq_len
,
device
,
dtype
):
self
.
max_seq_len_cached
=
seq_len
if
seq_len
>
self
.
max_position_embeddings
:
base
=
self
.
base
*
((
self
.
scaling_factor
*
seq_len
/
self
.
max_position_embeddings
)
-
(
self
.
scaling_factor
-
1
))
**
(
self
.
dim
/
(
self
.
dim
-
2
))
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
).
float
().
to
(
device
)
/
self
.
dim
))
self
.
register_buffer
(
'inv_freq'
,
inv_freq
,
persistent
=
False
)
t
=
torch
.
arange
(
self
.
max_seq_len_cached
,
device
=
device
,
dtype
=
self
.
inv_freq
.
dtype
)
freqs
=
torch
.
einsum
(
'i,j->ij'
,
t
,
self
.
inv_freq
)
# Different from paper, but it uses a different permutation in order
# to obtain the same calculation
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
self
.
register_buffer
(
'cos_cached'
,
emb
.
cos
()[
None
,
None
,
:,
:].
to
(
dtype
),
persistent
=
False
)
self
.
register_buffer
(
'sin_cached'
,
emb
.
sin
()[
None
,
None
,
:,
:].
to
(
dtype
),
persistent
=
False
)
def
rotate_half
(
x
):
"""Rotates half the hidden dims of the input."""
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
position_ids
):
"""Apply rotary positional embeddings to query and key tensors.
This function applies the cosine and sine positional embeddings on the
input query (q) and key (k) tensors using element-wise multiplication and
addition.
"""
# The first two dimensions of cos and sin are always 1,
# so we can `squeeze` them.
cos
=
cos
.
squeeze
(
1
).
squeeze
(
0
)
# [seq_len, dim]
sin
=
sin
.
squeeze
(
1
).
squeeze
(
0
)
# [seq_len, dim]
cos
=
cos
[
position_ids
].
unsqueeze
(
1
)
# [bs, 1, seq_len, dim]
sin
=
sin
[
position_ids
].
unsqueeze
(
1
)
# [bs, 1, seq_len, dim]
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q_embed
,
k_embed
class
LlamaMLP
(
nn
.
Module
):
"""MLP for Llama Model."""
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
self
.
intermediate_size
=
config
.
intermediate_size
self
.
gate_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
intermediate_size
,
bias
=
False
)
self
.
up_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
intermediate_size
,
bias
=
False
)
self
.
down_proj
=
nn
.
Linear
(
self
.
intermediate_size
,
self
.
hidden_size
,
bias
=
False
)
self
.
act_fn
=
ACT2FN
[
config
.
hidden_act
]
def
forward
(
self
,
x
):
if
self
.
config
.
pretraining_tp
>
1
:
slice
=
self
.
intermediate_size
//
self
.
config
.
pretraining_tp
gate_proj_slices
=
self
.
gate_proj
.
weight
.
split
(
slice
,
dim
=
0
)
up_proj_slices
=
self
.
up_proj
.
weight
.
split
(
slice
,
dim
=
0
)
down_proj_slices
=
self
.
down_proj
.
weight
.
split
(
slice
,
dim
=
1
)
gate_proj
=
torch
.
cat
([
F
.
linear
(
x
,
gate_proj_slices
[
i
])
for
i
in
range
(
self
.
config
.
pretraining_tp
)
],
dim
=-
1
)
up_proj
=
torch
.
cat
([
F
.
linear
(
x
,
up_proj_slices
[
i
])
for
i
in
range
(
self
.
config
.
pretraining_tp
)
],
dim
=-
1
)
intermediate_states
=
(
self
.
act_fn
(
gate_proj
)
*
up_proj
).
split
(
slice
,
dim
=
2
)
down_proj
=
[
F
.
linear
(
intermediate_states
[
i
],
down_proj_slices
[
i
])
for
i
in
range
(
self
.
config
.
pretraining_tp
)
]
down_proj
=
sum
(
down_proj
)
else
:
down_proj
=
self
.
down_proj
(
self
.
act_fn
(
self
.
gate_proj
(
x
))
*
self
.
up_proj
(
x
))
return
down_proj
def
repeat_kv
(
hidden_states
:
torch
.
Tensor
,
n_rep
:
int
)
->
torch
.
Tensor
:
"""This is the equivalent of torch.repeat_interleave(x, dim=1,
repeats=n_rep).
The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
(batch, num_attention_heads, seqlen, head_dim)
"""
batch
,
num_key_value_heads
,
slen
,
head_dim
=
hidden_states
.
shape
if
n_rep
==
1
:
return
hidden_states
hidden_states
=
hidden_states
[:,
:,
None
,
:,
:].
expand
(
batch
,
num_key_value_heads
,
n_rep
,
slen
,
head_dim
)
return
hidden_states
.
reshape
(
batch
,
num_key_value_heads
*
n_rep
,
slen
,
head_dim
)
class
LlamaAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper."""
def
__init__
(
self
,
config
:
LlamaConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_dim
=
self
.
hidden_size
//
self
.
num_heads
self
.
num_key_value_heads
=
config
.
num_key_value_heads
self
.
num_key_value_groups
=
self
.
num_heads
//
self
.
num_key_value_heads
self
.
max_position_embeddings
=
config
.
max_position_embeddings
self
.
rope_theta
=
config
.
rope_theta
if
(
self
.
head_dim
*
self
.
num_heads
)
!=
self
.
hidden_size
:
raise
ValueError
(
'hidden_size must be divisible by num_heads '
f
'(got `hidden_size`:
{
self
.
hidden_size
}
'
f
' and `num_heads`:
{
self
.
num_heads
}
).'
)
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_heads
*
self
.
head_dim
,
bias
=
False
)
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_key_value_heads
*
self
.
head_dim
,
bias
=
False
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_key_value_heads
*
self
.
head_dim
,
bias
=
False
)
self
.
o_proj
=
nn
.
Linear
(
self
.
num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
bias
=
False
)
self
.
_init_rope
()
def
_init_rope
(
self
):
"""Initialize the Rotary Embedding Module."""
if
self
.
config
.
rope_scaling
is
None
:
self
.
rotary_emb
=
LlamaRotaryEmbedding
(
self
.
head_dim
,
max_position_embeddings
=
self
.
max_position_embeddings
,
base
=
self
.
rope_theta
,
)
else
:
scaling_type
=
self
.
config
.
rope_scaling
[
'type'
]
scaling_factor
=
self
.
config
.
rope_scaling
[
'factor'
]
if
scaling_type
==
'linear'
:
self
.
rotary_emb
=
LlamaLinearScalingRotaryEmbedding
(
self
.
head_dim
,
max_position_embeddings
=
self
.
max_position_embeddings
,
scaling_factor
=
scaling_factor
,
base
=
self
.
rope_theta
,
)
elif
scaling_type
==
'dynamic'
:
self
.
rotary_emb
=
LlamaDynamicNTKScalingRotaryEmbedding
(
self
.
head_dim
,
max_position_embeddings
=
self
.
max_position_embeddings
,
scaling_factor
=
scaling_factor
,
base
=
self
.
rope_theta
,
)
else
:
raise
ValueError
(
f
'Unknown RoPE scaling type
{
scaling_type
}
'
)
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
).
contiguous
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""Forward propagation method for the attention layer."""
bsz
,
q_len
,
_
=
hidden_states
.
size
()
if
self
.
config
.
pretraining_tp
>
1
:
key_value_slicing
=
(
self
.
num_key_value_heads
*
self
.
head_dim
)
//
self
.
config
.
pretraining_tp
query_slices
=
self
.
q_proj
.
weight
.
split
(
(
self
.
num_heads
*
self
.
head_dim
)
//
self
.
config
.
pretraining_tp
,
dim
=
0
)
key_slices
=
self
.
k_proj
.
weight
.
split
(
key_value_slicing
,
dim
=
0
)
value_slices
=
self
.
v_proj
.
weight
.
split
(
key_value_slicing
,
dim
=
0
)
query_states
=
[
F
.
linear
(
hidden_states
,
query_slices
[
i
])
for
i
in
range
(
self
.
config
.
pretraining_tp
)
]
query_states
=
torch
.
cat
(
query_states
,
dim
=-
1
)
key_states
=
[
F
.
linear
(
hidden_states
,
key_slices
[
i
])
for
i
in
range
(
self
.
config
.
pretraining_tp
)
]
key_states
=
torch
.
cat
(
key_states
,
dim
=-
1
)
value_states
=
[
F
.
linear
(
hidden_states
,
value_slices
[
i
])
for
i
in
range
(
self
.
config
.
pretraining_tp
)
]
value_states
=
torch
.
cat
(
value_states
,
dim
=-
1
)
else
:
query_states
=
self
.
q_proj
(
hidden_states
)
key_states
=
self
.
k_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
kv_seq_len
=
key_states
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
if
past_key_value
is
not
None
:
# reuse k, v, self_attention
key_states
=
torch
.
cat
([
past_key_value
[
0
],
key_states
],
dim
=
2
)
value_states
=
torch
.
cat
([
past_key_value
[
1
],
value_states
],
dim
=
2
)
past_key_value
=
(
key_states
,
value_states
)
if
use_cache
else
None
# repeat k/v heads if n_kv_heads < n_heads
key_states
=
repeat_kv
(
key_states
,
self
.
num_key_value_groups
)
value_states
=
repeat_kv
(
value_states
,
self
.
num_key_value_groups
)
attn_weights
=
torch
.
matmul
(
query_states
,
key_states
.
transpose
(
2
,
3
))
/
math
.
sqrt
(
self
.
head_dim
)
if
attn_weights
.
size
()
!=
(
bsz
,
self
.
num_heads
,
q_len
,
kv_seq_len
):
raise
ValueError
(
'Attention weights should be of size '
f
'
{
(
bsz
,
self
.
num_heads
,
q_len
,
kv_seq_len
)
}
, but is'
f
'
{
attn_weights
.
size
()
}
'
)
if
attention_mask
is
not
None
:
if
attention_mask
.
size
()
!=
(
bsz
,
1
,
q_len
,
kv_seq_len
):
raise
ValueError
(
'Attention mask should be of size '
f
'
{
(
bsz
,
1
,
q_len
,
kv_seq_len
)
}
, '
f
'but is
{
attention_mask
.
size
()
}
'
)
attn_weights
=
attn_weights
+
attention_mask
# upcast attention to fp32
attn_weights
=
nn
.
functional
.
softmax
(
attn_weights
,
dim
=-
1
,
dtype
=
torch
.
float32
).
to
(
query_states
.
dtype
)
attn_output
=
torch
.
matmul
(
attn_weights
,
value_states
)
if
attn_output
.
size
()
!=
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
head_dim
):
raise
ValueError
(
'`attn_output` should be of size '
f
'
{
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
head_dim
)
}
, but is'
f
'
{
attn_output
.
size
()
}
'
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
hidden_size
)
if
self
.
config
.
pretraining_tp
>
1
:
attn_output
=
attn_output
.
split
(
self
.
hidden_size
//
self
.
config
.
pretraining_tp
,
dim
=
2
)
o_proj_slices
=
self
.
o_proj
.
weight
.
split
(
self
.
hidden_size
//
self
.
config
.
pretraining_tp
,
dim
=
1
)
attn_output
=
sum
([
F
.
linear
(
attn_output
[
i
],
o_proj_slices
[
i
])
for
i
in
range
(
self
.
config
.
pretraining_tp
)
])
else
:
attn_output
=
self
.
o_proj
(
attn_output
)
if
not
output_attentions
:
attn_weights
=
None
return
attn_output
,
attn_weights
,
past_key_value
class
LlamaDecoderLayer
(
nn
.
Module
):
"""Decoder layer for Llama Model."""
def
__init__
(
self
,
config
:
LlamaConfig
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
LlamaAttention
(
config
=
config
)
self
.
mlp
=
LlamaMLP
(
config
)
self
.
input_layernorm
=
LlamaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
LlamaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
Optional
[
bool
]
=
False
,
use_cache
:
Optional
[
bool
]
=
False
,
)
->
Tuple
[
torch
.
FloatTensor
,
Optional
[
Tuple
[
torch
.
FloatTensor
,
torch
.
FloatTensor
]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape
`(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask
of size `(batch, 1, tgt_len, src_len)` where padding elements
are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all
attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are
returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached
past key and value projection states
"""
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
# Self Attention
hidden_states
,
self_attn_weights
,
present_key_value
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
use_cache
=
use_cache
,
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
outputs
=
(
hidden_states
,
)
if
output_attentions
:
outputs
+=
(
self_attn_weights
,
)
if
use_cache
:
outputs
+=
(
present_key_value
,
)
return
outputs
LLAMA_START_DOCSTRING
=
r
""" # noqa: E501
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`LlamaConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@
add_start_docstrings
(
'The bare LLaMA Model outputting raw hidden-states without any specific head on top.'
,
# noqa: E501
LLAMA_START_DOCSTRING
,
)
class
LlamaPreTrainedModel
(
PreTrainedModel
):
config_class
=
LlamaConfig
base_model_prefix
=
'model'
supports_gradient_checkpointing
=
True
_no_split_modules
=
[
'LlamaDecoderLayer'
]
_skip_keys_device_placement
=
'past_key_values'
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
if
isinstance
(
module
,
nn
.
Linear
):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
std
)
if
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
nn
.
Embedding
):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
std
)
if
module
.
padding_idx
is
not
None
:
module
.
weight
.
data
[
module
.
padding_idx
].
zero_
()
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
LlamaModel
):
module
.
gradient_checkpointing
=
value
LLAMA_INPUTS_DOCSTRING
=
r
""" # noqa: E501
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@
add_start_docstrings
(
'The bare LLaMA Model outputting raw hidden-states without any specific head on top.'
,
# noqa: E501
LLAMA_START_DOCSTRING
,
)
class
LlamaModel
(
LlamaPreTrainedModel
):
"""Transformer decoder consisting of *config.num_hidden_layers* layers.
Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
def
__init__
(
self
,
config
:
LlamaConfig
):
super
().
__init__
(
config
)
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
hidden_size
,
self
.
padding_idx
)
self
.
layers
=
nn
.
ModuleList
([
LlamaDecoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
LlamaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
gradient_checkpointing
=
False
# Initialize weights and apply final processing
self
.
post_init
()
def
get_input_embeddings
(
self
):
return
self
.
embed_tokens
def
set_input_embeddings
(
self
,
value
):
self
.
embed_tokens
=
value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask # noqa
def
_prepare_decoder_attention_mask
(
self
,
attention_mask
,
input_shape
,
inputs_embeds
,
past_key_values_length
):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask
=
None
if
input_shape
[
-
1
]
>
1
:
combined_attention_mask
=
_make_causal_mask
(
input_shape
,
inputs_embeds
.
dtype
,
device
=
inputs_embeds
.
device
,
past_key_values_length
=
past_key_values_length
,
)
if
attention_mask
is
not
None
:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask
=
_expand_mask
(
attention_mask
,
inputs_embeds
.
dtype
,
tgt_len
=
input_shape
[
-
1
]).
to
(
inputs_embeds
.
device
)
combined_attention_mask
=
(
expanded_attn_mask
if
combined_attention_mask
is
None
else
expanded_attn_mask
+
combined_attention_mask
)
return
combined_attention_mask
@
add_start_docstrings_to_model_forward
(
LLAMA_INPUTS_DOCSTRING
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPast
]:
output_attentions
=
(
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
)
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
use_cache
=
(
use_cache
if
use_cache
is
not
None
else
self
.
config
.
use_cache
)
return_dict
=
(
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
)
# retrieve input_ids and inputs_embeds
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
'You cannot specify both decoder_input_ids'
'and decoder_inputs_embeds at the same time'
)
elif
input_ids
is
not
None
:
batch_size
,
seq_length
=
input_ids
.
shape
elif
inputs_embeds
is
not
None
:
batch_size
,
seq_length
,
_
=
inputs_embeds
.
shape
else
:
raise
ValueError
(
'You have to specify either decoder_input_ids'
'or decoder_inputs_embeds'
)
seq_length_with_past
=
seq_length
past_key_values_length
=
0
if
past_key_values
is
not
None
:
past_key_values_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
seq_length_with_past
=
(
seq_length_with_past
+
past_key_values_length
)
if
position_ids
is
None
:
device
=
(
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
)
position_ids
=
torch
.
arange
(
past_key_values_length
,
seq_length
+
past_key_values_length
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
view
(
-
1
,
seq_length
)
else
:
position_ids
=
position_ids
.
view
(
-
1
,
seq_length
).
long
()
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
# embed positions
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones
((
batch_size
,
seq_length_with_past
),
dtype
=
torch
.
bool
,
device
=
inputs_embeds
.
device
)
attention_mask
=
self
.
_prepare_decoder_attention_mask
(
attention_mask
,
(
batch_size
,
seq_length
),
inputs_embeds
,
past_key_values_length
)
hidden_states
=
inputs_embeds
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
'`use_cache=True` is incompatible with gradient'
' checkpointing. Setting `use_cache=False`...'
)
use_cache
=
False
# decoder layers
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attns
=
()
if
output_attentions
else
None
next_decoder_cache
=
()
if
use_cache
else
None
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,
)
past_key_value
=
past_key_values
[
idx
]
if
past_key_values
is
not
None
else
None
if
self
.
gradient_checkpointing
and
self
.
training
:
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
# None for past_key_value
return
module
(
*
inputs
,
past_key_value
,
output_attentions
)
return
custom_forward
layer_outputs
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
decoder_layer
),
hidden_states
,
attention_mask
,
position_ids
,
)
else
:
layer_outputs
=
decoder_layer
(
hidden_states
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
use_cache
=
use_cache
,
)
hidden_states
=
layer_outputs
[
0
]
if
use_cache
:
next_decoder_cache
+=
(
layer_outputs
[
2
if
output_attentions
else
1
],
)
if
output_attentions
:
all_self_attns
+=
(
layer_outputs
[
1
],
)
hidden_states
=
self
.
norm
(
hidden_states
)
# add hidden states from the last decoder layer
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,
)
next_cache
=
next_decoder_cache
if
use_cache
else
None
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
BaseModelOutputWithPast
(
last_hidden_state
=
hidden_states
,
past_key_values
=
next_cache
,
hidden_states
=
all_hidden_states
,
attentions
=
all_self_attns
,
)
class
LlamaForCausalLM
(
LlamaPreTrainedModel
):
"""This class extends the `LlamaPreTrainedModel` to enable causal language
modeling.
It wraps the basic Llama model (`LlamaModel`) and includes a linear layer
as a language model head (`lm_head`). The purpose is to predict token
probabilities, given the previous tokens in the sequence.
"""
_tied_weights_keys
=
[
'lm_head.weight'
]
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
model
=
LlamaModel
(
config
)
self
.
vocab_size
=
config
.
vocab_size
self
.
lm_head
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
# Initialize weights and apply final processing
self
.
post_init
()
convert_to_qmodules
(
self
)
def
get_input_embeddings
(
self
):
"""Get the token embedding layer."""
return
self
.
model
.
embed_tokens
def
set_input_embeddings
(
self
,
value
):
"""Set the token embedding layer."""
self
.
model
.
embed_tokens
=
value
def
get_output_embeddings
(
self
):
"""Get the output embedding layer."""
return
self
.
lm_head
def
set_output_embeddings
(
self
,
new_embeddings
):
"""Set the output embedding layer."""
self
.
lm_head
=
new_embeddings
def
set_decoder
(
self
,
decoder
):
"""Set the decoder model."""
self
.
model
=
decoder
def
get_decoder
(
self
):
"""Get the decoder model."""
return
self
.
model
@
add_start_docstrings_to_model_forward
(
LLAMA_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
CausalLMOutputWithPast
,
config_class
=
_CONFIG_FOR_DOC
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
CausalLMOutputWithPast
]:
r
""" # noqa: E501
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions
=
(
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
)
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict
=
(
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # noqa: E501
outputs
=
self
.
model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs_embeds
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
hidden_states
=
outputs
[
0
]
if
self
.
config
.
pretraining_tp
>
1
:
lm_head_slices
=
self
.
lm_head
.
weight
.
split
(
self
.
vocab_size
//
self
.
config
.
pretraining_tp
,
dim
=
0
)
logits
=
[
F
.
linear
(
hidden_states
,
lm_head_slices
[
i
])
for
i
in
range
(
self
.
config
.
pretraining_tp
)
]
logits
=
torch
.
cat
(
logits
,
dim
=-
1
)
else
:
logits
=
self
.
lm_head
(
hidden_states
)
logits
=
logits
.
float
()
loss
=
None
if
labels
is
not
None
:
# Shift so that tokens < n predict n
shift_logits
=
logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
labels
[...,
1
:].
contiguous
()
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
()
shift_logits
=
shift_logits
.
view
(
-
1
,
self
.
config
.
vocab_size
)
shift_labels
=
shift_labels
.
view
(
-
1
)
# Enable model parallelism
shift_labels
=
shift_labels
.
to
(
shift_logits
.
device
)
loss
=
loss_fct
(
shift_logits
,
shift_labels
)
if
not
return_dict
:
output
=
(
logits
,
)
+
outputs
[
1
:]
return
(
loss
,
)
+
output
if
loss
is
not
None
else
output
return
CausalLMOutputWithPast
(
loss
=
loss
,
logits
=
logits
,
past_key_values
=
outputs
.
past_key_values
,
hidden_states
=
outputs
.
hidden_states
,
attentions
=
outputs
.
attentions
,
)
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past_key_values
=
None
,
attention_mask
=
None
,
inputs_embeds
=
None
,
**
kwargs
):
"""Prepare inputs for generating sequences using the model.
Args:
input_ids (torch.Tensor): Input token ids.
past_key_values (list[torch.Tensor], optional): List of past key
and value states.
attention_mask (torch.Tensor, optional): Mask indicating which
tokens should be attended to.
inputs_embeds (torch.FloatTensor, optional): Optionally,
the input embeddings instead of token ids.
Returns:
dict: Dictionary containing prepared inputs for model generation.
"""
if
past_key_values
:
input_ids
=
input_ids
[:,
-
1
:]
position_ids
=
kwargs
.
get
(
'position_ids'
,
None
)
if
attention_mask
is
not
None
and
position_ids
is
None
:
# create position_ids on the fly for batch generation
position_ids
=
attention_mask
.
long
().
cumsum
(
-
1
)
-
1
position_ids
.
masked_fill_
(
attention_mask
==
0
,
1
)
if
past_key_values
:
position_ids
=
position_ids
[:,
-
1
].
unsqueeze
(
-
1
)
# if `inputs_embeds` are passed, we only want to use them
# in the 1st generation step
if
inputs_embeds
is
not
None
and
past_key_values
is
None
:
model_inputs
=
{
'inputs_embeds'
:
inputs_embeds
}
else
:
model_inputs
=
{
'input_ids'
:
input_ids
}
model_inputs
.
update
({
'position_ids'
:
position_ids
,
'past_key_values'
:
past_key_values
,
'use_cache'
:
kwargs
.
get
(
'use_cache'
),
'attention_mask'
:
attention_mask
,
})
return
model_inputs
@
staticmethod
def
_reorder_cache
(
past_key_values
,
beam_idx
):
"""Reorder cached past key-values during generation using beam search.
This function reorders the cached past key-values according to the
given indices. It's useful in beam search where the order of hypotheses
can change from one time-step to another.
"""
reordered_past
=
()
for
layer_past
in
past_key_values
:
reordered_past
+=
(
tuple
(
past_state
.
index_select
(
0
,
beam_idx
.
to
(
past_state
.
device
))
for
past_state
in
layer_past
),
)
return
reordered_past
@
add_start_docstrings
(
""" # noqa: E501
The LLaMa Model transformer with a sequence classification head on top (linear layer).
[`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
"""
,
LLAMA_START_DOCSTRING
,
)
class
LlamaForSequenceClassification
(
LlamaPreTrainedModel
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
num_labels
=
config
.
num_labels
self
.
model
=
LlamaModel
(
config
)
self
.
score
=
nn
.
Linear
(
config
.
hidden_size
,
self
.
num_labels
,
bias
=
False
)
# Initialize weights and apply final processing
self
.
post_init
()
def
get_input_embeddings
(
self
):
return
self
.
model
.
embed_tokens
def
set_input_embeddings
(
self
,
value
):
self
.
model
.
embed_tokens
=
value
@
add_start_docstrings_to_model_forward
(
LLAMA_INPUTS_DOCSTRING
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
SequenceClassifierOutputWithPast
]:
r
""" # noqa: E501
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict
=
(
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
)
transformer_outputs
=
self
.
model
(
input_ids
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs_embeds
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
hidden_states
=
transformer_outputs
[
0
]
logits
=
self
.
score
(
hidden_states
)
if
input_ids
is
not
None
:
batch_size
=
input_ids
.
shape
[
0
]
else
:
batch_size
=
inputs_embeds
.
shape
[
0
]
if
self
.
config
.
pad_token_id
is
None
and
batch_size
!=
1
:
raise
ValueError
(
'Cannot handle batch sizes > 1 if no padding token is defined.'
)
if
self
.
config
.
pad_token_id
is
None
:
sequence_lengths
=
-
1
else
:
if
input_ids
is
not
None
:
sequence_lengths
=
(
torch
.
eq
(
input_ids
,
self
.
config
.
pad_token_id
).
long
().
argmax
(
-
1
)
-
1
).
to
(
logits
.
device
)
else
:
sequence_lengths
=
-
1
pooled_logits
=
logits
[
torch
.
arange
(
batch_size
,
device
=
logits
.
device
),
sequence_lengths
]
loss
=
None
if
labels
is
not
None
:
labels
=
labels
.
to
(
logits
.
device
)
if
self
.
config
.
problem_type
is
None
:
if
self
.
num_labels
==
1
:
self
.
config
.
problem_type
=
'regression'
elif
self
.
num_labels
>
1
and
(
labels
.
dtype
==
torch
.
long
or
labels
.
dtype
==
torch
.
int
):
self
.
config
.
problem_type
=
'single_label_classification'
else
:
self
.
config
.
problem_type
=
'multi_label_classification'
if
self
.
config
.
problem_type
==
'regression'
:
loss_fct
=
MSELoss
()
if
self
.
num_labels
==
1
:
loss
=
loss_fct
(
pooled_logits
.
squeeze
(),
labels
.
squeeze
())
else
:
loss
=
loss_fct
(
pooled_logits
,
labels
)
elif
self
.
config
.
problem_type
==
'single_label_classification'
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
pooled_logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
elif
self
.
config
.
problem_type
==
'multi_label_classification'
:
loss_fct
=
BCEWithLogitsLoss
()
loss
=
loss_fct
(
pooled_logits
,
labels
)
if
not
return_dict
:
output
=
(
pooled_logits
,
)
+
transformer_outputs
[
1
:]
return
((
loss
,
)
+
output
)
if
loss
is
not
None
else
output
return
SequenceClassifierOutputWithPast
(
loss
=
loss
,
logits
=
pooled_logits
,
past_key_values
=
transformer_outputs
.
past_key_values
,
hidden_states
=
transformer_outputs
.
hidden_states
,
attentions
=
transformer_outputs
.
attentions
,
)
lmdeploy/pytorch/models/__init__.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
from
.patch
import
patch
from
.q_modules
import
QLinear
,
QRMSNorm
__all__
=
[
'patch'
,
'QLinear'
,
'QRMSNorm'
]
lmdeploy/pytorch/models/baichuan.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.distributed
as
dist
from
torch
import
nn
from
torch.distributed._tensor
import
DeviceMesh
,
Shard
,
distribute_tensor
from
transformers.modeling_outputs
import
BaseModelOutputWithPast
from
..dist_utils
import
(
colwise_parallelize_linear_fn
,
rowwise_parallelize_linear_fn
,
try_to_local
)
from
..kernels
import
apply_rotary_pos_emb
from
..kernels.alibi_pagedattention
import
alibi_paged_attention_fwd
from
..kernels.fill_kv_cache
import
fill_kv_cache
from
..kernels.pagedattention
import
paged_attention_fwd
class
PatchedRMSNorm
(
nn
.
Module
):
"""Rewrite RMSNorm."""
def
forward
(
self
,
hidden_states
):
"""forward."""
from
..kernels
import
rms_norm
epsilon
=
getattr
(
self
,
'epsilon'
,
None
)
if
epsilon
is
None
:
epsilon
=
getattr
(
self
,
'variance_epsilon'
,
1e-10
)
ret
=
rms_norm
(
hidden_states
,
self
.
weight
,
epsilon
)
return
ret
def
_attention_partition_fn
(
mod_name
:
str
,
mod
:
nn
.
Module
,
device_mesh
:
DeviceMesh
):
"""A function for attention partition."""
def
__w_pack_linear_fn
(
mod
:
nn
.
Module
):
"""fn for w pack linear."""
for
name
,
param
in
mod
.
named_parameters
():
param
=
param
.
unflatten
(
0
,
(
3
,
-
1
))
dist_tensor
=
distribute_tensor
(
param
,
device_mesh
,
[
Shard
(
1
)])
dist_tensor
=
try_to_local
(
dist_tensor
)
dist_tensor
=
dist_tensor
.
flatten
(
0
,
1
)
dist_param
=
torch
.
nn
.
Parameter
(
dist_tensor
)
mod
.
register_parameter
(
name
,
dist_param
)
def
__w_pack_lora_linear_fn
(
mod
:
nn
.
Module
):
"""fn for w pack lora linear."""
mod
.
_tp_mode
=
'colwise'
base_layer
=
mod
.
base_layer
__w_pack_linear_fn
(
base_layer
)
for
lora_a_mod
in
mod
.
lora_A
.
values
():
colwise_parallelize_linear_fn
(
lora_a_mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
for
lora_b_mod
in
mod
.
lora_B
.
values
():
__w_pack_linear_fn
(
lora_b_mod
)
if
mod_name
in
[
'W_pack'
]:
from
peft.tuners.lora
import
Linear
as
LoraLinear
if
isinstance
(
mod
,
LoraLinear
):
__w_pack_lora_linear_fn
(
mod
)
else
:
__w_pack_linear_fn
(
mod
)
elif
mod_name
in
[
'o_proj'
]:
rowwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
class
Attention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper."""
@
classmethod
def
_distribute_partition_fn
(
cls
,
mod_name
:
str
,
mod
:
nn
.
Module
,
device_mesh
:
DeviceMesh
):
"""Distribution partition callback."""
return
_attention_partition_fn
(
mod_name
,
mod
,
device_mesh
)
@
classmethod
def
_distribute_output_fn
(
cls
,
outputs
,
device_mesh
:
DeviceMesh
):
"""Distribution output hook."""
dist
.
all_reduce
(
outputs
[
0
])
return
outputs
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""Rewrite of Attention.forward."""
world_size
=
1
if
dist
.
is_initialized
():
world_size
=
dist
.
get_world_size
()
return
self
.
_contiguous_batching_forward
(
hidden_states
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
world_size
=
world_size
,
)
def
_contiguous_batching_forward
(
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
world_size
:
int
=
1
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""Rewrite implementation of Attention.forward.
Add continuous batching support. Add paged attention support. TP
support.
"""
assert
not
output_attentions
context
=
self
.
context
.
context
max_kv_seq_length
=
context
.
max_kv_seq_length
kv_seq_length
=
context
.
kv_seq_length
q_seq_length
=
context
.
q_seq_length
q_start_loc
=
context
.
q_start_loc
block_offsets
=
context
.
block_offsets
max_q_seq_length
=
context
.
max_q_seq_length
num_heads
=
self
.
num_heads
//
world_size
num_kv_heads
=
self
.
num_heads
//
world_size
head_dim
=
self
.
head_dim
def
_qkv_proj
(
hidden_states
):
"""qkv proj."""
proj
=
self
.
W_pack
(
hidden_states
)
return
proj
.
chunk
(
3
,
-
1
)
def
_rotary_emb_fn
(
query_states
,
key_states
,
value_states
):
if
hasattr
(
self
,
'rotary_emb'
):
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
max_kv_seq_length
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
,
context
.
position_ids_1d
)
return
query_states
,
key_states
,
value_states
query_states
,
key_states
,
value_states
=
_qkv_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
-
1
,
num_heads
,
head_dim
)
key_states
=
key_states
.
view
(
-
1
,
num_kv_heads
,
head_dim
)
value_states
=
value_states
.
view
(
-
1
,
num_kv_heads
,
head_dim
)
query_states
,
key_states
,
value_states
=
_rotary_emb_fn
(
query_states
,
key_states
,
value_states
)
fill_kv_cache
(
key_states
,
value_states
,
past_key_value
[
0
],
past_key_value
[
1
],
q_start_loc
,
q_seq_length
,
kv_seq_length
=
kv_seq_length
,
max_q_seq_length
=
max_q_seq_length
,
block_offsets
=
block_offsets
,
)
attn_output
=
query_states
paged_attention_fwd
(
query_states
,
past_key_value
[
0
],
past_key_value
[
1
],
attn_output
,
block_offsets
,
q_start_loc
=
q_start_loc
,
q_seqlens
=
q_seq_length
,
kv_seqlens
=
kv_seq_length
,
max_seqlen
=
max_q_seq_length
,
)
hidden_size
=
num_heads
*
head_dim
attn_output
=
attn_output
.
reshape
(
*
hidden_states
.
shape
[:
-
1
],
hidden_size
)
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
class
BaichuanAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper."""
@
classmethod
def
_distribute_partition_fn
(
cls
,
mod_name
:
str
,
mod
:
nn
.
Module
,
device_mesh
:
DeviceMesh
):
"""Distribution partition callback."""
return
_attention_partition_fn
(
mod_name
,
mod
,
device_mesh
)
@
classmethod
def
_distribute_output_fn
(
cls
,
outputs
,
device_mesh
:
DeviceMesh
):
"""Distribution output hook."""
dist
.
all_reduce
(
outputs
[
0
])
return
outputs
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""Rewrite of BaichuanAttention.forward."""
world_size
=
1
if
dist
.
is_initialized
():
world_size
=
dist
.
get_world_size
()
return
self
.
_contiguous_batching_forward
(
hidden_states
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
world_size
=
world_size
,
)
def
_contiguous_batching_forward
(
self
,
hidden_states
:
torch
.
Tensor
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
world_size
:
int
=
1
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""Rewrite implementation of BaichuanAttention.forward.
Add continuous batching support. Add paged attention support. TP
support.
"""
assert
not
output_attentions
context
=
self
.
context
.
context
kv_seq_length
=
context
.
kv_seq_length
q_seq_length
=
context
.
q_seq_length
q_start_loc
=
context
.
q_start_loc
block_offsets
=
context
.
block_offsets
max_q_seq_length
=
context
.
max_q_seq_length
num_heads
=
self
.
num_heads
//
world_size
num_kv_heads
=
self
.
num_heads
//
world_size
head_dim
=
self
.
head_dim
def
_qkv_proj
(
hidden_states
):
proj
=
self
.
W_pack
(
hidden_states
)
return
proj
.
chunk
(
3
,
-
1
)
query_states
,
key_states
,
value_states
=
_qkv_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
-
1
,
num_heads
,
head_dim
)
key_states
=
key_states
.
view
(
-
1
,
num_kv_heads
,
head_dim
)
value_states
=
value_states
.
view
(
-
1
,
num_kv_heads
,
head_dim
)
fill_kv_cache
(
key_states
,
value_states
,
past_key_value
[
0
],
past_key_value
[
1
],
q_start_loc
,
q_seq_length
,
kv_seq_length
=
kv_seq_length
,
max_q_seq_length
=
max_q_seq_length
,
block_offsets
=
block_offsets
,
)
attn_output
=
query_states
num_heads_full
=
num_heads
head_offset
=
0
if
dist
.
is_initialized
():
world_size
=
dist
.
get_world_size
()
rank
=
dist
.
get_rank
()
num_heads_full
=
num_heads
*
world_size
head_offset
=
num_heads
*
rank
alibi_paged_attention_fwd
(
query_states
,
past_key_value
[
0
],
past_key_value
[
1
],
attn_output
,
block_offsets
,
b_start_loc
=
q_start_loc
,
b_seq_len
=
q_seq_length
,
b_kv_seq_len
=
kv_seq_length
,
max_input_len
=
max_q_seq_length
,
head_offset
=
head_offset
,
num_heads
=
num_heads_full
)
hidden_size
=
num_heads
*
head_dim
attn_output
=
attn_output
.
reshape
(
*
hidden_states
.
shape
[:
-
1
],
hidden_size
)
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
class
BaichuanModel
(
nn
.
Module
):
def
_continuous_batching_forward_7b
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPast
]:
"""Rewrite implementation of 7b BaichuanModel.forward."""
output_attentions
=
False
use_cache
=
True
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
# Attention mask is not necessary in continuous batching
attention_mask
=
None
hidden_states
=
inputs_embeds
# decoder layers
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
past_key_value
=
past_key_values
[
idx
]
layer_outputs
=
decoder_layer
(
hidden_states
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
use_cache
=
use_cache
,
)
hidden_states
=
layer_outputs
[
0
]
hidden_states
=
self
.
norm
(
hidden_states
)
return
BaseModelOutputWithPast
(
last_hidden_state
=
hidden_states
,
past_key_values
=
past_key_values
,
hidden_states
=
None
,
attentions
=
None
,
)
def
_continuous_batching_forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPast
]:
"""Rewrite implementation of BaichuanModel.forward.
Add continuous batching support. Add paged attention support. TP
support.
"""
use_cache
=
False
output_attentions
=
False
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
# Attention mask is not necessary in continuous batching
attention_mask
=
None
hidden_states
=
inputs_embeds
# decoder layers
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
past_key_value
=
past_key_values
[
idx
]
layer_outputs
=
decoder_layer
(
hidden_states
,
attention_mask
=
attention_mask
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
use_cache
=
use_cache
,
)
hidden_states
=
layer_outputs
[
0
]
hidden_states
=
self
.
norm
(
hidden_states
)
return
BaseModelOutputWithPast
(
last_hidden_state
=
hidden_states
,
past_key_values
=
past_key_values
,
hidden_states
=
None
,
attentions
=
None
,
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
False
,
output_attentions
:
Optional
[
bool
]
=
False
,
output_hidden_states
:
Optional
[
bool
]
=
False
,
return_dict
:
Optional
[
bool
]
=
True
,
):
"""Rewrite of BaichuanModel.forward."""
if
position_ids
is
not
None
:
return
self
.
_continuous_batching_forward_7b
(
input_ids
,
attention_mask
,
position_ids
,
past_key_values
,
inputs_embeds
,
)
else
:
return
self
.
_continuous_batching_forward
(
input_ids
,
attention_mask
,
past_key_values
,
inputs_embeds
,
)
lmdeploy/pytorch/models/chatglm2.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
# Adapted from https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py # noqa: E501
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.utils.checkpoint
from
torch.distributed._tensor
import
DeviceMesh
,
Shard
,
distribute_tensor
from
transformers.modeling_outputs
import
BaseModelOutputWithPast
from
..dist_utils
import
(
colwise_parallelize_linear
,
rowwise_parallelize_linear_fn
,
try_to_local
)
from
..kernels
import
paged_attention_fwd
from
.functional
import
fill_kv_cache
class
PatchedRMSNorm
(
nn
.
Module
):
"""Rewrite RMSNorm."""
def
forward
(
self
,
hidden_states
):
"""forward."""
# torch.nn.functional.normalize based implementation might leads
# to wrong output
from
..kernels
import
rms_norm
ret
=
rms_norm
(
hidden_states
.
permute
(
1
,
0
,
2
),
self
.
weight
,
self
.
eps
)
return
ret
.
permute
(
1
,
0
,
2
)
def
split_tensor_along_last_dim
(
tensor
:
torch
.
Tensor
,
num_partitions
:
int
,
contiguous_split_chunks
:
bool
=
False
,
)
->
List
[
torch
.
Tensor
]:
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
tensor_list
=
tensor
.
chunk
(
num_partitions
,
dim
=-
1
)
if
contiguous_split_chunks
:
return
tuple
(
chunk
.
contiguous
()
for
chunk
in
tensor_list
)
return
tensor_list
def
apply_rotary_pos_emb
(
x
:
torch
.
Tensor
,
rope_cache
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# x: [sq, b, np, hn]
sq
,
hn
=
x
.
size
(
0
),
x
.
size
(
-
1
)
xslice
=
x
[...,
:
hn
//
2
]
rope_cache
=
rope_cache
[:
sq
]
xshaped
=
xslice
.
unflatten
(
-
1
,
(
-
1
,
2
))
rope_cache
=
rope_cache
.
unsqueeze
(
2
)
# inplace
torch
.
stack
(
[
xshaped
[...,
0
]
*
rope_cache
[...,
0
]
-
xshaped
[...,
1
]
*
rope_cache
[...,
1
],
xshaped
[...,
1
]
*
rope_cache
[...,
0
]
+
xshaped
[...,
0
]
*
rope_cache
[...,
1
],
],
-
1
,
out
=
xshaped
,
)
return
x
class
PatchedSelfAttention
(
nn
.
Module
):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [s, b, h] and returns output of
the same size.
"""
def
_distribute_qkv_linear
(
self
,
mod
:
nn
.
Module
,
device_mesh
:
DeviceMesh
):
"""distribute qkv linear."""
sections
=
[
self
.
num_attention_heads_per_partition
*
self
.
hidden_size_per_attention_head
,
self
.
num_multi_query_groups_per_partition
*
self
.
hidden_size_per_attention_head
,
self
.
num_multi_query_groups_per_partition
*
self
.
hidden_size_per_attention_head
,
]
for
name
,
param
in
mod
.
named_parameters
():
splited_param
=
param
.
split
(
sections
,
dim
=
0
)
updated_param
=
[]
for
p
in
splited_param
:
dist_tensor
=
distribute_tensor
(
p
,
device_mesh
,
[
Shard
(
0
)])
dist_tensor
=
try_to_local
(
dist_tensor
)
updated_param
.
append
(
dist_tensor
)
param
=
torch
.
cat
(
updated_param
)
dist_param
=
torch
.
nn
.
Parameter
(
param
)
mod
.
register_parameter
(
name
,
dist_param
)
def
_distribute_qkv_lora_linear
(
self
,
module
:
nn
.
Module
,
device_mesh
:
DeviceMesh
):
"""distribute qkv lora linear."""
to_local
=
True
self
.
_distribute_qkv_linear
(
module
.
base_layer
,
device_mesh
=
device_mesh
,
)
for
mod
in
module
.
lora_A
.
values
():
colwise_parallelize_linear
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
to_local
)
for
mod
in
module
.
lora_B
.
values
():
self
.
_distribute_qkv_linear
(
mod
,
device_mesh
=
device_mesh
,
)
module
.
_tp_mode
=
'colwise'
def
_distribute_partition_fn
(
self
,
mod_name
:
str
,
mod
:
nn
.
Module
,
device_mesh
:
DeviceMesh
):
"""Distribution partition callback."""
if
mod_name
in
[
'query_key_value'
]:
from
peft.tuners.lora
import
Linear
as
LoraLinear
if
isinstance
(
mod
,
LoraLinear
):
self
.
_distribute_qkv_lora_linear
(
mod
,
device_mesh
)
else
:
self
.
_distribute_qkv_linear
(
mod
,
device_mesh
)
elif
mod_name
in
[
'dense'
]:
rowwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
@
classmethod
def
_distribute_output_fn
(
cls
,
outputs
,
device_mesh
:
DeviceMesh
):
"""Distribution output hook."""
dist
.
all_reduce
(
outputs
[
0
])
return
outputs
def
_contiguous_batching_forward
(
self
,
hidden_states
:
torch
.
Tensor
,
rotary_pos_emb
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_cache
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
# =====================
# Query, Key, and Value
# =====================
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
world_size
=
1
if
dist
.
is_initialized
():
world_size
=
dist
.
get_world_size
()
context
=
self
.
context
.
context
max_q_seq_length
=
context
.
max_q_seq_length
q_start_loc
=
context
.
q_start_loc
q_seq_length
=
context
.
q_seq_length
kv_seq_length
=
context
.
kv_seq_length
block_offsets
=
context
.
block_offsets
mixed_x_layer
=
self
.
query_key_value
(
hidden_states
)
if
self
.
multi_query_attention
:
(
query_layer
,
key_layer
,
value_layer
)
=
mixed_x_layer
.
split
(
[
self
.
num_attention_heads_per_partition
*
self
.
hidden_size_per_attention_head
//
world_size
,
self
.
num_multi_query_groups_per_partition
*
self
.
hidden_size_per_attention_head
//
world_size
,
self
.
num_multi_query_groups_per_partition
*
self
.
hidden_size_per_attention_head
//
world_size
,
],
dim
=-
1
,
)
query_layer
=
query_layer
.
unflatten
(
-
1
,
(
-
1
,
self
.
hidden_size_per_attention_head
))
key_layer
=
key_layer
.
unflatten
(
-
1
,
(
-
1
,
self
.
hidden_size_per_attention_head
))
value_layer
=
value_layer
.
unflatten
(
-
1
,
(
-
1
,
self
.
hidden_size_per_attention_head
))
else
:
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
(
self
.
num_attention_heads_per_partition
//
world_size
,
3
*
self
.
hidden_size_per_attention_head
,
)
mixed_x_layer
=
mixed_x_layer
.
view
(
*
new_tensor_shape
)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(
query_layer
,
key_layer
,
value_layer
)
=
split_tensor_along_last_dim
(
mixed_x_layer
,
3
)
# apply relative positional encoding (rotary embedding)
query_layer
=
apply_rotary_pos_emb
(
query_layer
,
rotary_pos_emb
)
key_layer
=
apply_rotary_pos_emb
(
key_layer
,
rotary_pos_emb
)
# [b, sq, np, hn]
query_layer
,
key_layer
,
value_layer
=
[
k
.
transpose
(
0
,
1
)
for
k
in
[
query_layer
,
key_layer
,
value_layer
]
]
# adjust key and value for inference
cache_k
,
cache_v
=
kv_cache
fill_kv_cache
(
key_layer
[
0
],
value_layer
[
0
],
cache_k
,
cache_v
,
q_start_loc
,
q_seq_length
,
kv_seq_length
=
kv_seq_length
,
max_q_seq_length
=
max_q_seq_length
,
block_offsets
=
block_offsets
,
)
# ==================================
# core attention computation
# ==================================
context_layer
=
query_layer
paged_attention_fwd
(
query_layer
,
cache_k
,
cache_v
,
context_layer
,
block_offsets
,
q_start_loc
=
q_start_loc
,
q_seqlens
=
q_seq_length
,
kv_seqlens
=
kv_seq_length
,
max_seqlen
=
max_q_seq_length
)
context_layer
=
context_layer
.
transpose
(
1
,
0
).
flatten
(
-
2
)
# =================
# Output. [sq, b, h]
# =================
output
=
self
.
dense
(
context_layer
)
return
output
,
kv_cache
def
forward
(
self
,
hidden_states
,
attention_mask
,
rotary_pos_emb
,
kv_cache
=
None
,
use_cache
=
True
,
output_attentions
=
False
,
):
return
self
.
_contiguous_batching_forward
(
hidden_states
,
rotary_pos_emb
,
kv_cache
,
)
class
MLP
(
nn
.
Module
):
@
classmethod
def
_distribute_partition_fn
(
cls
,
mod_name
:
str
,
mod
:
nn
.
Module
,
device_mesh
:
DeviceMesh
):
"""Distribution partition callback."""
if
mod_name
in
[
'dense_h_to_4h'
]:
for
name
,
param
in
mod
.
named_parameters
():
dist_tensor
=
distribute_tensor
(
param
.
unflatten
(
0
,
(
2
,
-
1
)),
device_mesh
,
[
Shard
(
1
)])
dist_tensor
=
try_to_local
(
dist_tensor
)
dist_param
=
torch
.
nn
.
Parameter
(
dist_tensor
.
flatten
(
0
,
1
))
mod
.
register_parameter
(
name
,
dist_param
)
elif
mod_name
in
[
'dense_4h_to_h'
]:
rowwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
@
classmethod
def
_distribute_output_fn
(
cls
,
outputs
,
device_mesh
:
DeviceMesh
):
"""Distribution output hook."""
dist
.
all_reduce
(
outputs
)
return
outputs
class
PatchedChatGLMModel
(
nn
.
Module
):
def
_contiguous_batching_forward
(
self
,
input_ids
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
BoolTensor
]
=
None
,
full_attention_mask
:
Optional
[
torch
.
BoolTensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
):
output_hidden_states
=
False
use_cache
=
True
batch_size
,
seq_length
=
input_ids
.
shape
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embedding
(
input_ids
)
if
self
.
pre_seq_len
is
not
None
:
if
past_key_values
is
None
:
past_key_values
=
self
.
get_prompt
(
batch_size
=
batch_size
,
device
=
input_ids
.
device
,
dtype
=
inputs_embeds
.
dtype
)
# Rotary positional embeddings
rotary_pos_emb
=
self
.
rotary_pos_emb
(
self
.
seq_length
)
if
position_ids
is
not
None
:
context
=
self
.
context
.
context
position_ids_1d
=
context
.
position_ids_1d
rotary_pos_emb
=
rotary_pos_emb
[
position_ids_1d
[
None
]]
else
:
rotary_pos_emb
=
rotary_pos_emb
[
None
,
:
seq_length
]
rotary_pos_emb
=
rotary_pos_emb
.
transpose
(
0
,
1
).
contiguous
()
# Run encoder.
(
hidden_states
,
presents
,
all_hidden_states
,
all_self_attentions
)
=
self
.
encoder
(
inputs_embeds
,
full_attention_mask
,
rotary_pos_emb
=
rotary_pos_emb
,
kv_caches
=
past_key_values
,
use_cache
=
use_cache
,
output_hidden_states
=
output_hidden_states
)
return
BaseModelOutputWithPast
(
last_hidden_state
=
hidden_states
,
past_key_values
=
presents
,
hidden_states
=
all_hidden_states
,
attentions
=
all_self_attentions
,
)
def
forward
(
self
,
input_ids
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
BoolTensor
]
=
None
,
full_attention_mask
:
Optional
[
torch
.
BoolTensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
):
return
self
.
_contiguous_batching_forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
attention_mask
=
attention_mask
,
full_attention_mask
=
full_attention_mask
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs_embeds
,
)
lmdeploy/pytorch/models/deepseek.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Optional
,
Tuple
import
torch
import
torch.distributed
as
dist
from
torch
import
nn
from
torch.distributed._tensor
import
DeviceMesh
from
..dist_utils
import
(
colwise_parallelize_linear_fn
,
rowwise_parallelize_linear_fn
)
from
..kernels
import
apply_rotary_pos_emb
,
fill_kv_cache
,
paged_attention_fwd
class
PatchedDeepseekAttention
(
nn
.
Module
):
@
classmethod
def
_distribute_partition_fn
(
cls
,
mod_name
:
str
,
mod
:
nn
.
Module
,
device_mesh
:
DeviceMesh
):
"""Distribution partition callback."""
if
mod_name
in
[
'q_proj'
,
'k_proj'
,
'v_proj'
]:
colwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
elif
mod_name
in
[
'o_proj'
]:
rowwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
@
classmethod
def
_distribute_output_fn
(
cls
,
outputs
,
device_mesh
:
DeviceMesh
):
"""Distribution output hook."""
dist
.
all_reduce
(
outputs
[
0
])
return
outputs
def
_contiguous_batching_forward_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
world_size
:
int
=
1
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""Rewrite implementation of forward.
Add continuous batching support. Add paged attention support. TP
support.
"""
context
=
self
.
context
.
context
kv_seq_length
=
context
.
kv_seq_length
q_seq_length
=
context
.
q_seq_length
q_start_loc
=
context
.
q_start_loc
block_offsets
=
context
.
block_offsets
max_q_seq_length
=
context
.
max_q_seq_length
max_kv_seq_length
=
context
.
max_kv_seq_length
num_heads
=
self
.
num_heads
//
world_size
num_kv_heads
=
self
.
num_key_value_heads
//
world_size
head_dim
=
self
.
head_dim
hidden_size
=
num_heads
*
head_dim
def
__qkv_proj
(
hidden_states
):
"""qkv proj."""
query_states
=
self
.
q_proj
(
hidden_states
)
key_states
=
self
.
k_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
return
query_states
,
key_states
,
value_states
def
__rotary_emb_fn
(
query_states
,
key_states
,
value_states
):
if
hasattr
(
self
,
'rotary_emb'
):
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
max_kv_seq_length
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
,
context
.
position_ids_1d
)
return
query_states
,
key_states
,
value_states
query_states
,
key_states
,
value_states
=
__qkv_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
-
1
,
num_heads
,
head_dim
)
key_states
=
key_states
.
view
(
-
1
,
num_kv_heads
,
head_dim
)
value_states
=
value_states
.
view
(
-
1
,
num_kv_heads
,
head_dim
)
query_states
,
key_states
,
value_states
=
__rotary_emb_fn
(
query_states
,
key_states
,
value_states
)
fill_kv_cache
(
key_states
,
value_states
,
past_key_value
[
0
],
past_key_value
[
1
],
q_start_loc
,
q_seq_length
,
kv_seq_length
=
kv_seq_length
,
max_q_seq_length
=
max_q_seq_length
,
block_offsets
=
block_offsets
,
)
attn_output
=
query_states
paged_attention_fwd
(
query_states
,
past_key_value
[
0
],
past_key_value
[
1
],
attn_output
,
block_offsets
,
q_start_loc
=
q_start_loc
,
q_seqlens
=
q_seq_length
,
kv_seqlens
=
kv_seq_length
,
max_seqlen
=
max_q_seq_length
,
)
attn_output
=
attn_output
.
reshape
(
*
hidden_states
.
shape
[:
-
1
],
hidden_size
)
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""forward."""
world_size
=
1
if
dist
.
is_initialized
():
world_size
=
dist
.
get_world_size
()
return
self
.
_contiguous_batching_forward_impl
(
hidden_states
,
position_ids
,
past_key_value
,
output_attentions
,
world_size
=
world_size
,
)
lmdeploy/pytorch/models/falcon.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
# Adapted from:
# https://huggingface.co/tiiuae/falcon-7b-instruct
# https://github.com/huggingface/transformers/blob/v4.33-release/src/transformers/models/falcon/modeling_falcon.py # noqa
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.utils.checkpoint
from
torch.distributed._tensor
import
DeviceMesh
from
transformers.modeling_outputs
import
\
BaseModelOutputWithPastAndCrossAttentions
from
..dist_utils
import
(
colwise_parallelize_linear_fn
,
rowwise_parallelize_linear_fn
)
from
..kernels
import
(
alibi_paged_attention_fwd
,
apply_rotary_pos_emb
,
fill_kv_cache
,
fused_rotary_emb
,
paged_attention_fwd
)
class
PatchedFalconAttention
(
nn
.
Module
):
# @classmethod
def
_distribute_partition_fn
(
self
,
mod_name
:
str
,
mod
:
nn
.
Module
,
device_mesh
:
DeviceMesh
):
"""Distribution partition callback."""
world_size
=
dist
.
get_world_size
()
if
mod_name
in
[
'query_key_value'
]:
if
self
.
new_decoder_architecture
:
# e.g. 40b-instruct, GQA
# split qkv across groups
# no finer-grained partitioning
mod
.
weight
.
data
=
mod
.
weight
.
reshape
(
-
1
,
# num groups
(
self
.
num_heads
+
self
.
num_kv_heads
*
2
)
*
self
.
head_dim
,
self
.
hidden_size
,
)
elif
self
.
multi_query
:
# e.g. 7b-instruct, MQA
# split to q, copy kv
weight
=
mod
.
weight
.
unflatten
(
0
,
(
-
1
,
self
.
head_dim
))
q_weight
=
weight
[:
self
.
num_heads
]
kv_weight
=
weight
[
-
2
:]
q_weight_shards
=
q_weight
.
chunk
(
world_size
,
0
)
weight_shards
=
[]
for
q
in
q_weight_shards
:
# only shard q heads but
# copy single k/v head to all ranks
weight_shards
.
append
(
q
)
weight_shards
.
append
(
kv_weight
)
mod
.
weight
.
data
=
torch
.
cat
(
weight_shards
,
dim
=
0
)
# here we keep the weight to be 3D,
# so that column parallel will split it
# into integer-numbered heads
# no bias for 7b-instruct and 40b-instruct
colwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
if
self
.
new_decoder_architecture
or
self
.
multi_query
:
# return to 2D for later matmul
mod
.
weight
.
data
=
mod
.
weight
.
data
.
reshape
(
-
1
,
self
.
hidden_size
)
elif
mod_name
in
[
'dense'
]:
if
self
.
new_decoder_architecture
:
# e.g. 40b-instruct, GQA
mod
.
weight
.
data
=
mod
.
weight
.
reshape
(
self
.
hidden_size
,
-
1
,
# num groups
self
.
num_heads
*
self
.
head_dim
,
)
elif
self
.
multi_query
:
# e.g. 7b-instruct, MQA
mod
.
weight
.
data
=
mod
.
weight
.
reshape
(
self
.
hidden_size
,
-
1
,
self
.
head_dim
)
rowwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
if
self
.
new_decoder_architecture
or
self
.
multi_query
:
mod
.
weight
.
data
=
mod
.
weight
.
reshape
(
self
.
hidden_size
,
-
1
)
@
classmethod
def
_distribute_output_fn
(
cls
,
outputs
,
device_mesh
:
DeviceMesh
):
"""Distribution output hook."""
dist
.
all_reduce
(
outputs
[
0
])
return
outputs
def
_split_heads
(
self
,
fused_qkv
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Split the last dimension into (num_heads, head_dim), results share
same memory storage as `fused_qkv`
Args:
fused_qkv (`torch.tensor`, *required*):
[batch_size, seq_length, num_heads * 3 * head_dim]
Returns:
query: [batch_size, seq_length, num_heads, head_dim]
key: [batch_size, seq_length, num_heads, head_dim]
value: [batch_size, seq_length, num_heads, head_dim]
"""
if
self
.
new_decoder_architecture
:
# e.g. 40b-instruct model
batch
,
seq_len
,
_
=
fused_qkv
.
shape
qkv
=
fused_qkv
.
view
(
batch
,
seq_len
,
-
1
,
self
.
num_heads
//
self
.
num_kv_heads
+
2
,
self
.
head_dim
)
query
=
qkv
[:,
:,
:,
:
-
2
]
key
=
qkv
[:,
:,
:,
[
-
2
]]
value
=
qkv
[:,
:,
:,
[
-
1
]]
# because cache_engine & kernel
# already handled grouped attention
# removing broadcast make it faster and more memory-saving
# key = torch.broadcast_to(key, query.shape)
# value = torch.broadcast_to(value, query.shape)
query
,
key
,
value
=
[
x
.
flatten
(
2
,
3
)
for
x
in
(
query
,
key
,
value
)]
return
query
,
key
,
value
elif
not
self
.
multi_query
:
# e.g. rw-1b model
batch_size
,
seq_length
,
three_times_hidden_size
=
fused_qkv
.
shape
fused_qkv
=
fused_qkv
.
view
(
batch_size
,
seq_length
,
self
.
num_heads
//
dist
.
get_world_size
(),
3
,
self
.
head_dim
)
return
fused_qkv
[...,
0
,
:],
fused_qkv
[...,
1
,
:],
fused_qkv
[...,
2
,
:]
else
:
# e.g. 7b-instruct model
fused_qkv
=
fused_qkv
.
unflatten
(
-
1
,
(
-
1
,
self
.
head_dim
))
split_shape
=
(
fused_qkv
.
size
(
-
2
)
-
2
,
1
,
1
)
return
fused_qkv
.
split
(
split_shape
,
dim
=-
2
)
def
_contiguous_batching_forward
(
self
,
hidden_states
:
torch
.
Tensor
,
alibi
:
Optional
[
torch
.
Tensor
],
layer_past
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
):
# prepare inputs for continuous batch forwarding
context
=
self
.
context
.
context
q_start_loc
=
context
.
q_start_loc
q_seq_length
=
context
.
q_seq_length
kv_seq_length
=
context
.
kv_seq_length
max_q_seq_length
=
context
.
max_q_seq_length
block_offsets
=
context
.
block_offsets
position_ids_1d
=
context
.
position_ids_1d
max_kv_seq_length
=
context
.
max_kv_seq_length
def
__maybe_rotary_fn
(
query_states
,
key_states
,
value_states
):
scaling_factor
=
1.0
inv_freq
=
self
.
maybe_rotary
.
inv_freq
query_states
,
key_states
=
fused_rotary_emb
(
query_states
[
None
],
key_states
[
None
],
position_ids_1d
[
None
],
inv_freq
=
inv_freq
,
scaling_factor
=
scaling_factor
,
out_q
=
query_states
[
None
],
out_k
=
key_states
[
None
])
return
query_states
[
0
],
key_states
[
0
],
value_states
def
__rotary_emb_fn
(
query_states
,
key_states
,
value_states
):
"""rotary embedding func."""
cos
,
sin
=
self
.
rotary_emb
(
value_states
.
transpose
(
0
,
1
),
max_kv_seq_length
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
context
.
position_ids
,
position_ids_1d
)
return
query_states
,
key_states
,
value_states
fused_qkv
=
self
.
query_key_value
(
hidden_states
)
# 3 x [batch_size, seq_length, num_heads, head_dim]
(
query_layer
,
key_layer
,
value_layer
)
=
self
.
_split_heads
(
fused_qkv
)
query_layer
=
query_layer
.
flatten
(
0
,
1
)
key_layer
=
key_layer
.
flatten
(
0
,
1
)
value_layer
=
value_layer
.
flatten
(
0
,
1
)
if
hasattr
(
self
,
'maybe_rotary'
):
query_layer
,
key_layer
,
value_layer
=
__maybe_rotary_fn
(
query_layer
,
key_layer
,
value_layer
)
elif
hasattr
(
self
,
'rotary_emb'
):
query_layer
,
key_layer
,
value_layer
=
__rotary_emb_fn
(
query_layer
,
key_layer
,
value_layer
)
past_key
,
past_value
=
layer_past
fill_kv_cache
(
key_layer
.
contiguous
(),
value_layer
.
contiguous
(),
past_key
,
past_value
,
q_start_loc
,
q_seq_length
,
kv_seq_length
=
kv_seq_length
,
max_q_seq_length
=
max_q_seq_length
,
block_offsets
=
block_offsets
,
)
attn_output
=
query_layer
if
not
alibi
:
paged_attention_fwd
(
q
=
query_layer
,
k
=
past_key
,
v
=
past_value
,
o
=
attn_output
,
block_offsets
=
block_offsets
,
q_start_loc
=
q_start_loc
,
q_seqlens
=
q_seq_length
,
kv_seqlens
=
kv_seq_length
,
max_seqlen
=
max_q_seq_length
)
else
:
num_heads_full
=
self
.
num_heads
head_offset
=
0
if
dist
.
is_initialized
():
world_size
=
dist
.
get_world_size
()
rank
=
dist
.
get_rank
()
head_offset
=
self
.
num_heads
//
world_size
*
rank
alibi_paged_attention_fwd
(
q
=
query_layer
,
k
=
past_key
,
v
=
past_value
,
o
=
attn_output
,
block_offsets
=
block_offsets
,
b_start_loc
=
q_start_loc
,
b_seq_len
=
q_seq_length
,
b_kv_seq_len
=
kv_seq_length
,
max_input_len
=
max_q_seq_length
,
head_offset
=
head_offset
,
num_heads
=
num_heads_full
,
alibi_scale
=
self
.
inv_norm_factor
)
attn_output
=
attn_output
[
None
].
flatten
(
-
2
,
-
1
)
output_tensor
=
self
.
dense
(
attn_output
)
if
output_attentions
:
return
output_tensor
,
layer_past
,
None
else
:
return
output_tensor
,
layer_past
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
alibi
:
Optional
[
torch
.
Tensor
],
attention_mask
:
torch
.
Tensor
,
layer_past
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
use_cache
:
bool
=
False
,
output_attentions
:
bool
=
False
,
):
return
self
.
_contiguous_batching_forward
(
hidden_states
,
alibi
,
layer_past
)
class
PatchedFalconMLP
(
nn
.
Module
):
@
classmethod
def
_distribute_partition_fn
(
cls
,
mod_name
:
str
,
mod
:
nn
.
Module
,
device_mesh
:
DeviceMesh
):
"""Distribution partition callback."""
if
mod_name
in
[
'dense_h_to_4h'
]:
colwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
elif
mod_name
in
[
'dense_4h_to_h'
]:
rowwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
@
classmethod
def
_distribute_output_fn
(
cls
,
outputs
,
device_mesh
:
DeviceMesh
):
"""Distribution output hook."""
dist
.
all_reduce
(
outputs
)
return
outputs
class
PatchedFalconModel
(
nn
.
Module
):
def
_contiguous_batching_forward
(
self
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...]]
=
None
,
head_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
LongTensor
]
=
None
,
)
->
Union
[
Tuple
[
torch
.
Tensor
,
...],
BaseModelOutputWithPastAndCrossAttentions
]:
output_attentions
=
False
use_cache
=
True
use_alibi
=
getattr
(
self
,
'use_alibi'
,
getattr
(
self
,
'alibi'
,
False
))
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
head_mask
=
self
.
get_head_mask
(
head_mask
,
self
.
config
.
num_hidden_layers
)
hidden_states
=
inputs_embeds
# Compute alibi tensor: check build_alibi_tensor documentation
for
i
,
(
block
,
layer_past
)
in
enumerate
(
zip
(
self
.
h
,
past_key_values
)):
outputs
=
block
(
hidden_states
,
layer_past
=
layer_past
,
attention_mask
=
None
,
head_mask
=
head_mask
[
i
],
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
alibi
=
use_alibi
,
)
hidden_states
=
outputs
[
0
]
# Add last hidden state
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
BaseModelOutputWithPastAndCrossAttentions
(
last_hidden_state
=
hidden_states
,
past_key_values
=
past_key_values
,
hidden_states
=
None
,
attentions
=
None
,
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...]]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
head_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
[
torch
.
Tensor
,
...],
BaseModelOutputWithPastAndCrossAttentions
]:
return
self
.
_contiguous_batching_forward
(
input_ids
=
input_ids
,
past_key_values
=
past_key_values
)
class
PatchedFalconForCausalLM
(
nn
.
Module
):
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...]]
=
None
,
return_dict
:
Optional
[
bool
]
=
True
,
output_attentions
:
Optional
[
bool
]
=
False
,
output_hidden_states
:
Optional
[
bool
]
=
False
,
use_origin
:
Optional
[
bool
]
=
True
,
)
->
Union
[
Tuple
[
torch
.
Tensor
,
...],
BaseModelOutputWithPastAndCrossAttentions
]:
"""Forward function, patched to ignore position_ids."""
outputs
=
self
.
origin_mod
(
input_ids
=
input_ids
,
past_key_values
=
past_key_values
,
attention_mask
=
attention_mask
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
)
return
outputs
lmdeploy/pytorch/models/functional.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
import
math
from
typing
import
Any
,
Callable
,
Optional
,
Sequence
,
Tuple
import
numpy
as
np
import
torch
# import torch.nn.functional as F
from
torch
import
Tensor
from
..kernels
import
apply_rotary_pos_emb
,
fill_kv_cache
,
rerope_attention_fwd
__all__
=
[
'apply_rotary_pos_emb'
]
def
rotate_half
(
x
):
"""Rotates half the hidden dims of the input."""
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
repeat_kv
(
hidden_states
:
torch
.
Tensor
,
n_rep
:
int
)
->
torch
.
Tensor
:
"""This is the equivalent of torch.repeat_interleave(x, dim=1,
repeats=n_rep).
The hidden states go from (num_key_value_heads, seqlen, head_dim) to
(num_attention_heads, seqlen, head_dim)
"""
if
n_rep
==
1
:
return
hidden_states
num_key_value_heads
,
slen
,
head_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
[:,
None
,
:,
:].
expand
(
num_key_value_heads
,
n_rep
,
slen
,
head_dim
)
return
hidden_states
.
reshape
(
num_key_value_heads
*
n_rep
,
slen
,
head_dim
)
def
generate_batched_mask
(
q_lens
,
k_lens
,
max_q_len
:
int
=
None
,
max_k_len
:
int
=
None
,
device
=
'cuda'
):
"""Generate batched mask."""
if
max_q_len
is
None
:
max_q_len
=
max
(
q_lens
)
if
max_k_len
is
None
:
max_k_len
=
max
(
k_lens
)
q_range
=
torch
.
arange
(
max_q_len
).
to
(
device
)
k_range
=
torch
.
arange
(
max_k_len
).
to
(
device
)
cross
=
k_range
.
unsqueeze
(
0
)
-
q_range
.
unsqueeze
(
1
)
cross
=
cross
.
unsqueeze
(
0
)
threshold
=
(
k_lens
-
q_lens
).
view
(
-
1
,
1
,
1
)
mask
=
torch
.
where
(
cross
<=
threshold
,
1
,
0
).
to
(
device
)
for
idx
,
q_len
in
enumerate
(
q_lens
):
mask
[
idx
,
q_len
:,
:]
=
0
return
mask
def
get_slopes
(
n
:
int
):
"""Get alibi slopes."""
def
_get_interleave_power_of_2
(
n
):
start
=
2
**
(
-
(
2
**-
(
math
.
log2
(
n
)
-
3
)))
ratio
=
start
return
[
start
*
ratio
**
i
for
i
in
range
(
n
)]
if
math
.
log2
(
n
).
is_integer
():
return
_get_interleave_power_of_2
(
n
)
else
:
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
n
))
return
(
_get_interleave_power_of_2
(
closest_power_of_2
)
+
get_slopes
(
2
*
closest_power_of_2
)[
0
::
2
][:
n
-
closest_power_of_2
])
@
torch
.
no_grad
()
def
get_alibi_biases
(
n_heads
:
int
,
mask
:
torch
.
Tensor
):
"""Get alibi bias."""
m
=
torch
.
tensor
(
get_slopes
(
n_heads
)).
to
(
mask
.
device
)
distance
=
mask
.
cumsum
(
dim
=-
1
)
-
1
return
distance
*
m
[
None
,
:,
None
,
None
]
def
quant_kv
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
out_type
:
torch
.
dtype
):
"""Quantize key and value of attention to `out_type`.
Args:
key (torch.Tensor): Attention key.
value (torch.Tensor): Attention value.
out_type (torch.dtype): Output data type.
"""
assert
out_type
is
torch
.
int8
# quantize key and value
_min
=
torch
.
min
(
key
,
axis
=-
1
).
values
_max
=
torch
.
max
(
key
,
axis
=-
1
).
values
key_zp
=
(
_min
+
_max
)
/
2
key_scale
=
(
_max
-
key_zp
)
/
127
key_int8
=
torch
.
round
(
(
key
-
key_zp
[:,
:,
None
])
/
key_scale
[:,
:,
None
]).
to
(
out_type
)
_min
=
torch
.
min
(
value
,
axis
=-
1
).
values
_max
=
torch
.
max
(
value
,
axis
=-
1
).
values
value_zp
=
(
_min
+
_max
)
/
2
value_scale
=
(
_max
-
value_zp
)
/
127
value_int8
=
torch
.
round
(
(
value
-
value_zp
[:,
:,
None
])
/
value_scale
[:,
:,
None
]).
to
(
out_type
)
# wrap zp and scale to qparams
qparams
=
{
'key_zp'
:
key_zp
,
'key_scale'
:
key_scale
,
'value_zp'
:
value_zp
,
'value_scale'
:
value_scale
,
}
return
key_int8
,
value_int8
,
qparams
def
dequant_kv
(
context
:
Any
,
layer_id
:
str
,
key_int8
:
torch
.
Tensor
,
value_int8
:
torch
.
Tensor
,
out_type
:
torch
.
dtype
):
"""Dequantize key and value of attention to `out_type`.
Args:
context (Any): StepContext during inference.
layer_id (str): Layer object id.
key (torch.Tensor): Quantized attention key.
value (torch.Tensor): Quantized attention value.
out_type (torch.dtype): output data type.
"""
qparams
=
context
.
get_output
(
layer_id
)
key_scale
=
qparams
[
'key_scale'
]
key_zp
=
qparams
[
'key_zp'
]
key_float
=
(
key_int8
*
key_scale
[:,
:,
None
]
+
key_zp
[:,
:,
None
]).
to
(
out_type
)
value_scale
=
qparams
[
'value_scale'
]
value_zp
=
qparams
[
'value_zp'
]
value_float
=
(
value_int8
*
value_scale
[:,
:,
None
]
+
value_zp
[:,
:,
None
]).
to
(
out_type
)
return
key_float
,
value_float
def
sync_qparam_to_context
(
context
:
Any
,
layer_id
:
str
,
qparams
:
dict
):
"""Merge quantization param to context.
Args:
context (Any): StepContext during inference.
layer_id (str): Layer object id.
qparams (dict): Quantization param of current step.
"""
if
context
.
inputs
.
meta
is
not
None
:
last_qparam
=
context
.
inputs
.
meta
[
layer_id
]
for
_k
in
last_qparam
.
keys
():
_v
=
torch
.
concat
([
last_qparam
[
_k
],
qparams
[
_k
]],
axis
=
0
)
last_qparam
[
_k
]
=
_v
context
.
set_output
(
layer_id
,
last_qparam
)
else
:
context
.
set_output
(
layer_id
,
qparams
)
@
torch
.
no_grad
()
def
attention_forward_with_rerope
(
hidden_states
:
Tensor
,
history_lengths
:
Sequence
,
block_offsets
:
Tensor
,
num_heads
:
int
,
num_kv_heads
:
int
,
head_dim
:
int
,
position_ids
:
torch
.
LongTensor
,
past_key_value
:
Tuple
[
Tensor
],
attention_mask
:
Tensor
,
context
:
Any
=
None
,
q_proj
:
Optional
[
Callable
]
=
None
,
k_proj
:
Optional
[
Callable
]
=
None
,
v_proj
:
Optional
[
Callable
]
=
None
,
qkv_proj
:
Optional
[
Callable
]
=
None
,
o_proj
:
Optional
[
Callable
]
=
None
,
rotary_emb_context_fn
:
Optional
[
Callable
]
=
None
,
rotary_emb_generate_fn
:
Optional
[
Callable
]
=
None
,
bias_type
:
str
=
'default'
,
training_length
=
4096
,
window
=
512
,
layer_id
:
str
=
None
)
->
Tensor
:
"""Attention module forward with ReRoPE.
Args:
hidden_states (Tensor): Input of attention layer.
history_lengths (Sequence): Cache lengths of each data in batch.
block_offsets (Tensor): Block table of the key/value caches,
used by paged attention.
num_heads (int): numbers of query heads.
num_kv_heads (int): numbers of key/value heads.
head_dim (int): Feature dimension of heads.
position_ids (LongTensor): position ids of the input.
past_key_value (Tuple[Tensor]): key value cache.
q_proj (Callable): query project module/function.
k_proj (Callable): key project module/function.
v_proj (Callable): value project module/function.
qkv_proj (Callable): query/key/value project module/function.
o_proj (Callable): output project module/function.
rotary_emb_context_fn (Callable): rotary embedding context callback.
rotary_emb_generate_fn (Callable): rotary embedding generate callback.
bias_type (str): type of attention bias. support ['default'].
training_length (int): model sequence length during trainning.
window (int): ReRoPE window size, default value is 512.
"""
hidden_size
=
-
1
if
qkv_proj
is
not
None
:
assert
q_proj
is
None
assert
k_proj
is
None
assert
v_proj
is
None
query_states
,
key_states
,
value_states
=
qkv_proj
(
hidden_states
)
else
:
assert
qkv_proj
is
None
assert
q_proj
is
not
None
assert
k_proj
is
not
None
assert
v_proj
is
not
None
query_states
=
q_proj
(
hidden_states
)
key_states
=
k_proj
(
hidden_states
)
value_states
=
v_proj
(
hidden_states
)
hidden_size
=
num_heads
*
head_dim
query_states
=
query_states
.
view
(
-
1
,
num_heads
,
head_dim
)
key_states
=
key_states
.
view
(
-
1
,
num_kv_heads
,
head_dim
)
value_states
=
value_states
.
view
(
-
1
,
num_kv_heads
,
head_dim
)
query_states
*=
((
position_ids
.
flatten
()
+
1
)[:,
None
,
None
].
log
()
/
np
.
log
(
training_length
)).
clip
(
1
).
to
(
query_states
.
dtype
)
kv_seq_length
=
(
position_ids
[...,
-
1
]
+
1
).
item
()
q_seq_length
=
getattr
(
context
,
'q_seq_length'
,
None
)
if
q_seq_length
is
None
:
q_seq_length
=
kv_seq_length
-
kv_seq_length
.
new_tensor
(
history_lengths
)
q_start_loc
=
getattr
(
context
,
'q_start_loc'
,
None
)
if
q_start_loc
is
None
:
q_start_loc
=
q_seq_length
.
cumsum
(
0
)
q_start_loc
=
torch
.
cat
([
q_start_loc
.
new_zeros
(
1
),
q_start_loc
[:
-
1
]])
if
past_key_value
[
0
].
dtype
!=
hidden_states
.
dtype
:
# dynamic quantize hidden_states to kv_cache and save
quant
=
True
qkey
,
qvalue
,
qparams
=
quant_kv
(
key_states
,
value_states
,
past_key_value
[
0
].
dtype
)
sync_qparam_to_context
(
context
=
context
,
layer_id
=
layer_id
,
qparams
=
qparams
)
fill_kv_cache
(
qkey
,
qvalue
,
past_key_value
[
0
],
past_key_value
[
1
],
q_start_loc
,
q_seq_length
,
block_offsets
=
block_offsets
,
history_lengths
=
history_lengths
,
context
=
context
)
else
:
fill_kv_cache
(
key_states
,
value_states
,
past_key_value
[
0
],
past_key_value
[
1
],
q_start_loc
,
q_seq_length
,
block_offsets
=
block_offsets
,
history_lengths
=
history_lengths
,
context
=
context
)
bsz
,
q_len
,
_
=
hidden_states
.
size
()
if
bias_type
.
lower
()
==
'default'
:
if
q_len
==
1
:
key_states
=
past_key_value
[
0
][
block_offsets
].
view
(
-
1
,
num_heads
,
head_dim
)[
0
:
history_lengths
[
-
1
]
+
1
]
value_states
=
past_key_value
[
1
][
block_offsets
].
view
(
-
1
,
num_heads
,
head_dim
)[
0
:
history_lengths
[
-
1
]
+
1
]
if
quant
:
# dequant int8 tensor to hidden_states.dtype
key_states
,
value_states
=
dequant_kv
(
context
=
context
,
layer_id
=
layer_id
,
key_int8
=
key_states
,
value_int8
=
value_states
,
out_type
=
hidden_states
.
dtype
)
full_position_ids
=
torch
.
arange
(
position_ids
.
item
()
+
1
,
device
=
position_ids
.
device
).
unsqueeze
(
0
)
key_states
,
value_states
=
rotary_emb_generate_fn
(
key_states
,
value_states
,
full_position_ids
,
window
)
attn_weights
=
torch
.
matmul
(
query_states
.
transpose
(
0
,
1
),
key_states
.
permute
(
1
,
2
,
0
))
/
math
.
sqrt
(
head_dim
)
if
attention_mask
is
not
None
:
attn_weights
=
attn_weights
+
attention_mask
# upcast attention to fp32
attn_weights
=
torch
.
nn
.
functional
.
softmax
(
attn_weights
,
dim
=-
1
,
dtype
=
torch
.
float32
).
to
(
query_states
.
dtype
)
attn_output
=
torch
.
matmul
(
attn_weights
,
value_states
.
transpose
(
0
,
1
))
else
:
query_states1
,
query_states2
,
key_states1
,
key_states2
,
value_states
=
rotary_emb_context_fn
(
# noqa: E501
query_states
,
key_states
,
value_states
,
position_ids
,
window
)
sm_scale
=
1.0
/
math
.
sqrt
(
head_dim
)
PADDING_UNIT
=
past_key_value
[
0
].
shape
[
1
]
assert
PADDING_UNIT
in
{
16
,
32
,
64
,
128
,
256
}
# padding_len = -query_states1.shape[2] % PADDING_UNIT
# query_states1 = F.pad(query_states1,
# (0, 0, 0, padding_len)).contiguous()
# query_states2 = F.pad(query_states2,
# (0, 0, 0, padding_len)).contiguous()
# key_states1 = F.pad(key_states1,
# (0, 0, 0, padding_len)).contiguous()
# key_states2 = F.pad(key_states2,
# (0, 0, 0, padding_len)).contiguous()
# value_states = F.pad(value_states,
# (0, 0, 0, padding_len)).contiguous()
query_states1
=
query_states1
.
contiguous
()
query_states2
=
query_states2
.
contiguous
()
key_states1
=
key_states1
.
contiguous
()
key_states2
=
key_states2
.
contiguous
()
value_states
=
value_states
.
contiguous
()
attn_output
=
rerope_attention_fwd
(
query_states1
,
query_states2
,
key_states1
,
key_states2
,
value_states
,
True
,
sm_scale
,
window
,
BLOCK_M
=
PADDING_UNIT
).
squeeze
(
0
)
# attn_output = attn_output[:, 0:q_len]
if
attn_output
.
size
()
!=
(
num_heads
,
q_len
,
head_dim
):
raise
ValueError
(
f
'`attn_output` should be of size
{
(
bsz
,
num_heads
,
q_len
,
head_dim
)
}
, but is'
# noqa: E501
f
'
{
attn_output
.
size
()
}
'
)
attn_output
=
attn_output
.
transpose
(
0
,
1
).
reshape
(
bsz
,
q_len
,
hidden_size
).
contiguous
()
else
:
raise
ValueError
(
f
'Unknown bias type:
{
bias_type
}
'
)
if
o_proj
is
not
None
:
attn_output
=
o_proj
(
attn_output
)
return
attn_output
lmdeploy/pytorch/models/gemma.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.distributed
as
dist
from
torch
import
nn
from
torch.distributed._tensor
import
DeviceMesh
from
transformers.modeling_outputs
import
BaseModelOutputWithPast
from
..dist_utils
import
(
colwise_parallelize_linear_fn
,
rowwise_parallelize_linear_fn
)
from
..kernels
import
fill_kv_cache
,
fused_rotary_emb
,
paged_attention_fwd
class
PatchedGemmaRMSNorm
(
nn
.
Module
):
"""Rewrite RMSNorm."""
def
forward
(
self
,
x
):
"""forward."""
# torch.nn.functional.normalize based implementation might leads
# to wrong output
from
..kernels
import
rms_norm
ret
=
rms_norm
(
x
.
contiguous
(),
self
.
weight
+
1
,
self
.
eps
)
return
ret
def
_make_inv_freq
(
self
,
device
:
torch
.
device
):
if
self
.
inv_freq
is
None
:
self
.
inv_freq
=
1.0
/
(
self
.
base
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
,
dtype
=
torch
.
int64
,
device
=
device
).
float
()
/
self
.
dim
))
class
PatchedGemmaAttention
(
nn
.
Module
):
@
classmethod
def
_distribute_partition_fn
(
cls
,
mod_name
:
str
,
mod
:
nn
.
Module
,
device_mesh
:
DeviceMesh
):
"""Distribution partition callback."""
if
mod_name
in
[
'q_proj'
,
'k_proj'
,
'v_proj'
]:
colwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
elif
mod_name
in
[
'o_proj'
]:
rowwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
@
classmethod
def
_distribute_output_fn
(
cls
,
outputs
,
device_mesh
:
DeviceMesh
):
"""Distribution output hook."""
dist
.
all_reduce
(
outputs
[
0
])
return
outputs
def
_contiguous_batching_forward_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
world_size
:
int
=
1
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""Rewrite implementation of forward.
Add continuous batching support. Add paged attention support. TP
support.
"""
context
=
self
.
context
.
context
kv_seq_length
=
context
.
kv_seq_length
q_seq_length
=
context
.
q_seq_length
q_start_loc
=
context
.
q_start_loc
block_offsets
=
context
.
block_offsets
max_q_seq_length
=
context
.
max_q_seq_length
num_heads
=
self
.
num_heads
//
world_size
num_kv_heads
=
self
.
num_key_value_heads
//
world_size
head_dim
=
self
.
head_dim
hidden_size
=
num_heads
*
head_dim
def
__qkv_proj
(
hidden_states
):
"""qkv proj."""
query_states
=
self
.
q_proj
(
hidden_states
)
key_states
=
self
.
k_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
return
query_states
,
key_states
,
value_states
def
__rotary_emb_fn
(
query_states
,
key_states
,
value_states
):
scaling_factor
=
1.0
_make_inv_freq
(
self
.
rotary_emb
,
query_states
.
device
)
inv_freq
=
self
.
rotary_emb
.
inv_freq
query_states
,
key_states
=
fused_rotary_emb
(
query_states
[
None
],
key_states
[
None
],
context
.
position_ids_1d
[
None
],
inv_freq
=
inv_freq
,
scaling_factor
=
scaling_factor
,
out_q
=
query_states
[
None
],
out_k
=
key_states
[
None
])
return
query_states
[
0
],
key_states
[
0
],
value_states
query_states
,
key_states
,
value_states
=
__qkv_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
-
1
,
num_heads
,
head_dim
)
key_states
=
key_states
.
view
(
-
1
,
num_kv_heads
,
head_dim
)
value_states
=
value_states
.
view
(
-
1
,
num_kv_heads
,
head_dim
)
query_states
,
key_states
,
value_states
=
__rotary_emb_fn
(
query_states
,
key_states
,
value_states
)
fill_kv_cache
(
key_states
,
value_states
,
past_key_value
[
0
],
past_key_value
[
1
],
q_start_loc
,
q_seq_length
,
kv_seq_length
=
kv_seq_length
,
max_q_seq_length
=
max_q_seq_length
,
block_offsets
=
block_offsets
,
)
attn_output
=
query_states
paged_attention_fwd
(
query_states
,
past_key_value
[
0
],
past_key_value
[
1
],
attn_output
,
block_offsets
,
q_start_loc
=
q_start_loc
,
q_seqlens
=
q_seq_length
,
kv_seqlens
=
kv_seq_length
,
max_seqlen
=
max_q_seq_length
,
)
attn_output
=
attn_output
.
reshape
(
*
hidden_states
.
shape
[:
-
1
],
hidden_size
)
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""Rewrite of forward."""
world_size
=
1
if
dist
.
is_initialized
():
world_size
=
dist
.
get_world_size
()
return
self
.
_contiguous_batching_forward_impl
(
hidden_states
,
position_ids
,
past_key_value
,
output_attentions
,
attention_mask
=
attention_mask
,
world_size
=
world_size
,
)
class
PatchedGemmaModel
(
nn
.
Module
):
def
_continuous_batching_forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPast
]:
"""Rewrite implementation of LlamaModel.forward."""
output_attentions
=
False
use_cache
=
True
# Attention mask is not necessary in continuous batching
attention_mask
=
None
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
inputs_embeds
# This is Gemma only!
hidden_states
=
hidden_states
*
(
self
.
config
.
hidden_size
**
0.5
)
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
past_key_value
=
(
past_key_values
[
idx
]
if
past_key_values
is
not
None
else
None
)
layer_outputs
=
decoder_layer
(
hidden_states
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
use_cache
=
use_cache
,
)
hidden_states
=
layer_outputs
[
0
]
hidden_states
=
self
.
norm
(
hidden_states
)
return
BaseModelOutputWithPast
(
last_hidden_state
=
hidden_states
,
past_key_values
=
past_key_values
,
hidden_states
=
None
,
attentions
=
None
,
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
**
kwargs
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPast
]:
"""Rewrite of LlamaModel.forward."""
return
self
.
_continuous_batching_forward
(
input_ids
,
position_ids
,
past_key_values
,
inputs_embeds
,
)
lmdeploy/pytorch/models/internlm.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Optional
,
Tuple
import
torch
import
torch.distributed
as
dist
from
torch
import
nn
from
torch.distributed._tensor
import
DeviceMesh
from
..dist_utils
import
(
colwise_parallelize_linear_fn
,
rowwise_parallelize_linear_fn
)
from
..kernels
import
apply_rotary_pos_emb
,
fill_kv_cache
,
paged_attention_fwd
class
PatchedInternLMAttention
(
nn
.
Module
):
@
classmethod
def
_distribute_partition_fn
(
cls
,
mod_name
:
str
,
mod
:
nn
.
Module
,
device_mesh
:
DeviceMesh
):
"""Distribution partition callback."""
if
mod_name
in
[
'q_proj'
,
'k_proj'
,
'v_proj'
]:
colwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
elif
mod_name
in
[
'o_proj'
]:
rowwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
@
classmethod
def
_distribute_output_fn
(
cls
,
outputs
,
device_mesh
:
DeviceMesh
):
"""Distribution output hook."""
dist
.
all_reduce
(
outputs
[
0
])
return
outputs
def
_contiguous_batching_forward_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
world_size
:
int
=
1
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""Rewrite implementation of LlamaAttention.forward.
Add continuous batching support. Add paged attention support. TP
support.
"""
context
=
self
.
context
.
context
q_start_loc
=
context
.
q_start_loc
kv_seq_length
=
context
.
kv_seq_length
q_seq_length
=
context
.
q_seq_length
block_offsets
=
context
.
block_offsets
max_q_seq_length
=
context
.
max_q_seq_length
max_kv_seq_length
=
context
.
max_kv_seq_length
num_heads
=
self
.
num_heads
//
world_size
num_kv_heads
=
num_heads
head_dim
=
self
.
head_dim
hidden_size
=
num_heads
*
head_dim
def
__qkv_proj
(
hidden_states
):
"""qkv proj."""
query_states
=
self
.
q_proj
(
hidden_states
)
key_states
=
self
.
k_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
return
query_states
,
key_states
,
value_states
def
__rotary_emb_fn
(
query_states
,
key_states
,
value_states
):
"""rotary embedding func."""
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
max_kv_seq_length
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
,
context
.
position_ids_1d
)
return
query_states
,
key_states
,
value_states
query_states
,
key_states
,
value_states
=
__qkv_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
-
1
,
num_heads
,
head_dim
)
key_states
=
key_states
.
view
(
-
1
,
num_kv_heads
,
head_dim
)
value_states
=
value_states
.
view
(
-
1
,
num_kv_heads
,
head_dim
)
query_states
,
key_states
,
value_states
=
__rotary_emb_fn
(
query_states
,
key_states
,
value_states
)
fill_kv_cache
(
key_states
,
value_states
,
past_key_value
[
0
],
past_key_value
[
1
],
q_start_loc
,
q_seq_length
,
kv_seq_length
=
kv_seq_length
,
max_q_seq_length
=
max_q_seq_length
,
block_offsets
=
block_offsets
,
)
attn_output
=
query_states
paged_attention_fwd
(
query_states
,
past_key_value
[
0
],
past_key_value
[
1
],
attn_output
,
block_offsets
,
q_start_loc
=
q_start_loc
,
q_seqlens
=
q_seq_length
,
kv_seqlens
=
kv_seq_length
,
max_seqlen
=
max_q_seq_length
,
)
attn_output
=
attn_output
.
reshape
(
*
hidden_states
.
shape
[:
-
1
],
hidden_size
)
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""forward."""
world_size
=
1
if
dist
.
is_initialized
():
world_size
=
dist
.
get_world_size
()
return
self
.
_contiguous_batching_forward_impl
(
hidden_states
,
position_ids
,
past_key_value
,
world_size
=
world_size
,
)
lmdeploy/pytorch/models/internlm2.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.distributed
as
dist
from
einops
import
rearrange
from
torch
import
nn
from
torch.distributed._tensor
import
DeviceMesh
from
transformers.modeling_outputs
import
BaseModelOutputWithPast
from
..dist_utils
import
(
colwise_parallelize_linear_fn
,
rowwise_parallelize_linear_fn
)
from
..kernels
import
apply_rotary_pos_emb
,
fill_kv_cache
,
paged_attention_fwd
class
PatchedInternLM2Attention
(
nn
.
Module
):
def
_distribute_partition_fn
(
self
,
mod_name
:
str
,
mod
:
nn
.
Module
,
device_mesh
:
DeviceMesh
):
"""Distribution partition callback."""
if
mod_name
in
[
'wqkv'
]:
colwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
elif
mod_name
in
[
'wo'
]:
rowwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
@
classmethod
def
_distribute_output_fn
(
cls
,
outputs
,
device_mesh
:
DeviceMesh
):
"""Distribution output hook."""
dist
.
all_reduce
(
outputs
[
0
])
return
outputs
def
_contiguous_batching_forward_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
world_size
:
int
=
1
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""Rewrite implementation of LlamaAttention.forward.
Add continuous batching support. Add paged attention support. TP
support.
"""
context
=
self
.
context
.
context
q_start_loc
=
context
.
q_start_loc
q_seq_length
=
context
.
q_seq_length
kv_seq_length
=
context
.
kv_seq_length
block_offsets
=
context
.
block_offsets
max_q_seq_length
=
context
.
max_q_seq_length
max_kv_seq_length
=
context
.
max_kv_seq_length
def
__qkv_proj
(
hidden_states
):
"""qkv_proj."""
qkv_states
=
self
.
wqkv
(
hidden_states
)
qkv_states
=
rearrange
(
qkv_states
,
'b q (h gs d) -> (b q) h gs d'
,
gs
=
2
+
self
.
num_key_value_groups
,
d
=
self
.
head_dim
,
)
query_states
=
qkv_states
[...,
:
self
.
num_key_value_groups
,
:]
query_states
=
query_states
.
flatten
(
1
,
2
)
key_states
=
qkv_states
[...,
-
2
,
:]
value_states
=
qkv_states
[...,
-
1
,
:]
return
query_states
,
key_states
,
value_states
def
__rotary_emb_fn
(
query_states
,
key_states
,
value_states
):
"""rotary embedding func."""
cos
,
sin
=
self
.
rotary_emb
(
value_states
.
transpose
(
0
,
1
),
seq_len
=
max_kv_seq_length
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
,
context
.
position_ids_1d
)
return
query_states
,
key_states
,
value_states
query_states
,
key_states
,
value_states
=
__qkv_proj
(
hidden_states
)
query_states
,
key_states
,
value_states
=
__rotary_emb_fn
(
query_states
,
key_states
,
value_states
)
fill_kv_cache
(
key_states
,
value_states
,
past_key_value
[
0
],
past_key_value
[
1
],
q_start_loc
,
q_seq_length
,
kv_seq_length
=
kv_seq_length
,
max_q_seq_length
=
max_q_seq_length
,
block_offsets
=
block_offsets
,
)
attn_output
=
query_states
paged_attention_fwd
(
query_states
,
past_key_value
[
0
],
past_key_value
[
1
],
attn_output
,
block_offsets
,
q_start_loc
=
q_start_loc
,
q_seqlens
=
q_seq_length
,
kv_seqlens
=
kv_seq_length
,
max_seqlen
=
max_q_seq_length
,
)
attn_output
=
attn_output
.
reshape
(
*
hidden_states
.
shape
[:
-
1
],
-
1
)
attn_output
=
self
.
wo
(
attn_output
)
return
attn_output
,
None
,
past_key_value
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""forward."""
world_size
=
1
if
dist
.
is_initialized
():
world_size
=
dist
.
get_world_size
()
return
self
.
_contiguous_batching_forward_impl
(
hidden_states
,
position_ids
,
past_key_value
,
output_attentions
,
world_size
=
world_size
,
)
class
PatchedInternLM2MLP
(
nn
.
Module
):
@
classmethod
def
_distribute_partition_fn
(
cls
,
mod_name
:
str
,
mod
:
nn
.
Module
,
device_mesh
:
DeviceMesh
):
"""Distribution partition callback."""
if
mod_name
in
[
'w1'
,
'w3'
]:
colwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
elif
mod_name
in
[
'w2'
]:
rowwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
@
classmethod
def
_distribute_output_fn
(
cls
,
outputs
,
device_mesh
:
DeviceMesh
):
"""Distribution output hook."""
dist
.
all_reduce
(
outputs
)
return
outputs
class
PatchedInternLM2Model
(
nn
.
Module
):
def
_continuous_batching_forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPast
]:
"""Rewrite implementation of LlamaModel.forward."""
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
tok_embeddings
(
input_ids
)
# Attention mask is not necessary in continuous batching
attention_mask
=
None
hidden_states
=
inputs_embeds
# decoder layers
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
past_key_value
=
(
past_key_values
[
idx
]
if
past_key_values
is
not
None
else
None
)
layer_outputs
=
decoder_layer
(
hidden_states
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
use_cache
=
use_cache
,
)
hidden_states
=
layer_outputs
[
0
]
hidden_states
=
self
.
norm
(
hidden_states
)
return
BaseModelOutputWithPast
(
last_hidden_state
=
hidden_states
,
past_key_values
=
past_key_value
,
hidden_states
=
None
,
attentions
=
None
,
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPast
]:
"""Rewrite of LlamaModel.forward."""
return
self
.
_continuous_batching_forward
(
input_ids
,
attention_mask
,
position_ids
,
past_key_values
,
inputs_embeds
,
use_cache
,
output_attentions
,
output_hidden_states
,
return_dict
,
)
lmdeploy/pytorch/models/llama.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.distributed
as
dist
import
transformers
from
packaging
import
version
from
torch
import
nn
from
torch.distributed._tensor
import
DeviceMesh
from
transformers.modeling_outputs
import
BaseModelOutputWithPast
from
..dist_utils
import
(
colwise_parallelize_linear_fn
,
rowwise_parallelize_linear_fn
)
from
..kernels
import
apply_rotary_pos_emb
as
apply_rotary_pos_emb_old
from
..kernels
import
fill_kv_cache
,
fused_rotary_emb
,
paged_attention_fwd
from
.functional
import
attention_forward_with_rerope
,
repeat_kv
TRANSFORMERS_VERSION
=
version
.
parse
(
transformers
.
__version__
)
class
LlamaRMSNorm
(
nn
.
Module
):
"""Rewrite RMSNorm."""
def
forward
(
self
,
hidden_states
):
"""forward."""
# torch.nn.functional.normalize based implementation might leads
# to wrong output
from
..kernels
import
rms_norm
ret
=
rms_norm
(
hidden_states
,
self
.
weight
,
self
.
variance_epsilon
)
return
ret
def
rotate_half
(
x
):
"""Rotates half the hidden dims of the input."""
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
position_ids
=
None
,
unsqueeze_dim
=
1
):
"""Applies Rotary Position Embedding to the query and key tensors."""
cos
=
cos
.
unsqueeze
(
unsqueeze_dim
)
sin
=
sin
.
unsqueeze
(
unsqueeze_dim
)
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q_embed
,
k_embed
class
LlamaAttention
(
nn
.
Module
):
"""Rewrite module of LlamaAttention."""
@
classmethod
def
_distribute_partition_fn
(
cls
,
mod_name
:
str
,
mod
:
nn
.
Module
,
device_mesh
:
DeviceMesh
):
"""Distribution partition callback."""
if
mod_name
in
[
'q_proj'
,
'k_proj'
,
'v_proj'
]:
colwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
elif
mod_name
in
[
'o_proj'
]:
rowwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
@
classmethod
def
_distribute_output_fn
(
cls
,
outputs
,
device_mesh
:
DeviceMesh
):
"""Distribution output hook."""
dist
.
all_reduce
(
outputs
[
0
])
return
outputs
def
_contiguous_batching_forward_rerope_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
world_size
:
int
=
1
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""rerope rewrite."""
context
=
self
.
context
.
context
history_lengths
=
context
.
history_lengths
def
apply_rotary_pos_emb_rerope
(
q
,
k
,
cos
,
sin
,
position_ids
):
assert
1
==
position_ids
.
shape
[
0
]
_
,
seq_len
=
position_ids
.
shape
_
,
dim
=
cos
.
shape
cos
=
cos
[
position_ids
].
reshape
(
seq_len
,
1
,
dim
)
# [bs, seq_len, dim] to [seq_len, 1, dim]
sin
=
sin
[
position_ids
].
reshape
(
seq_len
,
1
,
dim
)
# [bs, seq_len, dim] to [seq_len, 1, dim]
q_embed
=
((
q
*
cos
[
-
q
.
shape
[
0
]:])
+
(
rotate_half
(
q
)
*
sin
[
-
q
.
shape
[
0
]:]))
if
q
is
not
None
else
None
k_embed
=
((
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
))
if
k
is
not
None
else
None
return
q_embed
,
k_embed
def
_rotary_emb_context_rerope_fn
(
query_states
,
key_states
,
value_states
,
position_ids
,
window
):
kv_seq_len
,
num_dim
,
dim
=
key_states
.
shape
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
max
(
kv_seq_len
,
window
+
1
))
query_states1
,
key_states1
=
apply_rotary_pos_emb_rerope
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
query_states2
,
_
=
apply_rotary_pos_emb_rerope
(
query_states
,
None
,
cos
,
sin
,
position_ids
*
0
+
window
)
# repeat k/v heads if n_kv_heads < n_heads
if
self
.
num_key_value_groups
>
1
:
key_states1
=
repeat_kv
(
key_states1
,
self
.
num_key_value_groups
)
key_states2
=
repeat_kv
(
key_states
,
self
.
num_key_value_groups
)
value_states
=
repeat_kv
(
value_states
,
self
.
num_key_value_groups
)
else
:
key_states2
=
key_states
query_states1
=
query_states1
.
transpose
(
0
,
1
).
reshape
(
1
,
num_dim
,
kv_seq_len
,
dim
)
query_states2
=
query_states2
.
transpose
(
0
,
1
).
reshape
(
1
,
num_dim
,
kv_seq_len
,
dim
)
key_states1
=
key_states1
.
transpose
(
0
,
1
).
reshape
(
1
,
num_dim
,
kv_seq_len
,
dim
)
key_states2
=
key_states2
.
transpose
(
0
,
1
).
reshape
(
1
,
num_dim
,
kv_seq_len
,
dim
)
value_states
=
value_states
.
transpose
(
0
,
1
).
reshape
(
1
,
num_dim
,
kv_seq_len
,
dim
)
return
query_states1
,
query_states2
,
key_states1
,
key_states2
,
value_states
# noqa: E501
def
_rotary_emb_generate_rerope_fn
(
key_states
,
value_states
,
position_ids
,
window
):
kv_seq_len
=
key_states
.
shape
[
0
]
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
position_ids
=
(
position_ids
[:,
-
1
]
-
position_ids
).
clip
(
max
=
window
)
_
,
key_states
=
apply_rotary_pos_emb_rerope
(
None
,
key_states
,
cos
,
-
sin
,
position_ids
)
if
self
.
num_key_value_groups
>
1
:
key_states
=
repeat_kv
(
key_states
,
self
.
num_key_value_groups
)
value_states
=
repeat_kv
(
value_states
,
self
.
num_key_value_groups
)
return
key_states
,
value_states
attn_output
=
attention_forward_with_rerope
(
hidden_states
,
history_lengths
=
history_lengths
,
block_offsets
=
context
.
block_offsets
,
num_heads
=
self
.
num_heads
//
world_size
,
num_kv_heads
=
self
.
num_key_value_heads
//
world_size
,
head_dim
=
self
.
head_dim
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
attention_mask
=
attention_mask
,
context
=
context
,
q_proj
=
self
.
q_proj
,
k_proj
=
self
.
k_proj
,
v_proj
=
self
.
v_proj
,
o_proj
=
self
.
o_proj
,
rotary_emb_context_fn
=
_rotary_emb_context_rerope_fn
,
rotary_emb_generate_fn
=
_rotary_emb_generate_rerope_fn
,
layer_id
=
id
(
self
))
return
attn_output
,
None
,
past_key_value
def
_contiguous_batching_forward_default_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
world_size
:
int
=
1
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""default rewrite."""
context
=
self
.
context
.
context
kv_seq_length
=
context
.
kv_seq_length
q_seq_length
=
context
.
q_seq_length
q_start_loc
=
context
.
q_start_loc
block_offsets
=
context
.
block_offsets
max_q_seq_length
=
context
.
max_q_seq_length
max_kv_seq_length
=
context
.
max_kv_seq_length
num_heads
=
self
.
num_heads
//
world_size
num_kv_heads
=
self
.
num_key_value_heads
//
world_size
head_dim
=
self
.
head_dim
hidden_size
=
num_heads
*
head_dim
def
__qkv_proj
(
hidden_states
):
"""qkv proj."""
query_states
=
self
.
q_proj
(
hidden_states
)
key_states
=
self
.
k_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
return
query_states
,
key_states
,
value_states
def
__rotary_emb_fn_old
(
query_states
,
key_states
,
value_states
):
"""rotary embedding old."""
if
max_kv_seq_length
>=
self
.
rotary_emb
.
max_seq_len_cached
:
# create larger cache
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
max_kv_seq_length
+
128
)
cos
=
self
.
rotary_emb
.
cos_cached
sin
=
self
.
rotary_emb
.
sin_cached
query_states
,
key_states
=
apply_rotary_pos_emb_old
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
,
context
.
position_ids_1d
,
q_embed
=
query_states
,
k_embed
=
key_states
)
return
query_states
,
key_states
,
value_states
def
__rotary_emb_fn_438_naive
(
query_states
,
key_states
,
value_states
):
"""rotary embedding transformers>4.38."""
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
context
.
position_ids_1d
[
None
])
cos
=
cos
[
0
]
sin
=
sin
[
0
]
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
return
query_states
,
key_states
,
value_states
def
__rotary_emb_fn_438_fused
(
query_states
,
key_states
,
value_states
):
scaling_factor
=
getattr
(
self
.
rotary_emb
,
'scaling_factor'
,
1.0
)
inv_freq
=
self
.
rotary_emb
.
inv_freq
query_states
,
key_states
=
fused_rotary_emb
(
query_states
[
None
],
key_states
[
None
],
context
.
position_ids_1d
[
None
],
inv_freq
=
inv_freq
,
scaling_factor
=
scaling_factor
,
out_q
=
query_states
[
None
],
out_k
=
key_states
[
None
])
return
query_states
[
0
],
key_states
[
0
],
value_states
def
__rotary_emb_fn_438
(
query_states
,
key_states
,
value_states
):
rotary_name
=
type
(
self
.
rotary_emb
).
__name__
if
rotary_name
in
[
'LlamaRotaryEmbedding'
,
'LlamaLinearScalingRotaryEmbedding'
]:
return
__rotary_emb_fn_438_fused
(
query_states
,
key_states
,
value_states
)
else
:
return
__rotary_emb_fn_438_naive
(
query_states
,
key_states
,
value_states
)
def
__rotary_emb_fn
(
query_states
,
key_states
,
value_states
):
"""rotary embedding."""
if
TRANSFORMERS_VERSION
>=
version
.
parse
(
'4.38.0'
):
return
__rotary_emb_fn_438
(
query_states
,
key_states
,
value_states
)
else
:
return
__rotary_emb_fn_old
(
query_states
,
key_states
,
value_states
)
query_states
,
key_states
,
value_states
=
__qkv_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
-
1
,
num_heads
,
head_dim
)
key_states
=
key_states
.
view
(
-
1
,
num_kv_heads
,
head_dim
)
value_states
=
value_states
.
view
(
-
1
,
num_kv_heads
,
head_dim
)
query_states
,
key_states
,
value_states
=
__rotary_emb_fn
(
query_states
,
key_states
,
value_states
)
fill_kv_cache
(
key_states
,
value_states
,
past_key_value
[
0
],
past_key_value
[
1
],
q_start_loc
,
q_seq_length
,
kv_seq_length
=
kv_seq_length
,
max_q_seq_length
=
max_q_seq_length
,
block_offsets
=
block_offsets
,
)
attn_output
=
query_states
paged_attention_fwd
(
query_states
,
past_key_value
[
0
],
past_key_value
[
1
],
attn_output
,
block_offsets
,
q_start_loc
=
q_start_loc
,
q_seqlens
=
q_seq_length
,
kv_seqlens
=
kv_seq_length
,
max_seqlen
=
max_q_seq_length
,
)
attn_output
=
attn_output
.
reshape
(
*
hidden_states
.
shape
[:
-
1
],
hidden_size
)
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
def
_contiguous_batching_forward_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
world_size
:
int
=
1
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""Rewrite implementation of LlamaAttention.forward.
Add continuous batching support. Add paged attention support. TP
support.
"""
assert
not
output_attentions
json_config
=
self
.
context
.
context
.
json_config
use_rerope
=
False
if
json_config
is
not
None
:
use_rerope
=
json_config
.
get
(
'rerope'
,
False
)
if
use_rerope
:
return
self
.
_contiguous_batching_forward_rerope_impl
(
hidden_states
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
attention_mask
=
attention_mask
,
world_size
=
world_size
)
else
:
return
self
.
_contiguous_batching_forward_default_impl
(
hidden_states
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
attention_mask
=
attention_mask
,
world_size
=
world_size
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""Rewrite of LlamaAttention.forward."""
world_size
=
1
if
dist
.
is_initialized
():
world_size
=
dist
.
get_world_size
()
return
self
.
_contiguous_batching_forward_impl
(
hidden_states
,
position_ids
,
past_key_value
,
output_attentions
,
attention_mask
=
attention_mask
,
world_size
=
world_size
,
)
class
LlamaMLP
(
nn
.
Module
):
@
classmethod
def
_distribute_partition_fn
(
cls
,
mod_name
:
str
,
mod
:
nn
.
Module
,
device_mesh
:
DeviceMesh
):
"""Distribution partition callback."""
if
mod_name
in
[
'gate_proj'
,
'up_proj'
]:
colwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
elif
mod_name
in
[
'down_proj'
]:
rowwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
@
classmethod
def
_distribute_output_fn
(
cls
,
outputs
,
device_mesh
:
DeviceMesh
):
"""Distribution output hook."""
dist
.
all_reduce
(
outputs
)
return
outputs
class
LlamaModel
(
nn
.
Module
):
def
_continuous_batching_forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
)
->
Union
[
Tuple
,
BaseModelOutputWithPast
]:
"""Rewrite implementation of LlamaModel.forward."""
output_attentions
=
False
use_cache
=
True
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
# Attention mask is not necessary in continuous batching
attention_mask
=
None
hidden_states
=
inputs_embeds
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
past_key_value
=
past_key_values
[
idx
]
layer_outputs
=
decoder_layer
(
hidden_states
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
use_cache
=
use_cache
,
)
hidden_states
=
layer_outputs
[
0
]
hidden_states
=
self
.
norm
(
hidden_states
)
return
BaseModelOutputWithPast
(
last_hidden_state
=
hidden_states
,
past_key_values
=
past_key_values
,
hidden_states
=
None
,
attentions
=
None
,
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
**
kwargs
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPast
]:
"""Rewrite of LlamaModel.forward."""
return
self
.
_continuous_batching_forward
(
input_ids
,
position_ids
,
past_key_values
,
inputs_embeds
,
)
lmdeploy/pytorch/models/mistral.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Optional
,
Tuple
import
torch
import
torch.distributed
as
dist
from
torch
import
nn
from
torch.distributed._tensor
import
DeviceMesh
from
..dist_utils
import
(
colwise_parallelize_linear_fn
,
rowwise_parallelize_linear_fn
)
from
..kernels
import
apply_rotary_pos_emb
from
..kernels.fill_kv_cache
import
fill_kv_cache
from
..kernels.pagedattention
import
paged_attention_fwd
class
MistralFlashAttention2
(
nn
.
Module
):
@
classmethod
def
_distribute_partition_fn
(
cls
,
mod_name
:
str
,
mod
:
nn
.
Module
,
device_mesh
:
DeviceMesh
):
"""Distribution partition callback."""
if
mod_name
in
[
'q_proj'
,
'k_proj'
,
'v_proj'
]:
colwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
elif
mod_name
in
[
'o_proj'
]:
rowwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
@
classmethod
def
_distribute_output_fn
(
cls
,
outputs
,
device_mesh
:
DeviceMesh
):
"""Distribution output hook."""
dist
.
all_reduce
(
outputs
[
0
])
return
outputs
def
_contiguous_batching_forward_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
world_size
:
int
=
1
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""Rewrite implementation of forward.
Add continuous batching support. Add paged attention support. TP
support.
"""
context
=
self
.
context
.
context
kv_seq_length
=
context
.
kv_seq_length
q_seq_length
=
context
.
q_seq_length
q_start_loc
=
context
.
q_start_loc
block_offsets
=
context
.
block_offsets
max_q_seq_length
=
context
.
max_q_seq_length
max_kv_seq_length
=
context
.
max_kv_seq_length
num_heads
=
self
.
num_heads
//
world_size
num_kv_heads
=
self
.
num_key_value_heads
//
world_size
head_dim
=
self
.
head_dim
hidden_size
=
num_heads
*
head_dim
def
__qkv_proj
(
hidden_states
):
"""qkv proj."""
query_states
=
self
.
q_proj
(
hidden_states
)
key_states
=
self
.
k_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
return
query_states
,
key_states
,
value_states
def
__rotary_emb_fn
(
query_states
,
key_states
,
value_states
):
if
hasattr
(
self
,
'rotary_emb'
):
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
max_kv_seq_length
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
,
context
.
position_ids_1d
)
return
query_states
,
key_states
,
value_states
query_states
,
key_states
,
value_states
=
__qkv_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
-
1
,
num_heads
,
head_dim
)
key_states
=
key_states
.
view
(
-
1
,
num_kv_heads
,
head_dim
)
value_states
=
value_states
.
view
(
-
1
,
num_kv_heads
,
head_dim
)
query_states
,
key_states
,
value_states
=
__rotary_emb_fn
(
query_states
,
key_states
,
value_states
)
fill_kv_cache
(
key_states
,
value_states
,
past_key_value
[
0
],
past_key_value
[
1
],
q_start_loc
,
q_seq_length
,
kv_seq_length
=
kv_seq_length
,
max_q_seq_length
=
max_q_seq_length
,
block_offsets
=
block_offsets
,
)
attn_output
=
query_states
window_size
=
self
.
config
.
sliding_window
paged_attention_fwd
(
query_states
,
past_key_value
[
0
],
past_key_value
[
1
],
attn_output
,
block_offsets
,
q_start_loc
=
q_start_loc
,
q_seqlens
=
q_seq_length
,
kv_seqlens
=
kv_seq_length
,
max_seqlen
=
max_q_seq_length
,
window_size
=
window_size
,
)
attn_output
=
attn_output
.
reshape
(
*
hidden_states
.
shape
[:
-
1
],
hidden_size
)
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""Rewrite of forward."""
world_size
=
1
if
dist
.
is_initialized
():
world_size
=
dist
.
get_world_size
()
return
self
.
_contiguous_batching_forward_impl
(
hidden_states
,
position_ids
,
past_key_value
,
output_attentions
,
attention_mask
=
attention_mask
,
world_size
=
world_size
,
)
lmdeploy/pytorch/models/mixtral.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.distributed
as
dist
from
torch
import
nn
from
torch.distributed._tensor
import
DeviceMesh
from
transformers.modeling_outputs
import
BaseModelOutputWithPast
from
..dist_utils
import
(
colwise_parallelize_linear_fn
,
rowwise_parallelize_linear_fn
)
from
..kernels
import
apply_rotary_pos_emb
,
fill_kv_cache
,
paged_attention_fwd
class
PatchedMixtralAttention
(
nn
.
Module
):
"""Rewrite module of MixtralAttention."""
@
classmethod
def
_distribute_partition_fn
(
cls
,
mod_name
:
str
,
mod
:
nn
.
Module
,
device_mesh
:
DeviceMesh
):
"""Distribution partition callback."""
if
mod_name
in
[
'q_proj'
,
'k_proj'
,
'v_proj'
]:
colwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
elif
mod_name
in
[
'o_proj'
]:
rowwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
@
classmethod
def
_distribute_output_fn
(
cls
,
outputs
,
device_mesh
:
DeviceMesh
):
"""Distribution output hook."""
dist
.
all_reduce
(
outputs
[
0
])
return
outputs
def
_contiguous_batching_forward_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
world_size
:
int
=
1
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""default rewrite."""
context
=
self
.
context
.
context
kv_seq_length
=
context
.
kv_seq_length
q_seq_length
=
context
.
q_seq_length
q_start_loc
=
context
.
q_start_loc
block_offsets
=
context
.
block_offsets
max_q_seq_length
=
context
.
max_q_seq_length
max_kv_seq_length
=
context
.
max_kv_seq_length
num_heads
=
self
.
num_heads
//
world_size
num_kv_heads
=
self
.
num_key_value_heads
//
world_size
hidden_size
=
num_heads
*
self
.
head_dim
def
__qkv_proj
(
hidden_states
):
"""qkv proj."""
query_states
=
self
.
q_proj
(
hidden_states
)
key_states
=
self
.
k_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
return
query_states
,
key_states
,
value_states
def
__rotary_emb_fn
(
query_states
,
key_states
,
value_states
):
if
hasattr
(
self
,
'rotary_emb'
):
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
max_kv_seq_length
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
,
getattr
(
context
,
'position_ids_1d'
,
None
))
return
query_states
,
key_states
,
value_states
query_states
,
key_states
,
value_states
=
__qkv_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
-
1
,
num_heads
,
self
.
head_dim
)
key_states
=
key_states
.
view
(
-
1
,
num_kv_heads
,
self
.
head_dim
)
value_states
=
value_states
.
view
(
-
1
,
num_kv_heads
,
self
.
head_dim
)
query_states
,
key_states
,
value_states
=
__rotary_emb_fn
(
query_states
,
key_states
,
value_states
)
# fill kv cache
fill_kv_cache
(
key_states
,
value_states
,
past_key_value
[
0
],
past_key_value
[
1
],
q_start_loc
,
q_seq_length
,
kv_seq_length
=
kv_seq_length
,
max_q_seq_length
=
max_q_seq_length
,
block_offsets
=
block_offsets
,
)
# page attention
attn_output
=
query_states
window_size
=
self
.
config
.
sliding_window
or
-
1
paged_attention_fwd
(
query_states
,
past_key_value
[
0
],
past_key_value
[
1
],
attn_output
,
block_offsets
,
q_start_loc
=
q_start_loc
,
q_seqlens
=
q_seq_length
,
kv_seqlens
=
kv_seq_length
,
max_seqlen
=
max_q_seq_length
,
window_size
=
window_size
,
)
attn_output
=
attn_output
.
reshape
(
*
hidden_states
.
shape
[:
-
1
],
hidden_size
)
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""Rewrite of MistralAttention.forward."""
world_size
=
1
if
dist
.
is_initialized
():
world_size
=
dist
.
get_world_size
()
return
self
.
_contiguous_batching_forward_impl
(
hidden_states
,
position_ids
,
past_key_value
,
output_attentions
,
attention_mask
=
attention_mask
,
world_size
=
world_size
,
)
class
PatchedMixtralBLockSparseTop2MLP
(
nn
.
Module
):
@
classmethod
def
_distribute_partition_fn
(
cls
,
mod_name
:
str
,
mod
:
nn
.
Module
,
device_mesh
:
DeviceMesh
):
"""Distribution partition callback."""
if
mod_name
in
[
'w1'
,
'w3'
]:
colwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
elif
mod_name
in
[
'w2'
]:
rowwise_parallelize_linear_fn
(
mod
,
device_mesh
=
device_mesh
,
to_local
=
True
)
@
classmethod
def
_distribute_output_fn
(
cls
,
outputs
,
device_mesh
:
DeviceMesh
):
"""Distribution output hook."""
dist
.
all_reduce
(
outputs
)
return
outputs
class
PatchedMixtralModel
(
nn
.
Module
):
def
_continuous_batching_forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPast
]:
"""Rewrite implementation of LlamaModel.forward."""
from
transformers.modeling_outputs
import
MoeModelOutputWithPast
output_attentions
=
(
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
)
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
if
use_cache
is
None
:
use_cache
=
self
.
config
.
use_cache
return_dict
=
(
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
)
assert
(
position_ids
is
not
None
),
'position_ids can not be none when using continuous batching mode.'
assert
position_ids
.
dim
()
==
2
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
# Attention mask is not necessary in continuous batching
attention_mask
=
None
hidden_states
=
inputs_embeds
# decoder layers
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attns
=
()
if
output_attentions
else
None
next_decoder_cache
=
()
if
use_cache
else
None
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,
)
past_key_value
=
(
past_key_values
[
idx
]
if
past_key_values
is
not
None
else
None
)
layer_outputs
=
decoder_layer
(
hidden_states
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
use_cache
=
use_cache
,
)
hidden_states
=
layer_outputs
[
0
]
if
use_cache
:
next_decoder_cache
+=
(
layer_outputs
[
2
if
output_attentions
else
1
],
)
if
output_attentions
:
all_self_attns
+=
(
layer_outputs
[
1
],
)
hidden_states
=
self
.
norm
(
hidden_states
)
# add hidden states from the last decoder layer
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,
)
next_cache
=
next_decoder_cache
if
use_cache
else
None
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
MoeModelOutputWithPast
(
last_hidden_state
=
hidden_states
,
past_key_values
=
next_cache
,
hidden_states
=
all_hidden_states
,
attentions
=
all_self_attns
,
router_logits
=
''
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
**
kwargs
)
->
Union
[
Tuple
,
BaseModelOutputWithPast
]:
"""Rewrite of LlamaModel.forward."""
return
self
.
_continuous_batching_forward
(
input_ids
,
attention_mask
,
position_ids
,
past_key_values
,
inputs_embeds
,
use_cache
,
output_attentions
,
output_hidden_states
,
return_dict
,
)
lmdeploy/pytorch/models/module_map.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
LMDEPLOY_PYTORCH_MODEL_PATH
=
'lmdeploy.pytorch.models'
# llama
MODULE_MAP
=
{
'transformers.models.llama.modeling_llama.LlamaFlashAttention2'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaAttention'
,
'transformers.models.llama.modeling_llama.LlamaSdpaAttention'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaAttention'
,
'transformers.models.llama.modeling_llama.LlamaAttention'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaAttention'
,
'transformers.models.llama.modeling_llama.LlamaModel'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaModel'
,
'transformers.models.llama.modeling_llama.LlamaMLP'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaMLP'
,
'transformers.models.llama.modeling_llama.LlamaRMSNorm'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaRMSNorm'
,
# support modeling rewritten in lmdeploy
'modeling_llama.LlamaAttention'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaAttention'
,
'modeling_llama.LlamaModel'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaModel'
,
'modeling_llama.LlamaMLP'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaMLP'
,
}
# Falcon Models in transformer / on hub
MODULE_MAP
.
update
({
'modeling_falcon.FalconAttention'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.falcon.PatchedFalconAttention'
,
'modeling_falcon.FalconModel'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.falcon.PatchedFalconModel'
,
'modeling_falcon.FalconMLP'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.falcon.PatchedFalconMLP'
,
'modeling_falcon.FalconForCausalLM'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.falcon.PatchedFalconForCausalLM'
,
# for old implementations on hub
'modelling_RW.Attention'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.falcon.PatchedFalconAttention'
,
'modelling_RW.MLP'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.falcon.PatchedFalconMLP'
,
'modelling_RW.RWModel'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.falcon.PatchedFalconModel'
,
'modelling_RW.RotaryEmbedding'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.falcon.PatchedFalconRotaryEmbedding'
,
})
# baichuan
MODULE_MAP
.
update
({
'modeling_baichuan.Model'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaModel'
,
# noqa
'modeling_baichuan.BaichuanModel'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.baichuan.BaichuanModel'
,
# noqa
'modeling_baichuan.Attention'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.baichuan.Attention'
,
# noqa
'modeling_baichuan.BaichuanAttention'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.baichuan.BaichuanAttention'
,
# noqa
'modeling_baichuan.MLP'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaMLP'
,
# noqa
'modeling_baichuan.RMSNorm'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.baichuan.PatchedRMSNorm'
,
})
# chatglm2
MODULE_MAP
.
update
({
'modeling_chatglm.SelfAttention'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.chatglm2.PatchedSelfAttention'
,
'modeling_chatglm.ChatGLMModel'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.chatglm2.PatchedChatGLMModel'
,
'modeling_chatglm.MLP'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.chatglm2.MLP'
,
'modeling_chatglm.RMSNorm'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.chatglm2.PatchedRMSNorm'
,
})
# internlm
MODULE_MAP
.
update
({
'modeling_internlm.InternLMAttention'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.internlm.PatchedInternLMAttention'
,
'modeling_internlm.InternLMModel'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaModel'
,
'modeling_internlm.InternLMMLP'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaMLP'
,
'modeling_internlm.InternLMRMSNorm'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaRMSNorm'
,
})
# internlm2
MODULE_MAP
.
update
({
'modeling_internlm2.InternLM2Attention'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.internlm2.PatchedInternLM2Attention'
,
'modeling_internlm2.InternLM2Model'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.internlm2.PatchedInternLM2Model'
,
'modeling_internlm2.InternLM2MLP'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.internlm2.PatchedInternLM2MLP'
,
'modeling_internlm2.InternLM2RMSNorm'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaRMSNorm'
,
})
# mistral
MODULE_MAP
.
update
({
'transformers.models.mistral.modeling_mistral.MistralAttention'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.mistral.MistralFlashAttention2'
,
'transformers.models.mistral.modeling_mistral.MistralFlashAttention2'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.mistral.MistralFlashAttention2'
,
'transformers.models.mistral.modeling_mistral.MistralSdpaAttention'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.mistral.MistralFlashAttention2'
,
'transformers.models.mistral.modeling_mistral.MistralModel'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaModel'
,
'transformers.models.mistral.modeling_mistral.MistralMLP'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaMLP'
,
'transformers.models.mistral.modeling_mistral.MistralRMSNorm'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaRMSNorm'
,
})
# gemma
MODULE_MAP
.
update
({
'transformers.models.gemma.modeling_gemma.GemmaAttention'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.gemma.PatchedGemmaAttention'
,
'transformers.models.gemma.modeling_gemma.GemmaFlashAttention2'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.gemma.PatchedGemmaAttention'
,
'transformers.models.gemma.modeling_gemma.GemmaSdpaAttention'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.gemma.PatchedGemmaAttention'
,
'transformers.models.gemma.modeling_gemma.GemmaModel'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.gemma.PatchedGemmaModel'
,
'transformers.models.gemma.modeling_gemma.modeling_mistral.GemmaMLP'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaMLP'
,
'transformers.models.gemma.modeling_gemma.GemmaRMSNorm'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.gemma.PatchedGemmaRMSNorm'
,
})
# deepseek
MODULE_MAP
.
update
({
'modeling_deepseek.DeepseekAttention'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.deepseek.PatchedDeepseekAttention'
,
'modeling_deepseek.DeepseekFlashAttention2'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.deepseek.PatchedDeepseekAttention'
,
'modeling_deepseek.DeepseekSdpaAttention'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.deepseek.PatchedDeepseekAttention'
,
'modeling_deepseek.DeepseekModel'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaModel'
,
'modeling_deepseek.DeepseekMLP'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaMLP'
,
'modeling_deepseek.DeepseekRMSNorm'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaRMSNorm'
,
})
# qwen1.5
MODULE_MAP
.
update
({
'transformers.models.qwen2.modeling_qwen2.Qwen2Attention'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.qwen2.PatchedQwen2Attention'
,
'transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.qwen2.PatchedQwen2Attention'
,
'transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.qwen2.PatchedQwen2Attention'
,
'transformers.models.qwen2.modeling_qwen2.Qwen2Model'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaModel'
,
'transformers.models.qwen2.modeling_qwen2.Qwen2MLP'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaMLP'
,
'transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaRMSNorm'
,
})
# peft
MODULE_MAP
.
update
({
'peft.tuners.lora.layer.Linear'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.peft.LoRALinear'
})
# mixtral
MODULE_MAP
.
update
({
'transformers.models.mixtral.modeling_mixtral.MixtralAttention'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.mixtral.PatchedMixtralAttention'
,
'transformers.models.mixtral.modeling_mixtral.MixtralFlashAttention2'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.mixtral.PatchedMixtralAttention'
,
'transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.mixtral.PatchedMixtralAttention'
,
'transformers.models.mixtral.modeling_mixtral.MixtralModel'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.mixtral.PatchedMixtralModel'
,
'transformers.models.mixtral.modeling_mixtral.MixtralBLockSparseTop2MLP'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.mixtral.PatchedMixtralBLockSparseTop2MLP'
,
'transformers.models.mixtral.modeling_mixtral.MixtralBlockSparseTop2MLP'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.mixtral.PatchedMixtralBLockSparseTop2MLP'
,
'transformers.models.mixtral.modeling_mixtral.MixtralRMSNorm'
:
f
'
{
LMDEPLOY_PYTORCH_MODEL_PATH
}
.llama.LlamaRMSNorm'
,
})
lmdeploy/pytorch/models/patch.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
import
importlib
import
inspect
import
re
from
copy
import
copy
from
typing
import
Any
,
Dict
,
Sequence
import
torch
from
addict
import
Addict
from
torch.distributed._tensor
import
DeviceMesh
from
lmdeploy.utils
import
get_logger
from
..dist_utils
import
partition_module
,
replicate_module
from
.module_map
import
MODULE_MAP
logger
=
get_logger
(
'lmdeploy'
)
def
_get_rewrite_qualname
(
origin_qualname
:
str
)
->
str
:
"""get rewrite module from origin module name.
Args:
origin_qualname (str): The origin qualname of the module.
Returns:
str: The rewrite qualname.
"""
if
origin_qualname
in
MODULE_MAP
:
return
MODULE_MAP
[
origin_qualname
]
for
key
,
value
in
MODULE_MAP
.
items
():
if
re
.
search
(
key
,
origin_qualname
):
return
value
return
None
def
_class_from_qualname
(
qualname
:
str
)
->
Any
:
"""Import class with qualname.
Args:
qualname (str): Qualname of the class
Returns:
Any: class or builder of the class
"""
last_dot
=
qualname
.
rfind
(
'.'
)
modname
=
qualname
[:
last_dot
]
clsname
=
qualname
[
last_dot
+
1
:]
# get class at runtime
mod
=
importlib
.
import_module
(
modname
)
assert
mod
is
not
None
,
f
'failed to import module:
{
modname
}
'
cls_type
=
getattr
(
mod
,
clsname
)
return
cls_type
def
_find_rewrite_module_qualname
(
model
):
"""find rewrite module."""
module_name
=
inspect
.
getmodule
(
model
).
__name__
class_name
=
model
.
__class__
.
__name__
def
_find_fullname
():
origin_qualname
=
f
'
{
module_name
}
.
{
class_name
}
'
rewrite_qualname
=
_get_rewrite_qualname
(
origin_qualname
)
return
rewrite_qualname
def
_find_classname
():
origin_qualname
=
class_name
rewrite_qualname
=
_get_rewrite_qualname
(
origin_qualname
)
return
rewrite_qualname
def
_find_submodulename
():
# name with first module
mod_name
=
module_name
[
module_name
.
rfind
(
'.'
)
+
1
:]
origin_qualname
=
f
'
{
mod_name
}
.
{
class_name
}
'
rewrite_qualname
=
_get_rewrite_qualname
(
origin_qualname
)
return
rewrite_qualname
rewrite_qualname
=
_find_fullname
()
if
rewrite_qualname
is
None
:
rewrite_qualname
=
_find_classname
()
if
rewrite_qualname
is
None
:
rewrite_qualname
=
_find_submodulename
()
origin_qualname
=
f
'
{
module_name
}
.
{
class_name
}
'
if
rewrite_qualname
is
not
None
:
logger
.
debug
(
'Find rewrite of module
\n
'
f
'
{
origin_qualname
}
<=>
{
rewrite_qualname
}
'
)
return
rewrite_qualname
def
_update_module_type
(
model
:
Any
,
cls_type
:
type
,
custom_attrs
:
dict
=
None
):
"""Update class type of model."""
# directly return origin model is not cool
# origin model would be registered as a submodule
old_type
=
type
(
model
)
@
property
def
get_origin_mod
(
self
):
origin_mod
=
copy
(
self
)
origin_mod
.
__class__
=
old_type
return
origin_mod
attrs
=
dict
(
cls_type
.
__dict__
)
custom_attrs
=
custom_attrs
or
dict
()
custom_attrs
[
'origin_mod'
]
=
get_origin_mod
attrs
.
update
(
custom_attrs
)
new_type
=
type
(
cls_type
.
__name__
,
(
type
(
model
),
),
attrs
)
model
=
copy
(
model
)
model
.
__class__
=
new_type
return
model
def
_patch
(
model
:
torch
.
nn
.
Module
,
context
:
Addict
)
->
torch
.
nn
.
Module
:
"""patch the model with rewrite module.
Args:
model (Module): model to be patched.
context (Addict): The environment info to patched in model
Returns:
Module: The patched model
"""
def
_recursive_children
(
context
,
named_children
):
"""recursive children."""
for
name
,
child
in
named_children
:
patched_child
=
_patch
(
child
,
context
)
if
patched_child
!=
child
:
model
.
register_module
(
name
,
patched_child
)
_recursive_children
(
context
,
model
.
named_children
())
rewrite_qualname
=
_find_rewrite_module_qualname
(
model
)
if
rewrite_qualname
is
not
None
:
cls_type
=
_class_from_qualname
(
rewrite_qualname
)
model
=
_update_module_type
(
model
,
cls_type
,
dict
(
context
=
context
))
return
model
def
_update_model
(
model
:
torch
.
nn
.
Module
):
"""Update model after patch and load.
Args:
model (Module): The model to be updated.
"""
# recursive over children
for
_
,
child
in
model
.
named_children
():
_update_model
(
child
)
if
hasattr
(
model
,
'_update_model_fn'
):
model
.
_update_model_fn
()
def
_dist_model
(
model
:
torch
.
nn
.
Module
,
rank
:
int
=
0
,
device_mesh
:
DeviceMesh
=
None
):
"""distribute model parameters."""
def
_init_params
():
"""init params."""
device
=
torch
.
device
(
f
'cuda:
{
rank
}
'
)
for
name
,
param
in
model
.
named_parameters
(
recurse
=
False
):
if
device
!=
param
.
device
:
if
rank
==
0
:
new_param
=
param
.
to
(
device
)
model
.
register_parameter
(
name
,
torch
.
nn
.
Parameter
(
new_param
,
requires_grad
=
False
))
else
:
new_param
=
torch
.
empty_like
(
param
,
device
=
device
)
model
.
register_parameter
(
name
,
torch
.
nn
.
Parameter
(
new_param
,
requires_grad
=
False
))
for
name
,
param
in
model
.
named_buffers
(
recurse
=
False
):
if
device
!=
param
.
device
:
if
rank
==
0
:
new_param
=
param
.
to
(
device
)
model
.
register_buffer
(
name
,
new_param
)
else
:
new_param
=
torch
.
empty_like
(
param
,
device
=
device
)
model
.
register_buffer
(
name
,
new_param
)
def
_dist_params
():
"""dist params."""
if
hasattr
(
model
,
'_distribute_partition_fn'
):
partition_module
(
model
,
device_mesh
=
device_mesh
,
func
=
model
.
_distribute_partition_fn
,
to_local
=
True
,
)
else
:
replicate_module
(
model
,
device_mesh
=
device_mesh
)
torch
.
cuda
.
empty_cache
()
def
_register_hooks
():
"""register hooks."""
if
hasattr
(
model
,
'_distribute_input_fn'
):
input_fn
=
model
.
_distribute_input_fn
model
.
register_forward_pre_hook
(
lambda
_
,
inputs
,
inputs_dict
:
input_fn
(
inputs
,
inputs_dict
,
device_mesh
),
with_kwargs
=
True
,
)
if
hasattr
(
model
,
'_distribute_output_fn'
):
output_fn
=
model
.
_distribute_output_fn
model
.
register_forward_hook
(
lambda
mod
,
inputs
,
outputs
:
output_fn
(
outputs
,
device_mesh
))
for
name
,
child
in
model
.
named_children
():
if
rank
==
0
:
logger
.
debug
(
f
'Distribute module: <
{
name
}
>'
)
new_child
=
_dist_model
(
child
,
rank
,
device_mesh
)
if
new_child
!=
child
:
model
.
register_module
(
name
,
child
)
_init_params
()
_dist_params
()
_register_hooks
()
return
model
class
PatchedForward
:
"""patched forward."""
def
__init__
(
self
,
model
,
context
,
extra_args
):
self
.
_model
=
model
self
.
_patch_context
:
Dict
=
context
self
.
_extra_args
:
list
=
extra_args
def
__call__
(
self
,
*
args
,
**
kwargs
):
for
arg_name
in
self
.
_extra_args
:
extra_arg
=
kwargs
.
pop
(
arg_name
,
None
)
self
.
_patch_context
[
arg_name
]
=
extra_arg
output
=
self
.
_model
(
*
args
,
**
kwargs
)
self
.
_patch_context
.
clear
()
return
output
def
patch
(
model
:
torch
.
nn
.
Module
,
extra_args
:
Sequence
[
str
]
=
None
,
rank
:
int
=
0
,
world_size
:
int
=
1
,
):
"""Patch the model with rewrite modules.
Extra arguments will be patched in forward of model, weights on each rank
will be partitioned.
Args:
model (Module): Model to be patched.
extra_args (Sequence[str]): Extra arguments of model forward.
rank (int): Distribution rank.
world_size (int): Distribution world size.
Returns:
Module: The patched model.
"""
if
extra_args
is
None
:
extra_args
=
[]
_patch_context
=
Addict
()
model
=
_patch
(
model
,
_patch_context
)
if
world_size
>
1
:
if
rank
==
0
:
logger
.
info
(
'distribute model parameters.'
)
device_mesh
=
DeviceMesh
(
'cuda'
,
list
(
range
(
world_size
)))
model
=
_dist_model
(
model
,
rank
,
device_mesh
=
device_mesh
)
_update_model
(
model
)
patched_forward
=
PatchedForward
(
model
,
_patch_context
,
extra_args
=
extra_args
)
model
.
patched_forward
=
patched_forward
return
model
lmdeploy/pytorch/models/peft.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
from
dataclasses
import
dataclass
import
torch
import
torch.distributed
as
dist
from
..kernels.mbgmm
import
mbgmm_a
,
mbgmm_b
from
..kernels.mbgmv
import
mbgmv_a
,
mbgmv_b
from
..kernels.rearange_all_gather
import
rearange_all_gather
@
dataclass
class
PackedLoRAInput
:
x
:
torch
.
Tensor
a_cache
:
torch
.
Tensor
b_cache
:
torch
.
Tensor
q_start_loc
:
torch
.
Tensor
q_seqlens
:
torch
.
Tensor
adapter_ids
:
torch
.
Tensor
scaling
:
torch
.
Tensor
rank_page_table
:
torch
.
Tensor
rank_page_start
:
torch
.
Tensor
ranks
:
torch
.
Tensor
max_seq_len
:
int
max_rank
:
int
is_decoding
:
bool
class
LoRALinear
(
torch
.
nn
.
Module
):
def
_make_packed_lora_input
(
self
,
x
):
context
=
self
.
context
.
context
# adapter cache
global_adapter_ids
=
context
.
global_adapter_ids
layer_idx
=
self
.
layer_idx
ranks
=
self
.
ranks
[
global_adapter_ids
]
block_starts
=
self
.
block_starts
[
global_adapter_ids
]
scaling
=
self
.
scaling
[
global_adapter_ids
]
k_cache
,
v_cache
=
context
.
kv_caches
[
layer_idx
]
cache_len
=
k_cache
.
size
(
0
)
a_cache
=
k_cache
.
view
(
cache_len
,
-
1
)
b_cache
=
v_cache
.
view
(
cache_len
,
-
1
)
return
PackedLoRAInput
(
x
=
x
.
flatten
(
0
,
-
2
).
contiguous
(),
a_cache
=
a_cache
,
b_cache
=
b_cache
,
q_start_loc
=
context
.
q_start_loc
,
q_seqlens
=
context
.
q_seq_length
,
adapter_ids
=
context
.
local_adapter_ids
,
scaling
=
scaling
,
rank_page_table
=
context
.
adapter_offsets
,
rank_page_start
=
block_starts
,
ranks
=
ranks
,
max_seq_len
=
context
.
max_q_seq_length
,
max_rank
=
context
.
max_rank
,
is_decoding
=
context
.
is_decoding
)
def
_lora_forward_local
(
self
,
x
):
"""lora forward no tp."""
lora_input
=
self
.
_make_packed_lora_input
(
x
)
out_size
=
self
.
base_layer
.
weight
.
size
(
0
)
if
not
lora_input
.
is_decoding
:
xa
=
mbgmm_a
(
lora_input
.
x
,
lora_input
.
a_cache
,
q_start_loc
=
lora_input
.
q_start_loc
,
q_seqlens
=
lora_input
.
q_seqlens
,
adapter_ids
=
lora_input
.
adapter_ids
,
rank_page_table
=
lora_input
.
rank_page_table
,
rank_page_start
=
lora_input
.
rank_page_start
,
ranks
=
lora_input
.
ranks
,
max_seq_len
=
lora_input
.
max_seq_len
,
max_rank
=
lora_input
.
max_rank
)
lora_out
=
mbgmm_b
(
xa
,
lora_input
.
b_cache
,
q_start_loc
=
lora_input
.
q_start_loc
,
q_seqlens
=
lora_input
.
q_seqlens
,
adapter_ids
=
lora_input
.
adapter_ids
,
scaling
=
lora_input
.
scaling
,
rank_page_table
=
lora_input
.
rank_page_table
,
rank_page_start
=
lora_input
.
rank_page_start
,
ranks
=
lora_input
.
ranks
,
max_seq_len
=
lora_input
.
max_seq_len
,
max_rank
=
lora_input
.
max_rank
,
out_size
=
out_size
)
else
:
xa
=
mbgmv_a
(
lora_input
.
x
,
lora_input
.
a_cache
,
adapter_ids
=
lora_input
.
adapter_ids
,
rank_page_table
=
lora_input
.
rank_page_table
,
rank_page_start
=
lora_input
.
rank_page_start
,
ranks
=
lora_input
.
ranks
,
max_rank
=
lora_input
.
max_rank
)
lora_out
=
mbgmv_b
(
xa
,
lora_input
.
b_cache
,
adapter_ids
=
lora_input
.
adapter_ids
,
scaling
=
lora_input
.
scaling
,
rank_page_table
=
lora_input
.
rank_page_table
,
rank_page_start
=
lora_input
.
rank_page_start
,
ranks
=
lora_input
.
ranks
,
max_rank
=
lora_input
.
max_rank
,
out_size
=
out_size
)
base_out
=
self
.
base_layer
(
x
)
lora_out
=
lora_out
.
reshape
(
base_out
.
shape
)
output
=
base_out
+
lora_out
return
output
def
_lora_forward_tp_rowwise
(
self
,
x
):
"""lora forward tp rowwise."""
lora_input
=
self
.
_make_packed_lora_input
(
x
)
rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
out_size
=
self
.
base_layer
.
weight
.
size
(
0
)
//
world_size
if
not
lora_input
.
is_decoding
:
xa
=
mbgmm_a
(
lora_input
.
x
,
lora_input
.
a_cache
,
q_start_loc
=
lora_input
.
q_start_loc
,
q_seqlens
=
lora_input
.
q_seqlens
,
adapter_ids
=
lora_input
.
adapter_ids
,
rank_page_table
=
lora_input
.
rank_page_table
,
rank_page_start
=
lora_input
.
rank_page_start
,
ranks
=
lora_input
.
ranks
,
max_seq_len
=
lora_input
.
max_seq_len
,
max_rank
=
lora_input
.
max_rank
)
lora_out
=
mbgmm_b
(
xa
,
lora_input
.
b_cache
,
q_start_loc
=
lora_input
.
q_start_loc
,
q_seqlens
=
lora_input
.
q_seqlens
,
adapter_ids
=
lora_input
.
adapter_ids
,
scaling
=
lora_input
.
scaling
,
rank_page_table
=
lora_input
.
rank_page_table
,
rank_page_start
=
lora_input
.
rank_page_start
,
ranks
=
lora_input
.
ranks
,
max_seq_len
=
lora_input
.
max_seq_len
,
max_rank
=
lora_input
.
max_rank
,
out_size
=
out_size
)
else
:
xa
=
mbgmv_a
(
lora_input
.
x
,
lora_input
.
a_cache
,
adapter_ids
=
lora_input
.
adapter_ids
,
rank_page_table
=
lora_input
.
rank_page_table
,
rank_page_start
=
lora_input
.
rank_page_start
,
ranks
=
lora_input
.
ranks
,
max_rank
=
lora_input
.
max_rank
)
lora_out
=
mbgmv_b
(
xa
,
lora_input
.
b_cache
,
adapter_ids
=
lora_input
.
adapter_ids
,
scaling
=
lora_input
.
scaling
,
rank_page_table
=
lora_input
.
rank_page_table
,
rank_page_start
=
lora_input
.
rank_page_start
,
ranks
=
lora_input
.
ranks
,
max_rank
=
lora_input
.
max_rank
,
out_size
=
out_size
)
base_out
=
self
.
base_layer
(
x
)
out_shape
=
base_out
.
shape
base_out
=
base_out
.
flatten
(
0
,
-
2
)
slice_start
=
rank
*
out_size
slice_end
=
slice_start
+
out_size
base_out
[:,
slice_start
:
slice_end
]
+=
lora_out
base_out
=
base_out
.
reshape
(
out_shape
)
return
base_out
def
_lora_forward_tp_colwise
(
self
,
x
):
"""lora forward tp colwise."""
def
__gather_xa
(
xa
):
"""gather xa."""
gathered_xa
=
xa
.
new_empty
(
world_size
,
xa
.
size
(
0
),
xa
.
size
(
1
))
dist
.
all_gather_into_tensor
(
gathered_xa
,
xa
)
# TODO: gather would failed when adapters have different ranks.
gathered_xa
=
gathered_xa
.
permute
(
1
,
0
,
2
).
flatten
(
-
2
,
-
1
)
return
gathered_xa
lora_input
=
self
.
_make_packed_lora_input
(
x
)
world_size
=
dist
.
get_world_size
()
out_size
=
self
.
base_layer
.
weight
.
size
(
0
)
if
not
lora_input
.
is_decoding
:
xa
=
mbgmm_a
(
lora_input
.
x
,
lora_input
.
a_cache
,
q_start_loc
=
lora_input
.
q_start_loc
,
q_seqlens
=
lora_input
.
q_seqlens
,
adapter_ids
=
lora_input
.
adapter_ids
,
rank_page_table
=
lora_input
.
rank_page_table
,
rank_page_start
=
lora_input
.
rank_page_start
,
ranks
=
lora_input
.
ranks
,
max_seq_len
=
lora_input
.
max_seq_len
,
max_rank
=
lora_input
.
max_rank
,
rank_step
=
world_size
)
gathered_xa
=
__gather_xa
(
xa
)
if
len
(
lora_input
.
ranks
)
>
1
:
gathered_xa
=
rearange_all_gather
(
gathered_xa
,
b_start_loc
=
lora_input
.
q_start_loc
,
b_seq_lens
=
lora_input
.
q_seqlens
,
adapter_ids
=
lora_input
.
adapter_ids
,
ranks
=
lora_input
.
ranks
,
world_size
=
world_size
,
max_seq_len
=
lora_input
.
max_seq_len
,
output
=
gathered_xa
)
lora_out
=
mbgmm_b
(
gathered_xa
,
lora_input
.
b_cache
,
q_start_loc
=
lora_input
.
q_start_loc
,
q_seqlens
=
lora_input
.
q_seqlens
,
adapter_ids
=
lora_input
.
adapter_ids
,
scaling
=
lora_input
.
scaling
,
rank_page_table
=
lora_input
.
rank_page_table
,
rank_page_start
=
lora_input
.
rank_page_start
,
ranks
=
lora_input
.
ranks
,
max_seq_len
=
lora_input
.
max_seq_len
,
max_rank
=
lora_input
.
max_rank
,
out_size
=
out_size
)
else
:
xa
=
mbgmv_a
(
lora_input
.
x
,
lora_input
.
a_cache
,
adapter_ids
=
lora_input
.
adapter_ids
,
rank_page_table
=
lora_input
.
rank_page_table
,
rank_page_start
=
lora_input
.
rank_page_start
,
ranks
=
lora_input
.
ranks
,
max_rank
=
lora_input
.
max_rank
,
rank_step
=
world_size
)
gathered_xa
=
__gather_xa
(
xa
)
if
len
(
lora_input
.
ranks
)
>
1
:
gathered_xa
=
rearange_all_gather
(
gathered_xa
,
b_start_loc
=
lora_input
.
q_start_loc
,
b_seq_lens
=
lora_input
.
q_seqlens
,
adapter_ids
=
lora_input
.
adapter_ids
,
ranks
=
lora_input
.
ranks
,
world_size
=
world_size
,
max_seq_len
=
lora_input
.
max_seq_len
,
output
=
gathered_xa
)
lora_out
=
mbgmv_b
(
gathered_xa
,
lora_input
.
b_cache
,
adapter_ids
=
lora_input
.
adapter_ids
,
scaling
=
lora_input
.
scaling
,
rank_page_table
=
lora_input
.
rank_page_table
,
rank_page_start
=
lora_input
.
rank_page_start
,
ranks
=
lora_input
.
ranks
,
max_rank
=
lora_input
.
max_rank
,
out_size
=
out_size
)
base_out
=
self
.
base_layer
(
x
)
lora_out
=
lora_out
.
reshape
(
base_out
.
shape
)
output
=
base_out
+
lora_out
return
output
def
_lora_forward_tp
(
self
,
x
):
"""lora forward tp."""
tp_mode
=
getattr
(
self
,
'_tp_mode'
,
None
)
if
tp_mode
==
'rowwise'
:
return
self
.
_lora_forward_tp_rowwise
(
x
)
elif
tp_mode
==
'colwise'
:
return
self
.
_lora_forward_tp_colwise
(
x
)
else
:
assert
tp_mode
is
None
,
'tp_mode == None failed.'
return
self
.
_lora_forward_local
(
x
)
def
_lora_forward
(
self
,
x
):
"""lora forward."""
if
dist
.
is_initialized
():
return
self
.
_lora_forward_tp
(
x
)
else
:
return
self
.
_lora_forward_local
(
x
)
def
forward
(
self
,
x
):
"""forward."""
context
=
self
.
context
.
context
max_rank
=
context
.
max_rank
if
max_rank
==
0
:
return
self
.
origin_mod
.
forward
(
x
)
else
:
return
self
.
_lora_forward
(
x
)
lmdeploy/pytorch/models/q_modules.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
from
dataclasses
import
dataclass
import
torch
import
torch.nn
as
nn
from
..kernels.w8a8_triton_kernels
import
(
matmul_kernel_dynamic_quant
,
per_channel_quant
,
per_token_quant_int8
,
rms_norm_dynamic_quant
)
@
dataclass
class
QTensor
:
"""A data class representing a Quantized Tensor.
This class wraps around a regular Pytorch tensor and adds quantization-
specific parameters.
"""
tensor
:
torch
.
Tensor
scale
:
torch
.
Tensor
zero_point
:
torch
.
Tensor
=
None
def
__getattr__
(
self
,
name
:
str
):
"""Allows attribute access to be forwarded to the wrapped tensor when
the attribute doesn't exist in QTensor."""
try
:
return
super
().
__getattr__
(
name
)
except
AttributeError
:
return
getattr
(
self
.
tensor
,
name
)
class
QRMSNorm
(
nn
.
Module
):
"""It performs traditional RMS normalization and then quantizes the output
to 8-bit integers."""
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
@
classmethod
def
from_float
(
cls
,
mod
:
nn
.
Module
,
initialization
:
bool
=
True
):
"""Class method to create a QRMSNorm instance from a floating-point
module.
`initialization = True` for real init.
`initialization = False` for dummy init.
"""
hidden_size
=
mod
.
weight
.
shape
[
0
]
eps
=
mod
.
variance_epsilon
q_mod
=
cls
(
hidden_size
,
eps
)
if
initialization
:
q_mod
.
weight
=
nn
.
Parameter
(
mod
.
weight
.
detach
())
return
q_mod
def
forward
(
self
,
hidden_states
):
"""Defines the computation performed at every call.
Performs RMS normalization followed by dynamic quantization on
hidden_states. Returns a QTensor which wraps the quantized tensor along
with its scale factor.
"""
hidden_states_quant
,
rms_scale
=
rms_norm_dynamic_quant
(
hidden_states
,
self
.
weight
,
self
.
variance_epsilon
)
return
QTensor
(
hidden_states_quant
,
rms_scale
)
class
QLinear
(
nn
.
Module
):
"""A Linear layer that operates on quantized inputs and weights.
It performs matrix multiplication in 8-bit precision and dequantize the
results back to float.
"""
__constants__
=
[
'in_features'
,
'out_features'
]
in_features
:
int
out_features
:
int
weight
:
torch
.
Tensor
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
bias
:
bool
=
True
,
device
=
None
,
dtype
=
None
)
->
None
:
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
register_buffer
(
'weight'
,
torch
.
empty
((
out_features
,
in_features
),
device
=
device
,
dtype
=
torch
.
int8
))
self
.
register_buffer
(
'scale'
,
torch
.
empty
((
out_features
,
1
),
device
=
device
,
dtype
=
torch
.
float32
))
if
bias
:
self
.
register_buffer
(
'bias'
,
torch
.
empty
(
out_features
,
**
factory_kwargs
))
else
:
self
.
register_parameter
(
'bias'
,
None
)
@
classmethod
def
from_float
(
cls
,
mod
:
nn
.
Module
,
initialization
:
bool
=
True
):
"""Class method to create a QLinear instance from a floating-point
module.
`initialization = True` for real init.
`initialization = False` for dummy init.
"""
q_mod
=
cls
(
mod
.
in_features
,
mod
.
out_features
,
mod
.
bias
is
not
None
,
device
=
mod
.
weight
.
device
,
dtype
=
mod
.
weight
.
dtype
)
if
initialization
:
weight_quant
,
scale
=
per_channel_quant
(
mod
.
weight
.
detach
(),
8
,
torch
.
int8
)
q_mod
.
weight
.
data
=
weight_quant
q_mod
.
scale
=
scale
if
mod
.
bias
is
not
None
:
q_mod
.
bias
.
data
=
mod
.
bias
.
detach
()
return
q_mod
def
forward
(
self
,
input
):
"""Defines the computation performed at every call.
Performs quantization if the input is a tensor, otherwise it assumes
the input is already quantized (instance of QTensor). Then, it performs
linear transformation using dynamic quantization method, resulting in
an 8-bit integer output. Finally, it dequantizes the result back to a
floating point tensor.
"""
if
isinstance
(
input
,
torch
.
Tensor
):
input_quant
,
input_scale
=
per_token_quant_int8
(
input
,
1e-7
)
else
:
assert
isinstance
(
input
,
QTensor
)
input_quant
,
input_scale
=
input
.
tensor
,
input
.
scale
out
=
matmul_kernel_dynamic_quant
(
input_quant
,
self
.
weight
,
input_scale
,
self
.
scale
,
output_dtype
=
torch
.
float16
,
bias
=
self
.
bias
)
return
out
def
extra_repr
(
self
)
->
str
:
return
'in_features={}, out_features={}, bias={}'
.
format
(
self
.
in_features
,
self
.
out_features
,
self
.
bias
is
not
None
)
Prev
1
…
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment