Unverified Commit 07bf2dff authored by Joseph Enguehard's avatar Joseph Enguehard Committed by GitHub
Browse files

Add TokenClassification for Mistral, Mixtral and Qwen2 (#29878)



* Add MistralForTokenClassification

* Add tests and docs

* Add token classification for Mixtral and Qwen2

* Save llma for token classification draft

* Add token classification support for Llama, Gemma, Persimmon, StableLm and StarCoder2

* Formatting

* Add token classification support for Qwen2Moe model

* Add dropout layer to each ForTokenClassification model

* Add copied from in tests

* Update src/transformers/models/llama/modeling_llama.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Propagate suggested changes

* Style

---------
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
parent 481a9578
......@@ -60,6 +60,11 @@ This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ), [
[[autodoc]] GemmaForSequenceClassification
- forward
## GemmaForTokenClassification
[[autodoc]] GemmaForTokenClassification
- forward
## FlaxGemmaModel
[[autodoc]] FlaxGemmaModel
......
......@@ -121,6 +121,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] LlamaForQuestionAnswering
- forward
## LlamaForTokenClassification
[[autodoc]] LlamaForTokenClassification
- forward
## FlaxLlamaModel
[[autodoc]] FlaxLlamaModel
......
......@@ -203,6 +203,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] MistralForSequenceClassification
- forward
## MistralForTokenClassification
[[autodoc]] MistralForTokenClassification
- forward
## FlaxMistralModel
[[autodoc]] FlaxMistralModel
......
......@@ -204,3 +204,8 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] MixtralForSequenceClassification
- forward
## MixtralForTokenClassification
[[autodoc]] MixtralForTokenClassification
- forward
......@@ -96,3 +96,8 @@ The `LlamaTokenizer` is used as it is a standard wrapper around sentencepiece. T
[[autodoc]] PersimmonForSequenceClassification
- forward
## PersimmonForTokenClassification
[[autodoc]] PersimmonForTokenClassification
- forward
......@@ -80,3 +80,8 @@ In the following, we demonstrate how to use `Qwen2-7B-Chat-beta` for the inferen
[[autodoc]] Qwen2ForSequenceClassification
- forward
## Qwen2ForTokenClassification
[[autodoc]] Qwen2ForTokenClassification
- forward
......@@ -75,3 +75,8 @@ In the following, we demonstrate how to use `Qwen1.5-MoE-A2.7B-Chat` for the inf
[[autodoc]] Qwen2MoeForSequenceClassification
- forward
## Qwen2MoeForTokenClassification
[[autodoc]] Qwen2MoeForTokenClassification
- forward
......@@ -104,3 +104,8 @@ Now, to run the model with Flash Attention 2, refer to the snippet below:
[[autodoc]] StableLmForSequenceClassification
- forward
## StableLmForTokenClassification
[[autodoc]] StableLmForTokenClassification
- forward
......@@ -66,3 +66,8 @@ These ready-to-use checkpoints can be downloaded and used via the HuggingFace Hu
[[autodoc]] Starcoder2ForSequenceClassification
- forward
## Starcoder2ForTokenClassification
[[autodoc]] Starcoder2ForTokenClassification
- forward
......@@ -2031,6 +2031,7 @@ else:
[
"GemmaForCausalLM",
"GemmaForSequenceClassification",
"GemmaForTokenClassification",
"GemmaModel",
"GemmaPreTrainedModel",
]
......@@ -2288,6 +2289,7 @@ else:
"LlamaForCausalLM",
"LlamaForQuestionAnswering",
"LlamaForSequenceClassification",
"LlamaForTokenClassification",
"LlamaModel",
"LlamaPreTrainedModel",
]
......@@ -2435,12 +2437,19 @@ else:
[
"MistralForCausalLM",
"MistralForSequenceClassification",
"MistralForTokenClassification",
"MistralModel",
"MistralPreTrainedModel",
]
)
_import_structure["models.mixtral"].extend(
["MixtralForCausalLM", "MixtralForSequenceClassification", "MixtralModel", "MixtralPreTrainedModel"]
[
"MixtralForCausalLM",
"MixtralForSequenceClassification",
"MixtralForTokenClassification",
"MixtralModel",
"MixtralPreTrainedModel",
]
)
_import_structure["models.mobilebert"].extend(
[
......@@ -2714,6 +2723,7 @@ else:
[
"PersimmonForCausalLM",
"PersimmonForSequenceClassification",
"PersimmonForTokenClassification",
"PersimmonModel",
"PersimmonPreTrainedModel",
]
......@@ -2810,6 +2820,7 @@ else:
[
"Qwen2ForCausalLM",
"Qwen2ForSequenceClassification",
"Qwen2ForTokenClassification",
"Qwen2Model",
"Qwen2PreTrainedModel",
]
......@@ -2818,6 +2829,7 @@ else:
[
"Qwen2MoeForCausalLM",
"Qwen2MoeForSequenceClassification",
"Qwen2MoeForTokenClassification",
"Qwen2MoeModel",
"Qwen2MoePreTrainedModel",
]
......@@ -3066,6 +3078,7 @@ else:
[
"StableLmForCausalLM",
"StableLmForSequenceClassification",
"StableLmForTokenClassification",
"StableLmModel",
"StableLmPreTrainedModel",
]
......@@ -3074,6 +3087,7 @@ else:
[
"Starcoder2ForCausalLM",
"Starcoder2ForSequenceClassification",
"Starcoder2ForTokenClassification",
"Starcoder2Model",
"Starcoder2PreTrainedModel",
]
......@@ -6489,6 +6503,7 @@ if TYPE_CHECKING:
from .models.gemma import (
GemmaForCausalLM,
GemmaForSequenceClassification,
GemmaForTokenClassification,
GemmaModel,
GemmaPreTrainedModel,
)
......@@ -6686,6 +6701,7 @@ if TYPE_CHECKING:
LlamaForCausalLM,
LlamaForQuestionAnswering,
LlamaForSequenceClassification,
LlamaForTokenClassification,
LlamaModel,
LlamaPreTrainedModel,
)
......@@ -6801,12 +6817,14 @@ if TYPE_CHECKING:
from .models.mistral import (
MistralForCausalLM,
MistralForSequenceClassification,
MistralForTokenClassification,
MistralModel,
MistralPreTrainedModel,
)
from .models.mixtral import (
MixtralForCausalLM,
MixtralForSequenceClassification,
MixtralForTokenClassification,
MixtralModel,
MixtralPreTrainedModel,
)
......@@ -7025,6 +7043,7 @@ if TYPE_CHECKING:
from .models.persimmon import (
PersimmonForCausalLM,
PersimmonForSequenceClassification,
PersimmonForTokenClassification,
PersimmonModel,
PersimmonPreTrainedModel,
)
......@@ -7099,12 +7118,14 @@ if TYPE_CHECKING:
from .models.qwen2 import (
Qwen2ForCausalLM,
Qwen2ForSequenceClassification,
Qwen2ForTokenClassification,
Qwen2Model,
Qwen2PreTrainedModel,
)
from .models.qwen2_moe import (
Qwen2MoeForCausalLM,
Qwen2MoeForSequenceClassification,
Qwen2MoeForTokenClassification,
Qwen2MoeModel,
Qwen2MoePreTrainedModel,
)
......@@ -7306,12 +7327,14 @@ if TYPE_CHECKING:
from .models.stablelm import (
StableLmForCausalLM,
StableLmForSequenceClassification,
StableLmForTokenClassification,
StableLmModel,
StableLmPreTrainedModel,
)
from .models.starcoder2 import (
Starcoder2ForCausalLM,
Starcoder2ForSequenceClassification,
Starcoder2ForTokenClassification,
Starcoder2Model,
Starcoder2PreTrainedModel,
)
......
......@@ -1038,6 +1038,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("flaubert", "FlaubertForTokenClassification"),
("fnet", "FNetForTokenClassification"),
("funnel", "FunnelForTokenClassification"),
("gemma", "GemmaForTokenClassification"),
("gpt-sw3", "GPT2ForTokenClassification"),
("gpt2", "GPT2ForTokenClassification"),
("gpt_bigcode", "GPTBigCodeForTokenClassification"),
......@@ -1048,11 +1049,14 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
("layoutlmv3", "LayoutLMv3ForTokenClassification"),
("lilt", "LiltForTokenClassification"),
("llama", "LlamaForTokenClassification"),
("longformer", "LongformerForTokenClassification"),
("luke", "LukeForTokenClassification"),
("markuplm", "MarkupLMForTokenClassification"),
("mega", "MegaForTokenClassification"),
("megatron-bert", "MegatronBertForTokenClassification"),
("mistral", "MistralForTokenClassification"),
("mixtral", "MixtralForTokenClassification"),
("mobilebert", "MobileBertForTokenClassification"),
("mpnet", "MPNetForTokenClassification"),
("mpt", "MptForTokenClassification"),
......@@ -1060,15 +1064,20 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("mt5", "MT5ForTokenClassification"),
("nezha", "NezhaForTokenClassification"),
("nystromformer", "NystromformerForTokenClassification"),
("persimmon", "PersimmonForTokenClassification"),
("phi", "PhiForTokenClassification"),
("phi3", "Phi3ForTokenClassification"),
("qdqbert", "QDQBertForTokenClassification"),
("qwen2", "Qwen2ForTokenClassification"),
("qwen2_moe", "Qwen2MoeForTokenClassification"),
("rembert", "RemBertForTokenClassification"),
("roberta", "RobertaForTokenClassification"),
("roberta-prelayernorm", "RobertaPreLayerNormForTokenClassification"),
("roc_bert", "RoCBertForTokenClassification"),
("roformer", "RoFormerForTokenClassification"),
("squeezebert", "SqueezeBertForTokenClassification"),
("stablelm", "StableLmForTokenClassification"),
("starcoder2", "Starcoder2ForTokenClassification"),
("t5", "T5ForTokenClassification"),
("umt5", "UMT5ForTokenClassification"),
("xlm", "XLMForTokenClassification"),
......
......@@ -55,6 +55,7 @@ else:
"GemmaModel",
"GemmaPreTrainedModel",
"GemmaForSequenceClassification",
"GemmaForTokenClassification",
]
try:
......@@ -98,6 +99,7 @@ if TYPE_CHECKING:
from .modeling_gemma import (
GemmaForCausalLM,
GemmaForSequenceClassification,
GemmaForTokenClassification,
GemmaModel,
GemmaPreTrainedModel,
)
......
......@@ -30,7 +30,12 @@ from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
)
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
from ...utils import (
......@@ -1346,3 +1351,88 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
"""
The Gemma Model transformer with a token classification head on top (a linear layer on top of the hidden-states
output) e.g. for Named-Entity-Recognition (NER) tasks.
""",
GEMMA_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Gemma, LLAMA->GEMMA
class GemmaForTokenClassification(GemmaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = GemmaModel(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -55,6 +55,7 @@ else:
"LlamaPreTrainedModel",
"LlamaForSequenceClassification",
"LlamaForQuestionAnswering",
"LlamaForTokenClassification",
]
try:
......@@ -95,6 +96,7 @@ if TYPE_CHECKING:
LlamaForCausalLM,
LlamaForQuestionAnswering,
LlamaForSequenceClassification,
LlamaForTokenClassification,
LlamaModel,
LlamaPreTrainedModel,
)
......
......@@ -36,6 +36,7 @@ from ...modeling_outputs import (
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
......@@ -1516,3 +1517,87 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states
output) e.g. for Named-Entity-Recognition (NER) tasks.
""",
LLAMA_START_DOCSTRING,
)
class LlamaForTokenClassification(LlamaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = LlamaModel(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -32,6 +32,7 @@ else:
"MistralModel",
"MistralPreTrainedModel",
"MistralForSequenceClassification",
"MistralForTokenClassification",
]
try:
......@@ -59,6 +60,7 @@ if TYPE_CHECKING:
from .modeling_mistral import (
MistralForCausalLM,
MistralForSequenceClassification,
MistralForTokenClassification,
MistralModel,
MistralPreTrainedModel,
)
......
......@@ -31,7 +31,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
......@@ -1366,3 +1371,88 @@ class MistralForSequenceClassification(MistralPreTrainedModel):
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
"""
The Mistral Model transformer with a token classification head on top (a linear layer on top of the hidden-states
output) e.g. for Named-Entity-Recognition (NER) tasks.
""",
MISTRAL_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mistral, LLAMA->MISTRAL
class MistralForTokenClassification(MistralPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = MistralModel(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -36,6 +36,7 @@ else:
"MixtralModel",
"MixtralPreTrainedModel",
"MixtralForSequenceClassification",
"MixtralForTokenClassification",
]
......@@ -51,6 +52,7 @@ if TYPE_CHECKING:
from .modeling_mixtral import (
MixtralForCausalLM,
MixtralForSequenceClassification,
MixtralForTokenClassification,
MixtralModel,
MixtralPreTrainedModel,
)
......
......@@ -38,6 +38,7 @@ from ...modeling_outputs import (
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
......@@ -1582,3 +1583,88 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel):
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
"""
The Mixtral Model transformer with a token classification head on top (a linear layer on top of the hidden-states
output) e.g. for Named-Entity-Recognition (NER) tasks.
""",
MIXTRAL_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mixtral, LLAMA->MIXTRAL
class MixtralForTokenClassification(MixtralPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = MixtralModel(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -36,6 +36,7 @@ else:
"PersimmonModel",
"PersimmonPreTrainedModel",
"PersimmonForSequenceClassification",
"PersimmonForTokenClassification",
]
......@@ -51,6 +52,7 @@ if TYPE_CHECKING:
from .modeling_persimmon import (
PersimmonForCausalLM,
PersimmonForSequenceClassification,
PersimmonForTokenClassification,
PersimmonModel,
PersimmonPreTrainedModel,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment