Unverified Commit 476890e9 authored by Jonny Li's avatar Jonny Li Committed by GitHub
Browse files

Fix DeepSpeed compatibility with weight_norm (#30881) (#31018)

parent aada568f
...@@ -295,8 +295,14 @@ class HubertPositionalConvEmbedding(nn.Module): ...@@ -295,8 +295,14 @@ class HubertPositionalConvEmbedding(nn.Module):
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = weight_norm(self.conv, name="weight", dim=2) self.conv = weight_norm(self.conv, name="weight", dim=2)
deepspeed.zero.register_external_parameter(self, self.conv.weight_v) if hasattr(self.conv, "parametrizations"):
deepspeed.zero.register_external_parameter(self, self.conv.weight_g) weight_g = self.conv.parametrizations.weight.original0
weight_v = self.conv.parametrizations.weight.original1
else:
weight_g = self.conv.weight_g
weight_v = self.conv.weight_v
deepspeed.zero.register_external_parameter(self, weight_v)
deepspeed.zero.register_external_parameter(self, weight_g)
else: else:
self.conv = weight_norm(self.conv, name="weight", dim=2) self.conv = weight_norm(self.conv, name="weight", dim=2)
......
...@@ -325,8 +325,14 @@ class SeamlessM4TConformerPositionalConvEmbedding(nn.Module): ...@@ -325,8 +325,14 @@ class SeamlessM4TConformerPositionalConvEmbedding(nn.Module):
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = weight_norm(self.conv, name="weight", dim=2) self.conv = weight_norm(self.conv, name="weight", dim=2)
deepspeed.zero.register_external_parameter(self, self.conv.weight_v) if hasattr(self.conv, "parametrizations"):
deepspeed.zero.register_external_parameter(self, self.conv.weight_g) weight_g = self.conv.parametrizations.weight.original0
weight_v = self.conv.parametrizations.weight.original1
else:
weight_g = self.conv.weight_g
weight_v = self.conv.weight_v
deepspeed.zero.register_external_parameter(self, weight_v)
deepspeed.zero.register_external_parameter(self, weight_g)
else: else:
self.conv = weight_norm(self.conv, name="weight", dim=2) self.conv = weight_norm(self.conv, name="weight", dim=2)
......
...@@ -294,8 +294,14 @@ class SEWPositionalConvEmbedding(nn.Module): ...@@ -294,8 +294,14 @@ class SEWPositionalConvEmbedding(nn.Module):
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
deepspeed.zero.register_external_parameter(self, self.conv.weight_v) if hasattr(self.conv, "parametrizations"):
deepspeed.zero.register_external_parameter(self, self.conv.weight_g) weight_g = self.conv.parametrizations.weight.original0
weight_v = self.conv.parametrizations.weight.original1
else:
weight_g = self.conv.weight_g
weight_v = self.conv.weight_v
deepspeed.zero.register_external_parameter(self, weight_v)
deepspeed.zero.register_external_parameter(self, weight_g)
else: else:
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
......
...@@ -354,8 +354,14 @@ class SEWDPositionalConvEmbedding(nn.Module): ...@@ -354,8 +354,14 @@ class SEWDPositionalConvEmbedding(nn.Module):
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
deepspeed.zero.register_external_parameter(self, self.conv.weight_v) if hasattr(self.conv, "parametrizations"):
deepspeed.zero.register_external_parameter(self, self.conv.weight_g) weight_g = self.conv.parametrizations.weight.original0
weight_v = self.conv.parametrizations.weight.original1
else:
weight_g = self.conv.weight_g
weight_v = self.conv.weight_v
deepspeed.zero.register_external_parameter(self, weight_v)
deepspeed.zero.register_external_parameter(self, weight_g)
else: else:
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
......
...@@ -368,8 +368,14 @@ class SpeechT5PositionalConvEmbedding(nn.Module): ...@@ -368,8 +368,14 @@ class SpeechT5PositionalConvEmbedding(nn.Module):
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = weight_norm(self.conv, name="weight", dim=2) self.conv = weight_norm(self.conv, name="weight", dim=2)
deepspeed.zero.register_external_parameter(self, self.conv.weight_v) if hasattr(self.conv, "parametrizations"):
deepspeed.zero.register_external_parameter(self, self.conv.weight_g) weight_g = self.conv.parametrizations.weight.original0
weight_v = self.conv.parametrizations.weight.original1
else:
weight_g = self.conv.weight_g
weight_v = self.conv.weight_v
deepspeed.zero.register_external_parameter(self, weight_v)
deepspeed.zero.register_external_parameter(self, weight_g)
else: else:
self.conv = weight_norm(self.conv, name="weight", dim=2) self.conv = weight_norm(self.conv, name="weight", dim=2)
......
...@@ -330,8 +330,14 @@ class UniSpeechPositionalConvEmbedding(nn.Module): ...@@ -330,8 +330,14 @@ class UniSpeechPositionalConvEmbedding(nn.Module):
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = weight_norm(self.conv, name="weight", dim=2) self.conv = weight_norm(self.conv, name="weight", dim=2)
deepspeed.zero.register_external_parameter(self, self.conv.weight_v) if hasattr(self.conv, "parametrizations"):
deepspeed.zero.register_external_parameter(self, self.conv.weight_g) weight_g = self.conv.parametrizations.weight.original0
weight_v = self.conv.parametrizations.weight.original1
else:
weight_g = self.conv.weight_g
weight_v = self.conv.weight_v
deepspeed.zero.register_external_parameter(self, weight_v)
deepspeed.zero.register_external_parameter(self, weight_g)
else: else:
self.conv = weight_norm(self.conv, name="weight", dim=2) self.conv = weight_norm(self.conv, name="weight", dim=2)
......
...@@ -347,8 +347,14 @@ class UniSpeechSatPositionalConvEmbedding(nn.Module): ...@@ -347,8 +347,14 @@ class UniSpeechSatPositionalConvEmbedding(nn.Module):
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = weight_norm(self.conv, name="weight", dim=2) self.conv = weight_norm(self.conv, name="weight", dim=2)
deepspeed.zero.register_external_parameter(self, self.conv.weight_v) if hasattr(self.conv, "parametrizations"):
deepspeed.zero.register_external_parameter(self, self.conv.weight_g) weight_g = self.conv.parametrizations.weight.original0
weight_v = self.conv.parametrizations.weight.original1
else:
weight_g = self.conv.weight_g
weight_v = self.conv.weight_v
deepspeed.zero.register_external_parameter(self, weight_v)
deepspeed.zero.register_external_parameter(self, weight_g)
else: else:
self.conv = weight_norm(self.conv, name="weight", dim=2) self.conv = weight_norm(self.conv, name="weight", dim=2)
......
...@@ -398,8 +398,14 @@ class Wav2Vec2PositionalConvEmbedding(nn.Module): ...@@ -398,8 +398,14 @@ class Wav2Vec2PositionalConvEmbedding(nn.Module):
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = weight_norm(self.conv, name="weight", dim=2) self.conv = weight_norm(self.conv, name="weight", dim=2)
deepspeed.zero.register_external_parameter(self, self.conv.weight_v) if hasattr(self.conv, "parametrizations"):
deepspeed.zero.register_external_parameter(self, self.conv.weight_g) weight_g = self.conv.parametrizations.weight.original0
weight_v = self.conv.parametrizations.weight.original1
else:
weight_g = self.conv.weight_g
weight_v = self.conv.weight_v
deepspeed.zero.register_external_parameter(self, weight_v)
deepspeed.zero.register_external_parameter(self, weight_g)
else: else:
self.conv = weight_norm(self.conv, name="weight", dim=2) self.conv = weight_norm(self.conv, name="weight", dim=2)
......
...@@ -361,8 +361,14 @@ class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module): ...@@ -361,8 +361,14 @@ class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module):
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = weight_norm(self.conv, name="weight", dim=2) self.conv = weight_norm(self.conv, name="weight", dim=2)
deepspeed.zero.register_external_parameter(self, self.conv.weight_v) if hasattr(self.conv, "parametrizations"):
deepspeed.zero.register_external_parameter(self, self.conv.weight_g) weight_g = self.conv.parametrizations.weight.original0
weight_v = self.conv.parametrizations.weight.original1
else:
weight_g = self.conv.weight_g
weight_v = self.conv.weight_v
deepspeed.zero.register_external_parameter(self, weight_v)
deepspeed.zero.register_external_parameter(self, weight_g)
else: else:
self.conv = weight_norm(self.conv, name="weight", dim=2) self.conv = weight_norm(self.conv, name="weight", dim=2)
......
...@@ -287,8 +287,14 @@ class WavLMPositionalConvEmbedding(nn.Module): ...@@ -287,8 +287,14 @@ class WavLMPositionalConvEmbedding(nn.Module):
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = weight_norm(self.conv, name="weight", dim=2) self.conv = weight_norm(self.conv, name="weight", dim=2)
deepspeed.zero.register_external_parameter(self, self.conv.weight_v) if hasattr(self.conv, "parametrizations"):
deepspeed.zero.register_external_parameter(self, self.conv.weight_g) weight_g = self.conv.parametrizations.weight.original0
weight_v = self.conv.parametrizations.weight.original1
else:
weight_g = self.conv.weight_g
weight_v = self.conv.weight_v
deepspeed.zero.register_external_parameter(self, weight_v)
deepspeed.zero.register_external_parameter(self, weight_g)
else: else:
self.conv = weight_norm(self.conv, name="weight", dim=2) self.conv = weight_norm(self.conv, name="weight", dim=2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment