Unverified Commit 11f3ec72 authored by Ali Hassani's avatar Ali Hassani Committed by GitHub
Browse files

Add LayerScale to NAT/DiNAT (#20325)



* Add LayerScale to NAT/DiNAT.

Completely dropped the ball on LayerScale in the original PR (#20219).
This is just an optional argument in both models, and is only activated for larger variants in order to provide training stability.

* Add LayerScale to NAT/DiNAT.

Minor error fixed.
Co-authored-by: default avatarAli Hassani <ahassanijr@gmail.com>
parent d28448c5
......@@ -70,6 +70,8 @@ class DinatConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
layer_scale_init_value (`float`, *optional*, defaults to 0.0):
The initial value for the layer scale. Disabled if <=0.
Example:
......@@ -110,6 +112,7 @@ class DinatConfig(PretrainedConfig):
patch_norm=True,
initializer_range=0.02,
layer_norm_eps=1e-5,
layer_scale_init_value=0.0,
**kwargs
):
super().__init__(**kwargs)
......@@ -134,3 +137,4 @@ class DinatConfig(PretrainedConfig):
# we set the hidden_size attribute in order to make Dinat work with VisionEncoderDecoderModel
# this indicates the channel dimension after the last stage of the model
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
self.layer_scale_init_value = layer_scale_init_value
......@@ -462,6 +462,11 @@ class DinatLayer(nn.Module):
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.intermediate = DinatIntermediate(config, dim)
self.output = DinatOutput(config, dim)
self.layer_scale_parameters = (
nn.Parameter(config.layer_scale_init_value * torch.ones((2, dim)), requires_grad=True)
if config.layer_scale_init_value > 0
else None
)
def maybe_pad(self, hidden_states, height, width):
window_size = self.window_size
......@@ -496,11 +501,18 @@ class DinatLayer(nn.Module):
if was_padded:
attention_output = attention_output[:, :height, :width, :].contiguous()
if self.layer_scale_parameters is not None:
attention_output = self.layer_scale_parameters[0] * attention_output
hidden_states = shortcut + self.drop_path(attention_output)
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = hidden_states + self.output(layer_output)
layer_output = self.output(self.intermediate(layer_output))
if self.layer_scale_parameters is not None:
layer_output = self.layer_scale_parameters[1] * layer_output
layer_output = hidden_states + self.drop_path(layer_output)
layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
return layer_outputs
......
......@@ -12,7 +12,6 @@
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......
......@@ -68,6 +68,8 @@ class NatConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
layer_scale_init_value (`float`, *optional*, defaults to 0.0):
The initial value for the layer scale. Disabled if <=0.
Example:
......@@ -107,6 +109,7 @@ class NatConfig(PretrainedConfig):
patch_norm=True,
initializer_range=0.02,
layer_norm_eps=1e-5,
layer_scale_init_value=0.0,
**kwargs
):
super().__init__(**kwargs)
......@@ -130,3 +133,4 @@ class NatConfig(PretrainedConfig):
# we set the hidden_size attribute in order to make Nat work with VisionEncoderDecoderModel
# this indicates the channel dimension after the last stage of the model
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
self.layer_scale_init_value = layer_scale_init_value
......@@ -445,6 +445,11 @@ class NatLayer(nn.Module):
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.intermediate = NatIntermediate(config, dim)
self.output = NatOutput(config, dim)
self.layer_scale_parameters = (
nn.Parameter(config.layer_scale_init_value * torch.ones((2, dim)), requires_grad=True)
if config.layer_scale_init_value > 0
else None
)
def maybe_pad(self, hidden_states, height, width):
window_size = self.kernel_size
......@@ -479,11 +484,18 @@ class NatLayer(nn.Module):
if was_padded:
attention_output = attention_output[:, :height, :width, :].contiguous()
if self.layer_scale_parameters is not None:
attention_output = self.layer_scale_parameters[0] * attention_output
hidden_states = shortcut + self.drop_path(attention_output)
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = hidden_states + self.output(layer_output)
layer_output = self.output(self.intermediate(layer_output))
if self.layer_scale_parameters is not None:
layer_output = self.layer_scale_parameters[1] * layer_output
layer_output = hidden_states + self.drop_path(layer_output)
layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
return layer_outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment