Unverified Commit accccdd0 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`Add Mixtral`] Adds support for the Mixtral MoE (#27942)



* up

* up

* test

* logits ok

* up

* up

* few fixes

* conversion script

* up

* nits

* nits

* update

* nuke

* more updates

* nites

* fix many issues

* nit

* scatter

* nit

* nuke megablocks

* nits

* fix conversion script

* nit

* remove

* nits

* nit

* update

* oupsssss

* change

* nits device

* nits

* fixup

* update

* merge

* add copied from

* fix the copy mentions

* update tests

* more fixes

* nits

* conversion script

* add parts of the readme

* Update tests/models/mixtral/test_modeling_mixtral.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* new test + conversion script

* Apply suggestions from code review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Apply suggestions from code review

* fix

* fix copies

* fix copies

* ooops

* fix config

* Apply suggestions from code review

* fix nits

* nit

* add copies

* add batched tests

* docs

* fix flash attention

* let's add more verbose

* add correct outputs

* support router ouptus

* ignore copies where needed

* fix

* cat list if list is given for now

* nits

* Update docs/source/en/model_doc/mixtral.md

* finish router refactoring

* fix forward

* fix expected values

* nits

* fixup

* fix

* fix bug

* fix

* fix dtype mismatch

* fix

* grrr grrr I support item assignment

* fix CI

* docs

* fixup

* remove some copied form

* fix weird diff

* skip doctest fast on the config and modeling

* mark that is supports flash attention in the doc

* update

* Update src/transformers/models/mixtral/modeling_mixtral.py
Co-authored-by: default avatarLysandre Debut <hi@lysand.re>

* Update docs/source/en/model_doc/mixtral.md
Co-authored-by: default avatarLysandre Debut <hi@lysand.re>

* revert router logits config issue

* update doc accordingly

* Update src/transformers/models/mixtral/convert_mixtral_weights_to_hf.py

* nits

* use torch testing asssert close

* fixup

* doc nits

---------
Co-authored-by: default avataryounesbelkada <younesbelkada@gmail.com>
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: default avatarLysandre Debut <hi@lysand.re>
parent 0676d992
# coding=utf-8
# Copyright 2023 Mixtral AI and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Mixtral model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"mistral-ai/Mixtral-8x7B": "https://huggingface.co/mistral-ai/Mixtral-8x7B/resolve/main/config.json",
}
class MixtralConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MixtralModel`]. It is used to instantiate an
Mixtral model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Mixtral-7B-v0.1 or Mixtral-7B-Instruct-v0.1.
[mixtralai/Mixtral-8x7B](https://huggingface.co/mixtralai/Mixtral-8x7B)
[mixtralai/Mixtral-7B-Instruct-v0.1](https://huggingface.co/mixtralai/Mixtral-7B-Instruct-v0.1)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the Mixtral model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`MixtralModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 14336):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
The maximum sequence length that this model might ever be used with. Mixtral's sliding window attention
allows sequence of up to 4096*32 tokens.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
The id of the padding token.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the "end-of-sequence" token.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 1000000.0):
The base period of the RoPE embeddings.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention window size. If not specified, will default to `4096`.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
num_experts_per_tok (`int`, *optional*, defaults to 2):
The number of experts to root per-token, can be also interpreted as the `top-p` routing
parameter
num_local_experts (`int`, *optional*, defaults to 8):
Number of experts per Sparse MLP layer.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabeling this will also
allow the model to output the auxiliary loss. See [here]() for more details
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss.
```python
>>> from transformers import MixtralModel, MixtralConfig
>>> # Initializing a Mixtral 7B style configuration
>>> configuration = MixtralConfig()
>>> # Initializing a model from the Mixtral 7B style configuration
>>> model = MixtralModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "mixtral"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=1e6,
sliding_window=4096,
attention_dropout=0.0,
num_experts_per_tok=2,
num_local_experts=8,
output_router_logits=False,
router_aux_loss_coef=0.001,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.num_experts_per_tok = num_experts_per_tok
self.num_local_experts = num_local_experts
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# Copyright 2023 Mistral AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import torch
from transformers import (
MixtralConfig,
MixtralForCausalLM,
)
"""
Sample usage:
```
python src/transformers/models/mixtral/convert_mixtral_weights_to_hf.py \
--input_dir /path/to/downloaded/mixtral/weights --model_size 7B --output_dir /output/path
```
Thereafter, models can be loaded via:
```py
from transformers import MixtralForCausalLM
model = MixtralForCausalLM.from_pretrained("/output/path")
```
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
"""
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def write_json(text, path):
with open(path, "w") as f:
json.dump(text, f)
def write_model(model_path, input_base_path, model_size, safe_serialization=True):
os.makedirs(model_path, exist_ok=True)
params = read_json(os.path.join(input_base_path, "params.json"))
num_shards = 1
# For some reason this is a string in the params.json
sliding_window = int(params["sliding_window"])
n_layers = params["num_hidden_layers"]
n_heads = params["num_attention_heads"]
n_heads_per_shard = n_heads // num_shards
dim = params["hidden_size"]
dims_per_head = dim // n_heads
base = params.get("rope_theta", 10000.0)
max_position_embeddings = 4096 * 8
num_local_experts = params["num_local_experts"]
ffn_dim = params["intermediate_size"]
vocab_size = params["vocab_size"]
if "num_key_value_heads" in params:
num_key_value_heads = params["num_key_value_heads"] # for GQA / MQA
num_local_key_value_heads = num_key_value_heads // num_shards
key_value_dim = dims_per_head * num_local_key_value_heads
else: # compatibility with other checkpoints
num_key_value_heads = n_heads
num_local_key_value_heads = n_heads_per_shard
key_value_dim = dim
# permute for sliced rotary
def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
# Load weights
loaded = [
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pt"), map_location="cpu") for i in range(8)
]
merged_state_dict = {}
for state_dict in loaded:
merged_state_dict.update(state_dict)
state_dict = {}
for layer_i in range(n_layers):
# Sharded
# Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
# the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
# redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.
state_dict.update(
{
f"model.layers.{layer_i}.input_layernorm.weight": merged_state_dict[
f"layers.{layer_i}.attention_norm.weight"
].clone(),
f"model.layers.{layer_i}.post_attention_layernorm.weight": merged_state_dict[
f"layers.{layer_i}.ffn_norm.weight"
].clone(),
}
)
state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
merged_state_dict[f"layers.{layer_i}.attention.wq.weight"]
.view(n_heads_per_shard, dims_per_head, dim)
.reshape(dim, dim)
)
state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
merged_state_dict[f"layers.{layer_i}.attention.wk.weight"]
.view(num_local_key_value_heads, dims_per_head, dim)
.reshape(key_value_dim, dim),
num_key_value_heads,
key_value_dim,
dim,
)
state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = (
merged_state_dict[f"layers.{layer_i}.attention.wv.weight"]
.view(num_local_key_value_heads, dims_per_head, dim)
.reshape(key_value_dim, dim)
)
state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = merged_state_dict[
f"layers.{layer_i}.attention.wo.weight"
]
w1 = merged_state_dict[f"layers.{layer_i}.block_sparse_moe.w1"]
w2 = merged_state_dict[f"layers.{layer_i}.block_sparse_moe.w2"]
w3 = merged_state_dict[f"layers.{layer_i}.block_sparse_moe.w3"]
experts_w1 = [
w1[ffn_dim * expert_idx : ffn_dim * (expert_idx + 1), :].contiguous().clone()
for expert_idx in range(num_local_experts)
]
for idx, expert_block in enumerate(experts_w1):
expert_key = f"model.layers.{layer_i}.block_sparse_moe.experts.{idx}.w1"
state_dict[expert_key + ".weight"] = expert_block.clone()
experts_w2 = [
w2[ffn_dim * expert_idx : ffn_dim * (expert_idx + 1), :].contiguous().clone()
for expert_idx in range(num_local_experts)
]
for idx, expert_block in enumerate(experts_w2):
expert_key = f"model.layers.{layer_i}.block_sparse_moe.experts.{idx}.w2"
state_dict[expert_key + ".weight"] = expert_block.T.clone().contiguous()
experts_w3 = [
w3[ffn_dim * expert_idx : ffn_dim * (expert_idx + 1), :].contiguous().clone()
for expert_idx in range(num_local_experts)
]
for idx, expert_block in enumerate(experts_w3):
expert_key = f"model.layers.{layer_i}.block_sparse_moe.experts.{idx}.w3"
state_dict[expert_key + ".weight"] = expert_block.clone()
state_dict[f"model.layers.{layer_i}.block_sparse_moe.gate.weight"] = merged_state_dict[
f"layers.{layer_i}.block_sparse_moe.gate.weight"
]
state_dict.update(
{
"model.norm.weight": merged_state_dict["norm.weight"],
"model.embed_tokens.weight": merged_state_dict["tok_embeddings.weight"],
"lm_head.weight": merged_state_dict["output.weight"],
}
)
config = MixtralConfig(
hidden_size=dim,
intermediate_size=ffn_dim,
num_attention_heads=params["num_attention_heads"],
num_hidden_layers=params["num_hidden_layers"],
rms_norm_eps=params["rms_norm_eps"],
num_key_value_heads=num_key_value_heads,
vocab_size=vocab_size,
rope_theta=base,
max_position_embeddings=max_position_embeddings,
sliding_window=sliding_window,
num_local_experts=num_local_experts,
)
print("Loading the checkpoint in a Mixtral model.")
with torch.device("meta"):
model = MixtralForCausalLM(config)
# Avoid saving this as part of the config.
del model.config._name_or_path
model.config.torch_dtype = torch.float16
print("Saving in the Transformers format.")
model.load_state_dict(state_dict, strict=True, assign=True)
for n, p in model.named_parameters():
assert p.device.type != "meta", f"{n} has not been loaded!"
model.save_pretrained(model_path, safe_serialization=safe_serialization)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir",
help="Location of Mixtral weights, which contains tokenizer.model and model folders",
required=True,
)
parser.add_argument(
"--model_size",
choices=["7B"],
help="'f' models correspond to the finetuned versions, and are specific to the Mixtral official release. For more details on Mixtral, checkout the original repo: https://huggingface.co/mistral-ai",
default="7B",
)
parser.add_argument("--output_dir", help="Location to write HF model", required=True)
parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
args = parser.parse_args()
write_model(
model_path=args.output_dir,
input_base_path=args.input_dir,
model_size=args.model_size,
safe_serialization=args.safe_serialization,
)
if __name__ == "__main__":
main()
This diff is collapsed.
......@@ -5262,6 +5262,34 @@ class MistralPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
class MixtralForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MixtralForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MixtralModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MixtralPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
This diff is collapsed.
......@@ -677,6 +677,8 @@ src/transformers/models/mgp_str/configuration_mgp_str.py
src/transformers/models/mgp_str/modeling_mgp_str.py
src/transformers/models/mistral/configuration_mistral.py
src/transformers/models/mistral/modeling_mistral.py
src/transformers/models/mixtral/configuration_mixtral.py
src/transformers/models/mixtral/modeling_mixtral.py
src/transformers/models/mluke/convert_mluke_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py
src/transformers/models/mobilenet_v1/configuration_mobilenet_v1.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment