"tests/vscode:/vscode.git/clone" did not exist on "41d47db90fbe9937c0941f2f9cdb2ddd83e49a2e"
Unverified Commit 07bf2dff authored by Joseph Enguehard's avatar Joseph Enguehard Committed by GitHub
Browse files

Add TokenClassification for Mistral, Mixtral and Qwen2 (#29878)



* Add MistralForTokenClassification

* Add tests and docs

* Add token classification for Mixtral and Qwen2

* Save llma for token classification draft

* Add token classification support for Llama, Gemma, Persimmon, StableLm and StarCoder2

* Formatting

* Add token classification support for Qwen2Moe model

* Add dropout layer to each ForTokenClassification model

* Add copied from in tests

* Update src/transformers/models/llama/modeling_llama.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Propagate suggested changes

* Style

---------
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
parent 481a9578
...@@ -29,7 +29,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ...@@ -29,7 +29,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_persimmon import PersimmonConfig from .configuration_persimmon import PersimmonConfig
...@@ -1011,3 +1016,88 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel): ...@@ -1011,3 +1016,88 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel):
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@add_start_docstrings(
"""
The Persimmon Model transformer with a token classification head on top (a linear layer on top of the hidden-states
output) e.g. for Named-Entity-Recognition (NER) tasks.
""",
PERSIMMON_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Persimmon, LLAMA->PERSIMMON
class PersimmonForTokenClassification(PersimmonPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = PersimmonModel(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
...@@ -45,6 +45,7 @@ else: ...@@ -45,6 +45,7 @@ else:
"Qwen2Model", "Qwen2Model",
"Qwen2PreTrainedModel", "Qwen2PreTrainedModel",
"Qwen2ForSequenceClassification", "Qwen2ForSequenceClassification",
"Qwen2ForTokenClassification",
] ]
...@@ -69,6 +70,7 @@ if TYPE_CHECKING: ...@@ -69,6 +70,7 @@ if TYPE_CHECKING:
from .modeling_qwen2 import ( from .modeling_qwen2 import (
Qwen2ForCausalLM, Qwen2ForCausalLM,
Qwen2ForSequenceClassification, Qwen2ForSequenceClassification,
Qwen2ForTokenClassification,
Qwen2Model, Qwen2Model,
Qwen2PreTrainedModel, Qwen2PreTrainedModel,
) )
......
...@@ -31,7 +31,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ...@@ -31,7 +31,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
add_start_docstrings, add_start_docstrings,
...@@ -1375,3 +1380,88 @@ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): ...@@ -1375,3 +1380,88 @@ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@add_start_docstrings(
"""
The Qwen2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
output) e.g. for Named-Entity-Recognition (NER) tasks.
""",
QWEN2_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2, LLAMA->QWEN2
class Qwen2ForTokenClassification(Qwen2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = Qwen2Model(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
...@@ -36,6 +36,7 @@ else: ...@@ -36,6 +36,7 @@ else:
"Qwen2MoeModel", "Qwen2MoeModel",
"Qwen2MoePreTrainedModel", "Qwen2MoePreTrainedModel",
"Qwen2MoeForSequenceClassification", "Qwen2MoeForSequenceClassification",
"Qwen2MoeForTokenClassification",
] ]
...@@ -51,6 +52,7 @@ if TYPE_CHECKING: ...@@ -51,6 +52,7 @@ if TYPE_CHECKING:
from .modeling_qwen2_moe import ( from .modeling_qwen2_moe import (
Qwen2MoeForCausalLM, Qwen2MoeForCausalLM,
Qwen2MoeForSequenceClassification, Qwen2MoeForSequenceClassification,
Qwen2MoeForTokenClassification,
Qwen2MoeModel, Qwen2MoeModel,
Qwen2MoePreTrainedModel, Qwen2MoePreTrainedModel,
) )
......
...@@ -32,7 +32,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ...@@ -32,7 +32,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_outputs import (
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
add_start_docstrings, add_start_docstrings,
...@@ -1571,3 +1576,88 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel): ...@@ -1571,3 +1576,88 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel):
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@add_start_docstrings(
"""
The Qwen2MoE Model transformer with a token classification head on top (a linear layer on top of the hidden-states
output) e.g. for Named-Entity-Recognition (NER) tasks.
""",
QWEN2MOE_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE
class Qwen2MoeForTokenClassification(Qwen2MoePreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = Qwen2MoeModel(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
...@@ -36,6 +36,7 @@ else: ...@@ -36,6 +36,7 @@ else:
"StableLmModel", "StableLmModel",
"StableLmPreTrainedModel", "StableLmPreTrainedModel",
"StableLmForSequenceClassification", "StableLmForSequenceClassification",
"StableLmForTokenClassification",
] ]
...@@ -51,6 +52,7 @@ if TYPE_CHECKING: ...@@ -51,6 +52,7 @@ if TYPE_CHECKING:
from .modeling_stablelm import ( from .modeling_stablelm import (
StableLmForCausalLM, StableLmForCausalLM,
StableLmForSequenceClassification, StableLmForSequenceClassification,
StableLmForTokenClassification,
StableLmModel, StableLmModel,
StableLmPreTrainedModel, StableLmPreTrainedModel,
) )
......
...@@ -30,7 +30,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ...@@ -30,7 +30,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
add_start_docstrings, add_start_docstrings,
...@@ -1383,3 +1388,88 @@ class StableLmForSequenceClassification(StableLmPreTrainedModel): ...@@ -1383,3 +1388,88 @@ class StableLmForSequenceClassification(StableLmPreTrainedModel):
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@add_start_docstrings(
"""
The StableLm Model transformer with a token classification head on top (a linear layer on top of the hidden-states
output) e.g. for Named-Entity-Recognition (NER) tasks.
""",
STABLELM_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->StableLm, LLAMA->STABLELM
class StableLmForTokenClassification(StableLmPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = StableLmModel(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
...@@ -36,6 +36,7 @@ else: ...@@ -36,6 +36,7 @@ else:
"Starcoder2Model", "Starcoder2Model",
"Starcoder2PreTrainedModel", "Starcoder2PreTrainedModel",
"Starcoder2ForSequenceClassification", "Starcoder2ForSequenceClassification",
"Starcoder2ForTokenClassification",
] ]
...@@ -51,6 +52,7 @@ if TYPE_CHECKING: ...@@ -51,6 +52,7 @@ if TYPE_CHECKING:
from .modeling_starcoder2 import ( from .modeling_starcoder2 import (
Starcoder2ForCausalLM, Starcoder2ForCausalLM,
Starcoder2ForSequenceClassification, Starcoder2ForSequenceClassification,
Starcoder2ForTokenClassification,
Starcoder2Model, Starcoder2Model,
Starcoder2PreTrainedModel, Starcoder2PreTrainedModel,
) )
......
...@@ -31,7 +31,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ...@@ -31,7 +31,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
add_start_docstrings, add_start_docstrings,
...@@ -1357,3 +1362,88 @@ class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel): ...@@ -1357,3 +1362,88 @@ class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel):
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@add_start_docstrings(
"""
The Starcoder2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
output) e.g. for Named-Entity-Recognition (NER) tasks.
""",
STARCODER2_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Starcoder2, LLAMA->STARCODER2
class Starcoder2ForTokenClassification(Starcoder2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = Starcoder2Model(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
...@@ -3692,6 +3692,13 @@ class GemmaForSequenceClassification(metaclass=DummyObject): ...@@ -3692,6 +3692,13 @@ class GemmaForSequenceClassification(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class GemmaForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class GemmaModel(metaclass=DummyObject): class GemmaModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -4642,6 +4649,13 @@ class LlamaForSequenceClassification(metaclass=DummyObject): ...@@ -4642,6 +4649,13 @@ class LlamaForSequenceClassification(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class LlamaForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LlamaModel(metaclass=DummyObject): class LlamaModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -5237,6 +5251,13 @@ class MistralForSequenceClassification(metaclass=DummyObject): ...@@ -5237,6 +5251,13 @@ class MistralForSequenceClassification(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class MistralForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MistralModel(metaclass=DummyObject): class MistralModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -5265,6 +5286,13 @@ class MixtralForSequenceClassification(metaclass=DummyObject): ...@@ -5265,6 +5286,13 @@ class MixtralForSequenceClassification(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class MixtralForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MixtralModel(metaclass=DummyObject): class MixtralModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -6373,6 +6401,13 @@ class PersimmonForSequenceClassification(metaclass=DummyObject): ...@@ -6373,6 +6401,13 @@ class PersimmonForSequenceClassification(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class PersimmonForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PersimmonModel(metaclass=DummyObject): class PersimmonModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -6734,6 +6769,13 @@ class Qwen2ForSequenceClassification(metaclass=DummyObject): ...@@ -6734,6 +6769,13 @@ class Qwen2ForSequenceClassification(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class Qwen2ForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Qwen2Model(metaclass=DummyObject): class Qwen2Model(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -6762,6 +6804,13 @@ class Qwen2MoeForSequenceClassification(metaclass=DummyObject): ...@@ -6762,6 +6804,13 @@ class Qwen2MoeForSequenceClassification(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class Qwen2MoeForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Qwen2MoeModel(metaclass=DummyObject): class Qwen2MoeModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -7793,6 +7842,13 @@ class StableLmForSequenceClassification(metaclass=DummyObject): ...@@ -7793,6 +7842,13 @@ class StableLmForSequenceClassification(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class StableLmForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class StableLmModel(metaclass=DummyObject): class StableLmModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -7821,6 +7877,13 @@ class Starcoder2ForSequenceClassification(metaclass=DummyObject): ...@@ -7821,6 +7877,13 @@ class Starcoder2ForSequenceClassification(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class Starcoder2ForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Starcoder2Model(metaclass=DummyObject): class Starcoder2Model(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -41,7 +41,13 @@ from ...test_pipeline_mixin import PipelineTesterMixin ...@@ -41,7 +41,13 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import GemmaForCausalLM, GemmaForSequenceClassification, GemmaModel, GemmaTokenizer from transformers import (
GemmaForCausalLM,
GemmaForSequenceClassification,
GemmaForTokenClassification,
GemmaModel,
GemmaTokenizer,
)
class GemmaModelTester: class GemmaModelTester:
...@@ -284,12 +290,17 @@ class GemmaModelTester: ...@@ -284,12 +290,17 @@ class GemmaModelTester:
@require_torch @require_torch
class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (GemmaModel, GemmaForCausalLM, GemmaForSequenceClassification) if is_torch_available() else () all_model_classes = (
(GemmaModel, GemmaForCausalLM, GemmaForSequenceClassification, GemmaForTokenClassification)
if is_torch_available()
else ()
)
all_generative_model_classes = (GemmaForCausalLM,) if is_torch_available() else () all_generative_model_classes = (GemmaForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": GemmaModel, "feature-extraction": GemmaModel,
"text-classification": GemmaForSequenceClassification, "text-classification": GemmaForSequenceClassification,
"token-classification": GemmaForTokenClassification,
"text-generation": GemmaForCausalLM, "text-generation": GemmaForCausalLM,
"zero-shot": GemmaForSequenceClassification, "zero-shot": GemmaForSequenceClassification,
} }
...@@ -370,6 +381,22 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -370,6 +381,22 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Gemma,llama->Gemma
def test_Gemma_token_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
model = GemmaForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
self.assertEqual(
result.logits.shape,
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip("Gemma buffers include complex numbers, which breaks this test") @unittest.skip("Gemma buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass
......
...@@ -47,6 +47,7 @@ if is_torch_available(): ...@@ -47,6 +47,7 @@ if is_torch_available():
LlamaForCausalLM, LlamaForCausalLM,
LlamaForQuestionAnswering, LlamaForQuestionAnswering,
LlamaForSequenceClassification, LlamaForSequenceClassification,
LlamaForTokenClassification,
LlamaModel, LlamaModel,
LlamaTokenizer, LlamaTokenizer,
) )
...@@ -286,7 +287,13 @@ class LlamaModelTester: ...@@ -286,7 +287,13 @@ class LlamaModelTester:
@require_torch @require_torch
class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(LlamaModel, LlamaForCausalLM, LlamaForSequenceClassification, LlamaForQuestionAnswering) (
LlamaModel,
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaForQuestionAnswering,
LlamaForTokenClassification,
)
if is_torch_available() if is_torch_available()
else () else ()
) )
...@@ -298,6 +305,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -298,6 +305,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
"text-generation": LlamaForCausalLM, "text-generation": LlamaForCausalLM,
"zero-shot": LlamaForSequenceClassification, "zero-shot": LlamaForSequenceClassification,
"question-answering": LlamaForQuestionAnswering, "question-answering": LlamaForQuestionAnswering,
"token-classification": LlamaForTokenClassification,
} }
if is_torch_available() if is_torch_available()
else {} else {}
...@@ -370,6 +378,21 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -370,6 +378,21 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
def test_llama_token_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
model = LlamaForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
self.assertEqual(
result.logits.shape,
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip("Llama buffers include complex numbers, which breaks this test") @unittest.skip("Llama buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass
......
...@@ -46,6 +46,7 @@ if is_torch_available(): ...@@ -46,6 +46,7 @@ if is_torch_available():
from transformers import ( from transformers import (
MistralForCausalLM, MistralForCausalLM,
MistralForSequenceClassification, MistralForSequenceClassification,
MistralForTokenClassification,
MistralModel, MistralModel,
) )
...@@ -288,13 +289,16 @@ class MistralModelTester: ...@@ -288,13 +289,16 @@ class MistralModelTester:
@require_torch @require_torch
class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(MistralModel, MistralForCausalLM, MistralForSequenceClassification) if is_torch_available() else () (MistralModel, MistralForCausalLM, MistralForSequenceClassification, MistralForTokenClassification)
if is_torch_available()
else ()
) )
all_generative_model_classes = (MistralForCausalLM,) if is_torch_available() else () all_generative_model_classes = (MistralForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": MistralModel, "feature-extraction": MistralModel,
"text-classification": MistralForSequenceClassification, "text-classification": MistralForSequenceClassification,
"token-classification": MistralForTokenClassification,
"text-generation": MistralForCausalLM, "text-generation": MistralForCausalLM,
"zero-shot": MistralForSequenceClassification, "zero-shot": MistralForSequenceClassification,
} }
...@@ -376,6 +380,22 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -376,6 +380,22 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Mistral,llama->Mistral
def test_Mistral_token_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
model = MistralForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
self.assertEqual(
result.logits.shape,
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip("Mistral buffers include complex numbers, which breaks this test") @unittest.skip("Mistral buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass
......
...@@ -40,7 +40,12 @@ from ...test_pipeline_mixin import PipelineTesterMixin ...@@ -40,7 +40,12 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import MixtralForCausalLM, MixtralForSequenceClassification, MixtralModel from transformers import (
MixtralForCausalLM,
MixtralForSequenceClassification,
MixtralForTokenClassification,
MixtralModel,
)
class MixtralModelTester: class MixtralModelTester:
...@@ -287,13 +292,16 @@ class MixtralModelTester: ...@@ -287,13 +292,16 @@ class MixtralModelTester:
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Mixtral # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Mixtral
class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(MixtralModel, MixtralForCausalLM, MixtralForSequenceClassification) if is_torch_available() else () (MixtralModel, MixtralForCausalLM, MixtralForSequenceClassification, MixtralForTokenClassification)
if is_torch_available()
else ()
) )
all_generative_model_classes = (MixtralForCausalLM,) if is_torch_available() else () all_generative_model_classes = (MixtralForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": MixtralModel, "feature-extraction": MixtralModel,
"text-classification": MixtralForSequenceClassification, "text-classification": MixtralForSequenceClassification,
"token-classification": MixtralForTokenClassification,
"text-generation": MixtralForCausalLM, "text-generation": MixtralForCausalLM,
"zero-shot": MixtralForSequenceClassification, "zero-shot": MixtralForSequenceClassification,
} }
...@@ -375,6 +383,22 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -375,6 +383,22 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Mixtral,llama->Mixtral
def test_Mixtral_token_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
model = MixtralForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
self.assertEqual(
result.logits.shape,
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip("Mixtral buffers include complex numbers, which breaks this test") @unittest.skip("Mixtral buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass
......
...@@ -44,6 +44,7 @@ if is_torch_available(): ...@@ -44,6 +44,7 @@ if is_torch_available():
AutoTokenizer, AutoTokenizer,
PersimmonForCausalLM, PersimmonForCausalLM,
PersimmonForSequenceClassification, PersimmonForSequenceClassification,
PersimmonForTokenClassification,
PersimmonModel, PersimmonModel,
) )
from transformers.models.persimmon.modeling_persimmon import ( from transformers.models.persimmon.modeling_persimmon import (
...@@ -283,12 +284,15 @@ class PersimmonModelTester: ...@@ -283,12 +284,15 @@ class PersimmonModelTester:
@require_torch @require_torch
class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(PersimmonModel, PersimmonForCausalLM, PersimmonForSequenceClassification) if is_torch_available() else () (PersimmonModel, PersimmonForCausalLM, PersimmonForSequenceClassification, PersimmonForTokenClassification)
if is_torch_available()
else ()
) )
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": PersimmonModel, "feature-extraction": PersimmonModel,
"text-classification": PersimmonForSequenceClassification, "text-classification": PersimmonForSequenceClassification,
"token-classification": PersimmonForTokenClassification,
# TODO (ydshieh): check why these two fail. Fix them or skip them in a better way. # TODO (ydshieh): check why these two fail. Fix them or skip them in a better way.
# "text-generation": PersimmonForCausalLM, # "text-generation": PersimmonForCausalLM,
# "zero-shot": PersimmonForSequenceClassification, # "zero-shot": PersimmonForSequenceClassification,
...@@ -365,6 +369,22 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester ...@@ -365,6 +369,22 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Persimmon,llama->persimmon
def test_persimmon_token_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
model = PersimmonForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
self.assertEqual(
result.logits.shape,
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip("Persimmon buffers include complex numbers, which breaks this test") @unittest.skip("Persimmon buffers include complex numbers, which breaks this test")
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_save_load_fast_init_from_base # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_save_load_fast_init_from_base
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
......
...@@ -45,6 +45,7 @@ if is_torch_available(): ...@@ -45,6 +45,7 @@ if is_torch_available():
from transformers import ( from transformers import (
Qwen2ForCausalLM, Qwen2ForCausalLM,
Qwen2ForSequenceClassification, Qwen2ForSequenceClassification,
Qwen2ForTokenClassification,
Qwen2Model, Qwen2Model,
) )
...@@ -299,12 +300,17 @@ class Qwen2ModelTester: ...@@ -299,12 +300,17 @@ class Qwen2ModelTester:
@require_torch @require_torch
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2 # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2
class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Qwen2Model, Qwen2ForCausalLM, Qwen2ForSequenceClassification) if is_torch_available() else () all_model_classes = (
(Qwen2Model, Qwen2ForCausalLM, Qwen2ForSequenceClassification, Qwen2ForTokenClassification)
if is_torch_available()
else ()
)
all_generative_model_classes = (Qwen2ForCausalLM,) if is_torch_available() else () all_generative_model_classes = (Qwen2ForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": Qwen2Model, "feature-extraction": Qwen2Model,
"text-classification": Qwen2ForSequenceClassification, "text-classification": Qwen2ForSequenceClassification,
"token-classification": Qwen2ForTokenClassification,
"text-generation": Qwen2ForCausalLM, "text-generation": Qwen2ForCausalLM,
"zero-shot": Qwen2ForSequenceClassification, "zero-shot": Qwen2ForSequenceClassification,
} }
...@@ -387,6 +393,22 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -387,6 +393,22 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen2,llama->Qwen2
def test_Qwen2_token_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
model = Qwen2ForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
self.assertEqual(
result.logits.shape,
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip("Qwen2 buffers include complex numbers, which breaks this test") @unittest.skip("Qwen2 buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass
......
...@@ -45,6 +45,7 @@ if is_torch_available(): ...@@ -45,6 +45,7 @@ if is_torch_available():
from transformers import ( from transformers import (
Qwen2MoeForCausalLM, Qwen2MoeForCausalLM,
Qwen2MoeForSequenceClassification, Qwen2MoeForSequenceClassification,
Qwen2MoeForTokenClassification,
Qwen2MoeModel, Qwen2MoeModel,
) )
...@@ -327,13 +328,16 @@ class Qwen2MoeModelTester: ...@@ -327,13 +328,16 @@ class Qwen2MoeModelTester:
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2Moe # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2Moe
class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(Qwen2MoeModel, Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification) if is_torch_available() else () (Qwen2MoeModel, Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification, Qwen2MoeForTokenClassification)
if is_torch_available()
else ()
) )
all_generative_model_classes = (Qwen2MoeForCausalLM,) if is_torch_available() else () all_generative_model_classes = (Qwen2MoeForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": Qwen2MoeModel, "feature-extraction": Qwen2MoeModel,
"text-classification": Qwen2MoeForSequenceClassification, "text-classification": Qwen2MoeForSequenceClassification,
"token-classification": Qwen2MoeForTokenClassification,
"text-generation": Qwen2MoeForCausalLM, "text-generation": Qwen2MoeForCausalLM,
"zero-shot": Qwen2MoeForSequenceClassification, "zero-shot": Qwen2MoeForSequenceClassification,
} }
...@@ -414,6 +418,22 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM ...@@ -414,6 +418,22 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen2Moe,llama->Qwen2Moe
def test_Qwen2Moe_token_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
model = Qwen2MoeForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
self.assertEqual(
result.logits.shape,
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip("Qwen2Moe buffers include complex numbers, which breaks this test") @unittest.skip("Qwen2Moe buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass
......
...@@ -43,6 +43,7 @@ if is_torch_available(): ...@@ -43,6 +43,7 @@ if is_torch_available():
AutoTokenizer, AutoTokenizer,
StableLmForCausalLM, StableLmForCausalLM,
StableLmForSequenceClassification, StableLmForSequenceClassification,
StableLmForTokenClassification,
StableLmModel, StableLmModel,
) )
from transformers.models.stablelm.modeling_stablelm import ( from transformers.models.stablelm.modeling_stablelm import (
...@@ -287,12 +288,15 @@ class StableLmModelTester: ...@@ -287,12 +288,15 @@ class StableLmModelTester:
# Copied from transformers.tests.persimmon.test_modeling_persimmon.PersimmonModelTest with Persimmon -> StableLm # Copied from transformers.tests.persimmon.test_modeling_persimmon.PersimmonModelTest with Persimmon -> StableLm
class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(StableLmModel, StableLmForCausalLM, StableLmForSequenceClassification) if is_torch_available() else () (StableLmModel, StableLmForCausalLM, StableLmForSequenceClassification, StableLmForTokenClassification)
if is_torch_available()
else ()
) )
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": StableLmModel, "feature-extraction": StableLmModel,
"text-classification": StableLmForSequenceClassification, "text-classification": StableLmForSequenceClassification,
"token-classification": StableLmForTokenClassification,
# TODO (ydshieh): check why these two fail. Fix them or skip them in a better way. # TODO (ydshieh): check why these two fail. Fix them or skip them in a better way.
# "text-generation": StableLmForCausalLM, # "text-generation": StableLmForCausalLM,
# "zero-shot": StableLmForSequenceClassification, # "zero-shot": StableLmForSequenceClassification,
...@@ -356,6 +360,22 @@ class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM ...@@ -356,6 +360,22 @@ class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->StableLm,llama->stablelm
def test_stablelm_token_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
model = StableLmForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
self.assertEqual(
result.logits.shape,
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@parameterized.expand([("linear",), ("dynamic",)]) @parameterized.expand([("linear",), ("dynamic",)])
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->StableLm # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->StableLm
def test_model_rope_scaling_from_config(self, scaling_type): def test_model_rope_scaling_from_config(self, scaling_type):
......
...@@ -43,6 +43,7 @@ if is_torch_available(): ...@@ -43,6 +43,7 @@ if is_torch_available():
AutoTokenizer, AutoTokenizer,
Starcoder2ForCausalLM, Starcoder2ForCausalLM,
Starcoder2ForSequenceClassification, Starcoder2ForSequenceClassification,
Starcoder2ForTokenClassification,
Starcoder2Model, Starcoder2Model,
) )
...@@ -290,13 +291,16 @@ class Starcoder2ModelTester: ...@@ -290,13 +291,16 @@ class Starcoder2ModelTester:
# Copied from transformers.tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Starcoder2 # Copied from transformers.tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Starcoder2
class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(Starcoder2Model, Starcoder2ForCausalLM, Starcoder2ForSequenceClassification) if is_torch_available() else () (Starcoder2Model, Starcoder2ForCausalLM, Starcoder2ForSequenceClassification, Starcoder2ForTokenClassification)
if is_torch_available()
else ()
) )
all_generative_model_classes = (Starcoder2ForCausalLM,) if is_torch_available() else () all_generative_model_classes = (Starcoder2ForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": Starcoder2Model, "feature-extraction": Starcoder2Model,
"text-classification": Starcoder2ForSequenceClassification, "text-classification": Starcoder2ForSequenceClassification,
"token-classification": Starcoder2ForTokenClassification,
"text-generation": Starcoder2ForCausalLM, "text-generation": Starcoder2ForCausalLM,
"zero-shot": Starcoder2ForSequenceClassification, "zero-shot": Starcoder2ForSequenceClassification,
} }
...@@ -370,6 +374,22 @@ class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ...@@ -370,6 +374,22 @@ class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Starcoder2,llama->Starcoder2
def test_Starcoder2_token_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
model = Starcoder2ForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
self.assertEqual(
result.logits.shape,
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip("Starcoder2 buffers include complex numbers, which breaks this test") @unittest.skip("Starcoder2 buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment