Unverified Commit 005b957f authored by Abhi Venigalla's avatar Abhi Venigalla Committed by GitHub
Browse files

Add DBRX Model (#29921)



* wip

* fix __init__.py

* add docs

* Apply suggestions from code review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* address comments 1

* work on make fixup

* pass configs down

* add sdpa attention

* remove DbrxBlock

* add to configuration_auto

* docstring now passes formatting test

* fix style

* update READMEs

* add dbrx to modeling_auto

* make fix-copies generated this

* add DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP

* config docstring passes formatting test

* rename moe_loss_weight to router_aux_loss_coef

* add to flash-attn documentation

* fix model-path in tests

* Explicitly make `"suli"` the default `ffn_act_fn`
Co-authored-by: default avatarWing Lian <wing.lian@gmail.com>

* default to using router_aux_loss_coef over ffn_config[moe_loss_weight]

* fix _flash_attn_uses_top_left_mask and is_causal

* fix tests path

* don't use token type IDs

* follow Llama and remove token_type_ids from test

* init ConfigTester differently so tests pass

* remove multiple choice test

* remove question + answer test

* remove sequence classification test

* remove token classification test

* copy Llama tests and remove token_type_ids from test inputs

* do not test pruning or headmasking; style code

* add _tied_weights_keys parameter to pass test

* add type hints

* fix type check

* update config tester

* remove masked_lm test

* remove encoder tests

* initialize DbrxModelTester with correct params

* style

* torch_dtype does not rely on torch

* run make fixup, fix-copies

* use https://huggingface.co/v2ray/dbrx-base-fixed/blob/main/modeling_dbrx.py



* add copyright info

* fix imports and DbrxRotaryEmbedding

* update DbrxModel docstring

* use copies

* change model path in docstring

* use config in DbrxFFN

* fix flashattention2, sdpaattention

* input config to DbrXAttention, DbrxNormAttentionNorm

* more fixes

* fix

* fix again!

* add informative comment

* fix ruff?

* remove print statement + style

* change doc-test

* fix doc-test

* fix docstring

* delete commented out text

* make defaults match dbrx-instruct

* replace `router_aux_loss_coef` with `moe_loss_weight`

* is_decoder=True

* remove is_decoder from configtester

* implement sdpa properly

* make is_decoder pass tests

* start on the GenerationTesterMixin tests

* add dbrx to sdpa documentation

* skip weight typing test

* style

* initialize smaller model
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* Add DBRX to toctree

* skip test_new_cache_format

* make config defaults smaller again

* add pad_token_id

* remove pad_token_id from config

* Remove all references to DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP

* Update src/transformers/models/dbrx/__init__.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/dbrx/modeling_dbrx.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/model_doc/dbrx.md
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* Update src/transformers/models/dbrx/configuration_dbrx.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/model_doc/dbrx.md
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fix typo

* Apply suggestions from code review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* update docs, fix configuration_auto.py

* address pr comments

* remove is_decoder flag

* slice

* fix requires grad

* remove grad

* disconnect differently

* remove grad

* enable grads

* patch

* detach expert

* nissan al ghaib

* Update modeling_dbrx.py

* Update src/transformers/models/dbrx/modeling_dbrx.py
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* replace "Gemma" with "Dbrx"

* remove # type: ignore

* don't hardcode vocab_size

* remove ToDo

* Re-add removed idefics2 line

* Update test to use tiny-random!

* Remove TODO

* Remove one more case of loading the entire dbrx-instruct in the tests

* Update src/transformers/models/dbrx/modeling_dbrx.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* address some comments

* small model

* add dbrx to tokenization_auto

* More docstrings with add_start_docstrings

* Dbrx for now

* add PipelineTesterMixin

* Update src/transformers/models/dbrx/configuration_dbrx.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* remove flash-attn2 import error

* fix docstring
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add useage example

* put on one line
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix ffn_act_fn
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* change "dbrx" to "DBRX" for display purposes.

* fix __init__.py?

* fix __init__.py

* fix README

* return the aux_loss

* remove extra spaces

* fix configuration_auto.py

* fix format in tokenization_auto

* remove new line

* add more useage examples

---------
Co-authored-by: default avatarAbhi Venigalla <abhi.venigalla@databricks.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avatarEitan Turok <eitan.turok@databricks.com>
Co-authored-by: default avatarEitan Turok <150733043+eitanturok@users.noreply.github.com>
Co-authored-by: default avatarWing Lian <wing.lian@gmail.com>
Co-authored-by: default avatarEitan Turok <eitanturok@gmail.com>
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
Co-authored-by: default avatarMatt <rocketknight1@gmail.com>
Co-authored-by: default avatarYour Name <you@example.com>
Co-authored-by: default avatarMihir Patel <mihir.v.patel7@gmail.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 63c5e27e
......@@ -77,6 +77,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("data2vec-audio", "Data2VecAudioConfig"),
("data2vec-text", "Data2VecTextConfig"),
("data2vec-vision", "Data2VecVisionConfig"),
("dbrx", "DbrxConfig"),
("deberta", "DebertaConfig"),
("deberta-v2", "DebertaV2Config"),
("decision_transformer", "DecisionTransformerConfig"),
......@@ -340,6 +341,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("data2vec-audio", "Data2VecAudio"),
("data2vec-text", "Data2VecText"),
("data2vec-vision", "Data2VecVision"),
("dbrx", "DBRX"),
("deberta", "DeBERTa"),
("deberta-v2", "DeBERTa-v2"),
("decision_transformer", "Decision Transformer"),
......
......@@ -75,6 +75,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("data2vec-audio", "Data2VecAudioModel"),
("data2vec-text", "Data2VecTextModel"),
("data2vec-vision", "Data2VecVisionModel"),
("dbrx", "DbrxModel"),
("deberta", "DebertaModel"),
("deberta-v2", "DebertaV2Model"),
("decision_transformer", "DecisionTransformerModel"),
......@@ -439,6 +440,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("cpmant", "CpmAntForCausalLM"),
("ctrl", "CTRLLMHeadModel"),
("data2vec-text", "Data2VecTextForCausalLM"),
("dbrx", "DbrxForCausalLM"),
("electra", "ElectraForCausalLM"),
("ernie", "ErnieForCausalLM"),
("falcon", "FalconForCausalLM"),
......
......@@ -150,6 +150,7 @@ else:
("ctrl", ("CTRLTokenizer", None)),
("data2vec-audio", ("Wav2Vec2CTCTokenizer", None)),
("data2vec-text", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("dbrx", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)),
(
"deberta-v2",
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
"configuration_dbrx": ["DbrxConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_dbrx"] = [
"DbrxForCausalLM",
"DbrxModel",
"DbrxPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_dbrx import DbrxConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_dbrx import DbrxForCausalLM, DbrxModel, DbrxPreTrainedModel
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# coding=utf-8
# Copyright 2024 Databricks Mosaic Research and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" DBRX model configuration """
from typing import Any, Optional
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class DbrxAttentionConfig(PretrainedConfig):
"""Configuration class for Dbrx Attention.
[`DbrxAttention`] class. It is used to instantiate attention layers
according to the specified arguments, defining the layers architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
attn_pdrop (`float`, *optional*, defaults to 0.0):
The dropout probability for the attention layers.
clip_qkv (`float`, *optional*):
If set, clip the queries, keys, and values in the attention layer to this value.
kv_n_heads (`Optional[int]`, defaults to 1): For grouped_query_attention only, allow user to specify number of kv heads.
rope_theta (`float`, defaults to 10000.0): The base frequency for rope.
"""
def __init__(
self,
attn_pdrop: float = 0.0,
clip_qkv: Optional[float] = None,
kv_n_heads: int = 1,
rope_theta: float = 10000.0,
**kwargs: Any,
):
super().__init__(**kwargs)
self.attn_pdrop = attn_pdrop
self.clip_qkv = clip_qkv
self.kv_n_heads = kv_n_heads
self.rope_theta = rope_theta
for k in ["model_type"]:
if k in kwargs:
kwargs.pop(k)
if len(kwargs) != 0:
raise ValueError(f"Found unknown {kwargs=}")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs: Any) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if config_dict.get("model_type") == "dbrx":
config_dict = config_dict["attn_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class DbrxFFNConfig(PretrainedConfig):
"""Configuration class for Dbrx FFN.
[`DbrxFFN`] class. It is used to instantiate feedforward layers according to
the specified arguments, defining the layers architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
ffn_act_fn (`dict`, *optional*, defaults to `None`): A dict specifying activation function for the FFN.
The dict should have a key 'name' with the value being the name of the activation function along with
any additional keyword arguments. If `None`, then set to `{"name": "silu"}`.
ffn_hidden_size (`int`, defaults to 3584): The hidden size of the feedforward network.
moe_num_experts (`int`, defaults to 4): The number of experts in the mixture of experts layer.
moe_top_k (`int`, defaults to 1): The number of experts to use in the mixture of experts layer.
moe_jitter_eps (`float`, *optional*, defaults to `None`): If not `None`, the jitter epsilon for the mixture of experts layer.
moe_loss_weight (`float`, defaults to 0.01): The loss weight for the mixture of experts layer.
moe_normalize_expert_weights (`float`, *optional*, defaults to 1.0): The normalization factor for the expert weights.
"""
def __init__(
self,
ffn_act_fn: dict = None,
ffn_hidden_size: int = 3584,
moe_num_experts: int = 4,
moe_top_k: int = 1,
moe_jitter_eps: Optional[float] = None,
moe_loss_weight: float = 0.01,
moe_normalize_expert_weights: Optional[float] = 1.0,
**kwargs: Any,
):
super().__init__()
if ffn_act_fn is None:
ffn_act_fn = {"name": "silu"}
self.ffn_act_fn = ffn_act_fn
self.ffn_hidden_size = ffn_hidden_size
self.moe_num_experts = moe_num_experts
self.moe_top_k = moe_top_k
self.moe_jitter_eps = moe_jitter_eps
self.moe_loss_weight = moe_loss_weight
self.moe_normalize_expert_weights = moe_normalize_expert_weights
for k in ["model_type"]:
if k in kwargs:
kwargs.pop(k)
if len(kwargs) != 0:
raise ValueError(f"Found unknown {kwargs=}")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs: Any) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if config_dict.get("model_type") == "dbrx":
config_dict = config_dict["ffn_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class DbrxConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`DbrxModel`]. It is used to instantiate a Dbrx model according to the
specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a different configuration to that of the [databricks/dbrx-instruct](https://huggingface.co/databricks/dbrx-instruct) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
d_model (`int`, *optional*, defaults to 2048):
Dimensionality of the embeddings and hidden states.
n_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
n_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
max_seq_len (`int`, *optional*, defaults to 2048):
The maximum sequence length of the model.
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by
the `inputs_ids` passed when calling [`DbrxModel`].
resid_pdrop (`float`, *optional*, defaults to 0.0):
The dropout probability applied to the attention output before combining with residual.
emb_pdrop (`float`, *optional*, defaults to 0.0):
The dropout probability for the embedding layer.
attn_config (`dict`, *optional*):
A dictionary used to configure the model's attention module.
ffn_config (`dict`, *optional*):
A dictionary used to configure the model's FFN module.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabling this will also
allow the model to output the auxiliary loss. See [here]() for more details.
Example:
```python
>>> from transformers import DbrxConfig, DbrxModel
>>> # Initializing a Dbrx configuration
>>> configuration = DbrxConfig(n_layers=2, d_model=256, n_heads=8, vocab_size=128)
>>> # Initializing a model (with random weights) from the configuration
>>> model = DbrxModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "dbrx"
attribute_map = {
"num_attention_heads": "n_heads",
"hidden_size": "d_model",
"num_hidden_layers": "n_layers",
"max_position_embeddings": "max_seq_len",
}
def __init__(
self,
d_model: int = 2048,
n_heads: int = 16,
n_layers: int = 24,
max_seq_len: int = 2048,
vocab_size: int = 32000,
resid_pdrop: float = 0.0,
emb_pdrop: float = 0.0,
attn_config: Optional[DbrxAttentionConfig] = None,
ffn_config: Optional[DbrxFFNConfig] = None,
use_cache: bool = True,
initializer_range: float = 0.02,
output_router_logits: bool = False,
**kwargs: Any,
):
if attn_config is None:
self.attn_config = DbrxAttentionConfig()
elif isinstance(attn_config, dict):
self.attn_config = DbrxAttentionConfig(**attn_config)
else:
self.attn_config = attn_config
if ffn_config is None:
self.ffn_config = DbrxFFNConfig()
elif isinstance(ffn_config, dict):
self.ffn_config = DbrxFFNConfig(**ffn_config)
else:
self.ffn_config = ffn_config
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.resid_pdrop = resid_pdrop
self.emb_pdrop = emb_pdrop
self.use_cache = use_cache
self.initializer_range = initializer_range
self.output_router_logits = output_router_logits
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
if tie_word_embeddings:
raise ValueError("tie_word_embeddings is not supported for DBRX models.")
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
This diff is collapsed.
......@@ -2457,6 +2457,27 @@ class Data2VecVisionPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
class DbrxForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DbrxModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DbrxPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
......@@ -25,7 +25,7 @@ Jump to the [Add new model like section](#add-new-model-like-command) to learn h
## Cookiecutter Templates
Using the `cookiecutter` utility requires to have all the `dev` dependencies installed. Let's first clone the
Using the `cookiecutter` utility requires to have all the `dev` dependencies installed. Let's first clone the
repository and install it in our environment:
```shell script
......@@ -53,20 +53,20 @@ This should launch the `cookiecutter` package which should prompt you to fill in
The `modelname` should be cased according to the plain text casing, i.e., BERT, RoBERTa, DeBERTa.
```
modelname [<ModelNAME>]:
uppercase_modelname [<MODEL_NAME>]:
lowercase_modelname [<model_name>]:
camelcase_modelname [<ModelName>]:
uppercase_modelname [<MODEL_NAME>]:
lowercase_modelname [<model_name>]:
camelcase_modelname [<ModelName>]:
```
Fill in the `authors` with your team members:
```
authors [The HuggingFace Team]:
authors [The HuggingFace Team]:
```
The checkpoint identifier is the checkpoint that will be used in the examples across the files. Put the name you wish,
as it will appear on the modelhub. Do not forget to include the organisation.
```
checkpoint_identifier [organisation/<model_name>-base-cased]:
checkpoint_identifier [organisation/<model_name>-base-cased]:
```
The tokenizer should either be based on BERT if it behaves exactly like the BERT tokenizer, or a standalone otherwise.
......@@ -74,19 +74,19 @@ The tokenizer should either be based on BERT if it behaves exactly like the BERT
Select tokenizer_type:
1 - Based on BERT
2 - Standalone
Choose from 1, 2 [1]:
Choose from 1, 2 [1]:
```
<!---
Choose if your model is an encoder-decoder, or an encoder-only architecture.
If your model is an encoder-only architecture, the generated architecture will be based on the BERT model.
If your model is an encoder-only architecture, the generated architecture will be based on the BERT model.
If your model is an encoder-decoder architecture, the generated architecture will be based on the BART model. You can,
of course, edit the files once the generation is complete.
```
Select is_encoder_decoder_model:
1 - True
2 - False
Choose from 1, 2 [1]:
Choose from 1, 2 [1]:
```
-->
......@@ -97,8 +97,8 @@ src/transformers/models/<model_name>/configuration_<model_name>.py
src/transformers/models/<model_name>/modeling_<model_name>.py
src/transformers/models/<model_name>/modeling_tf_<model_name>.py
src/transformers/models/<model_name>/tokenization_<model_name>.py
tests/test_modeling_<model_name>.py
tests/test_modeling_tf_<model_name>.py
tests/models/<model_name>/test_modeling_<model_name>.py
tests/models/<model_name>/test_modeling_tf_<model_name>.py
```
You can run the tests to ensure that they all pass:
......@@ -107,9 +107,9 @@ You can run the tests to ensure that they all pass:
python -m pytest ./tests/test_*<model_name>*.py
```
Feel free to modify each file to mimic the behavior of your model.
Feel free to modify each file to mimic the behavior of your model.
⚠ You should be careful about the classes preceded by the following line:️
⚠ You should be careful about the classes preceded by the following line:️
```python
# Copied from transformers.[...]
......@@ -119,8 +119,8 @@ This line ensures that the copy does not diverge from the source. If it *should*
is different, this line needs to be deleted. If you don't delete this line and run `make fix-copies`,
your changes will be overwritten.
Once you have edited the files to fit your architecture, simply re-run the tests (and edit them if a change
is needed!) afterwards to make sure everything works as expected.
Once you have edited the files to fit your architecture, simply re-run the tests (and edit them if a change
is needed!) afterwards to make sure everything works as expected.
Once the files are generated and you are happy with your changes, here's a checklist to ensure that your contribution
will be merged quickly:
......@@ -251,7 +251,7 @@ Once you're done, you can run the tests to ensure that they all pass:
python -m pytest ./tests/test_*<model_name>*.py
```
⚠ You should be careful about the classes preceded by the following line:️
⚠ You should be careful about the classes preceded by the following line:️
```python
# Copied from transformers.[...]
......@@ -261,8 +261,8 @@ This line ensures that the copy does not diverge from the source. If it *should*
is different, this line needs to be deleted. If you don't delete this line and run `make fix-copies`,
your changes will be overwritten.
Once you have edited the files to fit your architecture, simply re-run the tests (and edit them if a change
is needed!) afterwards to make sure everything works as expected.
Once you have edited the files to fit your architecture, simply re-run the tests (and edit them if a change
is needed!) afterwards to make sure everything works as expected.
Once the files are generated and you are happy with your changes, here's a checklist to ensure that your contribution
will be merged quickly:
......
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch DBRX model. """
import unittest
from parameterized import parameterized
from transformers import DbrxConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import DbrxForCausalLM, DbrxModel
class DbrxModelTester:
def __init__(
self,
parent,
hidden_size=32,
ffn_hidden_size=32,
num_attention_heads=4,
kv_n_heads=4,
num_hidden_layers=5,
max_position_embeddings=512,
type_vocab_size=16,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
use_cache=True,
type_sequence_label_size=2,
num_labels=3,
num_choices=4,
scope=None,
clip_qkv=8,
rope_theta=500000,
attn_config_model_type="",
emb_pdrop=0.0,
moe_jitter_eps=0,
moe_loss_weight=0.05,
moe_num_experts=16,
moe_top_k=4,
ffn_config_model_type="",
ffn_act_fn_name="gelu",
initializer_range=0.02,
output_router_logits=False,
resid_pdrop=0.0,
tie_word_embeddings=False,
torch_dtype="bfloat16",
vocab_size=99,
is_decoder=True,
pad_token_id=0,
):
# Parameters unique to testing
self.batch_size = batch_size
self.seq_length = seq_length
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
self.parent = parent
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
# attn_config params
self.clip_qkv = clip_qkv
self.kv_n_heads = kv_n_heads
self.rope_theta = rope_theta
self.attn_config_model_type = attn_config_model_type
# ffn_config params
self.ffn_hidden_size = ffn_hidden_size
self.moe_jitter_eps = moe_jitter_eps
self.moe_loss_weight = moe_loss_weight
self.moe_num_experts = moe_num_experts
self.moe_top_k = moe_top_k
self.ffn_config_model_type = ffn_config_model_type
self.ffn_act_fn_name = ffn_act_fn_name
# Other model params
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.max_position_embeddings = max_position_embeddings
self.vocab_size = vocab_size
self.use_cache = use_cache
self.initializer_range = initializer_range
self.emb_pdrop = emb_pdrop
self.output_router_logits = output_router_logits
self.resid_pdrop = resid_pdrop
self.tie_word_embeddings = tie_word_embeddings
self.torch_dtype = torch_dtype
self.is_decoder = is_decoder
self.pad_token_id = pad_token_id
# Make the dictionaries
self.ffn_config = {
"ffn_hidden_size": self.ffn_hidden_size,
"moe_jitter_eps": self.moe_jitter_eps,
"moe_loss_weight": self.moe_loss_weight,
"moe_num_experts": self.moe_num_experts,
"moe_top_k": self.moe_top_k,
"model_type": self.ffn_config_model_type,
"ffn_act_fn": {"name": self.ffn_act_fn_name},
}
self.attn_config = {
"clip_qkv": self.clip_qkv,
"kv_n_heads": self.kv_n_heads,
"model_type": self.attn_config_model_type,
"rope_theta": self.rope_theta,
}
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config()
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
# Behind the scenes, `DbrxConfig` maps the parameters `hidden_size`, `num_hidden_layers`,
# `num_attention_heads`, `max_position_embeddings` to the parameters `d_model`, `n_layers`,
# `n_heads`, `max_seq_len` respectively. We use the first group of parameters because
# other tests expect every model to have these parameters with these specific names.
config = DbrxConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size, # mapped to `d_model`
num_hidden_layers=self.num_hidden_layers, # mapped to `n_layers`
num_attention_heads=self.num_attention_heads, # mapped to `n_heads`
max_position_embeddings=self.max_position_embeddings, # mapped to `max_seq_len`
attn_config=self.attn_config,
ffn_config=self.ffn_config,
resid_pdrop=self.resid_pdrop,
emb_pdrop=self.emb_pdrop,
use_cache=self.use_cache,
initializer_range=self.initializer_range,
output_router_logits=self.output_router_logits,
is_decoder=self.is_decoder,
pad_token_id=self.pad_token_id,
)
return config
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Dbrx
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = DbrxModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model_as_decoder with Llama->Dbrx
def create_and_check_model_as_decoder(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = DbrxModel(config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
result = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
)
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_for_causal_lm with Llama->Dbrx
def create_and_check_for_causal_lm(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
model = DbrxForCausalLM(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_decoder_model_past_large_inputs(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.is_decoder = True
config.add_cross_attention = True
model = DbrxForCausalLM(config=config)
model.to(torch_device)
model.eval()
# first forward pass
outputs = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
)
past_key_values = outputs.past_key_values
# create hypothetical multiple next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_hidden_states=True,
)["hidden_states"][0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
)["hidden_states"][0]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common with Llama->Dbrx
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (DbrxModel, DbrxForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (DbrxForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = {"text-generation": DbrxForCausalLM} if is_torch_available() else {}
test_headmasking = False
test_pruning = False
def setUp(self):
self.model_tester = DbrxModelTester(self)
self.config_tester = ConfigTester(self, config_class=DbrxConfig, d_model=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
model_name = "eitanturok/dbrx-tiny"
model = DbrxModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@unittest.skip("Dbrx models have weight tying disabled.")
def test_tied_weights_keys(self):
pass
@unittest.skip("TODO @gante fix this for Llama")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
@require_torch
class DbrxModelIntegrationTest(unittest.TestCase):
@slow
def test_tiny_model_logits(self):
model = DbrxForCausalLM.from_pretrained("Rocketknight1/dbrx-tiny-random")
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
output = model(input_ids)[0]
vocab_size = model.vocab_size
expected_shape = torch.Size((1, 6, vocab_size))
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor(
[
[
[-1.6300e-04, 5.0118e-04, 2.5437e-04],
[2.0422e-05, 2.7210e-04, -1.5125e-04],
[-1.5105e-04, 4.6879e-04, 3.3309e-04],
]
]
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment