Unverified Commit 046c5ea9 authored by Moreno La Quatra's avatar Moreno La Quatra Committed by GitHub
Browse files

Implemented loss for training AudioFrameClassification (#17513)

* Implemented loss for training AudioFrameClassification

* reported changes in wav2vec2 main class and used make copies to propagate

* running black for code formatting
parent 085321c9
...@@ -1243,6 +1243,7 @@ class Data2VecAudioForAudioFrameClassification(Data2VecAudioPreTrainedModel): ...@@ -1243,6 +1243,7 @@ class Data2VecAudioForAudioFrameClassification(Data2VecAudioPreTrainedModel):
if config.use_weighted_layer_sum: if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.num_labels = config.num_labels
self.init_weights() self.init_weights()
...@@ -1286,6 +1287,7 @@ class Data2VecAudioForAudioFrameClassification(Data2VecAudioPreTrainedModel): ...@@ -1286,6 +1287,7 @@ class Data2VecAudioForAudioFrameClassification(Data2VecAudioPreTrainedModel):
self, self,
input_values: Optional[torch.Tensor], input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
...@@ -1318,12 +1320,17 @@ class Data2VecAudioForAudioFrameClassification(Data2VecAudioPreTrainedModel): ...@@ -1318,12 +1320,17 @@ class Data2VecAudioForAudioFrameClassification(Data2VecAudioPreTrainedModel):
logits = self.classifier(hidden_states) logits = self.classifier(hidden_states)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
if not return_dict: if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return output return output
return TokenClassifierOutput( return TokenClassifierOutput(
loss=None, loss=loss,
logits=logits, logits=logits,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
......
...@@ -1611,6 +1611,7 @@ class UniSpeechSatForAudioFrameClassification(UniSpeechSatPreTrainedModel): ...@@ -1611,6 +1611,7 @@ class UniSpeechSatForAudioFrameClassification(UniSpeechSatPreTrainedModel):
if config.use_weighted_layer_sum: if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.num_labels = config.num_labels
self.init_weights() self.init_weights()
...@@ -1654,6 +1655,7 @@ class UniSpeechSatForAudioFrameClassification(UniSpeechSatPreTrainedModel): ...@@ -1654,6 +1655,7 @@ class UniSpeechSatForAudioFrameClassification(UniSpeechSatPreTrainedModel):
self, self,
input_values: Optional[torch.Tensor], input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
...@@ -1686,12 +1688,17 @@ class UniSpeechSatForAudioFrameClassification(UniSpeechSatPreTrainedModel): ...@@ -1686,12 +1688,17 @@ class UniSpeechSatForAudioFrameClassification(UniSpeechSatPreTrainedModel):
logits = self.classifier(hidden_states) logits = self.classifier(hidden_states)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
if not return_dict: if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return output return output
return TokenClassifierOutput( return TokenClassifierOutput(
loss=None, loss=loss,
logits=logits, logits=logits,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
......
...@@ -1852,6 +1852,7 @@ class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel): ...@@ -1852,6 +1852,7 @@ class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel):
if config.use_weighted_layer_sum: if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.num_labels = config.num_labels
self.init_weights() self.init_weights()
...@@ -1895,6 +1896,7 @@ class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel): ...@@ -1895,6 +1896,7 @@ class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel):
self, self,
input_values: Optional[torch.Tensor], input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
...@@ -1927,12 +1929,17 @@ class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel): ...@@ -1927,12 +1929,17 @@ class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel):
logits = self.classifier(hidden_states) logits = self.classifier(hidden_states)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
if not return_dict: if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return output return output
return TokenClassifierOutput( return TokenClassifierOutput(
loss=None, loss=loss,
logits=logits, logits=logits,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
......
...@@ -1845,6 +1845,7 @@ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedMo ...@@ -1845,6 +1845,7 @@ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedMo
if config.use_weighted_layer_sum: if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.num_labels = config.num_labels
self.init_weights() self.init_weights()
...@@ -1879,6 +1880,7 @@ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedMo ...@@ -1879,6 +1880,7 @@ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedMo
self, self,
input_values: Optional[torch.Tensor], input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
...@@ -1911,12 +1913,17 @@ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedMo ...@@ -1911,12 +1913,17 @@ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedMo
logits = self.classifier(hidden_states) logits = self.classifier(hidden_states)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
if not return_dict: if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return output return output
return TokenClassifierOutput( return TokenClassifierOutput(
loss=None, loss=loss,
logits=logits, logits=logits,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
......
...@@ -1545,6 +1545,7 @@ class WavLMForAudioFrameClassification(WavLMPreTrainedModel): ...@@ -1545,6 +1545,7 @@ class WavLMForAudioFrameClassification(WavLMPreTrainedModel):
if config.use_weighted_layer_sum: if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.num_labels = config.num_labels
self.init_weights() self.init_weights()
...@@ -1588,6 +1589,7 @@ class WavLMForAudioFrameClassification(WavLMPreTrainedModel): ...@@ -1588,6 +1589,7 @@ class WavLMForAudioFrameClassification(WavLMPreTrainedModel):
self, self,
input_values: Optional[torch.Tensor], input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
...@@ -1620,12 +1622,17 @@ class WavLMForAudioFrameClassification(WavLMPreTrainedModel): ...@@ -1620,12 +1622,17 @@ class WavLMForAudioFrameClassification(WavLMPreTrainedModel):
logits = self.classifier(hidden_states) logits = self.classifier(hidden_states)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
if not return_dict: if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return output return output
return TokenClassifierOutput( return TokenClassifierOutput(
loss=None, loss=loss,
logits=logits, logits=logits,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment