Unverified Commit ccdabc56 authored by Yikang Shen's avatar Yikang Shen Committed by GitHub
Browse files

Add JetMoE model (#30005)



* init jetmoe code

* update archive maps

* remove flax import

* fix import error

* update README

* ruff fix

* update readme

* fix

* update config

* fix issue

* merge files

* fix model bug

* fix test

* auto fix

* model size

* add comments

* fix form

* add flash attention support

* fix attention head number

* fix init

* fix support list

* sort auto mapping

* fix test

* fix docs

* update test

* fix test

* fix test

* change variable name

* fix config

* fix init

* update format

* clean code

* fix config

* fix config

* change default config

* update config

* fix issues

* update formate

* update config argument

* update format

* Update src/transformers/models/jetmoe/modeling_jetmoe.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* change to mixtral aux loss

* change to cache_position

* debug

* fix bugs

* debug

* fix format

* fix format

* fix copy

* fix format

* fix format

* fix sort

* fix sort

* fix sort

* add copy comment

* add copy from

* remove debug code

* revert readme update

* add copy

* debug

* remove debug code

* fix flash attention

* add comments

* clean code

* clean format

* fix format

* fix format

* Update src/transformers/models/jetmoe/modeling_jetmoe.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* change variable name

* add copied from

* fix variable name

* remove deprecated functinos

* sync to llama implementation

* fix format

* fix copy

* fix format

* update format

* remove repr

* add comment for moe weight

* fix copy

* Update src/transformers/models/jetmoe/configuration_jetmoe.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* add comments and reformat config

* fix format

* fix format

* fix format

* update test

* update doc string in config

* Update src/transformers/models/jetmoe/modeling_jetmoe.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* update config doc

* update attention cache

* fix format

* fix copy

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
parent d84f34ad
...@@ -386,6 +386,8 @@ ...@@ -386,6 +386,8 @@
title: I-BERT title: I-BERT
- local: model_doc/jamba - local: model_doc/jamba
title: Jamba title: Jamba
- local: model_doc/jetmoe
title: JetMoe
- local: model_doc/jukebox - local: model_doc/jukebox
title: Jukebox title: Jukebox
- local: model_doc/led - local: model_doc/led
......
...@@ -166,6 +166,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -166,6 +166,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Informer](model_doc/informer) | ✅ | ❌ | ❌ | | [Informer](model_doc/informer) | ✅ | ❌ | ❌ |
| [InstructBLIP](model_doc/instructblip) | ✅ | ❌ | ❌ | | [InstructBLIP](model_doc/instructblip) | ✅ | ❌ | ❌ |
| [Jamba](model_doc/jamba) | ✅ | ❌ | ❌ | | [Jamba](model_doc/jamba) | ✅ | ❌ | ❌ |
| [JetMoe](model_doc/jetmoe) | ✅ | ❌ | ❌ |
| [Jukebox](model_doc/jukebox) | ✅ | ❌ | ❌ | | [Jukebox](model_doc/jukebox) | ✅ | ❌ | ❌ |
| [KOSMOS-2](model_doc/kosmos-2) | ✅ | ❌ | ❌ | | [KOSMOS-2](model_doc/kosmos-2) | ✅ | ❌ | ❌ |
| [LayoutLM](model_doc/layoutlm) | ✅ | ✅ | ❌ | | [LayoutLM](model_doc/layoutlm) | ✅ | ✅ | ❌ |
......
<!--Copyright 2024 JetMoe team and The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# JetMoe
## Overview
**JetMoe-8B** is an 8B Mixture-of-Experts (MoE) language model developed by [Yikang Shen](https://scholar.google.com.hk/citations?user=qff5rRYAAAAJ) and [MyShell](https://myshell.ai/).
JetMoe project aims to provide a LLaMA2-level performance and efficient language model with a limited budget.
To achieve this goal, JetMoe uses a sparsely activated architecture inspired by the [ModuleFormer](https://arxiv.org/abs/2306.04640).
Each JetMoe block consists of two MoE layers: Mixture of Attention Heads and Mixture of MLP Experts.
Given the input tokens, it activates a subset of its experts to process them.
This sparse activation schema enables JetMoe to achieve much better training throughput than similar size dense models.
The training throughput of JetMoe-8B is around 100B tokens per day on a cluster of 96 H100 GPUs with a straightforward 3-way pipeline parallelism strategy.
This model was contributed by [Yikang Shen](https://huggingface.co/YikangS).
## JetMoeConfig
[[autodoc]] JetMoeConfig
## JetMoeModel
[[autodoc]] JetMoeModel
- forward
## JetMoeForCausalLM
[[autodoc]] JetMoeForCausalLM
- forward
## JetMoeForSequenceClassification
[[autodoc]] JetMoeForSequenceClassification
- forward
...@@ -50,6 +50,7 @@ FlashAttention-2 is currently supported for the following architectures: ...@@ -50,6 +50,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj#transformers.GPTJModel) * [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj#transformers.GPTJModel)
* [Idefics2](https://huggingface.co/docs/transformers/model_doc/idefics2#transformers.Idefics2Model) * [Idefics2](https://huggingface.co/docs/transformers/model_doc/idefics2#transformers.Idefics2Model)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel) * [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) * [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [Llava](https://huggingface.co/docs/transformers/model_doc/llava) * [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
...@@ -198,6 +199,7 @@ For now, Transformers supports SDPA inference and training for the following arc ...@@ -198,6 +199,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) * [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel) * [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel) * [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) * [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel) * [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
......
...@@ -448,6 +448,7 @@ _import_structure = { ...@@ -448,6 +448,7 @@ _import_structure = {
"InstructBlipVisionConfig", "InstructBlipVisionConfig",
], ],
"models.jamba": ["JambaConfig"], "models.jamba": ["JambaConfig"],
"models.jetmoe": ["JetMoeConfig"],
"models.jukebox": [ "models.jukebox": [
"JukeboxConfig", "JukeboxConfig",
"JukeboxPriorConfig", "JukeboxPriorConfig",
...@@ -2202,6 +2203,14 @@ else: ...@@ -2202,6 +2203,14 @@ else:
"JambaPreTrainedModel", "JambaPreTrainedModel",
] ]
) )
_import_structure["models.jetmoe"].extend(
[
"JetMoeForCausalLM",
"JetMoeForSequenceClassification",
"JetMoeModel",
"JetMoePreTrainedModel",
]
)
_import_structure["models.jukebox"].extend( _import_structure["models.jukebox"].extend(
[ [
"JukeboxModel", "JukeboxModel",
...@@ -4973,6 +4982,7 @@ if TYPE_CHECKING: ...@@ -4973,6 +4982,7 @@ if TYPE_CHECKING:
InstructBlipVisionConfig, InstructBlipVisionConfig,
) )
from .models.jamba import JambaConfig from .models.jamba import JambaConfig
from .models.jetmoe import JetMoeConfig
from .models.jukebox import ( from .models.jukebox import (
JukeboxConfig, JukeboxConfig,
JukeboxPriorConfig, JukeboxPriorConfig,
...@@ -6591,6 +6601,12 @@ if TYPE_CHECKING: ...@@ -6591,6 +6601,12 @@ if TYPE_CHECKING:
JambaModel, JambaModel,
JambaPreTrainedModel, JambaPreTrainedModel,
) )
from .models.jetmoe import (
JetMoeForCausalLM,
JetMoeForSequenceClassification,
JetMoeModel,
JetMoePreTrainedModel,
)
from .models.jukebox import ( from .models.jukebox import (
JukeboxModel, JukeboxModel,
JukeboxPreTrainedModel, JukeboxPreTrainedModel,
......
...@@ -117,6 +117,7 @@ from . import ( ...@@ -117,6 +117,7 @@ from . import (
informer, informer,
instructblip, instructblip,
jamba, jamba,
jetmoe,
jukebox, jukebox,
kosmos2, kosmos2,
layoutlm, layoutlm,
......
...@@ -128,6 +128,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ...@@ -128,6 +128,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("informer", "InformerConfig"), ("informer", "InformerConfig"),
("instructblip", "InstructBlipConfig"), ("instructblip", "InstructBlipConfig"),
("jamba", "JambaConfig"), ("jamba", "JambaConfig"),
("jetmoe", "JetMoeConfig"),
("jukebox", "JukeboxConfig"), ("jukebox", "JukeboxConfig"),
("kosmos-2", "Kosmos2Config"), ("kosmos-2", "Kosmos2Config"),
("layoutlm", "LayoutLMConfig"), ("layoutlm", "LayoutLMConfig"),
...@@ -399,6 +400,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ...@@ -399,6 +400,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("informer", "Informer"), ("informer", "Informer"),
("instructblip", "InstructBLIP"), ("instructblip", "InstructBLIP"),
("jamba", "Jamba"), ("jamba", "Jamba"),
("jetmoe", "JetMoe"),
("jukebox", "Jukebox"), ("jukebox", "Jukebox"),
("kosmos-2", "KOSMOS-2"), ("kosmos-2", "KOSMOS-2"),
("layoutlm", "LayoutLM"), ("layoutlm", "LayoutLM"),
......
...@@ -125,6 +125,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ...@@ -125,6 +125,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("imagegpt", "ImageGPTModel"), ("imagegpt", "ImageGPTModel"),
("informer", "InformerModel"), ("informer", "InformerModel"),
("jamba", "JambaModel"), ("jamba", "JambaModel"),
("jetmoe", "JetMoeModel"),
("jukebox", "JukeboxModel"), ("jukebox", "JukeboxModel"),
("kosmos-2", "Kosmos2Model"), ("kosmos-2", "Kosmos2Model"),
("layoutlm", "LayoutLMModel"), ("layoutlm", "LayoutLMModel"),
...@@ -458,6 +459,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ...@@ -458,6 +459,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"), ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
("gptj", "GPTJForCausalLM"), ("gptj", "GPTJForCausalLM"),
("jamba", "JambaForCausalLM"), ("jamba", "JambaForCausalLM"),
("jetmoe", "JetMoeForCausalLM"),
("llama", "LlamaForCausalLM"), ("llama", "LlamaForCausalLM"),
("mamba", "MambaForCausalLM"), ("mamba", "MambaForCausalLM"),
("marian", "MarianForCausalLM"), ("marian", "MarianForCausalLM"),
...@@ -860,6 +862,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -860,6 +862,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("gptj", "GPTJForSequenceClassification"), ("gptj", "GPTJForSequenceClassification"),
("ibert", "IBertForSequenceClassification"), ("ibert", "IBertForSequenceClassification"),
("jamba", "JambaForSequenceClassification"), ("jamba", "JambaForSequenceClassification"),
("jetmoe", "JetMoeForSequenceClassification"),
("layoutlm", "LayoutLMForSequenceClassification"), ("layoutlm", "LayoutLMForSequenceClassification"),
("layoutlmv2", "LayoutLMv2ForSequenceClassification"), ("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
("layoutlmv3", "LayoutLMv3ForSequenceClassification"), ("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
......
...@@ -211,6 +211,13 @@ else: ...@@ -211,6 +211,13 @@ else:
"LlamaTokenizerFast" if is_tokenizers_available() else None, "LlamaTokenizerFast" if is_tokenizers_available() else None,
), ),
), ),
(
"jetmoe",
(
"LlamaTokenizer" if is_sentencepiece_available() else None,
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
("jukebox", ("JukeboxTokenizer", None)), ("jukebox", ("JukeboxTokenizer", None)),
( (
"kosmos-2", "kosmos-2",
......
# Copyright 2024 JetMoe AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
"configuration_jetmoe": ["JetMoeConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_jetmoe"] = [
"JetMoeForCausalLM",
"JetMoeModel",
"JetMoePreTrainedModel",
"JetMoeForSequenceClassification",
]
if TYPE_CHECKING:
from .configuration_jetmoe import JetMoeConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_jetmoe import (
JetMoeForCausalLM,
JetMoeForSequenceClassification,
JetMoeModel,
JetMoePreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# coding=utf-8
# Copyright 2024 JetMoe AI and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""JetMoe model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class JetMoeConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`JetMoeModel`]. It is used to instantiate a
JetMoe model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a configuration of the JetMoe-4B.
[jetmoe/jetmoe-8b](https://huggingface.co/jetmoe/jetmoe-8b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the JetMoe model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`JetMoeModel`]
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each key and value in the Transformer encoder.
kv_channels (`int`, *optional*, defaults to 128):
Defines the number of channels for the key and value tensors.
intermediate_size (`int`, *optional*, defaults to 5632):
Dimension of the MLP representations.
max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model might ever be used with. JetMoe's attention allows sequence of
up to 4096 tokens.
activation_function (`string`, *optional*, defaults to `"silu"`):
Defines the activation function for MLP experts.
num_local_experts (`int`, *optional*, defaults to 8):
Defines the number of experts in the MoE and MoA.
num_experts_per_tok (`int, *optional*, defaults to 2):
The number of experts to route per-token and for MoE and MoA.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabeling this will also
allow the model to output the auxiliary loss.
aux_loss_coef (`float`, *optional*, defaults to 0.01):
The coefficient for the auxiliary loss.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the "end-of-sequence" token.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
initializer_range (`float`, *optional*, defaults to 0.01):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import JetMoeModel, JetMoeConfig
>>> # Initializing a JetMoe 4B style configuration
>>> configuration = JetMoeConfig()
>>> # Initializing a model from the JetMoe 4B style configuration
>>> model = JetMoeModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "jetmoe"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32000,
hidden_size=2048,
num_hidden_layers=12,
num_key_value_heads=16,
kv_channels=128,
intermediate_size=5632,
max_position_embeddings=4096,
activation_function="silu",
num_local_experts=8,
num_experts_per_tok=2,
output_router_logits=False,
aux_loss_coef=0.01,
use_cache=True,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=True,
rope_theta=10000.0,
rms_norm_eps=1e-6,
initializer_range=0.01,
attention_dropout=0.0,
**kwargs,
):
if num_experts_per_tok > num_local_experts:
raise ValueError("`num_experts_per_tok` must be less than or equal to `num_local_experts`")
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_key_value_heads * num_experts_per_tok
self.num_key_value_heads = num_key_value_heads
self.kv_channels = kv_channels
self.intermediate_size = intermediate_size
self.max_position_embeddings = max_position_embeddings
self.activation_function = activation_function
self.num_local_experts = num_local_experts
self.num_experts_per_tok = num_experts_per_tok
self.output_router_logits = output_router_logits
self.aux_loss_coef = aux_loss_coef
self.use_cache = use_cache
self.initializer_range = initializer_range
self.attention_dropout = attention_dropout
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.rope_theta = rope_theta
self.rms_norm_eps = rms_norm_eps
super().__init__(
bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
)
This diff is collapsed.
...@@ -4334,6 +4334,34 @@ class JambaPreTrainedModel(metaclass=DummyObject): ...@@ -4334,6 +4334,34 @@ class JambaPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class JetMoeForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class JetMoeForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class JetMoeModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class JetMoePreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class JukeboxModel(metaclass=DummyObject): class JukeboxModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment