Commit fe851fbc authored by zhouxiang's avatar zhouxiang
Browse files

0.2.6版本新增文件补充

parent e2d98ddc
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedModel
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (BaseModelOutputWithPast,
CausalLMOutputWithPast)
from lmdeploy.pytorch.modeling.convert_to_qmodules import convert_to_qmodules
from lmdeploy.utils import get_logger
from .configuration_baichuan import BaiChuanConfig
logger = get_logger('lmdeploy')
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0):
"""Make causal mask used for bi-directional self-attention."""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len),
torch.tensor(torch.finfo(dtype).min, device=device),
device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([
torch.zeros(
tgt_len, past_key_values_length, dtype=dtype, device=device),
mask
],
dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len,
tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor,
dtype: torch.dtype,
tgt_len: Optional[int] = None):
"""Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len,
src_seq_len]`."""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len,
src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool),
torch.finfo(dtype).min)
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""RMSNorm is equivalent to T5LayerNorm."""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1,
keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance +
self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
class RotaryEmbedding(torch.nn.Module):
"""RotaryEmbedding for Baichuan Model.
This module generates sine and cosine positional encodings based on
the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding".
The purpose of this class is to provide positional embeddings to the
input tensors. It utilizes a cache mechanism to store precomputed
sine and cosine values for speedup.
Args:
dim (int): The dimensionality of the embeddings.
max_position_embeddings (int, optional): The maximum number of
position embeddings. Default is 2048.
base (int, optional): The base value for the inverse frequency
calculation. Default is 10000.
device (str, optional): The device to run operations on.
If None, defaults to the device of the model.
"""
def __init__(self,
dim,
max_position_embeddings=2048,
base=10000,
device=None):
super().__init__()
index = (torch.arange(0, dim, 2).float().to(device) / dim)
inv_freq = 1.0 / (base**index)
self.register_buffer('inv_freq', inv_freq)
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached,
device=self.inv_freq.device,
dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order
# to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer('cos_cached',
emb.cos()[None, None, :, :],
persistent=False)
self.register_buffer('sin_cached',
emb.sin()[None, None, :, :],
persistent=False)
def forward(self, x, seq_len=None):
"""Forward propagation method for the embedding layer.
Generates positional embeddings for the given input tensor.
"""
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in
# `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached,
device=x.device,
dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in
# order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer('cos_cached',
emb.cos()[None, None, :, :],
persistent=False)
self.register_buffer('sin_cached',
emb.sin()[None, None, :, :],
persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
"""Apply rotary positional embeddings to query and key tensors.
This function applies the cosine and sine positional embeddings on the
input query (q) and key (k) tensors using element-wise multiplication and
addition.
"""
# The first two dimensions of cos and sin are always 1,
# so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class MLP(nn.Module):
"""MLP for Baichuan Model."""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper."""
def __init__(self, config: BaiChuanConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.max_position_embeddings = config.max_position_embeddings
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError('hidden_size must be divisible by num_heads '
f'(got `hidden_size`: {self.hidden_size}'
f' and `num_heads`: {self.num_heads}).')
self.W_pack = nn.Linear(self.hidden_size,
3 * self.hidden_size,
bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim,
self.hidden_size,
bias=False)
self.rotary_emb = RotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Forward propagation method for the attention layer."""
bsz, q_len, _ = hidden_states.size()
proj = self.W_pack(hidden_states)
proj = proj.unflatten(-1,
(3, self.hidden_size)).unsqueeze(0).transpose(
0, -2).squeeze(-2)
query_states = proj[0].view(
bsz, q_len, self.num_heads, self.head_dim).transpose(
1, 2) # batch_size x source_len x hidden_size
key_states = proj[1].view(bsz, q_len,
self.num_heads, self.head_dim).transpose(
1,
2) # batch_size x target_len x head_size
value_states = proj[2].view(
bsz, q_len, self.num_heads, self.head_dim).transpose(
1, 2) # batch_size x source_len x hidden_size
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
attn_weights = torch.matmul(query_states, key_states.transpose(
2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
'Attention weights should be of size '
f'{(bsz, self.num_heads, q_len, kv_seq_len)}, but is'
f' {attn_weights.size()}')
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError('Attention mask should be of size '
f'{(bsz, 1, q_len, kv_seq_len)},'
f' but is {attention_mask.size()}')
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights,
torch.tensor(torch.finfo(attn_weights.dtype).min))
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights,
dim=-1,
dtype=torch.float32).to(
query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
'`attn_output` should be of size '
f'{(bsz, self.num_heads, q_len, self.head_dim)}, but is'
f' {attn_output.size()}')
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class DecoderLayer(nn.Module):
"""Decoder layer for Baichuan Model."""
def __init__(self, config: BaiChuanConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Attention(config=config)
self.mlp = MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor,
torch.FloatTensor]]]:
""" # noqa: E501
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states, )
if output_attentions:
outputs += (self_attn_weights, )
if use_cache:
outputs += (present_key_value, )
return outputs
class PreTrainedModel(PreTrainedModel):
config_class = BaiChuanConfig
base_model_prefix = 'model'
supports_gradient_checkpointing = True
_no_split_modules = ['DecoderLayer']
_keys_to_ignore_on_load_unexpected = [r'decoder\.version']
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, Model):
module.gradient_checkpointing = value
class Model(PreTrainedModel):
"""Transformer decoder consisting of *config.num_hidden_layers* layers.
Each layer is a [`DecoderLayer`]
Args:
config: BaiChuanConfig
"""
def __init__(self, config: BaiChuanConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
self.padding_idx)
self.layers = nn.ModuleList(
[DecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder.
# prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape,
inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask,
inputs_embeds.dtype,
tgt_len=input_shape[-1]).to(
inputs_embeds.device)
combined_attention_mask = (expanded_attn_mask
if combined_attention_mask is None else
expanded_attn_mask +
combined_attention_mask)
return combined_attention_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (output_attentions if output_attentions is not None
else self.config.output_attentions)
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else
self.config.output_hidden_states)
use_cache = (use_cache
if use_cache is not None else self.config.use_cache)
return_dict = (return_dict if return_dict is not None else
self.config.use_return_dict)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError('You cannot specify both decoder_input_ids '
'and decoder_inputs_embeds at the same time')
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError('You have to specify either decoder_input_ids '
'or decoder_inputs_embeds')
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = (seq_length_with_past +
past_key_values_length)
if position_ids is None:
device = (input_ids.device
if input_ids is not None else inputs_embeds.device)
position_ids = torch.arange(past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds,
past_key_values_length)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
'`use_cache=True` is incompatible with gradient '
'checkpointing. Setting `use_cache=False`...')
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states, )
past_key_value = past_key_values[
idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (
layer_outputs[2 if output_attentions else 1], )
if output_attentions:
all_self_attns += (layer_outputs[1], )
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states, )
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v for v in
[hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class BaiChuanForCausalLM(PreTrainedModel):
"""This class extends the `PreTrainedModel` to enable causal language
modeling.
It wraps the basic Baichuan model (`Model`) and includes a linear layer as
a language model head (`lm_head`). The purpose is to predict token
probabilities, given the previous tokens in the sequence.
"""
def __init__(self, config):
super().__init__(config)
self.model = Model(config)
self.lm_head = nn.Linear(config.hidden_size,
config.vocab_size,
bias=False)
# Initialize weights and apply final processing
self.post_init()
convert_to_qmodules(self)
def get_input_embeddings(self):
"""Get the token embedding layer."""
return self.model.embed_tokens
def set_input_embeddings(self, value):
"""Set the token embedding layer."""
self.model.embed_tokens = value
def get_output_embeddings(self):
"""Get the output embedding layer."""
return self.lm_head
def set_output_embeddings(self, new_embeddings):
"""Set the output embedding layer."""
self.lm_head = new_embeddings
def set_decoder(self, decoder):
"""Set the decoder model."""
self.model = decoder
def get_decoder(self):
"""Get the decoder model."""
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r""" # noqa: E501
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, ModelForCausalLM
>>> model = ModelForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you consciours? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```"""
output_attentions = (output_attentions if output_attentions is not None
else self.config.output_attentions)
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else
self.config.output_hidden_states)
return_dict = (return_dict if return_dict is not None else
self.config.use_return_dict)
# decoder outputs consists of
# (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits, ) + outputs[1:]
return (loss, ) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs):
"""Prepare inputs for generating sequences using the model.
Args:
input_ids (torch.Tensor): Input token ids.
past_key_values (list[torch.Tensor], optional): List of past key
and value states.
attention_mask (torch.Tensor, optional): Mask indicating which
tokens should be attended to.
inputs_embeds (torch.FloatTensor, optional): Optionally,
the input embeddings instead of token ids.
Returns:
dict: Dictionary containing prepared inputs for model generation.
"""
if past_key_values:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get('position_ids', None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if `inputs_embeds` are passed,
# we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
model_inputs = {'input_ids': input_ids}
model_inputs.update({
'position_ids': position_ids,
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache'),
'attention_mask': attention_mask,
})
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
"""Reorder cached past key-values during generation using beam search.
This function reorders the cached past key-values according to the
given indices. It's useful in beam search where the order of hypotheses
can change from one time-step to another.
"""
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(
past_state.index_select(0, beam_idx)
for past_state in layer_past), )
return reordered_past
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch InternLM model."""
import math
import queue
import threading
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.generation.streamers import BaseStreamer
from transformers.modeling_outputs import (BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings)
from lmdeploy.pytorch.modeling.convert_to_qmodules import convert_to_qmodules
from lmdeploy.utils import get_logger
from .configuration_internlm import InternLMConfig
logger = get_logger('lmdeploy')
_CONFIG_FOR_DOC = 'InternLMConfig'
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0):
"""Make causal mask used for bi-directional self-attention."""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len),
torch.tensor(torch.finfo(dtype).min, device=device),
device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([
torch.zeros(
tgt_len, past_key_values_length, dtype=dtype, device=device),
mask
],
dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len,
tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor,
dtype: torch.dtype,
tgt_len: Optional[int] = None):
"""Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len,
src_seq_len]`."""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len,
src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool),
torch.finfo(dtype).min)
class InternLMRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""InternLMRMSNorm is equivalent to T5LayerNorm."""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1,
keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance +
self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
class InternLMRotaryEmbedding(torch.nn.Module):
"""RotaryEmbedding for InternLM Model.
This module generates sine and cosine positional encodings based on
the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding".
The purpose of this class is to provide positional embeddings to the
input tensors. It utilizes a cache mechanism to store precomputed
sine and cosine values for speedup.
Args:
dim (int): The dimensionality of the embeddings.
max_position_embeddings (int, optional): The maximum number of
position embeddings. Default is 2048.
base (int, optional): The base value for the inverse frequency
calculation. Default is 10000.
device (str, optional): The device to run operations on.
If None, defaults to the device of the model.
"""
def __init__(self,
dim,
max_position_embeddings=2048,
base=10000,
device=None):
super().__init__()
index = (torch.arange(0, dim, 2).float().to(device) / dim)
inv_freq = 1.0 / (base**index)
self.register_buffer('inv_freq', inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached,
device=self.inv_freq.device,
dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order
# to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer('cos_cached',
emb.cos()[None, None, :, :],
persistent=False)
self.register_buffer('sin_cached',
emb.sin()[None, None, :, :],
persistent=False)
def forward(self, x, seq_len=None):
"""Forward propagation method for the embedding layer.
Generates positional embeddings for the given input tensor.
"""
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in
# `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached,
device=x.device,
dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in
# order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer('cos_cached',
emb.cos()[None, None, :, :],
persistent=False)
self.register_buffer('sin_cached',
emb.sin()[None, None, :, :],
persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
"""Apply rotary positional embeddings to query and key tensors.
This function applies the cosine and sine positional embeddings on the
input query (q) and key (k) tensors using element-wise multiplication and
addition.
"""
# The first two dimensions of cos and sin are always 1, so we can
# `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class InternLMMLP(nn.Module):
"""MLP for InternLM Model."""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class InternLMAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper."""
def __init__(self, config: InternLMConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.max_position_embeddings = config.max_position_embeddings
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError('hidden_size must be divisible by num_heads '
f'(got `hidden_size`: {self.hidden_size}'
f' and `num_heads`: {self.num_heads}).')
self.q_proj = nn.Linear(self.hidden_size,
self.num_heads * self.head_dim,
bias=config.bias)
self.k_proj = nn.Linear(self.hidden_size,
self.num_heads * self.head_dim,
bias=config.bias)
self.v_proj = nn.Linear(self.hidden_size,
self.num_heads * self.head_dim,
bias=config.bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim,
self.hidden_size,
bias=config.bias)
self.rotary_emb = InternLMRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Forward propagation method for the attention layer."""
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(
bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(
bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(
bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
attn_weights = torch.matmul(query_states, key_states.transpose(
2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
'Attention weights should be of size '
f'{(bsz, self.num_heads, q_len, kv_seq_len)}, but is'
f' {attn_weights.size()}')
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError('Attention mask should be of size '
f'{(bsz, 1, q_len, kv_seq_len)}, '
f'but is {attention_mask.size()}')
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights,
torch.tensor(torch.finfo(attn_weights.dtype).min))
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights,
dim=-1,
dtype=torch.float32).to(
query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
'attn_output` should be of size '
f'`{(bsz, self.num_heads, q_len, self.head_dim)}, but is'
f' {attn_output.size()}')
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class InternLMDecoderLayer(nn.Module):
"""Decoder layer for InternLM Model."""
def __init__(self, config: InternLMConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = InternLMAttention(config=config)
self.mlp = InternLMMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = InternLMRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = InternLMRMSNorm(
config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor,
torch.FloatTensor]]]:
""" # noqa: E501
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states, )
if output_attentions:
outputs += (self_attn_weights, )
if use_cache:
outputs += (present_key_value, )
return outputs
INTERNLM_START_DOCSTRING = r""" # noqa: E501
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`InternLMConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
'The bare InternLM Model outputting raw hidden-states without any specific head on top.', # noqa: E501
INTERNLM_START_DOCSTRING,
)
class InternLMPreTrainedModel(PreTrainedModel):
config_class = InternLMConfig
base_model_prefix = 'model'
supports_gradient_checkpointing = True
_no_split_modules = ['InternLMDecoderLayer']
_keys_to_ignore_on_load_unexpected = [r'decoder\.version']
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, InternLMModel):
module.gradient_checkpointing = value
INTERNLM_INPUTS_DOCSTRING = r""" # noqa: E501
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
'The bare InternLM Model outputting raw hidden-states without any specific head on top.', # noqa: E501
INTERNLM_START_DOCSTRING,
)
class InternLMModel(InternLMPreTrainedModel):
"""Transformer decoder consisting of *config.num_hidden_layers* layers.
Each layer is a [`InternLMDecoderLayer`]
Args:
config: InternLMConfig
"""
_auto_class = 'AutoModel'
def __init__(self, config: InternLMConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
self.padding_idx)
self.layers = nn.ModuleList([
InternLMDecoderLayer(config)
for _ in range(config.num_hidden_layers)
])
self.norm = InternLMRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder.
# prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape,
inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask,
inputs_embeds.dtype,
tgt_len=input_shape[-1]).to(
inputs_embeds.device)
combined_attention_mask = (expanded_attn_mask
if combined_attention_mask is None else
expanded_attn_mask +
combined_attention_mask)
return combined_attention_mask
@add_start_docstrings_to_model_forward(INTERNLM_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions or self.config.output_attentions
output_hidden_states = (output_hidden_states
or self.config.output_hidden_states)
use_cache = use_cache or self.config.use_cache
return_dict = (return_dict if return_dict is not None else
self.config.use_return_dict)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError('You cannot specify both decoder_input_ids '
'and decoder_inputs_embeds at the same time')
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError('You have to specify either decoder_input_ids '
'or decoder_inputs_embeds')
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = (seq_length_with_past +
past_key_values_length)
if position_ids is None:
device = (input_ids.device
if input_ids is not None else inputs_embeds.device)
position_ids = torch.arange(past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds,
past_key_values_length)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
'`use_cache=True` is incompatible with gradient '
'checkpointing. Setting `use_cache=False`...')
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states, )
past_key_value = past_key_values[
idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (
layer_outputs[2 if output_attentions else 1], )
if output_attentions:
all_self_attns += (layer_outputs[1], )
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states, )
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v for v in
[hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class InternLMForCausalLM(InternLMPreTrainedModel):
"""This class extends the `InternLMPreTrainedModel` to enable causal
language modeling.
It wraps the basic InternLM model (`InternLMModel`) and includes a linear
layer as a language model head (`lm_head`). The purpose is to predict token
probabilities, given the previous tokens in the sequence.
"""
_auto_class = 'AutoModelForCausalLM'
def __init__(self, config):
super().__init__(config)
self.model = InternLMModel(config)
self.lm_head = nn.Linear(config.hidden_size,
config.vocab_size,
bias=False)
# Initialize weights and apply final processing
self.post_init()
convert_to_qmodules(self)
def get_input_embeddings(self):
"""Get the token embedding layer."""
return self.model.embed_tokens
def set_input_embeddings(self, value):
"""Set the token embedding layer."""
self.model.embed_tokens = value
def get_output_embeddings(self):
"""Get the output embedding layer."""
return self.lm_head
def set_output_embeddings(self, new_embeddings):
"""Set the output embedding layer."""
self.lm_head = new_embeddings
def set_decoder(self, decoder):
"""Set the decoder model."""
self.model = decoder
def get_decoder(self):
"""Get the decoder model."""
return self.model
@add_start_docstrings_to_model_forward(INTERNLM_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast,
config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r""" # noqa: E501
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, InternLMForCausalLM
>>> model = InternLMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you consciours? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```"""
output_attentions = output_attentions or self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else
self.config.output_hidden_states)
return_dict = (return_dict if return_dict is not None else
self.config.use_return_dict)
# decoder outputs consists of
# (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits, ) + outputs[1:]
return (loss, ) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs):
"""Prepare inputs for generating sequences using the model.
Args:
input_ids (torch.Tensor): Input token ids.
past_key_values (list[torch.Tensor], optional): List of past key
and value states.
attention_mask (torch.Tensor, optional): Mask indicating which
tokens should be attended to.
inputs_embeds (torch.FloatTensor, optional): Optionally,
the input embeddings instead of token ids.
Returns:
dict: Dictionary containing prepared inputs for model generation.
"""
if past_key_values:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get('position_ids', None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if `inputs_embeds` are passed,
# we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
model_inputs = {'input_ids': input_ids}
model_inputs.update({
'position_ids': position_ids,
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache'),
'attention_mask': attention_mask,
})
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
"""Reorder cached past key-values during generation using beam search.
This function reorders the cached past key-values according to the
given indices. It's useful in beam search where the order of hypotheses
can change from one time-step to another.
"""
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(
past_state.index_select(0, beam_idx)
for past_state in layer_past), )
return reordered_past
def build_inputs(self,
tokenizer,
query: str,
history: List[Tuple[str, str]] = []):
"""Builds the input for the model."""
prompt = ''
for record in history:
prompt += f"""<|User|>:{record[0]}<eoh>\n<|Bot|>:{record[1]}<eoa>\n""" # noqa: E501
prompt += f"""<|User|>:{query}<eoh>\n<|Bot|>:"""
return tokenizer([prompt], return_tensors='pt')
@torch.no_grad()
def chat(self,
tokenizer,
query: str,
history: List[Tuple[str, str]] = [],
streamer: Optional[BaseStreamer] = None,
max_new_tokens: int = 1024,
do_sample: bool = True,
temperature: float = 0.8,
top_p: float = 0.8,
**kwargs):
"""Provides a chatting functionality for the model."""
inputs = self.build_inputs(tokenizer, query, history)
inputs = {
k: v.to(self.device)
for k, v in inputs.items() if torch.is_tensor(v)
}
outputs = self.generate(**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
**kwargs)
outputs = outputs[0].cpu().tolist()[len(inputs['input_ids'][0]):]
response = tokenizer.decode(outputs, skip_special_tokens=True)
response = response.split('<eoa>')[0]
history = history + [(query, response)]
return response, history
@torch.no_grad()
def stream_chat(self,
tokenizer,
query: str,
history: List[Tuple[str, str]] = [],
max_new_tokens: int = 1024,
do_sample: bool = True,
temperature: float = 0.8,
top_p: float = 0.8,
**kwargs):
"""Return a generator in format: (response, history) Eg.
('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')]) ('你好,有什么可以帮助您的吗?', [('你好',
'你好,有什么可以帮助您的吗?')])
"""
response_queue = queue.Queue(maxsize=20)
class ChatStreamer(BaseStreamer):
def __init__(self, tokenizer) -> None:
super().__init__()
self.tokenizer = tokenizer
self.queue = response_queue
self.query = query
self.history = history
self.response = ''
self.received_inputs = False
self.queue.put(
(self.response, history + [(self.query, self.response)]))
def put(self, value):
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError('ChatStreamer only supports batch size 1')
elif len(value.shape) > 1:
value = value[0]
if not self.received_inputs:
# The first received value is input_ids, ignore here
self.received_inputs = True
return
token = self.tokenizer.decode([value[-1]],
skip_special_tokens=True)
if token.strip() != '<eoa>':
self.response = self.response + token
history = self.history + [(self.query, self.response)]
self.queue.put((self.response, history))
def end(self):
self.queue.put(None)
def stream_producer():
return self.chat(tokenizer=tokenizer,
query=query,
streamer=ChatStreamer(tokenizer=tokenizer),
history=history,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
**kwargs)
def consumer():
producer = threading.Thread(target=stream_producer)
producer.start()
while True:
res = response_queue.get()
if res is None:
return
yield res
return consumer()
@add_start_docstrings(
""" # noqa: E501
The InternLM Model transformer with a sequence classification head on top (linear layer).
[`InternLMForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
INTERNLM_START_DOCSTRING,
)
class InternLMForSequenceClassification(InternLMPreTrainedModel):
_keys_to_ignore_on_load_missing = [r'lm_head.weight']
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = InternLMModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(INTERNLM_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r""" # noqa: E501
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = (return_dict if return_dict is not None else
self.config.use_return_dict)
transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError(
'Cannot handle batch sizes > 1 if no padding token is defined.'
)
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (
torch.ne(input_ids, self.config.pad_token_id).sum(-1) -
1).to(logits.device)
else:
sequence_lengths = -1
pooled_logits = logits[torch.arange(batch_size, device=logits.device),
sequence_lengths]
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = 'regression'
elif self.num_labels > 1 and (labels.dtype == torch.long
or labels.dtype == torch.int):
self.config.problem_type = 'single_label_classification'
else:
self.config.problem_type = 'multi_label_classification'
if self.config.problem_type == 'regression':
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == 'single_label_classification':
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels),
labels.view(-1))
elif self.config.problem_type == 'multi_label_classification':
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits, ) + transformer_outputs[1:]
return ((loss, ) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
# # Copyright (c) InternLM. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch InternLM2 model."""
import math
import queue
import threading
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from einops import rearrange
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings)
from lmdeploy.pytorch.modeling.convert_to_qmodules import convert_to_qmodules
from lmdeploy.utils import get_logger
try:
from transformers.generation.streamers import BaseStreamer
except: # noqa # pylint: disable=bare-except
BaseStreamer = None
from .configuration_internlm2 import InternLM2Config
logger = get_logger('lmdeploy')
_CONFIG_FOR_DOC = 'InternLM2Config'
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0):
"""Make causal mask used for bi-directional self-attention."""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len),
torch.tensor(torch.finfo(dtype).min, device=device),
device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([
torch.zeros(
tgt_len, past_key_values_length, dtype=dtype, device=device),
mask
],
dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len,
tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor,
dtype: torch.dtype,
tgt_len: Optional[int] = None):
"""Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len,
src_seq_len]`."""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len,
src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool),
torch.finfo(dtype).min)
class InternLM2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""InternLM2RMSNorm is equivalent to T5LayerNorm."""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance +
self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class InternLM2RotaryEmbedding(nn.Module):
def __init__(self,
dim,
max_position_embeddings=2048,
base=10000,
device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base**(
torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer('inv_freq', inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(seq_len=max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.get_default_dtype())
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached,
device=device,
dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer('cos_cached',
emb.cos().to(dtype),
persistent=False)
self.register_buffer('sin_cached',
emb.sin().to(dtype),
persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len,
device=x.device,
dtype=torch.float32)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
"""InternLM2RotaryEmbedding extended with linear scaling.
Credits to the Reddit user /u/kaiokendev
"""
def __init__(self,
dim,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached,
device=device,
dtype=self.inv_freq.dtype)
t = t / self.scaling_factor
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer('cos_cached',
emb.cos().to(dtype),
persistent=False)
self.register_buffer('sin_cached',
emb.sin().to(dtype),
persistent=False)
class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
"""InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
Credits to the Reddit users /u/bloc97 and /u/emozilla.
"""
def __init__(self,
dim,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
if seq_len > self.max_position_embeddings:
base = self.base * ((self.scaling_factor * seq_len /
self.max_position_embeddings) -
(self.scaling_factor - 1))**(self.dim /
(self.dim - 2))
inv_freq = 1.0 / (base**(
torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer('inv_freq', inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached,
device=device,
dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer('cos_cached',
emb.cos().to(dtype),
persistent=False)
self.register_buffer('sin_cached',
emb.sin().to(dtype),
persistent=False)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class InternLM2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.w1 = nn.Linear(self.hidden_size,
self.intermediate_size,
bias=False)
self.w3 = nn.Linear(self.hidden_size,
self.intermediate_size,
bias=False)
self.w2 = nn.Linear(self.intermediate_size,
self.hidden_size,
bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x))
return down_proj
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""This is the equivalent of torch.repeat_interleave(x, dim=1,
repeats=n_rep).
The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
(batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :,
None, :, :].expand(batch,
num_key_value_heads,
n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
head_dim)
class InternLM2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper."""
def __init__(self, config: InternLM2Config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
f' and `num_heads`: {self.num_heads}).')
self.wqkv = nn.Linear(
self.hidden_size,
(self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
bias=config.bias,
)
self.wo = nn.Linear(self.num_heads * self.head_dim,
self.hidden_size,
bias=config.bias)
self._init_rope()
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = InternLM2RotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.config.rope_theta,
)
else:
scaling_type = self.config.rope_scaling['type']
scaling_factor = self.config.rope_scaling['factor']
if scaling_type == 'dynamic':
self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.config.rope_theta,
scaling_factor=scaling_factor)
elif scaling_type == 'linear':
self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.config.rope_theta,
scaling_factor=scaling_factor)
else:
raise ValueError(
"Currently we only support rotary embedding's type being 'dynamic' or 'linear'."
)
return self.rotary_emb
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
if 'padding_mask' in kwargs:
warnings.warn(
'Passing `padding_mask` is deprecated and will be removed in v4.37. '
'Please make sure use `attention_mask` instead.`')
bsz, q_len, _ = hidden_states.size()
qkv_states = self.wqkv(hidden_states)
qkv_states = rearrange(
qkv_states,
'b q (h gs d) -> b q h gs d',
gs=2 + self.num_key_value_groups,
d=self.head_dim,
)
query_states = qkv_states[..., :self.num_key_value_groups, :]
query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
key_states = qkv_states[..., -2, :]
value_states = qkv_states[..., -1, :]
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(
2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is'
f' {attn_weights.size()}')
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights,
dim=-1,
dtype=torch.float32).to(
query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
f' {attn_output.size()}')
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.wo(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class InternLM2FlashAttention2(InternLM2Attention):
"""InternLM2 flash attention module.
This module inherits from `InternLM2Attention` as the weights of the module
stays untouched. The only required change would be on the forward pass
where it needs to correctly call the public API of flash attention and deal
with padding tokens in case the input contains any of them.
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
# InternLM2FlashAttention2 attention does not support output_attentions
if 'padding_mask' in kwargs:
warnings.warn(
'Passing `padding_mask` is deprecated and will be removed in v4.37. '
'Please make sure use `attention_mask` instead.`')
# overwrite attention_mask with padding_mask
attention_mask = kwargs.pop('padding_mask')
output_attentions = False
bsz, q_len, _ = hidden_states.size()
qkv_states = self.wqkv(hidden_states)
qkv_states = rearrange(
qkv_states,
'b q (h gs d) -> b q h gs d',
gs=self.num_heads + 2 * self.num_key_value_heads,
d=self.head_dim,
q=q_len,
)
query_states = qkv_states[..., :self.num_key_value_groups, :]
query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
key_states = qkv_states[..., -2, :]
value_states = qkv_states[..., -1, :]
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = 0.0 if not self.training else self.attention_dropout
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (InternLM2RMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
# Handle the case where the model is quantized
if hasattr(self.config, '_pre_quantization_dtype'):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f'The input hidden states seems to be silently casted in float32, this might be related to'
f' the fact you have upcasted embedding or layer norm layers in float32. We will cast back '
f'the input in {target_dtype}.')
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward(query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate)
attn_output = attn_output.reshape(bsz, q_len,
self.hidden_size).contiguous()
attn_output = self.wo(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class InternLM2DecoderLayer(nn.Module):
def __init__(self, config: InternLM2Config):
super().__init__()
self.hidden_size = config.hidden_size
self.attention = (InternLM2Attention(config=config) if
not getattr(config, '_flash_attn_2_enabled', False)
else InternLM2FlashAttention2(config=config))
self.feed_forward = InternLM2MLP(config)
self.attention_norm = InternLM2RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.ffn_norm = InternLM2RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor,
torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
if 'padding_mask' in kwargs:
warnings.warn(
'Passing `padding_mask` is deprecated and will be removed in v4.37. '
'Please make sure use `attention_mask` instead.`')
residual = hidden_states
hidden_states = self.attention_norm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.attention(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.ffn_norm(hidden_states)
hidden_states = self.feed_forward(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states, )
if output_attentions:
outputs += (self_attn_weights, )
if use_cache:
outputs += (present_key_value, )
return outputs
InternLM2_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`InternLM2Config`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.',
InternLM2_START_DOCSTRING,
)
class InternLM2PreTrainedModel(PreTrainedModel):
config_class = InternLM2Config
base_model_prefix = 'model'
supports_gradient_checkpointing = True
_no_split_modules = ['InternLM2DecoderLayer']
_skip_keys_device_placement = 'past_key_values'
_supports_flash_attn_2 = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
InternLM2_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or
when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.',
InternLM2_START_DOCSTRING,
)
class InternLM2Model(InternLM2PreTrainedModel):
"""Transformer decoder consisting of *config.num_hidden_layers* layers.
Each layer is a [`InternLM2DecoderLayer`]
Args:
config: InternLM2Config
"""
_auto_class = 'AutoModel'
def __init__(self, config: InternLM2Config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.tok_embeddings = nn.Embedding(config.vocab_size,
config.hidden_size,
self.padding_idx)
self.layers = nn.ModuleList([
InternLM2DecoderLayer(config)
for _ in range(config.num_hidden_layers)
])
self.norm = InternLM2RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.tok_embeddings
def set_input_embeddings(self, value):
self.tok_embeddings = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape,
inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask,
inputs_embeds.dtype,
tgt_len=input_shape[-1]).to(
inputs_embeds.device)
combined_attention_mask = (expanded_attn_mask
if combined_attention_mask is None else
expanded_attn_mask +
combined_attention_mask)
return combined_attention_mask
@add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else
self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
'You cannot specify both input_ids and inputs_embeds at the same time'
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError(
'You have to specify either input_ids or inputs_embeds')
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device)
position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.tok_embeddings(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds,
past_key_values_length)
# embed positions
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
'`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states, )
past_key_value = past_key_values[
idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (
layer_outputs[2 if output_attentions else 1], )
if output_attentions:
all_self_attns += (layer_outputs[1], )
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states, )
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v for v in
[hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class InternLM2ForCausalLM(InternLM2PreTrainedModel):
_auto_class = 'AutoModelForCausalLM'
_tied_weights_keys = ['output.weight']
def __init__(self, config):
super().__init__(config)
self.model = InternLM2Model(config)
self.vocab_size = config.vocab_size
self.output = nn.Linear(config.hidden_size,
config.vocab_size,
bias=False)
# Initialize weights and apply final processing
self.post_init()
convert_to_qmodules(self)
def get_input_embeddings(self):
return self.model.tok_embeddings
def set_input_embeddings(self, value):
self.model.tok_embeddings = value
def get_output_embeddings(self):
return self.output
def set_output_embeddings(self, new_embeddings):
self.output = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast,
config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, InternLM2ForCausalLM
>>> model = InternLM2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else
self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.output(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits, ) + outputs[1:]
return (loss, ) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
position_ids = kwargs.get('position_ids', None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1]:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
model_inputs = {'input_ids': input_ids}
model_inputs.update({
'position_ids': position_ids,
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache'),
'attention_mask': attention_mask,
})
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(
past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past), )
return reordered_past
def build_inputs(self,
tokenizer,
query: str,
history: List[Tuple[str, str]] = []):
prompt = ''
for record in history:
prompt += f"""<|User|>:{record[0]}<eoh>\n<|Bot|>:{record[1]}<eoa>\n"""
prompt += f"""<|User|>:{query}<eoh>\n<|Bot|>:"""
return tokenizer([prompt], return_tensors='pt')
@torch.no_grad()
def chat(
self,
tokenizer,
query: str,
history: List[Tuple[str, str]] = [],
streamer: Optional[BaseStreamer] = None,
max_new_tokens: int = 1024,
do_sample: bool = True,
temperature: float = 0.8,
top_p: float = 0.8,
**kwargs,
):
inputs = self.build_inputs(tokenizer, query, history)
inputs = {
k: v.to(self.device)
for k, v in inputs.items() if torch.is_tensor(v)
}
outputs = self.generate(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
**kwargs,
)
outputs = outputs[0].cpu().tolist()[len(inputs['input_ids'][0]):]
response = tokenizer.decode(outputs, skip_special_tokens=True)
response = response.split('<eoa>')[0]
history = history + [(query, response)]
return response, history
@torch.no_grad()
def stream_chat(
self,
tokenizer,
query: str,
history: List[Tuple[str, str]] = [],
max_new_tokens: int = 1024,
do_sample: bool = True,
temperature: float = 0.8,
top_p: float = 0.8,
**kwargs,
):
"""Return a generator in format: (response, history) Eg.
('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')]) ('你好,有什么可以帮助您的吗?', [('你好',
'你好,有什么可以帮助您的吗?')])
"""
if BaseStreamer is None:
raise ModuleNotFoundError(
'The version of `transformers` is too low. Please make sure '
'that you have installed `transformers>=4.28.0`.')
response_queue = queue.Queue(maxsize=20)
class ChatStreamer(BaseStreamer):
def __init__(self, tokenizer) -> None:
super().__init__()
self.tokenizer = tokenizer
self.queue = response_queue
self.query = query
self.history = history
self.response = ''
self.received_inputs = False
self.queue.put(
(self.response, history + [(self.query, self.response)]))
def put(self, value):
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError('ChatStreamer only supports batch size 1')
elif len(value.shape) > 1:
value = value[0]
if not self.received_inputs:
# The first received value is input_ids, ignore here
self.received_inputs = True
return
token = self.tokenizer.decode([value[-1]],
skip_special_tokens=True)
if token.strip() != '<eoa>':
self.response = self.response + token
history = self.history + [(self.query, self.response)]
self.queue.put((self.response, history))
def end(self):
self.queue.put(None)
def stream_producer():
return self.chat(
tokenizer=tokenizer,
query=query,
streamer=ChatStreamer(tokenizer=tokenizer),
history=history,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
**kwargs,
)
def consumer():
producer = threading.Thread(target=stream_producer)
producer.start()
while True:
res = response_queue.get()
if res is None:
return
yield res
return consumer()
@add_start_docstrings(
"""
The InternLM2 Model transformer with a sequence classification head on top (linear layer).
[`InternLM2ForSequenceClassification`] uses the last token in order to do the classification,
as other causal models (e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
InternLM2_START_DOCSTRING,
)
class InternLM2ForSequenceClassification(InternLM2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = InternLM2Model(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.tok_embeddings
def set_input_embeddings(self, value):
self.model.tok_embeddings = value
@add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError(
'Cannot handle batch sizes > 1 if no padding token is defined.'
)
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(
input_ids, self.config.pad_token_id).int().argmax(-1) -
1).to(logits.device)
else:
sequence_lengths = -1
pooled_logits = logits[torch.arange(batch_size, device=logits.device),
sequence_lengths]
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = 'regression'
elif self.num_labels > 1 and (labels.dtype == torch.long
or labels.dtype == torch.int):
self.config.problem_type = 'single_label_classification'
else:
self.config.problem_type = 'multi_label_classification'
if self.config.problem_type == 'regression':
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == 'single_label_classification':
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels),
labels.view(-1))
elif self.config.problem_type == 'multi_label_classification':
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits, ) + transformer_outputs[1:]
return ((loss, ) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch LLaMA model."""
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast)
from transformers.modeling_utils import PreTrainedModel
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.utils import (add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings)
from lmdeploy.pytorch.modeling.convert_to_qmodules import convert_to_qmodules
from lmdeploy.utils import get_logger
logger = get_logger('lmdeploy')
_CONFIG_FOR_DOC = 'LlamaConfig'
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0):
"""Make causal mask used for bi-directional self-attention."""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len),
torch.finfo(dtype).min,
device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([
torch.zeros(
tgt_len, past_key_values_length, dtype=dtype, device=device),
mask
],
dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len,
tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor,
dtype: torch.dtype,
tgt_len: Optional[int] = None):
"""Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len,
src_seq_len]`."""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len,
src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool),
torch.finfo(dtype).min)
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""LlamaRMSNorm is equivalent to T5LayerNorm."""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance +
self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class LlamaRotaryEmbedding(torch.nn.Module):
"""RotaryEmbedding for Llama Model.
This module generates sine and cosine positional encodings based on
the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding".
The purpose of this class is to provide positional embeddings to the
input tensors. It utilizes a cache mechanism to store precomputed
sine and cosine values for speedup.
Args:
dim (int): The dimensionality of the embeddings.
max_position_embeddings (int, optional): The maximum number of
position embeddings. Default is 2048.
base (int, optional): The base value for the inverse frequency
calculation. Default is 10000.
device (str, optional): The device to run operations on.
If None, defaults to the device of the model.
"""
def __init__(self,
dim,
max_position_embeddings=2048,
base=10000,
device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base**(
torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer('inv_freq', inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(seq_len=max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.get_default_dtype())
def _set_cos_sin_cache(self, seq_len, device, dtype):
"""Sets the cached sine and cosine values for the specified sequence
length.
Args:
seq_len (int): The sequence length for which to set the cache.
device (str): The device to use for computation.
dtype (torch.dtype): The data type to be used for tensors.
"""
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached,
device=device,
dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order
# to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer('cos_cached',
emb.cos()[None, None, :, :].to(dtype),
persistent=False)
self.register_buffer('sin_cached',
emb.sin()[None, None, :, :].to(dtype),
persistent=False)
def forward(self, x, seq_len=None):
"""Forward propagation method for the embedding layer. Generates
positional embeddings for the given input tensor.
If the sequence length is larger than the cache, it resets the cache.
Args:
x (torch.Tensor): Input tensor of shape
[batch_size, num_attention_heads, seq_len, head_size].
seq_len (int, optional): Sequence length. If None, it is obtained
from `x`.
Returns:
tuple: Tuple containing cosine and sine positional embeddings.
"""
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len,
device=x.device,
dtype=x.dtype)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""This class extends the `LlamaRotaryEmbedding` with linear scaling.
It provides a mechanism for adjusting the scale of the positional
embeddings by dividing the tensor generated by the range of sequence length
with a scaling factor. This is useful when dealing with sequences of
varying lengths.
Credits to Reddit User /u/kaiokendev for this extension.
Args:
dim (int): The dimensionality of the embeddings.
max_position_embeddings (int, optional): The maximum number of
position embeddings. Default is 2048.
base (int, optional): The base value for the inverse frequency
calculation. Default is 10000.
device (str, optional): The device to run operations on. If None,
defaults to the device of the model.
scaling_factor (float, optional): Scaling factor used in adjusting
the scale of positional embeddings. Default is 1.0.
"""
def __init__(self,
dim,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
"""Sets the cached sine and cosine values for the specified sequence
length.
Args:
seq_len (int): The sequence length for which to set the cache.
device (str): The device to use for computation.
dtype (torch.dtype): The data type to use for tensors.
"""
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached,
device=device,
dtype=self.inv_freq.dtype)
t = t / self.scaling_factor
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order
# to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer('cos_cached',
emb.cos()[None, None, :, :].to(dtype),
persistent=False)
self.register_buffer('sin_cached',
emb.sin()[None, None, :, :].to(dtype),
persistent=False)
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling.
Credits to the Reddit users /u/bloc97 and /u/emozilla
"""
def __init__(self,
dim,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
if seq_len > self.max_position_embeddings:
base = self.base * ((self.scaling_factor * seq_len /
self.max_position_embeddings) -
(self.scaling_factor - 1))**(self.dim /
(self.dim - 2))
inv_freq = 1.0 / (base**(
torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer('inv_freq', inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached,
device=device,
dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order
# to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer('cos_cached',
emb.cos()[None, None, :, :].to(dtype),
persistent=False)
self.register_buffer('sin_cached',
emb.sin()[None, None, :, :].to(dtype),
persistent=False)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
"""Apply rotary positional embeddings to query and key tensors.
This function applies the cosine and sine positional embeddings on the
input query (q) and key (k) tensors using element-wise multiplication and
addition.
"""
# The first two dimensions of cos and sin are always 1,
# so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class LlamaMLP(nn.Module):
"""MLP for Llama Model."""
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size,
self.intermediate_size,
bias=False)
self.up_proj = nn.Linear(self.hidden_size,
self.intermediate_size,
bias=False)
self.down_proj = nn.Linear(self.intermediate_size,
self.hidden_size,
bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
if self.config.pretraining_tp > 1:
slice = self.intermediate_size // self.config.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
gate_proj = torch.cat([
F.linear(x, gate_proj_slices[i])
for i in range(self.config.pretraining_tp)
],
dim=-1)
up_proj = torch.cat([
F.linear(x, up_proj_slices[i])
for i in range(self.config.pretraining_tp)
],
dim=-1)
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(
slice, dim=2)
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i])
for i in range(self.config.pretraining_tp)
]
down_proj = sum(down_proj)
else:
down_proj = self.down_proj(
self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""This is the equivalent of torch.repeat_interleave(x, dim=1,
repeats=n_rep).
The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
(batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :,
None, :, :].expand(batch,
num_key_value_heads,
n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
head_dim)
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper."""
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError('hidden_size must be divisible by num_heads '
f'(got `hidden_size`: {self.hidden_size}'
f' and `num_heads`: {self.num_heads}).')
self.q_proj = nn.Linear(self.hidden_size,
self.num_heads * self.head_dim,
bias=False)
self.k_proj = nn.Linear(self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=False)
self.v_proj = nn.Linear(self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim,
self.hidden_size,
bias=False)
self._init_rope()
def _init_rope(self):
"""Initialize the Rotary Embedding Module."""
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling['type']
scaling_factor = self.config.rope_scaling['factor']
if scaling_type == 'linear':
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == 'dynamic':
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
else:
raise ValueError(f'Unknown RoPE scaling type {scaling_type}')
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Forward propagation method for the attention layer."""
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads *
self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp,
dim=0)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [
F.linear(hidden_states, query_slices[i])
for i in range(self.config.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i])
for i in range(self.config.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i])
for i in range(self.config.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(
2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
'Attention weights should be of size '
f'{(bsz, self.num_heads, q_len, kv_seq_len)}, but is'
f' {attn_weights.size()}')
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError('Attention mask should be of size '
f'{(bsz, 1, q_len, kv_seq_len)}, '
f'but is {attention_mask.size()}')
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights,
dim=-1,
dtype=torch.float32).to(
query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
'`attn_output` should be of size '
f'{(bsz, self.num_heads, q_len, self.head_dim)}, but is'
f' {attn_output.size()}')
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size //
self.config.pretraining_tp,
dim=2)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.config.pretraining_tp)
])
else:
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class LlamaDecoderLayer(nn.Module):
"""Decoder layer for Llama Model."""
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config=config)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor,
torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape
`(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask
of size `(batch, 1, tgt_len, src_len)` where padding elements
are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all
attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are
returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached
past key and value projection states
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states, )
if output_attentions:
outputs += (self_attn_weights, )
if use_cache:
outputs += (present_key_value, )
return outputs
LLAMA_START_DOCSTRING = r""" # noqa: E501
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`LlamaConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
'The bare LLaMA Model outputting raw hidden-states without any specific head on top.', # noqa: E501
LLAMA_START_DOCSTRING,
)
class LlamaPreTrainedModel(PreTrainedModel):
config_class = LlamaConfig
base_model_prefix = 'model'
supports_gradient_checkpointing = True
_no_split_modules = ['LlamaDecoderLayer']
_skip_keys_device_placement = 'past_key_values'
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, LlamaModel):
module.gradient_checkpointing = value
LLAMA_INPUTS_DOCSTRING = r""" # noqa: E501
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
'The bare LLaMA Model outputting raw hidden-states without any specific head on top.', # noqa: E501
LLAMA_START_DOCSTRING,
)
class LlamaModel(LlamaPreTrainedModel):
"""Transformer decoder consisting of *config.num_hidden_layers* layers.
Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
self.padding_idx)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
])
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask # noqa
def _prepare_decoder_attention_mask(self, attention_mask, input_shape,
inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask,
inputs_embeds.dtype,
tgt_len=input_shape[-1]).to(
inputs_embeds.device)
combined_attention_mask = (expanded_attn_mask
if combined_attention_mask is None else
expanded_attn_mask +
combined_attention_mask)
return combined_attention_mask
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (output_attentions if output_attentions is not None
else self.config.output_attentions)
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else
self.config.output_hidden_states)
use_cache = (use_cache
if use_cache is not None else self.config.use_cache)
return_dict = (return_dict if return_dict is not None else
self.config.use_return_dict)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError('You cannot specify both decoder_input_ids'
'and decoder_inputs_embeds at the same time')
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError('You have to specify either decoder_input_ids'
'or decoder_inputs_embeds')
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = (seq_length_with_past +
past_key_values_length)
if position_ids is None:
device = (input_ids.device
if input_ids is not None else inputs_embeds.device)
position_ids = torch.arange(past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds,
past_key_values_length)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
'`use_cache=True` is incompatible with gradient'
' checkpointing. Setting `use_cache=False`...')
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states, )
past_key_value = past_key_values[
idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, past_key_value,
output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (
layer_outputs[2 if output_attentions else 1], )
if output_attentions:
all_self_attns += (layer_outputs[1], )
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states, )
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v for v in
[hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class LlamaForCausalLM(LlamaPreTrainedModel):
"""This class extends the `LlamaPreTrainedModel` to enable causal language
modeling.
It wraps the basic Llama model (`LlamaModel`) and includes a linear layer
as a language model head (`lm_head`). The purpose is to predict token
probabilities, given the previous tokens in the sequence.
"""
_tied_weights_keys = ['lm_head.weight']
def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size,
config.vocab_size,
bias=False)
# Initialize weights and apply final processing
self.post_init()
convert_to_qmodules(self)
def get_input_embeddings(self):
"""Get the token embedding layer."""
return self.model.embed_tokens
def set_input_embeddings(self, value):
"""Set the token embedding layer."""
self.model.embed_tokens = value
def get_output_embeddings(self):
"""Get the output embedding layer."""
return self.lm_head
def set_output_embeddings(self, new_embeddings):
"""Set the output embedding layer."""
self.lm_head = new_embeddings
def set_decoder(self, decoder):
"""Set the decoder model."""
self.model = decoder
def get_decoder(self):
"""Get the decoder model."""
return self.model
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast,
config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r""" # noqa: E501
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = (output_attentions if output_attentions is not None
else self.config.output_attentions)
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else
self.config.output_hidden_states)
return_dict = (return_dict if return_dict is not None else
self.config.use_return_dict)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # noqa: E501
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(
self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [
F.linear(hidden_states, lm_head_slices[i])
for i in range(self.config.pretraining_tp)
]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits, ) + outputs[1:]
return (loss, ) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs):
"""Prepare inputs for generating sequences using the model.
Args:
input_ids (torch.Tensor): Input token ids.
past_key_values (list[torch.Tensor], optional): List of past key
and value states.
attention_mask (torch.Tensor, optional): Mask indicating which
tokens should be attended to.
inputs_embeds (torch.FloatTensor, optional): Optionally,
the input embeddings instead of token ids.
Returns:
dict: Dictionary containing prepared inputs for model generation.
"""
if past_key_values:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get('position_ids', None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if `inputs_embeds` are passed, we only want to use them
# in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
model_inputs = {'input_ids': input_ids}
model_inputs.update({
'position_ids': position_ids,
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache'),
'attention_mask': attention_mask,
})
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
"""Reorder cached past key-values during generation using beam search.
This function reorders the cached past key-values according to the
given indices. It's useful in beam search where the order of hypotheses
can change from one time-step to another.
"""
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(
past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past), )
return reordered_past
@add_start_docstrings(
""" # noqa: E501
The LLaMa Model transformer with a sequence classification head on top (linear layer).
[`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
LLAMA_START_DOCSTRING,
)
class LlamaForSequenceClassification(LlamaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = LlamaModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r""" # noqa: E501
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = (return_dict if return_dict is not None else
self.config.use_return_dict)
transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError(
'Cannot handle batch sizes > 1 if no padding token is defined.'
)
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(
input_ids, self.config.pad_token_id).long().argmax(-1) -
1).to(logits.device)
else:
sequence_lengths = -1
pooled_logits = logits[torch.arange(batch_size, device=logits.device),
sequence_lengths]
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = 'regression'
elif self.num_labels > 1 and (labels.dtype == torch.long
or labels.dtype == torch.int):
self.config.problem_type = 'single_label_classification'
else:
self.config.problem_type = 'multi_label_classification'
if self.config.problem_type == 'regression':
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == 'single_label_classification':
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels),
labels.view(-1))
elif self.config.problem_type == 'multi_label_classification':
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits, ) + transformer_outputs[1:]
return ((loss, ) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
# Copyright (c) OpenMMLab. All rights reserved.
from .patch import patch
from .q_modules import QLinear, QRMSNorm
__all__ = ['patch', 'QLinear', 'QRMSNorm']
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed._tensor import DeviceMesh, Shard, distribute_tensor
from transformers.modeling_outputs import BaseModelOutputWithPast
from ..dist_utils import (colwise_parallelize_linear_fn,
rowwise_parallelize_linear_fn, try_to_local)
from ..kernels import apply_rotary_pos_emb
from ..kernels.alibi_pagedattention import alibi_paged_attention_fwd
from ..kernels.fill_kv_cache import fill_kv_cache
from ..kernels.pagedattention import paged_attention_fwd
class PatchedRMSNorm(nn.Module):
"""Rewrite RMSNorm."""
def forward(self, hidden_states):
"""forward."""
from ..kernels import rms_norm
epsilon = getattr(self, 'epsilon', None)
if epsilon is None:
epsilon = getattr(self, 'variance_epsilon', 1e-10)
ret = rms_norm(hidden_states, self.weight, epsilon)
return ret
def _attention_partition_fn(mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""A function for attention partition."""
def __w_pack_linear_fn(mod: nn.Module):
"""fn for w pack linear."""
for name, param in mod.named_parameters():
param = param.unflatten(0, (3, -1))
dist_tensor = distribute_tensor(param, device_mesh, [Shard(1)])
dist_tensor = try_to_local(dist_tensor)
dist_tensor = dist_tensor.flatten(0, 1)
dist_param = torch.nn.Parameter(dist_tensor)
mod.register_parameter(name, dist_param)
def __w_pack_lora_linear_fn(mod: nn.Module):
"""fn for w pack lora linear."""
mod._tp_mode = 'colwise'
base_layer = mod.base_layer
__w_pack_linear_fn(base_layer)
for lora_a_mod in mod.lora_A.values():
colwise_parallelize_linear_fn(lora_a_mod,
device_mesh=device_mesh,
to_local=True)
for lora_b_mod in mod.lora_B.values():
__w_pack_linear_fn(lora_b_mod)
if mod_name in ['W_pack']:
from peft.tuners.lora import Linear as LoraLinear
if isinstance(mod, LoraLinear):
__w_pack_lora_linear_fn(mod)
else:
__w_pack_linear_fn(mod)
elif mod_name in ['o_proj']:
rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
class Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper."""
@classmethod
def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
return _attention_partition_fn(mod_name, mod, device_mesh)
@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs[0])
return outputs
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite of Attention.forward."""
world_size = 1
if dist.is_initialized():
world_size = dist.get_world_size()
return self._contiguous_batching_forward(
hidden_states,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
world_size=world_size,
)
def _contiguous_batching_forward(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
world_size: int = 1,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite implementation of Attention.forward.
Add continuous batching support. Add paged attention support. TP
support.
"""
assert not output_attentions
context = self.context.context
max_kv_seq_length = context.max_kv_seq_length
kv_seq_length = context.kv_seq_length
q_seq_length = context.q_seq_length
q_start_loc = context.q_start_loc
block_offsets = context.block_offsets
max_q_seq_length = context.max_q_seq_length
num_heads = self.num_heads // world_size
num_kv_heads = self.num_heads // world_size
head_dim = self.head_dim
def _qkv_proj(hidden_states):
"""qkv proj."""
proj = self.W_pack(hidden_states)
return proj.chunk(3, -1)
def _rotary_emb_fn(query_states, key_states, value_states):
if hasattr(self, 'rotary_emb'):
cos, sin = self.rotary_emb(value_states,
seq_len=max_kv_seq_length)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids,
context.position_ids_1d)
return query_states, key_states, value_states
query_states, key_states, value_states = _qkv_proj(hidden_states)
query_states = query_states.view(-1, num_heads, head_dim)
key_states = key_states.view(-1, num_kv_heads, head_dim)
value_states = value_states.view(-1, num_kv_heads, head_dim)
query_states, key_states, value_states = _rotary_emb_fn(
query_states, key_states, value_states)
fill_kv_cache(
key_states,
value_states,
past_key_value[0],
past_key_value[1],
q_start_loc,
q_seq_length,
kv_seq_length=kv_seq_length,
max_q_seq_length=max_q_seq_length,
block_offsets=block_offsets,
)
attn_output = query_states
paged_attention_fwd(
query_states,
past_key_value[0],
past_key_value[1],
attn_output,
block_offsets,
q_start_loc=q_start_loc,
q_seqlens=q_seq_length,
kv_seqlens=kv_seq_length,
max_seqlen=max_q_seq_length,
)
hidden_size = num_heads * head_dim
attn_output = attn_output.reshape(*hidden_states.shape[:-1],
hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
class BaichuanAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper."""
@classmethod
def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
return _attention_partition_fn(mod_name, mod, device_mesh)
@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs[0])
return outputs
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite of BaichuanAttention.forward."""
world_size = 1
if dist.is_initialized():
world_size = dist.get_world_size()
return self._contiguous_batching_forward(
hidden_states,
past_key_value=past_key_value,
output_attentions=output_attentions,
world_size=world_size,
)
def _contiguous_batching_forward(
self,
hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
world_size: int = 1,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite implementation of BaichuanAttention.forward.
Add continuous batching support. Add paged attention support. TP
support.
"""
assert not output_attentions
context = self.context.context
kv_seq_length = context.kv_seq_length
q_seq_length = context.q_seq_length
q_start_loc = context.q_start_loc
block_offsets = context.block_offsets
max_q_seq_length = context.max_q_seq_length
num_heads = self.num_heads // world_size
num_kv_heads = self.num_heads // world_size
head_dim = self.head_dim
def _qkv_proj(hidden_states):
proj = self.W_pack(hidden_states)
return proj.chunk(3, -1)
query_states, key_states, value_states = _qkv_proj(hidden_states)
query_states = query_states.view(-1, num_heads, head_dim)
key_states = key_states.view(-1, num_kv_heads, head_dim)
value_states = value_states.view(-1, num_kv_heads, head_dim)
fill_kv_cache(
key_states,
value_states,
past_key_value[0],
past_key_value[1],
q_start_loc,
q_seq_length,
kv_seq_length=kv_seq_length,
max_q_seq_length=max_q_seq_length,
block_offsets=block_offsets,
)
attn_output = query_states
num_heads_full = num_heads
head_offset = 0
if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
num_heads_full = num_heads * world_size
head_offset = num_heads * rank
alibi_paged_attention_fwd(query_states,
past_key_value[0],
past_key_value[1],
attn_output,
block_offsets,
b_start_loc=q_start_loc,
b_seq_len=q_seq_length,
b_kv_seq_len=kv_seq_length,
max_input_len=max_q_seq_length,
head_offset=head_offset,
num_heads=num_heads_full)
hidden_size = num_heads * head_dim
attn_output = attn_output.reshape(*hidden_states.shape[:-1],
hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
class BaichuanModel(nn.Module):
def _continuous_batching_forward_7b(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""Rewrite implementation of 7b BaichuanModel.forward."""
output_attentions = False
use_cache = True
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# Attention mask is not necessary in continuous batching
attention_mask = None
hidden_states = inputs_embeds
# decoder layers
for idx, decoder_layer in enumerate(self.layers):
past_key_value = past_key_values[idx]
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=None,
attentions=None,
)
def _continuous_batching_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""Rewrite implementation of BaichuanModel.forward.
Add continuous batching support. Add paged attention support. TP
support.
"""
use_cache = False
output_attentions = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# Attention mask is not necessary in continuous batching
attention_mask = None
hidden_states = inputs_embeds
# decoder layers
for idx, decoder_layer in enumerate(self.layers):
past_key_value = past_key_values[idx]
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=None,
attentions=None,
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
):
"""Rewrite of BaichuanModel.forward."""
if position_ids is not None:
return self._continuous_batching_forward_7b(
input_ids,
attention_mask,
position_ids,
past_key_values,
inputs_embeds,
)
else:
return self._continuous_batching_forward(
input_ids,
attention_mask,
past_key_values,
inputs_embeds,
)
# Copyright (c) OpenMMLab. All rights reserved.
# Adapted from https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py # noqa: E501
from typing import List, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.utils.checkpoint
from torch.distributed._tensor import DeviceMesh, Shard, distribute_tensor
from transformers.modeling_outputs import BaseModelOutputWithPast
from ..dist_utils import (colwise_parallelize_linear,
rowwise_parallelize_linear_fn, try_to_local)
from ..kernels import paged_attention_fwd
from .functional import fill_kv_cache
class PatchedRMSNorm(nn.Module):
"""Rewrite RMSNorm."""
def forward(self, hidden_states):
"""forward."""
# torch.nn.functional.normalize based implementation might leads
# to wrong output
from ..kernels import rms_norm
ret = rms_norm(hidden_states.permute(1, 0, 2), self.weight, self.eps)
return ret.permute(1, 0, 2)
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> List[torch.Tensor]:
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
tensor_list = tensor.chunk(num_partitions, dim=-1)
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
def apply_rotary_pos_emb(x: torch.Tensor,
rope_cache: torch.Tensor) -> torch.Tensor:
# x: [sq, b, np, hn]
sq, hn = x.size(0), x.size(-1)
xslice = x[..., :hn // 2]
rope_cache = rope_cache[:sq]
xshaped = xslice.unflatten(-1, (-1, 2))
rope_cache = rope_cache.unsqueeze(2)
# inplace
torch.stack(
[
xshaped[..., 0] * rope_cache[..., 0] -
xshaped[..., 1] * rope_cache[..., 1],
xshaped[..., 1] * rope_cache[..., 0] +
xshaped[..., 0] * rope_cache[..., 1],
],
-1,
out=xshaped,
)
return x
class PatchedSelfAttention(nn.Module):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [s, b, h] and returns output of
the same size.
"""
def _distribute_qkv_linear(self, mod: nn.Module, device_mesh: DeviceMesh):
"""distribute qkv linear."""
sections = [
self.num_attention_heads_per_partition *
self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition *
self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition *
self.hidden_size_per_attention_head,
]
for name, param in mod.named_parameters():
splited_param = param.split(sections, dim=0)
updated_param = []
for p in splited_param:
dist_tensor = distribute_tensor(p, device_mesh, [Shard(0)])
dist_tensor = try_to_local(dist_tensor)
updated_param.append(dist_tensor)
param = torch.cat(updated_param)
dist_param = torch.nn.Parameter(param)
mod.register_parameter(name, dist_param)
def _distribute_qkv_lora_linear(self, module: nn.Module,
device_mesh: DeviceMesh):
"""distribute qkv lora linear."""
to_local = True
self._distribute_qkv_linear(
module.base_layer,
device_mesh=device_mesh,
)
for mod in module.lora_A.values():
colwise_parallelize_linear(mod,
device_mesh=device_mesh,
to_local=to_local)
for mod in module.lora_B.values():
self._distribute_qkv_linear(
mod,
device_mesh=device_mesh,
)
module._tp_mode = 'colwise'
def _distribute_partition_fn(self, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
if mod_name in ['query_key_value']:
from peft.tuners.lora import Linear as LoraLinear
if isinstance(mod, LoraLinear):
self._distribute_qkv_lora_linear(mod, device_mesh)
else:
self._distribute_qkv_linear(mod, device_mesh)
elif mod_name in ['dense']:
rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs[0])
return outputs
def _contiguous_batching_forward(
self,
hidden_states: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
kv_cache: Optional[Tuple[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
# =====================
# Query, Key, and Value
# =====================
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
world_size = 1
if dist.is_initialized():
world_size = dist.get_world_size()
context = self.context.context
max_q_seq_length = context.max_q_seq_length
q_start_loc = context.q_start_loc
q_seq_length = context.q_seq_length
kv_seq_length = context.kv_seq_length
block_offsets = context.block_offsets
mixed_x_layer = self.query_key_value(hidden_states)
if self.multi_query_attention:
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
[
self.num_attention_heads_per_partition *
self.hidden_size_per_attention_head // world_size,
self.num_multi_query_groups_per_partition *
self.hidden_size_per_attention_head // world_size,
self.num_multi_query_groups_per_partition *
self.hidden_size_per_attention_head // world_size,
],
dim=-1,
)
query_layer = query_layer.unflatten(
-1, (-1, self.hidden_size_per_attention_head))
key_layer = key_layer.unflatten(
-1, (-1, self.hidden_size_per_attention_head))
value_layer = value_layer.unflatten(
-1, (-1, self.hidden_size_per_attention_head))
else:
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition // world_size,
3 * self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer,
value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
# apply relative positional encoding (rotary embedding)
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
# [b, sq, np, hn]
query_layer, key_layer, value_layer = [
k.transpose(0, 1) for k in [query_layer, key_layer, value_layer]
]
# adjust key and value for inference
cache_k, cache_v = kv_cache
fill_kv_cache(
key_layer[0],
value_layer[0],
cache_k,
cache_v,
q_start_loc,
q_seq_length,
kv_seq_length=kv_seq_length,
max_q_seq_length=max_q_seq_length,
block_offsets=block_offsets,
)
# ==================================
# core attention computation
# ==================================
context_layer = query_layer
paged_attention_fwd(query_layer,
cache_k,
cache_v,
context_layer,
block_offsets,
q_start_loc=q_start_loc,
q_seqlens=q_seq_length,
kv_seqlens=kv_seq_length,
max_seqlen=max_q_seq_length)
context_layer = context_layer.transpose(1, 0).flatten(-2)
# =================
# Output. [sq, b, h]
# =================
output = self.dense(context_layer)
return output, kv_cache
def forward(
self,
hidden_states,
attention_mask,
rotary_pos_emb,
kv_cache=None,
use_cache=True,
output_attentions=False,
):
return self._contiguous_batching_forward(
hidden_states,
rotary_pos_emb,
kv_cache,
)
class MLP(nn.Module):
@classmethod
def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
if mod_name in ['dense_h_to_4h']:
for name, param in mod.named_parameters():
dist_tensor = distribute_tensor(param.unflatten(0, (2, -1)),
device_mesh, [Shard(1)])
dist_tensor = try_to_local(dist_tensor)
dist_param = torch.nn.Parameter(dist_tensor.flatten(0, 1))
mod.register_parameter(name, dist_param)
elif mod_name in ['dense_4h_to_h']:
rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs)
return outputs
class PatchedChatGLMModel(nn.Module):
def _contiguous_batching_forward(
self,
input_ids,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
full_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor],
...]] = None,
inputs_embeds: Optional[torch.Tensor] = None):
output_hidden_states = False
use_cache = True
batch_size, seq_length = input_ids.shape
if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)
if self.pre_seq_len is not None:
if past_key_values is None:
past_key_values = self.get_prompt(batch_size=batch_size,
device=input_ids.device,
dtype=inputs_embeds.dtype)
# Rotary positional embeddings
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
if position_ids is not None:
context = self.context.context
position_ids_1d = context.position_ids_1d
rotary_pos_emb = rotary_pos_emb[position_ids_1d[None]]
else:
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
# Run encoder.
(hidden_states, presents, all_hidden_states,
all_self_attentions) = self.encoder(
inputs_embeds,
full_attention_mask,
rotary_pos_emb=rotary_pos_emb,
kv_caches=past_key_values,
use_cache=use_cache,
output_hidden_states=output_hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
def forward(
self,
input_ids,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
full_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor],
...]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
return self._contiguous_batching_forward(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
full_attention_mask=full_attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed._tensor import DeviceMesh
from ..dist_utils import (colwise_parallelize_linear_fn,
rowwise_parallelize_linear_fn)
from ..kernels import apply_rotary_pos_emb, fill_kv_cache, paged_attention_fwd
class PatchedDeepseekAttention(nn.Module):
@classmethod
def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
if mod_name in ['q_proj', 'k_proj', 'v_proj']:
colwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
elif mod_name in ['o_proj']:
rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs[0])
return outputs
def _contiguous_batching_forward_impl(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
world_size: int = 1,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite implementation of forward.
Add continuous batching support. Add paged attention support. TP
support.
"""
context = self.context.context
kv_seq_length = context.kv_seq_length
q_seq_length = context.q_seq_length
q_start_loc = context.q_start_loc
block_offsets = context.block_offsets
max_q_seq_length = context.max_q_seq_length
max_kv_seq_length = context.max_kv_seq_length
num_heads = self.num_heads // world_size
num_kv_heads = self.num_key_value_heads // world_size
head_dim = self.head_dim
hidden_size = num_heads * head_dim
def __qkv_proj(hidden_states):
"""qkv proj."""
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
return query_states, key_states, value_states
def __rotary_emb_fn(query_states, key_states, value_states):
if hasattr(self, 'rotary_emb'):
cos, sin = self.rotary_emb(value_states,
seq_len=max_kv_seq_length)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids,
context.position_ids_1d)
return query_states, key_states, value_states
query_states, key_states, value_states = __qkv_proj(hidden_states)
query_states = query_states.view(-1, num_heads, head_dim)
key_states = key_states.view(-1, num_kv_heads, head_dim)
value_states = value_states.view(-1, num_kv_heads, head_dim)
query_states, key_states, value_states = __rotary_emb_fn(
query_states, key_states, value_states)
fill_kv_cache(
key_states,
value_states,
past_key_value[0],
past_key_value[1],
q_start_loc,
q_seq_length,
kv_seq_length=kv_seq_length,
max_q_seq_length=max_q_seq_length,
block_offsets=block_offsets,
)
attn_output = query_states
paged_attention_fwd(
query_states,
past_key_value[0],
past_key_value[1],
attn_output,
block_offsets,
q_start_loc=q_start_loc,
q_seqlens=q_seq_length,
kv_seqlens=kv_seq_length,
max_seqlen=max_q_seq_length,
)
attn_output = attn_output.reshape(*hidden_states.shape[:-1],
hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""forward."""
world_size = 1
if dist.is_initialized():
world_size = dist.get_world_size()
return self._contiguous_batching_forward_impl(
hidden_states,
position_ids,
past_key_value,
output_attentions,
world_size=world_size,
)
# Copyright (c) OpenMMLab. All rights reserved.
# Adapted from:
# https://huggingface.co/tiiuae/falcon-7b-instruct
# https://github.com/huggingface/transformers/blob/v4.33-release/src/transformers/models/falcon/modeling_falcon.py # noqa
from typing import Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.utils.checkpoint
from torch.distributed._tensor import DeviceMesh
from transformers.modeling_outputs import \
BaseModelOutputWithPastAndCrossAttentions
from ..dist_utils import (colwise_parallelize_linear_fn,
rowwise_parallelize_linear_fn)
from ..kernels import (alibi_paged_attention_fwd, apply_rotary_pos_emb,
fill_kv_cache, fused_rotary_emb, paged_attention_fwd)
class PatchedFalconAttention(nn.Module):
# @classmethod
def _distribute_partition_fn(self, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
world_size = dist.get_world_size()
if mod_name in ['query_key_value']:
if self.new_decoder_architecture:
# e.g. 40b-instruct, GQA
# split qkv across groups
# no finer-grained partitioning
mod.weight.data = mod.weight.reshape(
-1, # num groups
(self.num_heads + self.num_kv_heads * 2) * self.head_dim,
self.hidden_size,
)
elif self.multi_query:
# e.g. 7b-instruct, MQA
# split to q, copy kv
weight = mod.weight.unflatten(0, (-1, self.head_dim))
q_weight = weight[:self.num_heads]
kv_weight = weight[-2:]
q_weight_shards = q_weight.chunk(world_size, 0)
weight_shards = []
for q in q_weight_shards:
# only shard q heads but
# copy single k/v head to all ranks
weight_shards.append(q)
weight_shards.append(kv_weight)
mod.weight.data = torch.cat(weight_shards, dim=0)
# here we keep the weight to be 3D,
# so that column parallel will split it
# into integer-numbered heads
# no bias for 7b-instruct and 40b-instruct
colwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
if self.new_decoder_architecture or self.multi_query:
# return to 2D for later matmul
mod.weight.data = mod.weight.data.reshape(-1, self.hidden_size)
elif mod_name in ['dense']:
if self.new_decoder_architecture:
# e.g. 40b-instruct, GQA
mod.weight.data = mod.weight.reshape(
self.hidden_size,
-1, # num groups
self.num_heads * self.head_dim,
)
elif self.multi_query:
# e.g. 7b-instruct, MQA
mod.weight.data = mod.weight.reshape(self.hidden_size, -1,
self.head_dim)
rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
if self.new_decoder_architecture or self.multi_query:
mod.weight.data = mod.weight.reshape(self.hidden_size, -1)
@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs[0])
return outputs
def _split_heads(
self, fused_qkv: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Split the last dimension into (num_heads, head_dim), results share
same memory storage as `fused_qkv`
Args:
fused_qkv (`torch.tensor`, *required*):
[batch_size, seq_length, num_heads * 3 * head_dim]
Returns:
query: [batch_size, seq_length, num_heads, head_dim]
key: [batch_size, seq_length, num_heads, head_dim]
value: [batch_size, seq_length, num_heads, head_dim]
"""
if self.new_decoder_architecture:
# e.g. 40b-instruct model
batch, seq_len, _ = fused_qkv.shape
qkv = fused_qkv.view(batch, seq_len, -1,
self.num_heads // self.num_kv_heads + 2,
self.head_dim)
query = qkv[:, :, :, :-2]
key = qkv[:, :, :, [-2]]
value = qkv[:, :, :, [-1]]
# because cache_engine & kernel
# already handled grouped attention
# removing broadcast make it faster and more memory-saving
# key = torch.broadcast_to(key, query.shape)
# value = torch.broadcast_to(value, query.shape)
query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
return query, key, value
elif not self.multi_query:
# e.g. rw-1b model
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
fused_qkv = fused_qkv.view(batch_size, seq_length,
self.num_heads // dist.get_world_size(),
3, self.head_dim)
return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[...,
2, :]
else:
# e.g. 7b-instruct model
fused_qkv = fused_qkv.unflatten(-1, (-1, self.head_dim))
split_shape = (fused_qkv.size(-2) - 2, 1, 1)
return fused_qkv.split(split_shape, dim=-2)
def _contiguous_batching_forward(
self,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
output_attentions: bool = False,
):
# prepare inputs for continuous batch forwarding
context = self.context.context
q_start_loc = context.q_start_loc
q_seq_length = context.q_seq_length
kv_seq_length = context.kv_seq_length
max_q_seq_length = context.max_q_seq_length
block_offsets = context.block_offsets
position_ids_1d = context.position_ids_1d
max_kv_seq_length = context.max_kv_seq_length
def __maybe_rotary_fn(query_states, key_states, value_states):
scaling_factor = 1.0
inv_freq = self.maybe_rotary.inv_freq
query_states, key_states = fused_rotary_emb(
query_states[None],
key_states[None],
position_ids_1d[None],
inv_freq=inv_freq,
scaling_factor=scaling_factor,
out_q=query_states[None],
out_k=key_states[None])
return query_states[0], key_states[0], value_states
def __rotary_emb_fn(query_states, key_states, value_states):
"""rotary embedding func."""
cos, sin = self.rotary_emb(value_states.transpose(0, 1),
max_kv_seq_length)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, context.position_ids,
position_ids_1d)
return query_states, key_states, value_states
fused_qkv = self.query_key_value(hidden_states)
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
query_layer = query_layer.flatten(0, 1)
key_layer = key_layer.flatten(0, 1)
value_layer = value_layer.flatten(0, 1)
if hasattr(self, 'maybe_rotary'):
query_layer, key_layer, value_layer = __maybe_rotary_fn(
query_layer, key_layer, value_layer)
elif hasattr(self, 'rotary_emb'):
query_layer, key_layer, value_layer = __rotary_emb_fn(
query_layer, key_layer, value_layer)
past_key, past_value = layer_past
fill_kv_cache(
key_layer.contiguous(),
value_layer.contiguous(),
past_key,
past_value,
q_start_loc,
q_seq_length,
kv_seq_length=kv_seq_length,
max_q_seq_length=max_q_seq_length,
block_offsets=block_offsets,
)
attn_output = query_layer
if not alibi:
paged_attention_fwd(q=query_layer,
k=past_key,
v=past_value,
o=attn_output,
block_offsets=block_offsets,
q_start_loc=q_start_loc,
q_seqlens=q_seq_length,
kv_seqlens=kv_seq_length,
max_seqlen=max_q_seq_length)
else:
num_heads_full = self.num_heads
head_offset = 0
if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
head_offset = self.num_heads // world_size * rank
alibi_paged_attention_fwd(q=query_layer,
k=past_key,
v=past_value,
o=attn_output,
block_offsets=block_offsets,
b_start_loc=q_start_loc,
b_seq_len=q_seq_length,
b_kv_seq_len=kv_seq_length,
max_input_len=max_q_seq_length,
head_offset=head_offset,
num_heads=num_heads_full,
alibi_scale=self.inv_norm_factor)
attn_output = attn_output[None].flatten(-2, -1)
output_tensor = self.dense(attn_output)
if output_attentions:
return output_tensor, layer_past, None
else:
return output_tensor, layer_past
def forward(
self,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
return self._contiguous_batching_forward(hidden_states, alibi,
layer_past)
class PatchedFalconMLP(nn.Module):
@classmethod
def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
if mod_name in ['dense_h_to_4h']:
colwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
elif mod_name in ['dense_4h_to_h']:
rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs)
return outputs
class PatchedFalconModel(nn.Module):
def _contiguous_batching_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor],
...]] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.Tensor, ...],
BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = False
use_cache = True
use_alibi = getattr(self, 'use_alibi', getattr(self, 'alibi', False))
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
head_mask = self.get_head_mask(head_mask,
self.config.num_hidden_layers)
hidden_states = inputs_embeds
# Compute alibi tensor: check build_alibi_tensor documentation
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=None,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=use_alibi,
)
hidden_states = outputs[0]
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=None,
attentions=None,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor],
...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor, ...],
BaseModelOutputWithPastAndCrossAttentions]:
return self._contiguous_batching_forward(
input_ids=input_ids, past_key_values=past_key_values)
class PatchedFalconForCausalLM(nn.Module):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor],
...]] = None,
return_dict: Optional[bool] = True,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
use_origin: Optional[bool] = True,
) -> Union[Tuple[torch.Tensor, ...],
BaseModelOutputWithPastAndCrossAttentions]:
"""Forward function, patched to ignore position_ids."""
outputs = self.origin_mod(input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)
return outputs
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Any, Callable, Optional, Sequence, Tuple
import numpy as np
import torch
# import torch.nn.functional as F
from torch import Tensor
from ..kernels import apply_rotary_pos_emb, fill_kv_cache, rerope_attention_fwd
__all__ = ['apply_rotary_pos_emb']
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""This is the equivalent of torch.repeat_interleave(x, dim=1,
repeats=n_rep).
The hidden states go from (num_key_value_heads, seqlen, head_dim) to
(num_attention_heads, seqlen, head_dim)
"""
if n_rep == 1:
return hidden_states
num_key_value_heads, slen, head_dim = hidden_states.shape
hidden_states = hidden_states[:,
None, :, :].expand(num_key_value_heads,
n_rep, slen, head_dim)
return hidden_states.reshape(num_key_value_heads * n_rep, slen, head_dim)
def generate_batched_mask(q_lens,
k_lens,
max_q_len: int = None,
max_k_len: int = None,
device='cuda'):
"""Generate batched mask."""
if max_q_len is None:
max_q_len = max(q_lens)
if max_k_len is None:
max_k_len = max(k_lens)
q_range = torch.arange(max_q_len).to(device)
k_range = torch.arange(max_k_len).to(device)
cross = k_range.unsqueeze(0) - q_range.unsqueeze(1)
cross = cross.unsqueeze(0)
threshold = (k_lens - q_lens).view(-1, 1, 1)
mask = torch.where(cross <= threshold, 1, 0).to(device)
for idx, q_len in enumerate(q_lens):
mask[idx, q_len:, :] = 0
return mask
def get_slopes(n: int):
"""Get alibi slopes."""
def _get_interleave_power_of_2(n):
start = 2**(-(2**-(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return _get_interleave_power_of_2(n)
else:
closest_power_of_2 = 2**math.floor(math.log2(n))
return (
_get_interleave_power_of_2(closest_power_of_2) +
get_slopes(2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
@torch.no_grad()
def get_alibi_biases(n_heads: int, mask: torch.Tensor):
"""Get alibi bias."""
m = torch.tensor(get_slopes(n_heads)).to(mask.device)
distance = mask.cumsum(dim=-1) - 1
return distance * m[None, :, None, None]
def quant_kv(key: torch.Tensor, value: torch.Tensor, out_type: torch.dtype):
"""Quantize key and value of attention to `out_type`.
Args:
key (torch.Tensor): Attention key.
value (torch.Tensor): Attention value.
out_type (torch.dtype): Output data type.
"""
assert out_type is torch.int8
# quantize key and value
_min = torch.min(key, axis=-1).values
_max = torch.max(key, axis=-1).values
key_zp = (_min + _max) / 2
key_scale = (_max - key_zp) / 127
key_int8 = torch.round(
(key - key_zp[:, :, None]) / key_scale[:, :, None]).to(out_type)
_min = torch.min(value, axis=-1).values
_max = torch.max(value, axis=-1).values
value_zp = (_min + _max) / 2
value_scale = (_max - value_zp) / 127
value_int8 = torch.round(
(value - value_zp[:, :, None]) / value_scale[:, :, None]).to(out_type)
# wrap zp and scale to qparams
qparams = {
'key_zp': key_zp,
'key_scale': key_scale,
'value_zp': value_zp,
'value_scale': value_scale,
}
return key_int8, value_int8, qparams
def dequant_kv(context: Any, layer_id: str, key_int8: torch.Tensor,
value_int8: torch.Tensor, out_type: torch.dtype):
"""Dequantize key and value of attention to `out_type`.
Args:
context (Any): StepContext during inference.
layer_id (str): Layer object id.
key (torch.Tensor): Quantized attention key.
value (torch.Tensor): Quantized attention value.
out_type (torch.dtype): output data type.
"""
qparams = context.get_output(layer_id)
key_scale = qparams['key_scale']
key_zp = qparams['key_zp']
key_float = (key_int8 * key_scale[:, :, None] +
key_zp[:, :, None]).to(out_type)
value_scale = qparams['value_scale']
value_zp = qparams['value_zp']
value_float = (value_int8 * value_scale[:, :, None] +
value_zp[:, :, None]).to(out_type)
return key_float, value_float
def sync_qparam_to_context(context: Any, layer_id: str, qparams: dict):
"""Merge quantization param to context.
Args:
context (Any): StepContext during inference.
layer_id (str): Layer object id.
qparams (dict): Quantization param of current step.
"""
if context.inputs.meta is not None:
last_qparam = context.inputs.meta[layer_id]
for _k in last_qparam.keys():
_v = torch.concat([last_qparam[_k], qparams[_k]], axis=0)
last_qparam[_k] = _v
context.set_output(layer_id, last_qparam)
else:
context.set_output(layer_id, qparams)
@torch.no_grad()
def attention_forward_with_rerope(
hidden_states: Tensor,
history_lengths: Sequence,
block_offsets: Tensor,
num_heads: int,
num_kv_heads: int,
head_dim: int,
position_ids: torch.LongTensor,
past_key_value: Tuple[Tensor],
attention_mask: Tensor,
context: Any = None,
q_proj: Optional[Callable] = None,
k_proj: Optional[Callable] = None,
v_proj: Optional[Callable] = None,
qkv_proj: Optional[Callable] = None,
o_proj: Optional[Callable] = None,
rotary_emb_context_fn: Optional[Callable] = None,
rotary_emb_generate_fn: Optional[Callable] = None,
bias_type: str = 'default',
training_length=4096,
window=512,
layer_id: str = None) -> Tensor:
"""Attention module forward with ReRoPE.
Args:
hidden_states (Tensor): Input of attention layer.
history_lengths (Sequence): Cache lengths of each data in batch.
block_offsets (Tensor): Block table of the key/value caches,
used by paged attention.
num_heads (int): numbers of query heads.
num_kv_heads (int): numbers of key/value heads.
head_dim (int): Feature dimension of heads.
position_ids (LongTensor): position ids of the input.
past_key_value (Tuple[Tensor]): key value cache.
q_proj (Callable): query project module/function.
k_proj (Callable): key project module/function.
v_proj (Callable): value project module/function.
qkv_proj (Callable): query/key/value project module/function.
o_proj (Callable): output project module/function.
rotary_emb_context_fn (Callable): rotary embedding context callback.
rotary_emb_generate_fn (Callable): rotary embedding generate callback.
bias_type (str): type of attention bias. support ['default'].
training_length (int): model sequence length during trainning.
window (int): ReRoPE window size, default value is 512.
"""
hidden_size = -1
if qkv_proj is not None:
assert q_proj is None
assert k_proj is None
assert v_proj is None
query_states, key_states, value_states = qkv_proj(hidden_states)
else:
assert qkv_proj is None
assert q_proj is not None
assert k_proj is not None
assert v_proj is not None
query_states = q_proj(hidden_states)
key_states = k_proj(hidden_states)
value_states = v_proj(hidden_states)
hidden_size = num_heads * head_dim
query_states = query_states.view(-1, num_heads, head_dim)
key_states = key_states.view(-1, num_kv_heads, head_dim)
value_states = value_states.view(-1, num_kv_heads, head_dim)
query_states *= ((position_ids.flatten() + 1)[:, None, None].log() /
np.log(training_length)).clip(1).to(query_states.dtype)
kv_seq_length = (position_ids[..., -1] + 1).item()
q_seq_length = getattr(context, 'q_seq_length', None)
if q_seq_length is None:
q_seq_length = kv_seq_length - kv_seq_length.new_tensor(
history_lengths)
q_start_loc = getattr(context, 'q_start_loc', None)
if q_start_loc is None:
q_start_loc = q_seq_length.cumsum(0)
q_start_loc = torch.cat([q_start_loc.new_zeros(1), q_start_loc[:-1]])
if past_key_value[0].dtype != hidden_states.dtype:
# dynamic quantize hidden_states to kv_cache and save
quant = True
qkey, qvalue, qparams = quant_kv(key_states, value_states,
past_key_value[0].dtype)
sync_qparam_to_context(context=context,
layer_id=layer_id,
qparams=qparams)
fill_kv_cache(qkey,
qvalue,
past_key_value[0],
past_key_value[1],
q_start_loc,
q_seq_length,
block_offsets=block_offsets,
history_lengths=history_lengths,
context=context)
else:
fill_kv_cache(key_states,
value_states,
past_key_value[0],
past_key_value[1],
q_start_loc,
q_seq_length,
block_offsets=block_offsets,
history_lengths=history_lengths,
context=context)
bsz, q_len, _ = hidden_states.size()
if bias_type.lower() == 'default':
if q_len == 1:
key_states = past_key_value[0][block_offsets].view(
-1, num_heads, head_dim)[0:history_lengths[-1] + 1]
value_states = past_key_value[1][block_offsets].view(
-1, num_heads, head_dim)[0:history_lengths[-1] + 1]
if quant:
# dequant int8 tensor to hidden_states.dtype
key_states, value_states = dequant_kv(
context=context,
layer_id=layer_id,
key_int8=key_states,
value_int8=value_states,
out_type=hidden_states.dtype)
full_position_ids = torch.arange(
position_ids.item() + 1,
device=position_ids.device).unsqueeze(0)
key_states, value_states = rotary_emb_generate_fn(
key_states, value_states, full_position_ids, window)
attn_weights = torch.matmul(query_states.transpose(
0, 1), key_states.permute(1, 2, 0)) / math.sqrt(head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights,
dim=-1,
dtype=torch.float32).to(
query_states.dtype)
attn_output = torch.matmul(attn_weights,
value_states.transpose(0, 1))
else:
query_states1, query_states2, key_states1, key_states2, value_states = rotary_emb_context_fn( # noqa: E501
query_states, key_states, value_states, position_ids, window)
sm_scale = 1.0 / math.sqrt(head_dim)
PADDING_UNIT = past_key_value[0].shape[1]
assert PADDING_UNIT in {16, 32, 64, 128, 256}
# padding_len = -query_states1.shape[2] % PADDING_UNIT
# query_states1 = F.pad(query_states1,
# (0, 0, 0, padding_len)).contiguous()
# query_states2 = F.pad(query_states2,
# (0, 0, 0, padding_len)).contiguous()
# key_states1 = F.pad(key_states1,
# (0, 0, 0, padding_len)).contiguous()
# key_states2 = F.pad(key_states2,
# (0, 0, 0, padding_len)).contiguous()
# value_states = F.pad(value_states,
# (0, 0, 0, padding_len)).contiguous()
query_states1 = query_states1.contiguous()
query_states2 = query_states2.contiguous()
key_states1 = key_states1.contiguous()
key_states2 = key_states2.contiguous()
value_states = value_states.contiguous()
attn_output = rerope_attention_fwd(query_states1,
query_states2,
key_states1,
key_states2,
value_states,
True,
sm_scale,
window,
BLOCK_M=PADDING_UNIT).squeeze(0)
# attn_output = attn_output[:, 0:q_len]
if attn_output.size() != (num_heads, q_len, head_dim):
raise ValueError(
f'`attn_output` should be of size {(bsz, num_heads, q_len, head_dim)}, but is' # noqa: E501
f' {attn_output.size()}')
attn_output = attn_output.transpose(0, 1).reshape(
bsz, q_len, hidden_size).contiguous()
else:
raise ValueError(f'Unknown bias type: {bias_type}')
if o_proj is not None:
attn_output = o_proj(attn_output)
return attn_output
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed._tensor import DeviceMesh
from transformers.modeling_outputs import BaseModelOutputWithPast
from ..dist_utils import (colwise_parallelize_linear_fn,
rowwise_parallelize_linear_fn)
from ..kernels import fill_kv_cache, fused_rotary_emb, paged_attention_fwd
class PatchedGemmaRMSNorm(nn.Module):
"""Rewrite RMSNorm."""
def forward(self, x):
"""forward."""
# torch.nn.functional.normalize based implementation might leads
# to wrong output
from ..kernels import rms_norm
ret = rms_norm(x.contiguous(), self.weight + 1, self.eps)
return ret
def _make_inv_freq(self, device: torch.device):
if self.inv_freq is None:
self.inv_freq = 1.0 / (self.base**(torch.arange(
0, self.dim, 2, dtype=torch.int64, device=device).float() /
self.dim))
class PatchedGemmaAttention(nn.Module):
@classmethod
def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
if mod_name in ['q_proj', 'k_proj', 'v_proj']:
colwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
elif mod_name in ['o_proj']:
rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs[0])
return outputs
def _contiguous_batching_forward_impl(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
attention_mask: Optional[torch.Tensor] = None,
world_size: int = 1,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite implementation of forward.
Add continuous batching support. Add paged attention support. TP
support.
"""
context = self.context.context
kv_seq_length = context.kv_seq_length
q_seq_length = context.q_seq_length
q_start_loc = context.q_start_loc
block_offsets = context.block_offsets
max_q_seq_length = context.max_q_seq_length
num_heads = self.num_heads // world_size
num_kv_heads = self.num_key_value_heads // world_size
head_dim = self.head_dim
hidden_size = num_heads * head_dim
def __qkv_proj(hidden_states):
"""qkv proj."""
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
return query_states, key_states, value_states
def __rotary_emb_fn(query_states, key_states, value_states):
scaling_factor = 1.0
_make_inv_freq(self.rotary_emb, query_states.device)
inv_freq = self.rotary_emb.inv_freq
query_states, key_states = fused_rotary_emb(
query_states[None],
key_states[None],
context.position_ids_1d[None],
inv_freq=inv_freq,
scaling_factor=scaling_factor,
out_q=query_states[None],
out_k=key_states[None])
return query_states[0], key_states[0], value_states
query_states, key_states, value_states = __qkv_proj(hidden_states)
query_states = query_states.view(-1, num_heads, head_dim)
key_states = key_states.view(-1, num_kv_heads, head_dim)
value_states = value_states.view(-1, num_kv_heads, head_dim)
query_states, key_states, value_states = __rotary_emb_fn(
query_states, key_states, value_states)
fill_kv_cache(
key_states,
value_states,
past_key_value[0],
past_key_value[1],
q_start_loc,
q_seq_length,
kv_seq_length=kv_seq_length,
max_q_seq_length=max_q_seq_length,
block_offsets=block_offsets,
)
attn_output = query_states
paged_attention_fwd(
query_states,
past_key_value[0],
past_key_value[1],
attn_output,
block_offsets,
q_start_loc=q_start_loc,
q_seqlens=q_seq_length,
kv_seqlens=kv_seq_length,
max_seqlen=max_q_seq_length,
)
attn_output = attn_output.reshape(*hidden_states.shape[:-1],
hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite of forward."""
world_size = 1
if dist.is_initialized():
world_size = dist.get_world_size()
return self._contiguous_batching_forward_impl(
hidden_states,
position_ids,
past_key_value,
output_attentions,
attention_mask=attention_mask,
world_size=world_size,
)
class PatchedGemmaModel(nn.Module):
def _continuous_batching_forward(
self,
input_ids: torch.LongTensor = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""Rewrite implementation of LlamaModel.forward."""
output_attentions = False
use_cache = True
# Attention mask is not necessary in continuous batching
attention_mask = None
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
# This is Gemma only!
hidden_states = hidden_states * (self.config.hidden_size**0.5)
for idx, decoder_layer in enumerate(self.layers):
past_key_value = (past_key_values[idx]
if past_key_values is not None else None)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=None,
attentions=None,
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""Rewrite of LlamaModel.forward."""
return self._continuous_batching_forward(
input_ids,
position_ids,
past_key_values,
inputs_embeds,
)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed._tensor import DeviceMesh
from ..dist_utils import (colwise_parallelize_linear_fn,
rowwise_parallelize_linear_fn)
from ..kernels import apply_rotary_pos_emb, fill_kv_cache, paged_attention_fwd
class PatchedInternLMAttention(nn.Module):
@classmethod
def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
if mod_name in ['q_proj', 'k_proj', 'v_proj']:
colwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
elif mod_name in ['o_proj']:
rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs[0])
return outputs
def _contiguous_batching_forward_impl(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
world_size: int = 1,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite implementation of LlamaAttention.forward.
Add continuous batching support. Add paged attention support. TP
support.
"""
context = self.context.context
q_start_loc = context.q_start_loc
kv_seq_length = context.kv_seq_length
q_seq_length = context.q_seq_length
block_offsets = context.block_offsets
max_q_seq_length = context.max_q_seq_length
max_kv_seq_length = context.max_kv_seq_length
num_heads = self.num_heads // world_size
num_kv_heads = num_heads
head_dim = self.head_dim
hidden_size = num_heads * head_dim
def __qkv_proj(hidden_states):
"""qkv proj."""
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
return query_states, key_states, value_states
def __rotary_emb_fn(query_states, key_states, value_states):
"""rotary embedding func."""
cos, sin = self.rotary_emb(value_states, seq_len=max_kv_seq_length)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids,
context.position_ids_1d)
return query_states, key_states, value_states
query_states, key_states, value_states = __qkv_proj(hidden_states)
query_states = query_states.view(-1, num_heads, head_dim)
key_states = key_states.view(-1, num_kv_heads, head_dim)
value_states = value_states.view(-1, num_kv_heads, head_dim)
query_states, key_states, value_states = __rotary_emb_fn(
query_states, key_states, value_states)
fill_kv_cache(
key_states,
value_states,
past_key_value[0],
past_key_value[1],
q_start_loc,
q_seq_length,
kv_seq_length=kv_seq_length,
max_q_seq_length=max_q_seq_length,
block_offsets=block_offsets,
)
attn_output = query_states
paged_attention_fwd(
query_states,
past_key_value[0],
past_key_value[1],
attn_output,
block_offsets,
q_start_loc=q_start_loc,
q_seqlens=q_seq_length,
kv_seqlens=kv_seq_length,
max_seqlen=max_q_seq_length,
)
attn_output = attn_output.reshape(*hidden_states.shape[:-1],
hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""forward."""
world_size = 1
if dist.is_initialized():
world_size = dist.get_world_size()
return self._contiguous_batching_forward_impl(
hidden_states,
position_ids,
past_key_value,
world_size=world_size,
)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from einops import rearrange
from torch import nn
from torch.distributed._tensor import DeviceMesh
from transformers.modeling_outputs import BaseModelOutputWithPast
from ..dist_utils import (colwise_parallelize_linear_fn,
rowwise_parallelize_linear_fn)
from ..kernels import apply_rotary_pos_emb, fill_kv_cache, paged_attention_fwd
class PatchedInternLM2Attention(nn.Module):
def _distribute_partition_fn(self, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
if mod_name in ['wqkv']:
colwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
elif mod_name in ['wo']:
rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs[0])
return outputs
def _contiguous_batching_forward_impl(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
world_size: int = 1,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite implementation of LlamaAttention.forward.
Add continuous batching support. Add paged attention support. TP
support.
"""
context = self.context.context
q_start_loc = context.q_start_loc
q_seq_length = context.q_seq_length
kv_seq_length = context.kv_seq_length
block_offsets = context.block_offsets
max_q_seq_length = context.max_q_seq_length
max_kv_seq_length = context.max_kv_seq_length
def __qkv_proj(hidden_states):
"""qkv_proj."""
qkv_states = self.wqkv(hidden_states)
qkv_states = rearrange(
qkv_states,
'b q (h gs d) -> (b q) h gs d',
gs=2 + self.num_key_value_groups,
d=self.head_dim,
)
query_states = qkv_states[..., :self.num_key_value_groups, :]
query_states = query_states.flatten(1, 2)
key_states = qkv_states[..., -2, :]
value_states = qkv_states[..., -1, :]
return query_states, key_states, value_states
def __rotary_emb_fn(query_states, key_states, value_states):
"""rotary embedding func."""
cos, sin = self.rotary_emb(value_states.transpose(0, 1),
seq_len=max_kv_seq_length)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids,
context.position_ids_1d)
return query_states, key_states, value_states
query_states, key_states, value_states = __qkv_proj(hidden_states)
query_states, key_states, value_states = __rotary_emb_fn(
query_states, key_states, value_states)
fill_kv_cache(
key_states,
value_states,
past_key_value[0],
past_key_value[1],
q_start_loc,
q_seq_length,
kv_seq_length=kv_seq_length,
max_q_seq_length=max_q_seq_length,
block_offsets=block_offsets,
)
attn_output = query_states
paged_attention_fwd(
query_states,
past_key_value[0],
past_key_value[1],
attn_output,
block_offsets,
q_start_loc=q_start_loc,
q_seqlens=q_seq_length,
kv_seqlens=kv_seq_length,
max_seqlen=max_q_seq_length,
)
attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)
attn_output = self.wo(attn_output)
return attn_output, None, past_key_value
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""forward."""
world_size = 1
if dist.is_initialized():
world_size = dist.get_world_size()
return self._contiguous_batching_forward_impl(
hidden_states,
position_ids,
past_key_value,
output_attentions,
world_size=world_size,
)
class PatchedInternLM2MLP(nn.Module):
@classmethod
def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
if mod_name in ['w1', 'w3']:
colwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
elif mod_name in ['w2']:
rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs)
return outputs
class PatchedInternLM2Model(nn.Module):
def _continuous_batching_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""Rewrite implementation of LlamaModel.forward."""
if inputs_embeds is None:
inputs_embeds = self.tok_embeddings(input_ids)
# Attention mask is not necessary in continuous batching
attention_mask = None
hidden_states = inputs_embeds
# decoder layers
for idx, decoder_layer in enumerate(self.layers):
past_key_value = (past_key_values[idx]
if past_key_values is not None else None)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_value,
hidden_states=None,
attentions=None,
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""Rewrite of LlamaModel.forward."""
return self._continuous_batching_forward(
input_ids,
attention_mask,
position_ids,
past_key_values,
inputs_embeds,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import transformers
from packaging import version
from torch import nn
from torch.distributed._tensor import DeviceMesh
from transformers.modeling_outputs import BaseModelOutputWithPast
from ..dist_utils import (colwise_parallelize_linear_fn,
rowwise_parallelize_linear_fn)
from ..kernels import apply_rotary_pos_emb as apply_rotary_pos_emb_old
from ..kernels import fill_kv_cache, fused_rotary_emb, paged_attention_fwd
from .functional import attention_forward_with_rerope, repeat_kv
TRANSFORMERS_VERSION = version.parse(transformers.__version__)
class LlamaRMSNorm(nn.Module):
"""Rewrite RMSNorm."""
def forward(self, hidden_states):
"""forward."""
# torch.nn.functional.normalize based implementation might leads
# to wrong output
from ..kernels import rms_norm
ret = rms_norm(hidden_states, self.weight, self.variance_epsilon)
return ret
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors."""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class LlamaAttention(nn.Module):
"""Rewrite module of LlamaAttention."""
@classmethod
def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
if mod_name in ['q_proj', 'k_proj', 'v_proj']:
colwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
elif mod_name in ['o_proj']:
rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs[0])
return outputs
def _contiguous_batching_forward_rerope_impl(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
attention_mask: Optional[torch.Tensor] = None,
world_size: int = 1
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""rerope rewrite."""
context = self.context.context
history_lengths = context.history_lengths
def apply_rotary_pos_emb_rerope(q, k, cos, sin, position_ids):
assert 1 == position_ids.shape[0]
_, seq_len = position_ids.shape
_, dim = cos.shape
cos = cos[position_ids].reshape(
seq_len, 1, dim) # [bs, seq_len, dim] to [seq_len, 1, dim]
sin = sin[position_ids].reshape(
seq_len, 1, dim) # [bs, seq_len, dim] to [seq_len, 1, dim]
q_embed = ((q * cos[-q.shape[0]:]) +
(rotate_half(q) *
sin[-q.shape[0]:])) if q is not None else None
k_embed = ((k * cos) +
(rotate_half(k) * sin)) if k is not None else None
return q_embed, k_embed
def _rotary_emb_context_rerope_fn(query_states, key_states,
value_states, position_ids, window):
kv_seq_len, num_dim, dim = key_states.shape
cos, sin = self.rotary_emb(value_states,
seq_len=max(kv_seq_len, window + 1))
query_states1, key_states1 = apply_rotary_pos_emb_rerope(
query_states, key_states, cos, sin, position_ids)
query_states2, _ = apply_rotary_pos_emb_rerope(
query_states, None, cos, sin, position_ids * 0 + window)
# repeat k/v heads if n_kv_heads < n_heads
if self.num_key_value_groups > 1:
key_states1 = repeat_kv(key_states1, self.num_key_value_groups)
key_states2 = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states,
self.num_key_value_groups)
else:
key_states2 = key_states
query_states1 = query_states1.transpose(0, 1).reshape(
1, num_dim, kv_seq_len, dim)
query_states2 = query_states2.transpose(0, 1).reshape(
1, num_dim, kv_seq_len, dim)
key_states1 = key_states1.transpose(0, 1).reshape(
1, num_dim, kv_seq_len, dim)
key_states2 = key_states2.transpose(0, 1).reshape(
1, num_dim, kv_seq_len, dim)
value_states = value_states.transpose(0, 1).reshape(
1, num_dim, kv_seq_len, dim)
return query_states1, query_states2, key_states1, key_states2, value_states # noqa: E501
def _rotary_emb_generate_rerope_fn(key_states, value_states,
position_ids, window):
kv_seq_len = key_states.shape[0]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
position_ids = (position_ids[:, -1] -
position_ids).clip(max=window)
_, key_states = apply_rotary_pos_emb_rerope(
None, key_states, cos, -sin, position_ids)
if self.num_key_value_groups > 1:
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states,
self.num_key_value_groups)
return key_states, value_states
attn_output = attention_forward_with_rerope(
hidden_states,
history_lengths=history_lengths,
block_offsets=context.block_offsets,
num_heads=self.num_heads // world_size,
num_kv_heads=self.num_key_value_heads // world_size,
head_dim=self.head_dim,
position_ids=position_ids,
past_key_value=past_key_value,
attention_mask=attention_mask,
context=context,
q_proj=self.q_proj,
k_proj=self.k_proj,
v_proj=self.v_proj,
o_proj=self.o_proj,
rotary_emb_context_fn=_rotary_emb_context_rerope_fn,
rotary_emb_generate_fn=_rotary_emb_generate_rerope_fn,
layer_id=id(self))
return attn_output, None, past_key_value
def _contiguous_batching_forward_default_impl(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
attention_mask: Optional[torch.Tensor] = None,
world_size: int = 1,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""default rewrite."""
context = self.context.context
kv_seq_length = context.kv_seq_length
q_seq_length = context.q_seq_length
q_start_loc = context.q_start_loc
block_offsets = context.block_offsets
max_q_seq_length = context.max_q_seq_length
max_kv_seq_length = context.max_kv_seq_length
num_heads = self.num_heads // world_size
num_kv_heads = self.num_key_value_heads // world_size
head_dim = self.head_dim
hidden_size = num_heads * head_dim
def __qkv_proj(hidden_states):
"""qkv proj."""
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
return query_states, key_states, value_states
def __rotary_emb_fn_old(query_states, key_states, value_states):
"""rotary embedding old."""
if max_kv_seq_length >= self.rotary_emb.max_seq_len_cached:
# create larger cache
cos, sin = self.rotary_emb(value_states,
seq_len=max_kv_seq_length + 128)
cos = self.rotary_emb.cos_cached
sin = self.rotary_emb.sin_cached
query_states, key_states = apply_rotary_pos_emb_old(
query_states,
key_states,
cos,
sin,
position_ids,
context.position_ids_1d,
q_embed=query_states,
k_embed=key_states)
return query_states, key_states, value_states
def __rotary_emb_fn_438_naive(query_states, key_states, value_states):
"""rotary embedding transformers>4.38."""
cos, sin = self.rotary_emb(value_states,
context.position_ids_1d[None])
cos = cos[0]
sin = sin[0]
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin)
return query_states, key_states, value_states
def __rotary_emb_fn_438_fused(query_states, key_states, value_states):
scaling_factor = getattr(self.rotary_emb, 'scaling_factor', 1.0)
inv_freq = self.rotary_emb.inv_freq
query_states, key_states = fused_rotary_emb(
query_states[None],
key_states[None],
context.position_ids_1d[None],
inv_freq=inv_freq,
scaling_factor=scaling_factor,
out_q=query_states[None],
out_k=key_states[None])
return query_states[0], key_states[0], value_states
def __rotary_emb_fn_438(query_states, key_states, value_states):
rotary_name = type(self.rotary_emb).__name__
if rotary_name in [
'LlamaRotaryEmbedding', 'LlamaLinearScalingRotaryEmbedding'
]:
return __rotary_emb_fn_438_fused(query_states, key_states,
value_states)
else:
return __rotary_emb_fn_438_naive(query_states, key_states,
value_states)
def __rotary_emb_fn(query_states, key_states, value_states):
"""rotary embedding."""
if TRANSFORMERS_VERSION >= version.parse('4.38.0'):
return __rotary_emb_fn_438(query_states, key_states,
value_states)
else:
return __rotary_emb_fn_old(query_states, key_states,
value_states)
query_states, key_states, value_states = __qkv_proj(hidden_states)
query_states = query_states.view(-1, num_heads, head_dim)
key_states = key_states.view(-1, num_kv_heads, head_dim)
value_states = value_states.view(-1, num_kv_heads, head_dim)
query_states, key_states, value_states = __rotary_emb_fn(
query_states, key_states, value_states)
fill_kv_cache(
key_states,
value_states,
past_key_value[0],
past_key_value[1],
q_start_loc,
q_seq_length,
kv_seq_length=kv_seq_length,
max_q_seq_length=max_q_seq_length,
block_offsets=block_offsets,
)
attn_output = query_states
paged_attention_fwd(
query_states,
past_key_value[0],
past_key_value[1],
attn_output,
block_offsets,
q_start_loc=q_start_loc,
q_seqlens=q_seq_length,
kv_seqlens=kv_seq_length,
max_seqlen=max_q_seq_length,
)
attn_output = attn_output.reshape(*hidden_states.shape[:-1],
hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def _contiguous_batching_forward_impl(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
attention_mask: Optional[torch.Tensor] = None,
world_size: int = 1,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite implementation of LlamaAttention.forward.
Add continuous batching support. Add paged attention support. TP
support.
"""
assert not output_attentions
json_config = self.context.context.json_config
use_rerope = False
if json_config is not None:
use_rerope = json_config.get('rerope', False)
if use_rerope:
return self._contiguous_batching_forward_rerope_impl(
hidden_states,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
attention_mask=attention_mask,
world_size=world_size)
else:
return self._contiguous_batching_forward_default_impl(
hidden_states,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
attention_mask=attention_mask,
world_size=world_size)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite of LlamaAttention.forward."""
world_size = 1
if dist.is_initialized():
world_size = dist.get_world_size()
return self._contiguous_batching_forward_impl(
hidden_states,
position_ids,
past_key_value,
output_attentions,
attention_mask=attention_mask,
world_size=world_size,
)
class LlamaMLP(nn.Module):
@classmethod
def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
if mod_name in ['gate_proj', 'up_proj']:
colwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
elif mod_name in ['down_proj']:
rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs)
return outputs
class LlamaModel(nn.Module):
def _continuous_batching_forward(
self,
input_ids: torch.LongTensor = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None
) -> Union[Tuple, BaseModelOutputWithPast]:
"""Rewrite implementation of LlamaModel.forward."""
output_attentions = False
use_cache = True
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# Attention mask is not necessary in continuous batching
attention_mask = None
hidden_states = inputs_embeds
for idx, decoder_layer in enumerate(self.layers):
past_key_value = past_key_values[idx]
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=None,
attentions=None,
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""Rewrite of LlamaModel.forward."""
return self._continuous_batching_forward(
input_ids,
position_ids,
past_key_values,
inputs_embeds,
)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed._tensor import DeviceMesh
from ..dist_utils import (colwise_parallelize_linear_fn,
rowwise_parallelize_linear_fn)
from ..kernels import apply_rotary_pos_emb
from ..kernels.fill_kv_cache import fill_kv_cache
from ..kernels.pagedattention import paged_attention_fwd
class MistralFlashAttention2(nn.Module):
@classmethod
def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
if mod_name in ['q_proj', 'k_proj', 'v_proj']:
colwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
elif mod_name in ['o_proj']:
rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs[0])
return outputs
def _contiguous_batching_forward_impl(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
attention_mask: Optional[torch.Tensor] = None,
world_size: int = 1,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite implementation of forward.
Add continuous batching support. Add paged attention support. TP
support.
"""
context = self.context.context
kv_seq_length = context.kv_seq_length
q_seq_length = context.q_seq_length
q_start_loc = context.q_start_loc
block_offsets = context.block_offsets
max_q_seq_length = context.max_q_seq_length
max_kv_seq_length = context.max_kv_seq_length
num_heads = self.num_heads // world_size
num_kv_heads = self.num_key_value_heads // world_size
head_dim = self.head_dim
hidden_size = num_heads * head_dim
def __qkv_proj(hidden_states):
"""qkv proj."""
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
return query_states, key_states, value_states
def __rotary_emb_fn(query_states, key_states, value_states):
if hasattr(self, 'rotary_emb'):
cos, sin = self.rotary_emb(value_states,
seq_len=max_kv_seq_length)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids,
context.position_ids_1d)
return query_states, key_states, value_states
query_states, key_states, value_states = __qkv_proj(hidden_states)
query_states = query_states.view(-1, num_heads, head_dim)
key_states = key_states.view(-1, num_kv_heads, head_dim)
value_states = value_states.view(-1, num_kv_heads, head_dim)
query_states, key_states, value_states = __rotary_emb_fn(
query_states, key_states, value_states)
fill_kv_cache(
key_states,
value_states,
past_key_value[0],
past_key_value[1],
q_start_loc,
q_seq_length,
kv_seq_length=kv_seq_length,
max_q_seq_length=max_q_seq_length,
block_offsets=block_offsets,
)
attn_output = query_states
window_size = self.config.sliding_window
paged_attention_fwd(
query_states,
past_key_value[0],
past_key_value[1],
attn_output,
block_offsets,
q_start_loc=q_start_loc,
q_seqlens=q_seq_length,
kv_seqlens=kv_seq_length,
max_seqlen=max_q_seq_length,
window_size=window_size,
)
attn_output = attn_output.reshape(*hidden_states.shape[:-1],
hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite of forward."""
world_size = 1
if dist.is_initialized():
world_size = dist.get_world_size()
return self._contiguous_batching_forward_impl(
hidden_states,
position_ids,
past_key_value,
output_attentions,
attention_mask=attention_mask,
world_size=world_size,
)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed._tensor import DeviceMesh
from transformers.modeling_outputs import BaseModelOutputWithPast
from ..dist_utils import (colwise_parallelize_linear_fn,
rowwise_parallelize_linear_fn)
from ..kernels import apply_rotary_pos_emb, fill_kv_cache, paged_attention_fwd
class PatchedMixtralAttention(nn.Module):
"""Rewrite module of MixtralAttention."""
@classmethod
def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
if mod_name in ['q_proj', 'k_proj', 'v_proj']:
colwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
elif mod_name in ['o_proj']:
rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs[0])
return outputs
def _contiguous_batching_forward_impl(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
attention_mask: Optional[torch.Tensor] = None,
world_size: int = 1,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""default rewrite."""
context = self.context.context
kv_seq_length = context.kv_seq_length
q_seq_length = context.q_seq_length
q_start_loc = context.q_start_loc
block_offsets = context.block_offsets
max_q_seq_length = context.max_q_seq_length
max_kv_seq_length = context.max_kv_seq_length
num_heads = self.num_heads // world_size
num_kv_heads = self.num_key_value_heads // world_size
hidden_size = num_heads * self.head_dim
def __qkv_proj(hidden_states):
"""qkv proj."""
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
return query_states, key_states, value_states
def __rotary_emb_fn(query_states, key_states, value_states):
if hasattr(self, 'rotary_emb'):
cos, sin = self.rotary_emb(value_states,
seq_len=max_kv_seq_length)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids,
getattr(context, 'position_ids_1d', None))
return query_states, key_states, value_states
query_states, key_states, value_states = __qkv_proj(hidden_states)
query_states = query_states.view(-1, num_heads, self.head_dim)
key_states = key_states.view(-1, num_kv_heads, self.head_dim)
value_states = value_states.view(-1, num_kv_heads, self.head_dim)
query_states, key_states, value_states = __rotary_emb_fn(
query_states, key_states, value_states)
# fill kv cache
fill_kv_cache(
key_states,
value_states,
past_key_value[0],
past_key_value[1],
q_start_loc,
q_seq_length,
kv_seq_length=kv_seq_length,
max_q_seq_length=max_q_seq_length,
block_offsets=block_offsets,
)
# page attention
attn_output = query_states
window_size = self.config.sliding_window or -1
paged_attention_fwd(
query_states,
past_key_value[0],
past_key_value[1],
attn_output,
block_offsets,
q_start_loc=q_start_loc,
q_seqlens=q_seq_length,
kv_seqlens=kv_seq_length,
max_seqlen=max_q_seq_length,
window_size=window_size,
)
attn_output = attn_output.reshape(*hidden_states.shape[:-1],
hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite of MistralAttention.forward."""
world_size = 1
if dist.is_initialized():
world_size = dist.get_world_size()
return self._contiguous_batching_forward_impl(
hidden_states,
position_ids,
past_key_value,
output_attentions,
attention_mask=attention_mask,
world_size=world_size,
)
class PatchedMixtralBLockSparseTop2MLP(nn.Module):
@classmethod
def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
if mod_name in ['w1', 'w3']:
colwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
elif mod_name in ['w2']:
rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs)
return outputs
class PatchedMixtralModel(nn.Module):
def _continuous_batching_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""Rewrite implementation of LlamaModel.forward."""
from transformers.modeling_outputs import MoeModelOutputWithPast
output_attentions = (output_attentions if output_attentions is not None
else self.config.output_attentions)
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else
self.config.output_hidden_states)
if use_cache is None:
use_cache = self.config.use_cache
return_dict = (return_dict if return_dict is not None else
self.config.use_return_dict)
assert (
position_ids is not None
), 'position_ids can not be none when using continuous batching mode.'
assert position_ids.dim() == 2
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# Attention mask is not necessary in continuous batching
attention_mask = None
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states, )
past_key_value = (past_key_values[idx]
if past_key_values is not None else None)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (
layer_outputs[2 if output_attentions else 1], )
if output_attentions:
all_self_attns += (layer_outputs[1], )
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states, )
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v for v in
[hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None)
return MoeModelOutputWithPast(last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits='')
def forward(self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs) -> Union[Tuple, BaseModelOutputWithPast]:
"""Rewrite of LlamaModel.forward."""
return self._continuous_batching_forward(
input_ids,
attention_mask,
position_ids,
past_key_values,
inputs_embeds,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
)
# Copyright (c) OpenMMLab. All rights reserved.
LMDEPLOY_PYTORCH_MODEL_PATH = 'lmdeploy.pytorch.models'
# llama
MODULE_MAP = {
'transformers.models.llama.modeling_llama.LlamaFlashAttention2':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaAttention',
'transformers.models.llama.modeling_llama.LlamaSdpaAttention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaAttention',
'transformers.models.llama.modeling_llama.LlamaAttention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaAttention',
'transformers.models.llama.modeling_llama.LlamaModel':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel',
'transformers.models.llama.modeling_llama.LlamaMLP':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP',
'transformers.models.llama.modeling_llama.LlamaRMSNorm':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaRMSNorm',
# support modeling rewritten in lmdeploy
'modeling_llama.LlamaAttention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaAttention',
'modeling_llama.LlamaModel':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel',
'modeling_llama.LlamaMLP': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP',
}
# Falcon Models in transformer / on hub
MODULE_MAP.update({
'modeling_falcon.FalconAttention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.falcon.PatchedFalconAttention',
'modeling_falcon.FalconModel':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.falcon.PatchedFalconModel',
'modeling_falcon.FalconMLP':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.falcon.PatchedFalconMLP',
'modeling_falcon.FalconForCausalLM':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.falcon.PatchedFalconForCausalLM',
# for old implementations on hub
'modelling_RW.Attention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.falcon.PatchedFalconAttention',
'modelling_RW.MLP':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.falcon.PatchedFalconMLP',
'modelling_RW.RWModel':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.falcon.PatchedFalconModel',
'modelling_RW.RotaryEmbedding':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.falcon.PatchedFalconRotaryEmbedding',
})
# baichuan
MODULE_MAP.update({
'modeling_baichuan.Model':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel', # noqa
'modeling_baichuan.BaichuanModel':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.baichuan.BaichuanModel', # noqa
'modeling_baichuan.Attention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.baichuan.Attention', # noqa
'modeling_baichuan.BaichuanAttention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.baichuan.BaichuanAttention', # noqa
'modeling_baichuan.MLP':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', # noqa
'modeling_baichuan.RMSNorm':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.baichuan.PatchedRMSNorm',
})
# chatglm2
MODULE_MAP.update({
'modeling_chatglm.SelfAttention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.chatglm2.PatchedSelfAttention',
'modeling_chatglm.ChatGLMModel':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.chatglm2.PatchedChatGLMModel',
'modeling_chatglm.MLP':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.chatglm2.MLP',
'modeling_chatglm.RMSNorm':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.chatglm2.PatchedRMSNorm',
})
# internlm
MODULE_MAP.update({
'modeling_internlm.InternLMAttention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm.PatchedInternLMAttention',
'modeling_internlm.InternLMModel':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel',
'modeling_internlm.InternLMMLP':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP',
'modeling_internlm.InternLMRMSNorm':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaRMSNorm',
})
# internlm2
MODULE_MAP.update({
'modeling_internlm2.InternLM2Attention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm2.PatchedInternLM2Attention',
'modeling_internlm2.InternLM2Model':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm2.PatchedInternLM2Model',
'modeling_internlm2.InternLM2MLP':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm2.PatchedInternLM2MLP',
'modeling_internlm2.InternLM2RMSNorm':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaRMSNorm',
})
# mistral
MODULE_MAP.update({
'transformers.models.mistral.modeling_mistral.MistralAttention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mistral.MistralFlashAttention2',
'transformers.models.mistral.modeling_mistral.MistralFlashAttention2':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mistral.MistralFlashAttention2',
'transformers.models.mistral.modeling_mistral.MistralSdpaAttention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mistral.MistralFlashAttention2',
'transformers.models.mistral.modeling_mistral.MistralModel':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel',
'transformers.models.mistral.modeling_mistral.MistralMLP':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP',
'transformers.models.mistral.modeling_mistral.MistralRMSNorm':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaRMSNorm',
})
# gemma
MODULE_MAP.update({
'transformers.models.gemma.modeling_gemma.GemmaAttention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaAttention',
'transformers.models.gemma.modeling_gemma.GemmaFlashAttention2':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaAttention',
'transformers.models.gemma.modeling_gemma.GemmaSdpaAttention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaAttention',
'transformers.models.gemma.modeling_gemma.GemmaModel':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaModel',
'transformers.models.gemma.modeling_gemma.modeling_mistral.GemmaMLP':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP',
'transformers.models.gemma.modeling_gemma.GemmaRMSNorm':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaRMSNorm',
})
# deepseek
MODULE_MAP.update({
'modeling_deepseek.DeepseekAttention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek.PatchedDeepseekAttention',
'modeling_deepseek.DeepseekFlashAttention2':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek.PatchedDeepseekAttention',
'modeling_deepseek.DeepseekSdpaAttention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek.PatchedDeepseekAttention',
'modeling_deepseek.DeepseekModel':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel',
'modeling_deepseek.DeepseekMLP':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP',
'modeling_deepseek.DeepseekRMSNorm':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaRMSNorm',
})
# qwen1.5
MODULE_MAP.update({
'transformers.models.qwen2.modeling_qwen2.Qwen2Attention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2Attention',
'transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2Attention',
'transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2Attention',
'transformers.models.qwen2.modeling_qwen2.Qwen2Model':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel',
'transformers.models.qwen2.modeling_qwen2.Qwen2MLP':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP',
'transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaRMSNorm',
})
# peft
MODULE_MAP.update({
'peft.tuners.lora.layer.Linear':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.peft.LoRALinear'
})
# mixtral
MODULE_MAP.update({
'transformers.models.mixtral.modeling_mixtral.MixtralAttention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralAttention',
'transformers.models.mixtral.modeling_mixtral.MixtralFlashAttention2':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralAttention',
'transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralAttention',
'transformers.models.mixtral.modeling_mixtral.MixtralModel':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralModel',
'transformers.models.mixtral.modeling_mixtral.MixtralBLockSparseTop2MLP':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralBLockSparseTop2MLP',
'transformers.models.mixtral.modeling_mixtral.MixtralBlockSparseTop2MLP':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralBLockSparseTop2MLP',
'transformers.models.mixtral.modeling_mixtral.MixtralRMSNorm':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaRMSNorm',
})
# Copyright (c) OpenMMLab. All rights reserved.
import importlib
import inspect
import re
from copy import copy
from typing import Any, Dict, Sequence
import torch
from addict import Addict
from torch.distributed._tensor import DeviceMesh
from lmdeploy.utils import get_logger
from ..dist_utils import partition_module, replicate_module
from .module_map import MODULE_MAP
logger = get_logger('lmdeploy')
def _get_rewrite_qualname(origin_qualname: str) -> str:
"""get rewrite module from origin module name.
Args:
origin_qualname (str): The origin qualname of the module.
Returns:
str: The rewrite qualname.
"""
if origin_qualname in MODULE_MAP:
return MODULE_MAP[origin_qualname]
for key, value in MODULE_MAP.items():
if re.search(key, origin_qualname):
return value
return None
def _class_from_qualname(qualname: str) -> Any:
"""Import class with qualname.
Args:
qualname (str): Qualname of the class
Returns:
Any: class or builder of the class
"""
last_dot = qualname.rfind('.')
modname = qualname[:last_dot]
clsname = qualname[last_dot + 1:]
# get class at runtime
mod = importlib.import_module(modname)
assert mod is not None, f'failed to import module: {modname}'
cls_type = getattr(mod, clsname)
return cls_type
def _find_rewrite_module_qualname(model):
"""find rewrite module."""
module_name = inspect.getmodule(model).__name__
class_name = model.__class__.__name__
def _find_fullname():
origin_qualname = f'{module_name}.{class_name}'
rewrite_qualname = _get_rewrite_qualname(origin_qualname)
return rewrite_qualname
def _find_classname():
origin_qualname = class_name
rewrite_qualname = _get_rewrite_qualname(origin_qualname)
return rewrite_qualname
def _find_submodulename():
# name with first module
mod_name = module_name[module_name.rfind('.') + 1:]
origin_qualname = f'{mod_name}.{class_name}'
rewrite_qualname = _get_rewrite_qualname(origin_qualname)
return rewrite_qualname
rewrite_qualname = _find_fullname()
if rewrite_qualname is None:
rewrite_qualname = _find_classname()
if rewrite_qualname is None:
rewrite_qualname = _find_submodulename()
origin_qualname = f'{module_name}.{class_name}'
if rewrite_qualname is not None:
logger.debug('Find rewrite of module\n'
f'{origin_qualname} <=> {rewrite_qualname}')
return rewrite_qualname
def _update_module_type(model: Any, cls_type: type, custom_attrs: dict = None):
"""Update class type of model."""
# directly return origin model is not cool
# origin model would be registered as a submodule
old_type = type(model)
@property
def get_origin_mod(self):
origin_mod = copy(self)
origin_mod.__class__ = old_type
return origin_mod
attrs = dict(cls_type.__dict__)
custom_attrs = custom_attrs or dict()
custom_attrs['origin_mod'] = get_origin_mod
attrs.update(custom_attrs)
new_type = type(cls_type.__name__, (type(model), ), attrs)
model = copy(model)
model.__class__ = new_type
return model
def _patch(model: torch.nn.Module, context: Addict) -> torch.nn.Module:
"""patch the model with rewrite module.
Args:
model (Module): model to be patched.
context (Addict): The environment info to patched in model
Returns:
Module: The patched model
"""
def _recursive_children(context, named_children):
"""recursive children."""
for name, child in named_children:
patched_child = _patch(child, context)
if patched_child != child:
model.register_module(name, patched_child)
_recursive_children(context, model.named_children())
rewrite_qualname = _find_rewrite_module_qualname(model)
if rewrite_qualname is not None:
cls_type = _class_from_qualname(rewrite_qualname)
model = _update_module_type(model, cls_type, dict(context=context))
return model
def _update_model(model: torch.nn.Module):
"""Update model after patch and load.
Args:
model (Module): The model to be updated.
"""
# recursive over children
for _, child in model.named_children():
_update_model(child)
if hasattr(model, '_update_model_fn'):
model._update_model_fn()
def _dist_model(model: torch.nn.Module,
rank: int = 0,
device_mesh: DeviceMesh = None):
"""distribute model parameters."""
def _init_params():
"""init params."""
device = torch.device(f'cuda:{rank}')
for name, param in model.named_parameters(recurse=False):
if device != param.device:
if rank == 0:
new_param = param.to(device)
model.register_parameter(
name, torch.nn.Parameter(new_param,
requires_grad=False))
else:
new_param = torch.empty_like(param, device=device)
model.register_parameter(
name, torch.nn.Parameter(new_param,
requires_grad=False))
for name, param in model.named_buffers(recurse=False):
if device != param.device:
if rank == 0:
new_param = param.to(device)
model.register_buffer(name, new_param)
else:
new_param = torch.empty_like(param, device=device)
model.register_buffer(name, new_param)
def _dist_params():
"""dist params."""
if hasattr(model, '_distribute_partition_fn'):
partition_module(
model,
device_mesh=device_mesh,
func=model._distribute_partition_fn,
to_local=True,
)
else:
replicate_module(model, device_mesh=device_mesh)
torch.cuda.empty_cache()
def _register_hooks():
"""register hooks."""
if hasattr(model, '_distribute_input_fn'):
input_fn = model._distribute_input_fn
model.register_forward_pre_hook(
lambda _, inputs, inputs_dict: input_fn(
inputs, inputs_dict, device_mesh),
with_kwargs=True,
)
if hasattr(model, '_distribute_output_fn'):
output_fn = model._distribute_output_fn
model.register_forward_hook(
lambda mod, inputs, outputs: output_fn(outputs, device_mesh))
for name, child in model.named_children():
if rank == 0:
logger.debug(f'Distribute module: <{name}>')
new_child = _dist_model(child, rank, device_mesh)
if new_child != child:
model.register_module(name, child)
_init_params()
_dist_params()
_register_hooks()
return model
class PatchedForward:
"""patched forward."""
def __init__(self, model, context, extra_args):
self._model = model
self._patch_context: Dict = context
self._extra_args: list = extra_args
def __call__(self, *args, **kwargs):
for arg_name in self._extra_args:
extra_arg = kwargs.pop(arg_name, None)
self._patch_context[arg_name] = extra_arg
output = self._model(*args, **kwargs)
self._patch_context.clear()
return output
def patch(
model: torch.nn.Module,
extra_args: Sequence[str] = None,
rank: int = 0,
world_size: int = 1,
):
"""Patch the model with rewrite modules.
Extra arguments will be patched in forward of model, weights on each rank
will be partitioned.
Args:
model (Module): Model to be patched.
extra_args (Sequence[str]): Extra arguments of model forward.
rank (int): Distribution rank.
world_size (int): Distribution world size.
Returns:
Module: The patched model.
"""
if extra_args is None:
extra_args = []
_patch_context = Addict()
model = _patch(model, _patch_context)
if world_size > 1:
if rank == 0:
logger.info('distribute model parameters.')
device_mesh = DeviceMesh('cuda', list(range(world_size)))
model = _dist_model(model, rank, device_mesh=device_mesh)
_update_model(model)
patched_forward = PatchedForward(model,
_patch_context,
extra_args=extra_args)
model.patched_forward = patched_forward
return model
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass
import torch
import torch.distributed as dist
from ..kernels.mbgmm import mbgmm_a, mbgmm_b
from ..kernels.mbgmv import mbgmv_a, mbgmv_b
from ..kernels.rearange_all_gather import rearange_all_gather
@dataclass
class PackedLoRAInput:
x: torch.Tensor
a_cache: torch.Tensor
b_cache: torch.Tensor
q_start_loc: torch.Tensor
q_seqlens: torch.Tensor
adapter_ids: torch.Tensor
scaling: torch.Tensor
rank_page_table: torch.Tensor
rank_page_start: torch.Tensor
ranks: torch.Tensor
max_seq_len: int
max_rank: int
is_decoding: bool
class LoRALinear(torch.nn.Module):
def _make_packed_lora_input(self, x):
context = self.context.context
# adapter cache
global_adapter_ids = context.global_adapter_ids
layer_idx = self.layer_idx
ranks = self.ranks[global_adapter_ids]
block_starts = self.block_starts[global_adapter_ids]
scaling = self.scaling[global_adapter_ids]
k_cache, v_cache = context.kv_caches[layer_idx]
cache_len = k_cache.size(0)
a_cache = k_cache.view(cache_len, -1)
b_cache = v_cache.view(cache_len, -1)
return PackedLoRAInput(x=x.flatten(0, -2).contiguous(),
a_cache=a_cache,
b_cache=b_cache,
q_start_loc=context.q_start_loc,
q_seqlens=context.q_seq_length,
adapter_ids=context.local_adapter_ids,
scaling=scaling,
rank_page_table=context.adapter_offsets,
rank_page_start=block_starts,
ranks=ranks,
max_seq_len=context.max_q_seq_length,
max_rank=context.max_rank,
is_decoding=context.is_decoding)
def _lora_forward_local(self, x):
"""lora forward no tp."""
lora_input = self._make_packed_lora_input(x)
out_size = self.base_layer.weight.size(0)
if not lora_input.is_decoding:
xa = mbgmm_a(lora_input.x,
lora_input.a_cache,
q_start_loc=lora_input.q_start_loc,
q_seqlens=lora_input.q_seqlens,
adapter_ids=lora_input.adapter_ids,
rank_page_table=lora_input.rank_page_table,
rank_page_start=lora_input.rank_page_start,
ranks=lora_input.ranks,
max_seq_len=lora_input.max_seq_len,
max_rank=lora_input.max_rank)
lora_out = mbgmm_b(xa,
lora_input.b_cache,
q_start_loc=lora_input.q_start_loc,
q_seqlens=lora_input.q_seqlens,
adapter_ids=lora_input.adapter_ids,
scaling=lora_input.scaling,
rank_page_table=lora_input.rank_page_table,
rank_page_start=lora_input.rank_page_start,
ranks=lora_input.ranks,
max_seq_len=lora_input.max_seq_len,
max_rank=lora_input.max_rank,
out_size=out_size)
else:
xa = mbgmv_a(lora_input.x,
lora_input.a_cache,
adapter_ids=lora_input.adapter_ids,
rank_page_table=lora_input.rank_page_table,
rank_page_start=lora_input.rank_page_start,
ranks=lora_input.ranks,
max_rank=lora_input.max_rank)
lora_out = mbgmv_b(xa,
lora_input.b_cache,
adapter_ids=lora_input.adapter_ids,
scaling=lora_input.scaling,
rank_page_table=lora_input.rank_page_table,
rank_page_start=lora_input.rank_page_start,
ranks=lora_input.ranks,
max_rank=lora_input.max_rank,
out_size=out_size)
base_out = self.base_layer(x)
lora_out = lora_out.reshape(base_out.shape)
output = base_out + lora_out
return output
def _lora_forward_tp_rowwise(self, x):
"""lora forward tp rowwise."""
lora_input = self._make_packed_lora_input(x)
rank = dist.get_rank()
world_size = dist.get_world_size()
out_size = self.base_layer.weight.size(0) // world_size
if not lora_input.is_decoding:
xa = mbgmm_a(lora_input.x,
lora_input.a_cache,
q_start_loc=lora_input.q_start_loc,
q_seqlens=lora_input.q_seqlens,
adapter_ids=lora_input.adapter_ids,
rank_page_table=lora_input.rank_page_table,
rank_page_start=lora_input.rank_page_start,
ranks=lora_input.ranks,
max_seq_len=lora_input.max_seq_len,
max_rank=lora_input.max_rank)
lora_out = mbgmm_b(xa,
lora_input.b_cache,
q_start_loc=lora_input.q_start_loc,
q_seqlens=lora_input.q_seqlens,
adapter_ids=lora_input.adapter_ids,
scaling=lora_input.scaling,
rank_page_table=lora_input.rank_page_table,
rank_page_start=lora_input.rank_page_start,
ranks=lora_input.ranks,
max_seq_len=lora_input.max_seq_len,
max_rank=lora_input.max_rank,
out_size=out_size)
else:
xa = mbgmv_a(lora_input.x,
lora_input.a_cache,
adapter_ids=lora_input.adapter_ids,
rank_page_table=lora_input.rank_page_table,
rank_page_start=lora_input.rank_page_start,
ranks=lora_input.ranks,
max_rank=lora_input.max_rank)
lora_out = mbgmv_b(xa,
lora_input.b_cache,
adapter_ids=lora_input.adapter_ids,
scaling=lora_input.scaling,
rank_page_table=lora_input.rank_page_table,
rank_page_start=lora_input.rank_page_start,
ranks=lora_input.ranks,
max_rank=lora_input.max_rank,
out_size=out_size)
base_out = self.base_layer(x)
out_shape = base_out.shape
base_out = base_out.flatten(0, -2)
slice_start = rank * out_size
slice_end = slice_start + out_size
base_out[:, slice_start:slice_end] += lora_out
base_out = base_out.reshape(out_shape)
return base_out
def _lora_forward_tp_colwise(self, x):
"""lora forward tp colwise."""
def __gather_xa(xa):
"""gather xa."""
gathered_xa = xa.new_empty(world_size, xa.size(0), xa.size(1))
dist.all_gather_into_tensor(gathered_xa, xa)
# TODO: gather would failed when adapters have different ranks.
gathered_xa = gathered_xa.permute(1, 0, 2).flatten(-2, -1)
return gathered_xa
lora_input = self._make_packed_lora_input(x)
world_size = dist.get_world_size()
out_size = self.base_layer.weight.size(0)
if not lora_input.is_decoding:
xa = mbgmm_a(lora_input.x,
lora_input.a_cache,
q_start_loc=lora_input.q_start_loc,
q_seqlens=lora_input.q_seqlens,
adapter_ids=lora_input.adapter_ids,
rank_page_table=lora_input.rank_page_table,
rank_page_start=lora_input.rank_page_start,
ranks=lora_input.ranks,
max_seq_len=lora_input.max_seq_len,
max_rank=lora_input.max_rank,
rank_step=world_size)
gathered_xa = __gather_xa(xa)
if len(lora_input.ranks) > 1:
gathered_xa = rearange_all_gather(
gathered_xa,
b_start_loc=lora_input.q_start_loc,
b_seq_lens=lora_input.q_seqlens,
adapter_ids=lora_input.adapter_ids,
ranks=lora_input.ranks,
world_size=world_size,
max_seq_len=lora_input.max_seq_len,
output=gathered_xa)
lora_out = mbgmm_b(gathered_xa,
lora_input.b_cache,
q_start_loc=lora_input.q_start_loc,
q_seqlens=lora_input.q_seqlens,
adapter_ids=lora_input.adapter_ids,
scaling=lora_input.scaling,
rank_page_table=lora_input.rank_page_table,
rank_page_start=lora_input.rank_page_start,
ranks=lora_input.ranks,
max_seq_len=lora_input.max_seq_len,
max_rank=lora_input.max_rank,
out_size=out_size)
else:
xa = mbgmv_a(lora_input.x,
lora_input.a_cache,
adapter_ids=lora_input.adapter_ids,
rank_page_table=lora_input.rank_page_table,
rank_page_start=lora_input.rank_page_start,
ranks=lora_input.ranks,
max_rank=lora_input.max_rank,
rank_step=world_size)
gathered_xa = __gather_xa(xa)
if len(lora_input.ranks) > 1:
gathered_xa = rearange_all_gather(
gathered_xa,
b_start_loc=lora_input.q_start_loc,
b_seq_lens=lora_input.q_seqlens,
adapter_ids=lora_input.adapter_ids,
ranks=lora_input.ranks,
world_size=world_size,
max_seq_len=lora_input.max_seq_len,
output=gathered_xa)
lora_out = mbgmv_b(gathered_xa,
lora_input.b_cache,
adapter_ids=lora_input.adapter_ids,
scaling=lora_input.scaling,
rank_page_table=lora_input.rank_page_table,
rank_page_start=lora_input.rank_page_start,
ranks=lora_input.ranks,
max_rank=lora_input.max_rank,
out_size=out_size)
base_out = self.base_layer(x)
lora_out = lora_out.reshape(base_out.shape)
output = base_out + lora_out
return output
def _lora_forward_tp(self, x):
"""lora forward tp."""
tp_mode = getattr(self, '_tp_mode', None)
if tp_mode == 'rowwise':
return self._lora_forward_tp_rowwise(x)
elif tp_mode == 'colwise':
return self._lora_forward_tp_colwise(x)
else:
assert tp_mode is None, 'tp_mode == None failed.'
return self._lora_forward_local(x)
def _lora_forward(self, x):
"""lora forward."""
if dist.is_initialized():
return self._lora_forward_tp(x)
else:
return self._lora_forward_local(x)
def forward(self, x):
"""forward."""
context = self.context.context
max_rank = context.max_rank
if max_rank == 0:
return self.origin_mod.forward(x)
else:
return self._lora_forward(x)
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass
import torch
import torch.nn as nn
from ..kernels.w8a8_triton_kernels import (matmul_kernel_dynamic_quant,
per_channel_quant,
per_token_quant_int8,
rms_norm_dynamic_quant)
@dataclass
class QTensor:
"""A data class representing a Quantized Tensor.
This class wraps around a regular Pytorch tensor and adds quantization-
specific parameters.
"""
tensor: torch.Tensor
scale: torch.Tensor
zero_point: torch.Tensor = None
def __getattr__(self, name: str):
"""Allows attribute access to be forwarded to the wrapped tensor when
the attribute doesn't exist in QTensor."""
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.tensor, name)
class QRMSNorm(nn.Module):
"""It performs traditional RMS normalization and then quantizes the output
to 8-bit integers."""
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
@classmethod
def from_float(cls, mod: nn.Module, initialization: bool = True):
"""Class method to create a QRMSNorm instance from a floating-point
module.
`initialization = True` for real init.
`initialization = False` for dummy init.
"""
hidden_size = mod.weight.shape[0]
eps = mod.variance_epsilon
q_mod = cls(hidden_size, eps)
if initialization:
q_mod.weight = nn.Parameter(mod.weight.detach())
return q_mod
def forward(self, hidden_states):
"""Defines the computation performed at every call.
Performs RMS normalization followed by dynamic quantization on
hidden_states. Returns a QTensor which wraps the quantized tensor along
with its scale factor.
"""
hidden_states_quant, rms_scale = rms_norm_dynamic_quant(
hidden_states, self.weight, self.variance_epsilon)
return QTensor(hidden_states_quant, rms_scale)
class QLinear(nn.Module):
"""A Linear layer that operates on quantized inputs and weights.
It performs matrix multiplication in 8-bit precision and dequantize the
results back to float.
"""
__constants__ = ['in_features', 'out_features']
in_features: int
out_features: int
weight: torch.Tensor
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer(
'weight',
torch.empty((out_features, in_features),
device=device,
dtype=torch.int8))
self.register_buffer(
'scale',
torch.empty((out_features, 1), device=device, dtype=torch.float32))
if bias:
self.register_buffer('bias',
torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter('bias', None)
@classmethod
def from_float(cls, mod: nn.Module, initialization: bool = True):
"""Class method to create a QLinear instance from a floating-point
module.
`initialization = True` for real init.
`initialization = False` for dummy init.
"""
q_mod = cls(mod.in_features,
mod.out_features,
mod.bias is not None,
device=mod.weight.device,
dtype=mod.weight.dtype)
if initialization:
weight_quant, scale = per_channel_quant(mod.weight.detach(), 8,
torch.int8)
q_mod.weight.data = weight_quant
q_mod.scale = scale
if mod.bias is not None:
q_mod.bias.data = mod.bias.detach()
return q_mod
def forward(self, input):
"""Defines the computation performed at every call.
Performs quantization if the input is a tensor, otherwise it assumes
the input is already quantized (instance of QTensor). Then, it performs
linear transformation using dynamic quantization method, resulting in
an 8-bit integer output. Finally, it dequantizes the result back to a
floating point tensor.
"""
if isinstance(input, torch.Tensor):
input_quant, input_scale = per_token_quant_int8(input, 1e-7)
else:
assert isinstance(input, QTensor)
input_quant, input_scale = input.tensor, input.scale
out = matmul_kernel_dynamic_quant(input_quant,
self.weight,
input_scale,
self.scale,
output_dtype=torch.float16,
bias=self.bias)
return out
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment