Unverified Commit 19ade242 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[WIP]`NLLB-MoE` Adds the moe model (#22024)

* Initial commit

* update modeling code

* update doc

* add functions necessary

* fix impotrs

* revert changes

* fixup

* more styling to get going

* remove standalone encoder

* update code

* styling

* fix config and model

* update code and some refactoring

* make more tests pass

* Adding NLLB-200 - MoE - 54.5B for no language left behind
Fixes #21300

* fix mor common tests

* styke

* update testing file

* update

* update

* Router2 doc

* update check config with sparse layer

* add dummy router

* update current conversion script

* create on the fly conversion script

* Fixup

* style

* style 2

* fix empty return

* fix return

* Update default config sparse layers

* easier to create sparse layers

* update

* update conversion script

* update modeling

* add to toctree

* styling

* make ruff happy

* update docstring

* update conversion script

* update, will break tests but impelemting top2

* update

* local groups are supported here

* ️ Support for local groups is now removed ️

This is because it has to work with model parallelism that we do not support

* finish simplificaiton

* Fix forward

* style

* fixup

* Update modelling and test, refactoring

* update tests

* remove final layer)norm as it is done in the FF

* routing works! Logits test added

* nit in test

* remove top1router

* style

* make sure sparse are tested. Had to change route_tokens a liottle bit

* add support for unslip models when converting

* fixup

* style

* update test s

* update test

* REFACTOR

* encoder outputs match!

* style

* update testing

* 🎉encoder and decoder logits match 🎉



* styleing

* update tests

* cleanup tests

* fix router test and CIs

* cleanup

* cleanup test styling

* fix tests

* Finally the generation tests match!

* cleanup

* update test

* style testing file

* remove script

* cleanup

* more cleanup

* nits

* update

* NLLB tokenizer is wrong and will be fixed soon

* use LongTensors

* update tests

* revert some small changes

* fix second expert sampling and batch prioritized routing

* update tests

* finish last tests

* make ruff happy

* update

* ruff again

* style

* Update docs/source/en/model_doc/nllb-moe.mdx
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Updates based on review

* style and fix import issue

* nit

* more nits

* cleanup

* styling

* update test_seconde_expert_policy

* fix name

* last nit on the markdown examples

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 057e1d74
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import torch
from torch import nn
from transformers import NllbMoeConfig, NllbMoeModel
from transformers.modeling_utils import dtype_byte_size
from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME
def remove_ignore_keys_(state_dict):
ignore_keys = [
"encoder.version",
"decoder.version",
"model.encoder.version",
"model.decoder.version",
"decoder.output_projection.weight",
"_float_tensor",
"encoder.embed_positions._float_tensor",
"decoder.embed_positions._float_tensor",
]
for k in ignore_keys:
state_dict.pop(k, None)
def make_linear_from_emb(emb):
vocab_size, emb_size = emb.weight.shape
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
lin_layer.weight.data = emb.weight.data
return lin_layer
def rename_fairseq_keys(state_dict, expert_idx=None):
new_dict = {}
for old_key in state_dict.keys():
key = old_key
if "moe_layer.experts." in key:
if expert_idx is not None:
key = key.replace("moe_layer.experts.0", f"ffn.experts.expert_{expert_idx}")
else:
key = key.replace("moe_layer.experts.", "ffn.experts.expert_")
if "gate" in key:
key = key.replace(".moe_layer.gate.wg", ".ffn.router.classifier")
if "fc2" and "experts" not in key:
key = key.replace(".fc2.", ".ffn.fc2.")
if "fc1" and "experts" not in key:
key = key.replace(".fc1.", ".ffn.fc1.")
if ".encoder_attn." in key:
key = key.replace(".encoder_attn.", ".cross_attention.")
if "encoder_attn_layer_norm" in key:
key = key.replace("encoder_attn_layer_norm", "cross_attention_layer_norm")
if "final_layer_norm" in key:
key = key.replace("final_layer_norm", "ff_layer_norm")
new_dict[key] = state_dict[old_key]
return new_dict
def shard_on_the_fly(switch_checkpoint_path, dump_path, num_experts, dtype, weights_name: str = WEIGHTS_NAME):
sharded_state_dicts = []
total_size = 0
os.makedirs(dump_path, exist_ok=True)
for expert in range(num_experts):
expert_path = switch_checkpoint_path + f"-rank-{expert}.pt"
if os.path.isfile(expert_path):
expert_state = torch.load(expert_path)["model"]
remove_ignore_keys_(expert_state)
expert_state = rename_fairseq_keys(expert_state, expert)
save_path = os.path.join(
dump_path, weights_name.replace(".bin", f"-{len(sharded_state_dicts)+1:05d}-of-???.bin")
)
torch.save(expert_state, save_path)
sharded_state_dicts.append(expert_state.keys())
total_size += sum([value.numel() for key, value in expert_state.items()]) * dtype_byte_size(
expert_state[list(expert_state)[0]].dtype
)
# Add the last block
save_path = os.path.join(dump_path, weights_name.replace(".bin", f"-{len(sharded_state_dicts)+1:05d}-of-???.bin"))
shared_weights = torch.load(switch_checkpoint_path + "-shared.pt")["model"]
remove_ignore_keys_(shared_weights)
shared_weights = rename_fairseq_keys(shared_weights, None)
shared_weights["shared.weight"] = shared_weights["decoder.embed_tokens.weight"]
sharded_state_dicts.append(shared_weights.keys())
# If we only have the shared weights (dummy model/experts saved on the same file)
if len(sharded_state_dicts) == 1:
save_path = os.path.join(dump_path, weights_name)
torch.save(shared_weights, save_path)
return {weights_name: sharded_state_dicts[0]}, None
else:
torch.save(shared_weights, save_path)
# Otherwise, let's build the index
weight_map = {}
for idx, shard in enumerate(sharded_state_dicts):
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
temp_filename = os.path.join(dump_path, weights_name.replace(".bin", f"-{idx+1:05d}-of-???.bin"))
os.rename(temp_filename, os.path.join(dump_path, shard_file))
for key in shard:
weight_map[key] = shard_file
# Add the metadata
metadata = {"total_size": total_size}
index = {"metadata": metadata, "weight_map": weight_map}
with open(os.path.join(dump_path, WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
return metadata, index
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--nllb_moe_checkpoint_path",
default="/home/arthur_huggingface_co/fairseq/weights/checkpoints/model_moe_54b/checkpoint_2_300000",
type=str,
required=False,
help="Path to a directory containing a folder per layer. Follows the original Google format.",
)
parser.add_argument("--dtype", default="float32", type=str, required=False, help="dtype of the saved model")
parser.add_argument(
"--pytorch_dump_folder_path",
default="/home/arthur_huggingface_co/fairseq/weights/checkpoints/hf-converted-moe-54b",
type=str,
required=False,
help="Path to the output pytorch model.",
)
args = parser.parse_args()
metadata, index = shard_on_the_fly(
args.nllb_moe_checkpoint_path,
args.pytorch_dump_folder_path,
128,
args.dtype,
)
config = NllbMoeConfig.from_pretrained(
"facebook/nllb-200-3.3B", encoder_sparse_step=4, decoder_sparse_step=4, num_experts=128
)
config.save_pretrained(args.pytorch_dump_folder_path)
model = NllbMoeModel.from_pretrained(args.pytorch_dump_folder_path)
print("Done")
model.save_pretrained(args.pytorch_dump_folder_path)
This diff is collapsed.
...@@ -111,7 +111,6 @@ class SwitchTransformersConfig(PretrainedConfig): ...@@ -111,7 +111,6 @@ class SwitchTransformersConfig(PretrainedConfig):
num_sparse_decoder_layers=3, num_sparse_decoder_layers=3,
num_heads=12, num_heads=12,
num_experts=8, num_experts=8,
router_type="tokens_masked",
router_bias=False, router_bias=False,
router_jitter_noise=0.01, router_jitter_noise=0.01,
router_dtype="float32", router_dtype="float32",
...@@ -157,7 +156,6 @@ class SwitchTransformersConfig(PretrainedConfig): ...@@ -157,7 +156,6 @@ class SwitchTransformersConfig(PretrainedConfig):
self.decoder_sparse_step = self.num_decoder_layers # HACK: this will create 0 sparse layers self.decoder_sparse_step = self.num_decoder_layers # HACK: this will create 0 sparse layers
self.num_heads = num_heads self.num_heads = num_heads
self.router_type = router_type
self.num_experts = num_experts self.num_experts = num_experts
self.expert_capacity = expert_capacity self.expert_capacity = expert_capacity
self.router_bias = router_bias self.router_bias = router_bias
......
...@@ -4782,6 +4782,44 @@ class NezhaPreTrainedModel(metaclass=DummyObject): ...@@ -4782,6 +4782,44 @@ class NezhaPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
NLLB_MOE_PRETRAINED_MODEL_ARCHIVE_LIST = None
class NllbMoeForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class NllbMoeModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class NllbMoePreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class NllbMoeSparseMLP(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class NllbMoeTop2Router(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
NYSTROMFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None NYSTROMFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
This diff is collapsed.
...@@ -57,6 +57,8 @@ PRIVATE_MODELS = [ ...@@ -57,6 +57,8 @@ PRIVATE_MODELS = [
# Being in this list is an exception and should **not** be the rule. # Being in this list is an exception and should **not** be the rule.
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
# models to ignore for not tested # models to ignore for not tested
"NllbMoeDecoder",
"NllbMoeEncoder",
"LlamaDecoder", # Building part of bigger (tested) model. "LlamaDecoder", # Building part of bigger (tested) model.
"Blip2QFormerModel", # Building part of bigger (tested) model. "Blip2QFormerModel", # Building part of bigger (tested) model.
"DetaEncoder", # Building part of bigger (tested) model. "DetaEncoder", # Building part of bigger (tested) model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment