Unverified Commit 76d0d41e authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Wav2Vec2] Make sure that gradient checkpointing is only run if needed (#14407)

* [Wav2Vec2] Make sure that gradient checkpointing is only run if needed

* make fix-copies
parent 9fd937ea
...@@ -291,20 +291,22 @@ class HubertFeatureExtractor(nn.Module): ...@@ -291,20 +291,22 @@ class HubertFeatureExtractor(nn.Module):
) )
self.conv_layers = nn.ModuleList(conv_layers) self.conv_layers = nn.ModuleList(conv_layers)
self.gradient_checkpointing = False self.gradient_checkpointing = False
self._requires_grad = True
def _freeze_parameters(self): def _freeze_parameters(self):
for param in self.parameters(): for param in self.parameters():
param.requires_grad = False param.requires_grad = False
self._requires_grad = False
def forward(self, input_values): def forward(self, input_values):
hidden_states = input_values[:, None] hidden_states = input_values[:, None]
# make sure hidden_states require grad for gradient_checkpointing # make sure hidden_states require grad for gradient_checkpointing
if self.training: if self._requires_grad and self.training:
hidden_states.requires_grad = True hidden_states.requires_grad = True
for conv_layer in self.conv_layers: for conv_layer in self.conv_layers:
if self.gradient_checkpointing and self.training: if self._requires_grad and self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -308,20 +308,22 @@ class SEWFeatureExtractor(nn.Module): ...@@ -308,20 +308,22 @@ class SEWFeatureExtractor(nn.Module):
) )
self.conv_layers = nn.ModuleList(conv_layers) self.conv_layers = nn.ModuleList(conv_layers)
self.gradient_checkpointing = False self.gradient_checkpointing = False
self._requires_grad = True
def _freeze_parameters(self): def _freeze_parameters(self):
for param in self.parameters(): for param in self.parameters():
param.requires_grad = False param.requires_grad = False
self._requires_grad = False
def forward(self, input_values): def forward(self, input_values):
hidden_states = input_values[:, None] hidden_states = input_values[:, None]
# make sure hidden_states require grad for gradient_checkpointing # make sure hidden_states require grad for gradient_checkpointing
if self.training: if self._requires_grad and self.training:
hidden_states.requires_grad = True hidden_states.requires_grad = True
for conv_layer in self.conv_layers: for conv_layer in self.conv_layers:
if self.gradient_checkpointing and self.training: if self._requires_grad and self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -394,20 +394,22 @@ class SEWDFeatureExtractor(nn.Module): ...@@ -394,20 +394,22 @@ class SEWDFeatureExtractor(nn.Module):
) )
self.conv_layers = nn.ModuleList(conv_layers) self.conv_layers = nn.ModuleList(conv_layers)
self.gradient_checkpointing = False self.gradient_checkpointing = False
self._requires_grad = True
def _freeze_parameters(self): def _freeze_parameters(self):
for param in self.parameters(): for param in self.parameters():
param.requires_grad = False param.requires_grad = False
self._requires_grad = False
def forward(self, input_values): def forward(self, input_values):
hidden_states = input_values[:, None] hidden_states = input_values[:, None]
# make sure hidden_states require grad for gradient_checkpointing # make sure hidden_states require grad for gradient_checkpointing
if self.training: if self._requires_grad and self.training:
hidden_states.requires_grad = True hidden_states.requires_grad = True
for conv_layer in self.conv_layers: for conv_layer in self.conv_layers:
if self.gradient_checkpointing and self.training: if self._requires_grad and self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -361,20 +361,22 @@ class UniSpeechFeatureExtractor(nn.Module): ...@@ -361,20 +361,22 @@ class UniSpeechFeatureExtractor(nn.Module):
) )
self.conv_layers = nn.ModuleList(conv_layers) self.conv_layers = nn.ModuleList(conv_layers)
self.gradient_checkpointing = False self.gradient_checkpointing = False
self._requires_grad = True
def _freeze_parameters(self): def _freeze_parameters(self):
for param in self.parameters(): for param in self.parameters():
param.requires_grad = False param.requires_grad = False
self._requires_grad = False
def forward(self, input_values): def forward(self, input_values):
hidden_states = input_values[:, None] hidden_states = input_values[:, None]
# make sure hidden_states require grad for gradient_checkpointing # make sure hidden_states require grad for gradient_checkpointing
if self.training: if self._requires_grad and self.training:
hidden_states.requires_grad = True hidden_states.requires_grad = True
for conv_layer in self.conv_layers: for conv_layer in self.conv_layers:
if self.gradient_checkpointing and self.training: if self._requires_grad and self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -362,20 +362,22 @@ class UniSpeechSatFeatureExtractor(nn.Module): ...@@ -362,20 +362,22 @@ class UniSpeechSatFeatureExtractor(nn.Module):
) )
self.conv_layers = nn.ModuleList(conv_layers) self.conv_layers = nn.ModuleList(conv_layers)
self.gradient_checkpointing = False self.gradient_checkpointing = False
self._requires_grad = True
def _freeze_parameters(self): def _freeze_parameters(self):
for param in self.parameters(): for param in self.parameters():
param.requires_grad = False param.requires_grad = False
self._requires_grad = False
def forward(self, input_values): def forward(self, input_values):
hidden_states = input_values[:, None] hidden_states = input_values[:, None]
# make sure hidden_states require grad for gradient_checkpointing # make sure hidden_states require grad for gradient_checkpointing
if self.training: if self._requires_grad and self.training:
hidden_states.requires_grad = True hidden_states.requires_grad = True
for conv_layer in self.conv_layers: for conv_layer in self.conv_layers:
if self.gradient_checkpointing and self.training: if self._requires_grad and self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -399,20 +399,22 @@ class Wav2Vec2FeatureExtractor(nn.Module): ...@@ -399,20 +399,22 @@ class Wav2Vec2FeatureExtractor(nn.Module):
) )
self.conv_layers = nn.ModuleList(conv_layers) self.conv_layers = nn.ModuleList(conv_layers)
self.gradient_checkpointing = False self.gradient_checkpointing = False
self._requires_grad = True
def _freeze_parameters(self): def _freeze_parameters(self):
for param in self.parameters(): for param in self.parameters():
param.requires_grad = False param.requires_grad = False
self._requires_grad = False
def forward(self, input_values): def forward(self, input_values):
hidden_states = input_values[:, None] hidden_states = input_values[:, None]
# make sure hidden_states require grad for gradient_checkpointing # make sure hidden_states require grad for gradient_checkpointing
if self.training: if self._requires_grad and self.training:
hidden_states.requires_grad = True hidden_states.requires_grad = True
for conv_layer in self.conv_layers: for conv_layer in self.conv_layers:
if self.gradient_checkpointing and self.training: if self._requires_grad and self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment