"src/vscode:/vscode.git/clone" did not exist on "12d66b47012c9258f9557e6d3a0c13bcd1c72871"
Unverified Commit dcb183f4 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`MPT`] Add MosaicML's `MPT` model to transformers (#24629)



* draft add new model like

* some cleaning of the config

* nits

* add nested configs

* nits

* update

* update

* added layer norms + triton kernels

* consider only LPLayerNorm for now.

* update

* all keys match.

* Update

* fixing nits here and there

* working forward pass.

* removed einops dependency

* nits

* format

* add alibi

* byebye head mask

* refactor attention

* nits.

* format

* fix nits.

* nuke ande updates

* nuke tokenizer test

* don't reshape query with kv heads

* added a bit of documentation.

* remove unneeded things

* nuke more stuff

* nit

* logits match - same generations

* rm unneeded methods

* 1 remaining failing CI test

* nit

* fix nits

* fix docs

* fix docs

* rm tokenizer

* fixup

* fixup

* fixup and fix tests

* fixed configuration object.

* use correct activation

* few minor fixes

* clarify docs a bit

* logits match à 1e-12

* skip and unskip a test

* added some slow tests.

* fix readme

* add more details

* Update docs/source/en/model_doc/mpt.md
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fix configuration issues

* more fixes in config

* added more models

* Apply suggestions from code review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* remove unneeded position ids

* fix some  comments

* Apply suggestions from code review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* revert suggestion

* mpt alibi + added batched generation

* Update src/transformers/models/mpt/__init__.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* remove init config

* Update src/transformers/models/mpt/configuration_mpt.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fix nit

* add another slow test

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fits in one line

* some refactor because make fixup doesn't pass

* add ft notebook

* update md

* correct doc path

---------
Co-authored-by: default avataryounesbelkada <younesbelkada@gmail.com>
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 1dbc1440
# coding=utf-8
# Copyright 2023 HuggingFace Inc. team and MosaicML NLP team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Mpt configuration"""
import copy
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING:
pass
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
MPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"mosaicml/mpt-7b": "https://huggingface.co/mosaicml/mpt-7b/resolve/main/config.json",
}
class MptAttentionConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`MptAttention`] class. It is used to instantiate
attention layers according to the specified arguments, defining the layers architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the MPT
[mosaicml/mpt-7b](https://huggingface.co/mosaicml/mpt-7b) architecture. Most of the arguments are kept for backward
compatibility with previous MPT models that are hosted on the Hub (previously with `trust_remote_code=True`).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
attn_type (`str`, *optional*, defaults to `"multihead_attention"`):
type of attention to use. Options: `"multihead_attention"`, `"multiquery_attention"`.
attn_pdrop (`float`, *optional*, defaults to 0.0):
The dropout probability for the attention layers.
attn_impl (`str`, *optional*, defaults to `"torch"`):
The attention implementation to use. One of `"torch"`, `"flash"`, or `"triton"`.
clip_qkv (`float`, *optional*):
If not `None`, clip the queries, keys, and values in the attention layer to this value.
softmax_scale (`float`, *optional*, defaults to `None`):
If not `None`, scale the softmax in the attention layer by this value. If `None`, will default to
`1/sqrt(hidden_size)`.
prefix_lm (`bool`, *optional*, defaults to `False`)):
Whether the model should operate as a Prefix LM. This requires passing an extra `prefix_mask` argument
which indicates which tokens belong to the prefix. Tokens in the prefix can attend to one another
bi-directionally. Tokens outside the prefix use causal attention.
qk_ln (`bool`, *optional*, defaults to `False`):
Whether to apply layer normalization to the queries and keys in the attention layer.
attn_uses_sequence_id (`bool`, *optional*, defaults to `False`)):
Whether to restrict attention to tokens that have the same token_type_ids. When the model is in `train`
mode, this requires passing an extra *token_type_ids* argument which indicates which sub-sequence each
token belongs to. Defaults to `False` meaning any provided *token_type_ids* will be ignored.
alibi (`bool`, *optional*, defaults to `True`):
Whether or not to use the alibi bias instead of positional embedding.
alibi_bias_max (`int`, *optional*, defaults to 8):
The maximum value of the alibi bias.
"""
def __init__(
self,
attn_type="multihead_attention",
attn_pdrop=0,
attn_impl="torch",
clip_qkv=None,
softmax_scale=None,
prefix_lm=False,
qk_ln=False,
attn_uses_sequence_id=False,
alibi=True,
alibi_bias_max=8,
**kwargs,
):
super().__init__()
self.attn_type = attn_type
self.attn_pdrop = attn_pdrop
self.attn_impl = attn_impl
self.clip_qkv = clip_qkv
self.softmax_scale = softmax_scale
self.prefix_lm = prefix_lm
self.attn_uses_sequence_id = attn_uses_sequence_id
self.alibi = alibi
self.qk_ln = qk_ln
self.alibi_bias_max = alibi_bias_max
if attn_type not in ["multihead_attention", "multiquery_attention"]:
raise ValueError(
f"`attn_type` has to be either `multihead_attention` or `multiquery_attention`. Received: {attn_type}"
)
class MptConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`MptModel`]. It is used to instantiate a Mpt model
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to the Mpt-7b architecture
[mosaicml/mpt-7b](https://huggingface.co/mosaicml/mpt-7b).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
d_model (`int`, *optional*, defaults to 2048):
Dimensionality of the embeddings and hidden states.
n_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
n_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
expansion_ratio (`int`, *optional*, defaults to 4):
The ratio of the up/down scale in the MLP.
max_seq_len (`int`, *optional*, defaults to 2048):
The maximum sequence length of the model.
vocab_size (`int`, *optional*, defaults to 50368):
Vocabulary size of the Mpt model. Defines the maximum number of different tokens that can be represented by
the `inputs_ids` passed when calling [`MptModel`]. Check [this
discussion](https://huggingface.co/bigscience/mpt/discussions/120#633d28389addb8530b406c2a) on how the
`vocab_size` has been defined.
resid_pdrop (`float`, *optional*, defaults to 0.1):
The dropout probability applied to the attention output before combining with residual.
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
The epsilon to use in the layer normalization layers.
emb_pdrop (`float`, *optional*, defaults to 0.1):
The dropout probability for the embedding layer.
learned_pos_emb (`bool`, *optional*, defaults to `False`):
Whether to use learned positional embeddings.
attn_config (`dict`, *optional*):
A dictionary used to configure the model's attention module.
init_device (`str`, *optional*):
The device to use for parameter initialization. Defined for backward compatibility
logit_scale (`float`, *optional*):
If not None, scale the logits by this value.
no_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in all linear layers.
verbose (`int`, *optional*, defaults to 0):
The verbosity level to use for logging. Used in the previous versions of MPT models for logging. This
argument is deprecated.
embedding_fraction (`float`, *optional*, defaults to 1.0):
The fraction to scale the gradients of the embedding layer by.
norm_type (`str`, *optional*, defaults to `"low_precision_layernorm"`):
Type of layer norm to use. All MPT models uses the same layer norm implementation. Defined for backward
compatibility.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
Example:
```python
>>> from transformers import MptConfig, MptModel
>>> # Initializing a Mpt configuration
>>> configuration = MptConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = MptModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "mpt"
attribute_map = {
"num_attention_heads": "n_heads",
"hidden_size": "d_model",
"num_hidden_layers": "n_layers",
}
def __init__(
self,
d_model: int = 2048,
n_heads: int = 16,
n_layers: int = 24,
expansion_ratio: int = 4,
max_seq_len: int = 2048,
vocab_size: int = 50368,
resid_pdrop: float = 0.0,
layer_norm_epsilon: float = 1e-5,
emb_pdrop: float = 0.0,
learned_pos_emb: bool = True,
attn_config: MptAttentionConfig = None,
init_device: str = "cpu",
logit_scale: Optional[Union[float, str]] = None,
no_bias: bool = True,
verbose: int = 0,
embedding_fraction: float = 1.0,
norm_type: str = "low_precision_layernorm",
use_cache: bool = False,
initializer_range=0.02,
**kwargs,
):
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.expansion_ratio = expansion_ratio
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.resid_pdrop = resid_pdrop
self.emb_pdrop = emb_pdrop
self.learned_pos_emb = learned_pos_emb
self.init_device = init_device
self.logit_scale = logit_scale
self.no_bias = no_bias
self.verbose = verbose
self.embedding_fraction = embedding_fraction
self.norm_type = norm_type
self.layer_norm_epsilon = layer_norm_epsilon
self.use_cache = use_cache
self.initializer_range = initializer_range
if attn_config is None:
self.attn_config = MptAttentionConfig()
elif isinstance(attn_config, dict):
self.attn_config = MptAttentionConfig(**attn_config)
elif isinstance(attn_config, MptAttentionConfig):
self.attn_config = attn_config
else:
raise ValueError(
f"`attn_config` has to be either a `MptAttentionConfig` or a dictionary. Received: {type(attn_config)}"
)
super().__init__(**kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["attn_config"] = (
self.attn_config.to_dict() if not isinstance(self.attn_config, dict) else self.attn_config
)
output["model_type"] = self.__class__.model_type
return output
# coding=utf-8
# Copyright 2023 HuggingFace Inc. team and MosaicML NLP team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch MPT model."""
import math
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
from torch.nn import functional as F
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_mpt import MptConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "mosaicml/mpt-7b"
_CONFIG_FOR_DOC = "MptConfig"
MPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"mosaicml/mpt-7b",
"mosaicml/mpt-7b-storywriter",
"mosaicml/mpt-7b-instruct",
"mosaicml/mpt-7b-8k",
"mosaicml/mpt-7b-8k-instruct",
"mosaicml/mpt-7b-8k-chat",
"mosaicml/mpt-30b",
"mosaicml/mpt-30b-instruct",
"mosaicml/mpt-30b-chat"
# See all MPT models at https://huggingface.co/models?filter=mpt
]
# Copied from transformers.models.bloom.modeling_bloom._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
) -> torch.BoolTensor:
"""
Make causal mask used for self-attention.
"""
batch_size, target_length = input_ids_shape
mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
# ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
seq_ids = torch.arange(target_length, device=device)
mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
if past_key_values_length > 0:
mask[:, :past_key_values_length] = False
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
return expanded_mask
# Copied from transformers.models.bloom.modeling_bloom._expand_mask
def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
"""
Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
"""
batch_size, src_length = mask.shape
tgt_length = tgt_length if tgt_length is not None else src_length
expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
def build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max=8, device=None):
r"""
Link to paper: https://arxiv.org/abs/2108.12409 - Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation. This implementation has been copied from
the alibi implementation of MPT source code that led to slightly different results than the Bloom alibi:
https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L292
"""
alibi = torch.arange(1 - sequence_length, 1, dtype=torch.int32, device=device).view(1, 1, 1, sequence_length)
num_heads_power_of_2 = 2 ** math.ceil(math.log2(num_heads))
base = torch.arange(1, num_heads_power_of_2 + 1, dtype=torch.float32, device=device)
base = base * (alibi_bias_max / num_heads_power_of_2)
slopes = 1.0 / torch.pow(2, base)
slopes = slopes.view(1, num_heads, 1, 1)
if num_heads_power_of_2 != num_heads:
slopes = torch.concat([slopes[1::2], slopes[::2]])[:num_heads]
alibi = alibi * slopes
return alibi.squeeze(0)
class MptAttention(nn.Module):
"""Multi-head self attention.
Using torch or triton attention implemetation enables user to also use additive bias.
"""
def __init__(self, config: MptConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.n_heads = config.n_heads
self.max_seq_length = config.max_seq_len
self.head_dim = self.hidden_size // self.n_heads
self.softmax_scale = config.attn_config.softmax_scale
if self.softmax_scale is None:
self.softmax_scale = 1 / math.sqrt(self.hidden_size / self.n_heads)
self.attn_dropout_p = config.attn_config.attn_pdrop
self.Wqkv = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
def forward(
self,
hidden_states: torch.Tensor,
position_bias: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
):
batch_size, seq_length = hidden_states.shape[:2]
mixed_qkv = self.Wqkv(hidden_states)
query_states, key_states, value_states = mixed_qkv.chunk(3, dim=2)
query_states = query_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
key_states = key_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
value_states = value_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
if past_key_value is not None:
if len(past_key_value) != 0:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states)
else:
past_key_value = (key_states, value_states)
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale
query_length = seq_length
if past_key_value is not None:
query_length += past_key_value[0].shape[2]
if position_bias is not None:
if len(position_bias.shape) != 3:
raise ValueError(f"Expecting position_bias shape to be 3 dimensions, got {len(position_bias.shape)}")
key_length = key_states.shape[-2]
position_bias_query_index = max(0, position_bias.size(1) - query_length)
position_bias_key_index = max(0, position_bias.size(2) - key_length)
position_bias = position_bias[:, position_bias_query_index:, position_bias_key_index:]
attention_scores = attention_scores + position_bias
if attention_mask is not None:
attention_scores = attention_scores.masked_fill(attention_mask, torch.finfo(query_states.dtype).min)
# (batch_size, n_heads, seq_length, key_length)
attn_weights = nn.functional.softmax(attention_scores.float(), dim=-1).to(value_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attn_dropout_p, training=self.training)
context_states = torch.matmul(attn_weights, value_states)
context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1)
attn_output = self.out_proj(context_states)
return attn_output, attn_weights, past_key_value
class MptMLP(nn.Module):
def __init__(self, config: MptConfig):
super().__init__()
hidden_size = config.hidden_size
self.up_proj = nn.Linear(hidden_size, 4 * hidden_size, bias=False)
self.act = nn.GELU(approximate="none")
self.down_proj = nn.Linear(4 * hidden_size, hidden_size, bias=False)
self.hidden_dropout = config.attn_config.attn_pdrop
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
hidden_states = self.act(self.up_proj(hidden_states))
intermediate_output = self.down_proj(hidden_states)
output = F.dropout(intermediate_output, p=self.hidden_dropout, training=self.training)
output = output + residual
return output
class MptBlock(nn.Module):
def __init__(self, config: MptConfig):
super().__init__()
hidden_size = config.hidden_size
self.norm_1 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
# backward compatibility with weights on the Hub
self.norm_1.bias = None
self.num_heads = config.n_heads
self.attn = MptAttention(config)
self.norm_2 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
# backward compatibility with weights on the Hub
self.norm_2.bias = None
self.ffn = MptMLP(config)
self.dropout_rate = config.attn_config.attn_pdrop
self.resid_attn_dropout = nn.Dropout(self.dropout_rate)
def forward(
self,
hidden_states: torch.Tensor,
position_bias: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
# hidden_states: [batch_size, seq_length, hidden_size]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.norm_1(hidden_states)
residual = hidden_states
# Self attention.
attn_outputs, attn_weights, past_key_value = self.attn(
layernorm_output,
position_bias=position_bias,
attention_mask=attention_mask,
past_key_value=layer_past,
)
hidden_states = self.resid_attn_dropout(attn_outputs) + residual
layernorm_output = self.norm_2(hidden_states)
# Get residual
residual = hidden_states
# MLP.
output = self.ffn(layernorm_output, residual)
outputs = (output,)
if use_cache:
outputs += (past_key_value,)
if output_attentions:
outputs += (attn_weights,)
return outputs # hidden_states, present, attentions
class MptPreTrainedModel(PreTrainedModel):
config_class = MptConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["MptBlock"]
_keys_to_ignore_on_load_missing = [r"lm_head.*."]
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(self, module: nn.Module):
"""Initialize the weights."""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, LayerNorm):
if module.bias is not None:
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
if isinstance(module, MptModel):
module.gradient_checkpointing = value
@staticmethod
def _convert_to_mpt_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Converts the cache to the format expected by Mpt, i.e. to tuple(tuple([batch_size * num_heads, ...]))
"""
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
batch_size_times_num_heads = batch_size * num_heads
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].reshape(batch_size_times_num_heads, head_dim, seq_length),
layer_past[1].reshape(batch_size_times_num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)
MPT_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`MptConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
MPT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
(`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
`input_ids`.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
their past given to this model should not be passed as `input_ids` as they have already been computed.
Each element of `past_key_values` is a tuple (past_key, past_value):
- past_key: [batch_size * num_heads, head_dim, kv_length]
- past_value: [batch_size * num_heads, kv_length, head_dim]
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
`past_key_values`).
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare Mpt Model transformer outputting raw hidden-states without any specific head on top.",
MPT_START_DOCSTRING,
)
class MptModel(MptPreTrainedModel):
def __init__(self, config: MptConfig):
super().__init__(config)
self.hidden_size = config.hidden_size
self.num_heads = config.n_heads
# Embedding + LN Embedding
self.wte = nn.Embedding(config.vocab_size, self.hidden_size)
# Transformer blocks
self.blocks = nn.ModuleList([MptBlock(config) for _ in range(config.n_layers)])
# Final Layer Norm
self.norm_f = LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
# backward compatibility with weights on the Hub
self.norm_f.bias = None
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.wte
def _prepare_attn_mask(
self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
) -> torch.BoolTensor:
# create causal mask
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
if input_shape[1] + past_key_values_length != attention_mask.shape[1]:
raise ValueError(
"Attention mask shape should be (batch_size, seq_length + past_key_values_length)"
f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length"
f" {past_key_values_length}."
)
combined_attention_mask = None
device = attention_mask.device
_, src_length = input_shape
if src_length > 1:
combined_attention_mask = _make_causal_mask(
input_shape, device=device, past_key_values_length=past_key_values_length
)
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
)
return combined_attention_mask
def set_input_embeddings(self, new_embeddings: torch.Tensor):
self.wte = new_embeddings
@add_start_docstrings_to_model_forward(MPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if past_key_values is None:
past_key_values = tuple([None] * len(self.blocks))
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
hidden_states = inputs_embeds
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# Compute alibi tensor: check build_alibi_tensor documentation
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)
alibi = build_mpt_alibi_tensor(self.num_heads, self.config.max_seq_len, device=hidden_states.device)
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
for i, (block, layer_past) in enumerate(zip(self.blocks, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
alibi,
causal_mask,
layer_past,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
use_cache=use_cache,
output_attentions=output_attentions,
position_bias=alibi,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
# Add last hidden state
hidden_states = self.norm_f(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
@add_start_docstrings(
"""
The MPT Model transformer with a language modeling head on top (linear layer with weights tied to the input
embeddings).
""",
MPT_START_DOCSTRING,
)
class MptForCausalLM(MptPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: MptConfig):
super().__init__(config)
self.transformer = MptModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings: torch.Tensor):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
**kwargs,
) -> dict:
# only last token for input_ids if past is not None
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values, # NITS should it be layer_past?
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
return model_inputs
@add_start_docstrings_to_model_forward(MPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutputWithCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
def _reorder_cache(
self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
Output shares the same memory storage as `past`.
"""
# Get a copy of `beam_idx` on all the devices where we need those indices.
device_to_beam_idx = {
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
}
reordered_past = tuple(
(
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
)
for layer_past in past
)
return reordered_past
@add_start_docstrings(
"""
The MPT Model transformer with a sequence classification head on top (linear layer).
[`MptForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-1) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
MPT_START_DOCSTRING,
)
class MptForSequenceClassification(MptPreTrainedModel):
def __init__(self, config: MptConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = MptModel(config)
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(MPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
"""
MPT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
MPT_START_DOCSTRING,
)
class MptForTokenClassification(MptPreTrainedModel):
def __init__(self, config: MptConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = MptModel(config)
if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
classifier_dropout = config.classifier_dropout
elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(MPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.classifier(hidden_states)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
batch_size, seq_length = labels.shape
loss_fct = CrossEntropyLoss()
loss = loss_fct(
logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
)
if not return_dict:
output = (logits,) + transformer_outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
"""
The MPT Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD
(a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
MPT_START_DOCSTRING,
)
class MptForQuestionAnswering(MptPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.transformer = MptModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(MPT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -5101,6 +5101,51 @@ class MPNetPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
MPT_PRETRAINED_MODEL_ARCHIVE_LIST = None
class MptForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MptForQuestionAnswering(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MptForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MptForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MptModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MptPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
MRA_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import math
import unittest
from transformers import MptConfig, is_torch_available
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import (
MPT_PRETRAINED_MODEL_ARCHIVE_LIST,
AutoTokenizer,
MptForCausalLM,
MptForQuestionAnswering,
MptForSequenceClassification,
MptForTokenClassification,
MptModel,
)
@require_torch
class MptModelTester:
def __init__(
self,
parent,
batch_size=14,
seq_length=7,
is_training=True,
use_token_type_ids=False,
use_input_mask=True,
use_labels=True,
use_mc_token_ids=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_token_type_ids = use_token_type_ids
self.use_input_mask = use_input_mask
self.use_labels = use_labels
self.use_mc_token_ids = use_mc_token_ids
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_dropout_prob = attention_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = None
self.bos_token_id = vocab_size - 1
self.eos_token_id = vocab_size - 1
self.pad_token_id = vocab_size - 1
def get_large_model_config(self):
return MptConfig.from_pretrained("mosaicml/mpt-7")
def prepare_config_and_inputs(self, gradient_checkpointing=False):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
sequence_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
config = self.get_config(gradient_checkpointing=gradient_checkpointing)
return (config, input_ids, input_mask, sequence_labels)
def get_config(self, gradient_checkpointing=False):
return MptConfig(
vocab_size=self.vocab_size,
seq_length=self.seq_length,
hidden_size=self.hidden_size,
n_layers=self.num_hidden_layers,
n_heads=self.num_attention_heads,
hidden_dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_dropout_prob,
n_positions=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
use_cache=True,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
num_labels=self.num_labels,
gradient_checkpointing=gradient_checkpointing,
dtype="float32",
)
def create_and_check_mpt_model(self, config, input_ids, input_mask, *args):
model = MptModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(len(result.past_key_values), config.n_layers)
def create_and_check_mpt_model_past(self, config, input_ids, input_mask, *args):
model = MptModel(config=config)
model.to(torch_device)
model.eval()
# first forward pass
outputs = model(input_ids, attention_mask=torch.ones_like(input_ids), use_cache=True)
outputs_use_cache_conf = model(input_ids, attention_mask=torch.ones_like(input_ids))
outputs_no_past = model(input_ids, use_cache=False, attention_mask=torch.ones_like(input_ids))
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
past = outputs["past_key_values"]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# append to next input_ids and token_type_ids
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
output_from_no_past = model(next_input_ids)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_values=past)["last_hidden_state"]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_mpt_model_attention_mask_past(self, config, input_ids, input_mask, *args):
model = MptModel(config=config)
model.to(torch_device)
model.eval()
# create attention mask
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
half_seq_length = self.seq_length // 2
attn_mask[:, half_seq_length:] = 0
# first forward pass
output, past = model(input_ids, attention_mask=attn_mask).to_tuple()
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# change a random masked slice from input_ids
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
# append to next input_ids and attn_mask
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
attn_mask = torch.cat(
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
dim=1,
)
# get two different outputs
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_values=past, attention_mask=attn_mask)["last_hidden_state"]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_mpt_model_past_large_inputs(self, config, input_ids, input_mask, *args):
model = MptModel(config=config)
model.to(torch_device)
model.eval()
# first forward pass
outputs = model(
input_ids,
attention_mask=input_mask,
use_cache=True,
)
past_key_values = outputs.past_key_values
# create hypothetical multiple next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
output_hidden_states=True,
)
hidden_states_from_no_past = output_from_no_past["hidden_states"][0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
)
hidden_states_from_past = output_from_past["hidden_states"][0]
# select random slice
random_slice_idx = ids_tensor((1,), hidden_states_from_past.shape[-1]).item()
output_from_no_past_slice = hidden_states_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = hidden_states_from_past[:, :, random_slice_idx].detach()
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_lm_head_model(self, config, input_ids, input_mask, *args):
model = MptForCausalLM(config)
model.to(torch_device)
model.eval()
result = model(input_ids, labels=input_ids)
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_sequence_classification_model(self, config, input_ids, input_mask, *args):
config.num_labels = self.num_labels
model = MptForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_token_classification_model(self, config, input_ids, input_mask, *args):
model = MptForTokenClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_question_answering_model(self, config, input_ids, input_mask, *args):
model = MptForQuestionAnswering(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_forward_and_backwards(
self, config, input_ids, input_mask, *args, gradient_checkpointing=False
):
model = MptForCausalLM(config)
model.to(torch_device)
if gradient_checkpointing:
model.gradient_checkpointing_enable()
result = model(input_ids, labels=input_ids)
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
result.loss.backward()
def create_and_check_mpt_weight_initialization(self, config, *args):
model = MptModel(config)
model_std = model.config.initializer_range / math.sqrt(2 * model.config.n_layers)
for key in model.state_dict().keys():
if "c_proj" in key and "weight" in key:
self.parent.assertLessEqual(abs(torch.std(model.state_dict()[key]) - model_std), 0.001)
self.parent.assertLessEqual(abs(torch.mean(model.state_dict()[key]) - 0.0), 0.01)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, input_mask, sequence_labels = config_and_inputs
inputs_dict = {"input_ids": input_ids}
return config, inputs_dict
@require_torch
class MptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(
MptModel,
MptForCausalLM,
MptForSequenceClassification,
MptForTokenClassification,
MptForQuestionAnswering,
)
if is_torch_available()
else ()
)
all_generative_model_classes = (MptForCausalLM,) if is_torch_available() else ()
fx_compatible = False
test_missing_keys = False
test_pruning = False
test_torchscript = False
test_head_masking = False
pipeline_model_mapping = (
{"feature-extraction": MptModel, "text-generation": MptForCausalLM} if is_torch_available() else {}
)
def setUp(self):
self.model_tester = MptModelTester(self)
self.config_tester = ConfigTester(self, config_class=MptConfig, n_embd=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_mpt_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mpt_model(*config_and_inputs)
def test_mpt_model_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mpt_model_past(*config_and_inputs)
def test_mpt_model_att_mask_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mpt_model_attention_mask_past(*config_and_inputs)
def test_mpt_model_past_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mpt_model_past_large_inputs(*config_and_inputs)
def test_mpt_lm_head_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
def test_mpt_sequence_classification_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_sequence_classification_model(*config_and_inputs)
def test_mpt_token_classification_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_token_classification_model(*config_and_inputs)
def test_mpt_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
def test_mpt_weight_initialization(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mpt_weight_initialization(*config_and_inputs)
@unittest.skip("For backward compatibility the lm_head is not in the model's state dict on the Hub.")
def test_model_weights_reload_no_missing_tied_weights(self):
pass
@slow
def test_model_from_pretrained(self):
for model_name in MPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = MptModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@slow
@require_torch_gpu
class MptIntegrationTests(unittest.TestCase):
def test_generation_8k(self):
model_id = "mosaicml/mpt-7b-8k"
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Load in 4bit to fit the daily CI runner GPU RAM
model = MptForCausalLM.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device_map={"": 0}, load_in_4bit=True
)
input_text = "Hello"
expected_output = "Hello my name is [name] and I am a [type] at [company]. I have a [number]"
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=20)
decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
self.assertEqual(decoded_output, expected_output)
def test_generation(self):
model_id = "mosaicml/mpt-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Load in 4bit to fit the daily CI runner GPU RAM
model = MptForCausalLM.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device_map={"": 0}, load_in_4bit=True
)
input_text = "Hello"
expected_output = "Hello my name is Kaitlyn and I am a senior at the University of Wisconsin-Stout. I am major"
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=20)
decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
self.assertEqual(decoded_output, expected_output)
def test_generation_batched(self):
model_id = "mosaicml/mpt-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Load in 4bit to fit the daily CI runner GPU RAM
model = MptForCausalLM.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device_map={"": 0}, load_in_4bit=True
)
input_texts = ["Hello my name is", "Today I am going at the gym and"]
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
inputs = tokenizer(input_texts, return_tensors="pt", padding=True).to(torch_device)
expected_output = [
"Hello my name is Tiffany and I am a mother of two beautiful children. I have been a nanny for over",
"Today I am going at the gym and then I am going to go to the grocery store. I am going to get some food and then",
]
outputs = model.generate(**inputs, max_new_tokens=20)
decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
for i, predicted_output in enumerate(decoded_outputs):
self.assertEqual(predicted_output, expected_output[i])
def test_model_logits(self):
model_id = "mosaicml/mpt-7b"
# Load in 4bit to fit the daily CI runner GPU RAM
model = MptForCausalLM.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device_map={"": 0}, load_in_4bit=True
)
dummy_input = torch.LongTensor([[1, 2, 3, 4, 5]]).to(torch_device)
outputs = model(dummy_input, output_hidden_states=True)
expected_slice = torch.Tensor([-0.2559, -0.2197, -0.2480]).to(torch_device, torch.bfloat16)
predicted_slice = outputs.hidden_states[-1][0, 0, :3]
self.assertTrue(torch.allclose(expected_slice, predicted_slice, atol=1e-3, rtol=1e-3))
......@@ -98,6 +98,9 @@ SPECIAL_CASES_TO_ALLOW.update(
"LayoutLMv2Config": True,
"MaskFormerSwinConfig": True,
"MT5Config": True,
# For backward compatibility with trust remote code models
"MptConfig": True,
"MptAttentionConfig": True,
"NatConfig": True,
"OneFormerConfig": True,
"PerceiverConfig": True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment