Unverified Commit 005b957f authored by Abhi Venigalla's avatar Abhi Venigalla Committed by GitHub
Browse files

Add DBRX Model (#29921)



* wip

* fix __init__.py

* add docs

* Apply suggestions from code review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* address comments 1

* work on make fixup

* pass configs down

* add sdpa attention

* remove DbrxBlock

* add to configuration_auto

* docstring now passes formatting test

* fix style

* update READMEs

* add dbrx to modeling_auto

* make fix-copies generated this

* add DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP

* config docstring passes formatting test

* rename moe_loss_weight to router_aux_loss_coef

* add to flash-attn documentation

* fix model-path in tests

* Explicitly make `"suli"` the default `ffn_act_fn`
Co-authored-by: default avatarWing Lian <wing.lian@gmail.com>

* default to using router_aux_loss_coef over ffn_config[moe_loss_weight]

* fix _flash_attn_uses_top_left_mask and is_causal

* fix tests path

* don't use token type IDs

* follow Llama and remove token_type_ids from test

* init ConfigTester differently so tests pass

* remove multiple choice test

* remove question + answer test

* remove sequence classification test

* remove token classification test

* copy Llama tests and remove token_type_ids from test inputs

* do not test pruning or headmasking; style code

* add _tied_weights_keys parameter to pass test

* add type hints

* fix type check

* update config tester

* remove masked_lm test

* remove encoder tests

* initialize DbrxModelTester with correct params

* style

* torch_dtype does not rely on torch

* run make fixup, fix-copies

* use https://huggingface.co/v2ray/dbrx-base-fixed/blob/main/modeling_dbrx.py



* add copyright info

* fix imports and DbrxRotaryEmbedding

* update DbrxModel docstring

* use copies

* change model path in docstring

* use config in DbrxFFN

* fix flashattention2, sdpaattention

* input config to DbrXAttention, DbrxNormAttentionNorm

* more fixes

* fix

* fix again!

* add informative comment

* fix ruff?

* remove print statement + style

* change doc-test

* fix doc-test

* fix docstring

* delete commented out text

* make defaults match dbrx-instruct

* replace `router_aux_loss_coef` with `moe_loss_weight`

* is_decoder=True

* remove is_decoder from configtester

* implement sdpa properly

* make is_decoder pass tests

* start on the GenerationTesterMixin tests

* add dbrx to sdpa documentation

* skip weight typing test

* style

* initialize smaller model
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* Add DBRX to toctree

* skip test_new_cache_format

* make config defaults smaller again

* add pad_token_id

* remove pad_token_id from config

* Remove all references to DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP

* Update src/transformers/models/dbrx/__init__.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/dbrx/modeling_dbrx.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/model_doc/dbrx.md
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* Update src/transformers/models/dbrx/configuration_dbrx.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/model_doc/dbrx.md
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fix typo

* Apply suggestions from code review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* update docs, fix configuration_auto.py

* address pr comments

* remove is_decoder flag

* slice

* fix requires grad

* remove grad

* disconnect differently

* remove grad

* enable grads

* patch

* detach expert

* nissan al ghaib

* Update modeling_dbrx.py

* Update src/transformers/models/dbrx/modeling_dbrx.py
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* replace "Gemma" with "Dbrx"

* remove # type: ignore

* don't hardcode vocab_size

* remove ToDo

* Re-add removed idefics2 line

* Update test to use tiny-random!

* Remove TODO

* Remove one more case of loading the entire dbrx-instruct in the tests

* Update src/transformers/models/dbrx/modeling_dbrx.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* address some comments

* small model

* add dbrx to tokenization_auto

* More docstrings with add_start_docstrings

* Dbrx for now

* add PipelineTesterMixin

* Update src/transformers/models/dbrx/configuration_dbrx.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* remove flash-attn2 import error

* fix docstring
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add useage example

* put on one line
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix ffn_act_fn
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* change "dbrx" to "DBRX" for display purposes.

* fix __init__.py?

* fix __init__.py

* fix README

* return the aux_loss

* remove extra spaces

* fix configuration_auto.py

* fix format in tokenization_auto

* remove new line

* add more useage examples

---------
Co-authored-by: default avatarAbhi Venigalla <abhi.venigalla@databricks.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avatarEitan Turok <eitan.turok@databricks.com>
Co-authored-by: default avatarEitan Turok <150733043+eitanturok@users.noreply.github.com>
Co-authored-by: default avatarWing Lian <wing.lian@gmail.com>
Co-authored-by: default avatarEitan Turok <eitanturok@gmail.com>
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
Co-authored-by: default avatarMatt <rocketknight1@gmail.com>
Co-authored-by: default avatarYour Name <you@example.com>
Co-authored-by: default avatarMihir Patel <mihir.v.patel7@gmail.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 63c5e27e
...@@ -77,6 +77,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ...@@ -77,6 +77,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("data2vec-audio", "Data2VecAudioConfig"), ("data2vec-audio", "Data2VecAudioConfig"),
("data2vec-text", "Data2VecTextConfig"), ("data2vec-text", "Data2VecTextConfig"),
("data2vec-vision", "Data2VecVisionConfig"), ("data2vec-vision", "Data2VecVisionConfig"),
("dbrx", "DbrxConfig"),
("deberta", "DebertaConfig"), ("deberta", "DebertaConfig"),
("deberta-v2", "DebertaV2Config"), ("deberta-v2", "DebertaV2Config"),
("decision_transformer", "DecisionTransformerConfig"), ("decision_transformer", "DecisionTransformerConfig"),
...@@ -340,6 +341,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ...@@ -340,6 +341,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("data2vec-audio", "Data2VecAudio"), ("data2vec-audio", "Data2VecAudio"),
("data2vec-text", "Data2VecText"), ("data2vec-text", "Data2VecText"),
("data2vec-vision", "Data2VecVision"), ("data2vec-vision", "Data2VecVision"),
("dbrx", "DBRX"),
("deberta", "DeBERTa"), ("deberta", "DeBERTa"),
("deberta-v2", "DeBERTa-v2"), ("deberta-v2", "DeBERTa-v2"),
("decision_transformer", "Decision Transformer"), ("decision_transformer", "Decision Transformer"),
......
...@@ -75,6 +75,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ...@@ -75,6 +75,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("data2vec-audio", "Data2VecAudioModel"), ("data2vec-audio", "Data2VecAudioModel"),
("data2vec-text", "Data2VecTextModel"), ("data2vec-text", "Data2VecTextModel"),
("data2vec-vision", "Data2VecVisionModel"), ("data2vec-vision", "Data2VecVisionModel"),
("dbrx", "DbrxModel"),
("deberta", "DebertaModel"), ("deberta", "DebertaModel"),
("deberta-v2", "DebertaV2Model"), ("deberta-v2", "DebertaV2Model"),
("decision_transformer", "DecisionTransformerModel"), ("decision_transformer", "DecisionTransformerModel"),
...@@ -439,6 +440,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ...@@ -439,6 +440,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("cpmant", "CpmAntForCausalLM"), ("cpmant", "CpmAntForCausalLM"),
("ctrl", "CTRLLMHeadModel"), ("ctrl", "CTRLLMHeadModel"),
("data2vec-text", "Data2VecTextForCausalLM"), ("data2vec-text", "Data2VecTextForCausalLM"),
("dbrx", "DbrxForCausalLM"),
("electra", "ElectraForCausalLM"), ("electra", "ElectraForCausalLM"),
("ernie", "ErnieForCausalLM"), ("ernie", "ErnieForCausalLM"),
("falcon", "FalconForCausalLM"), ("falcon", "FalconForCausalLM"),
......
...@@ -150,6 +150,7 @@ else: ...@@ -150,6 +150,7 @@ else:
("ctrl", ("CTRLTokenizer", None)), ("ctrl", ("CTRLTokenizer", None)),
("data2vec-audio", ("Wav2Vec2CTCTokenizer", None)), ("data2vec-audio", ("Wav2Vec2CTCTokenizer", None)),
("data2vec-text", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), ("data2vec-text", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("dbrx", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)), ("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)),
( (
"deberta-v2", "deberta-v2",
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
"configuration_dbrx": ["DbrxConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_dbrx"] = [
"DbrxForCausalLM",
"DbrxModel",
"DbrxPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_dbrx import DbrxConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_dbrx import DbrxForCausalLM, DbrxModel, DbrxPreTrainedModel
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# coding=utf-8
# Copyright 2024 Databricks Mosaic Research and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" DBRX model configuration """
from typing import Any, Optional
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class DbrxAttentionConfig(PretrainedConfig):
"""Configuration class for Dbrx Attention.
[`DbrxAttention`] class. It is used to instantiate attention layers
according to the specified arguments, defining the layers architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
attn_pdrop (`float`, *optional*, defaults to 0.0):
The dropout probability for the attention layers.
clip_qkv (`float`, *optional*):
If set, clip the queries, keys, and values in the attention layer to this value.
kv_n_heads (`Optional[int]`, defaults to 1): For grouped_query_attention only, allow user to specify number of kv heads.
rope_theta (`float`, defaults to 10000.0): The base frequency for rope.
"""
def __init__(
self,
attn_pdrop: float = 0.0,
clip_qkv: Optional[float] = None,
kv_n_heads: int = 1,
rope_theta: float = 10000.0,
**kwargs: Any,
):
super().__init__(**kwargs)
self.attn_pdrop = attn_pdrop
self.clip_qkv = clip_qkv
self.kv_n_heads = kv_n_heads
self.rope_theta = rope_theta
for k in ["model_type"]:
if k in kwargs:
kwargs.pop(k)
if len(kwargs) != 0:
raise ValueError(f"Found unknown {kwargs=}")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs: Any) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if config_dict.get("model_type") == "dbrx":
config_dict = config_dict["attn_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class DbrxFFNConfig(PretrainedConfig):
"""Configuration class for Dbrx FFN.
[`DbrxFFN`] class. It is used to instantiate feedforward layers according to
the specified arguments, defining the layers architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
ffn_act_fn (`dict`, *optional*, defaults to `None`): A dict specifying activation function for the FFN.
The dict should have a key 'name' with the value being the name of the activation function along with
any additional keyword arguments. If `None`, then set to `{"name": "silu"}`.
ffn_hidden_size (`int`, defaults to 3584): The hidden size of the feedforward network.
moe_num_experts (`int`, defaults to 4): The number of experts in the mixture of experts layer.
moe_top_k (`int`, defaults to 1): The number of experts to use in the mixture of experts layer.
moe_jitter_eps (`float`, *optional*, defaults to `None`): If not `None`, the jitter epsilon for the mixture of experts layer.
moe_loss_weight (`float`, defaults to 0.01): The loss weight for the mixture of experts layer.
moe_normalize_expert_weights (`float`, *optional*, defaults to 1.0): The normalization factor for the expert weights.
"""
def __init__(
self,
ffn_act_fn: dict = None,
ffn_hidden_size: int = 3584,
moe_num_experts: int = 4,
moe_top_k: int = 1,
moe_jitter_eps: Optional[float] = None,
moe_loss_weight: float = 0.01,
moe_normalize_expert_weights: Optional[float] = 1.0,
**kwargs: Any,
):
super().__init__()
if ffn_act_fn is None:
ffn_act_fn = {"name": "silu"}
self.ffn_act_fn = ffn_act_fn
self.ffn_hidden_size = ffn_hidden_size
self.moe_num_experts = moe_num_experts
self.moe_top_k = moe_top_k
self.moe_jitter_eps = moe_jitter_eps
self.moe_loss_weight = moe_loss_weight
self.moe_normalize_expert_weights = moe_normalize_expert_weights
for k in ["model_type"]:
if k in kwargs:
kwargs.pop(k)
if len(kwargs) != 0:
raise ValueError(f"Found unknown {kwargs=}")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs: Any) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if config_dict.get("model_type") == "dbrx":
config_dict = config_dict["ffn_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class DbrxConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`DbrxModel`]. It is used to instantiate a Dbrx model according to the
specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a different configuration to that of the [databricks/dbrx-instruct](https://huggingface.co/databricks/dbrx-instruct) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
d_model (`int`, *optional*, defaults to 2048):
Dimensionality of the embeddings and hidden states.
n_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
n_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
max_seq_len (`int`, *optional*, defaults to 2048):
The maximum sequence length of the model.
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by
the `inputs_ids` passed when calling [`DbrxModel`].
resid_pdrop (`float`, *optional*, defaults to 0.0):
The dropout probability applied to the attention output before combining with residual.
emb_pdrop (`float`, *optional*, defaults to 0.0):
The dropout probability for the embedding layer.
attn_config (`dict`, *optional*):
A dictionary used to configure the model's attention module.
ffn_config (`dict`, *optional*):
A dictionary used to configure the model's FFN module.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabling this will also
allow the model to output the auxiliary loss. See [here]() for more details.
Example:
```python
>>> from transformers import DbrxConfig, DbrxModel
>>> # Initializing a Dbrx configuration
>>> configuration = DbrxConfig(n_layers=2, d_model=256, n_heads=8, vocab_size=128)
>>> # Initializing a model (with random weights) from the configuration
>>> model = DbrxModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "dbrx"
attribute_map = {
"num_attention_heads": "n_heads",
"hidden_size": "d_model",
"num_hidden_layers": "n_layers",
"max_position_embeddings": "max_seq_len",
}
def __init__(
self,
d_model: int = 2048,
n_heads: int = 16,
n_layers: int = 24,
max_seq_len: int = 2048,
vocab_size: int = 32000,
resid_pdrop: float = 0.0,
emb_pdrop: float = 0.0,
attn_config: Optional[DbrxAttentionConfig] = None,
ffn_config: Optional[DbrxFFNConfig] = None,
use_cache: bool = True,
initializer_range: float = 0.02,
output_router_logits: bool = False,
**kwargs: Any,
):
if attn_config is None:
self.attn_config = DbrxAttentionConfig()
elif isinstance(attn_config, dict):
self.attn_config = DbrxAttentionConfig(**attn_config)
else:
self.attn_config = attn_config
if ffn_config is None:
self.ffn_config = DbrxFFNConfig()
elif isinstance(ffn_config, dict):
self.ffn_config = DbrxFFNConfig(**ffn_config)
else:
self.ffn_config = ffn_config
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.resid_pdrop = resid_pdrop
self.emb_pdrop = emb_pdrop
self.use_cache = use_cache
self.initializer_range = initializer_range
self.output_router_logits = output_router_logits
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
if tie_word_embeddings:
raise ValueError("tie_word_embeddings is not supported for DBRX models.")
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
This diff is collapsed.
...@@ -2457,6 +2457,27 @@ class Data2VecVisionPreTrainedModel(metaclass=DummyObject): ...@@ -2457,6 +2457,27 @@ class Data2VecVisionPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class DbrxForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DbrxModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DbrxPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
...@@ -97,8 +97,8 @@ src/transformers/models/<model_name>/configuration_<model_name>.py ...@@ -97,8 +97,8 @@ src/transformers/models/<model_name>/configuration_<model_name>.py
src/transformers/models/<model_name>/modeling_<model_name>.py src/transformers/models/<model_name>/modeling_<model_name>.py
src/transformers/models/<model_name>/modeling_tf_<model_name>.py src/transformers/models/<model_name>/modeling_tf_<model_name>.py
src/transformers/models/<model_name>/tokenization_<model_name>.py src/transformers/models/<model_name>/tokenization_<model_name>.py
tests/test_modeling_<model_name>.py tests/models/<model_name>/test_modeling_<model_name>.py
tests/test_modeling_tf_<model_name>.py tests/models/<model_name>/test_modeling_tf_<model_name>.py
``` ```
You can run the tests to ensure that they all pass: You can run the tests to ensure that they all pass:
......
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch DBRX model. """
import unittest
from parameterized import parameterized
from transformers import DbrxConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import DbrxForCausalLM, DbrxModel
class DbrxModelTester:
def __init__(
self,
parent,
hidden_size=32,
ffn_hidden_size=32,
num_attention_heads=4,
kv_n_heads=4,
num_hidden_layers=5,
max_position_embeddings=512,
type_vocab_size=16,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
use_cache=True,
type_sequence_label_size=2,
num_labels=3,
num_choices=4,
scope=None,
clip_qkv=8,
rope_theta=500000,
attn_config_model_type="",
emb_pdrop=0.0,
moe_jitter_eps=0,
moe_loss_weight=0.05,
moe_num_experts=16,
moe_top_k=4,
ffn_config_model_type="",
ffn_act_fn_name="gelu",
initializer_range=0.02,
output_router_logits=False,
resid_pdrop=0.0,
tie_word_embeddings=False,
torch_dtype="bfloat16",
vocab_size=99,
is_decoder=True,
pad_token_id=0,
):
# Parameters unique to testing
self.batch_size = batch_size
self.seq_length = seq_length
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
self.parent = parent
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
# attn_config params
self.clip_qkv = clip_qkv
self.kv_n_heads = kv_n_heads
self.rope_theta = rope_theta
self.attn_config_model_type = attn_config_model_type
# ffn_config params
self.ffn_hidden_size = ffn_hidden_size
self.moe_jitter_eps = moe_jitter_eps
self.moe_loss_weight = moe_loss_weight
self.moe_num_experts = moe_num_experts
self.moe_top_k = moe_top_k
self.ffn_config_model_type = ffn_config_model_type
self.ffn_act_fn_name = ffn_act_fn_name
# Other model params
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.max_position_embeddings = max_position_embeddings
self.vocab_size = vocab_size
self.use_cache = use_cache
self.initializer_range = initializer_range
self.emb_pdrop = emb_pdrop
self.output_router_logits = output_router_logits
self.resid_pdrop = resid_pdrop
self.tie_word_embeddings = tie_word_embeddings
self.torch_dtype = torch_dtype
self.is_decoder = is_decoder
self.pad_token_id = pad_token_id
# Make the dictionaries
self.ffn_config = {
"ffn_hidden_size": self.ffn_hidden_size,
"moe_jitter_eps": self.moe_jitter_eps,
"moe_loss_weight": self.moe_loss_weight,
"moe_num_experts": self.moe_num_experts,
"moe_top_k": self.moe_top_k,
"model_type": self.ffn_config_model_type,
"ffn_act_fn": {"name": self.ffn_act_fn_name},
}
self.attn_config = {
"clip_qkv": self.clip_qkv,
"kv_n_heads": self.kv_n_heads,
"model_type": self.attn_config_model_type,
"rope_theta": self.rope_theta,
}
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config()
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
# Behind the scenes, `DbrxConfig` maps the parameters `hidden_size`, `num_hidden_layers`,
# `num_attention_heads`, `max_position_embeddings` to the parameters `d_model`, `n_layers`,
# `n_heads`, `max_seq_len` respectively. We use the first group of parameters because
# other tests expect every model to have these parameters with these specific names.
config = DbrxConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size, # mapped to `d_model`
num_hidden_layers=self.num_hidden_layers, # mapped to `n_layers`
num_attention_heads=self.num_attention_heads, # mapped to `n_heads`
max_position_embeddings=self.max_position_embeddings, # mapped to `max_seq_len`
attn_config=self.attn_config,
ffn_config=self.ffn_config,
resid_pdrop=self.resid_pdrop,
emb_pdrop=self.emb_pdrop,
use_cache=self.use_cache,
initializer_range=self.initializer_range,
output_router_logits=self.output_router_logits,
is_decoder=self.is_decoder,
pad_token_id=self.pad_token_id,
)
return config
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Dbrx
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = DbrxModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model_as_decoder with Llama->Dbrx
def create_and_check_model_as_decoder(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = DbrxModel(config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
result = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
)
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_for_causal_lm with Llama->Dbrx
def create_and_check_for_causal_lm(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
model = DbrxForCausalLM(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_decoder_model_past_large_inputs(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.is_decoder = True
config.add_cross_attention = True
model = DbrxForCausalLM(config=config)
model.to(torch_device)
model.eval()
# first forward pass
outputs = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
)
past_key_values = outputs.past_key_values
# create hypothetical multiple next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_hidden_states=True,
)["hidden_states"][0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
)["hidden_states"][0]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common with Llama->Dbrx
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (DbrxModel, DbrxForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (DbrxForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = {"text-generation": DbrxForCausalLM} if is_torch_available() else {}
test_headmasking = False
test_pruning = False
def setUp(self):
self.model_tester = DbrxModelTester(self)
self.config_tester = ConfigTester(self, config_class=DbrxConfig, d_model=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
model_name = "eitanturok/dbrx-tiny"
model = DbrxModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@unittest.skip("Dbrx models have weight tying disabled.")
def test_tied_weights_keys(self):
pass
@unittest.skip("TODO @gante fix this for Llama")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
@require_torch
class DbrxModelIntegrationTest(unittest.TestCase):
@slow
def test_tiny_model_logits(self):
model = DbrxForCausalLM.from_pretrained("Rocketknight1/dbrx-tiny-random")
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
output = model(input_ids)[0]
vocab_size = model.vocab_size
expected_shape = torch.Size((1, 6, vocab_size))
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor(
[
[
[-1.6300e-04, 5.0118e-04, 2.5437e-04],
[2.0422e-05, 2.7210e-04, -1.5125e-04],
[-1.5105e-04, 4.6879e-04, 3.3309e-04],
]
]
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment