Unverified Commit 0c67a72c authored by Alexander Krauck's avatar Alexander Krauck Committed by GitHub
Browse files

Fix Dropout Implementation in Graphormer (#24817)

This commit corrects the dropout implementation in Graphormer, aligning it with the original implementation and improving performance. Specifically:

1. The `attention_dropout` variable, intended for use in GraphormerMultiheadAttention, was defined but not used. This has been corrected to use `attention_dropout` instead of the regular `dropout`.
2. The `activation_dropout` for the activations in the feed-forward layers was missing. Instead, the regular `dropout` was used. This commit adds `activation_dropout` to the feed-forward layers.

These changes ensure the dropout implementation matches the original Graphormer and delivers empirically better performance.
parent fb7d2469
...@@ -79,6 +79,8 @@ class GraphormerConfig(PretrainedConfig): ...@@ -79,6 +79,8 @@ class GraphormerConfig(PretrainedConfig):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.1): attention_dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for the attention weights. The dropout probability for the attention weights.
activation_dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for the activation of the linear transformer layer.
layerdrop (`float`, *optional*, defaults to 0.0): layerdrop (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details. for more details.
...@@ -150,6 +152,7 @@ class GraphormerConfig(PretrainedConfig): ...@@ -150,6 +152,7 @@ class GraphormerConfig(PretrainedConfig):
num_attention_heads: int = 32, num_attention_heads: int = 32,
dropout: float = 0.1, dropout: float = 0.1,
attention_dropout: float = 0.1, attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
layerdrop: float = 0.0, layerdrop: float = 0.0,
encoder_normalize_before: bool = False, encoder_normalize_before: bool = False,
pre_layernorm: bool = False, pre_layernorm: bool = False,
...@@ -188,6 +191,7 @@ class GraphormerConfig(PretrainedConfig): ...@@ -188,6 +191,7 @@ class GraphormerConfig(PretrainedConfig):
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.dropout = dropout self.dropout = dropout
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.layerdrop = layerdrop self.layerdrop = layerdrop
self.encoder_normalize_before = encoder_normalize_before self.encoder_normalize_before = encoder_normalize_before
self.pre_layernorm = pre_layernorm self.pre_layernorm = pre_layernorm
......
...@@ -311,7 +311,7 @@ class GraphormerMultiheadAttention(nn.Module): ...@@ -311,7 +311,7 @@ class GraphormerMultiheadAttention(nn.Module):
self.qkv_same_dim = self.kdim == config.embedding_dim and self.vdim == config.embedding_dim self.qkv_same_dim = self.kdim == config.embedding_dim and self.vdim == config.embedding_dim
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.dropout_module = torch.nn.Dropout(p=config.dropout, inplace=False) self.attention_dropout_module = torch.nn.Dropout(p=config.attention_dropout, inplace=False)
self.head_dim = config.embedding_dim // config.num_attention_heads self.head_dim = config.embedding_dim // config.num_attention_heads
if not (self.head_dim * config.num_attention_heads == self.embedding_dim): if not (self.head_dim * config.num_attention_heads == self.embedding_dim):
...@@ -463,7 +463,7 @@ class GraphormerMultiheadAttention(nn.Module): ...@@ -463,7 +463,7 @@ class GraphormerMultiheadAttention(nn.Module):
attn_weights_float = torch.nn.functional.softmax(attn_weights, dim=-1) attn_weights_float = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights_float.type_as(attn_weights) attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = self.dropout_module(attn_weights) attn_probs = self.attention_dropout_module(attn_weights)
if v is None: if v is None:
raise AssertionError("No value generated") raise AssertionError("No value generated")
...@@ -494,14 +494,13 @@ class GraphormerGraphEncoderLayer(nn.Module): ...@@ -494,14 +494,13 @@ class GraphormerGraphEncoderLayer(nn.Module):
# Initialize parameters # Initialize parameters
self.embedding_dim = config.embedding_dim self.embedding_dim = config.embedding_dim
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.attention_dropout = config.attention_dropout
self.q_noise = config.q_noise self.q_noise = config.q_noise
self.qn_block_size = config.qn_block_size self.qn_block_size = config.qn_block_size
self.pre_layernorm = config.pre_layernorm self.pre_layernorm = config.pre_layernorm
self.dropout_module = torch.nn.Dropout(p=config.dropout, inplace=False) self.dropout_module = torch.nn.Dropout(p=config.dropout, inplace=False)
self.activation_dropout_module = torch.nn.Dropout(p=config.dropout, inplace=False) self.activation_dropout_module = torch.nn.Dropout(p=config.activation_dropout, inplace=False)
# Initialize blocks # Initialize blocks
self.activation_fn = ACT2FN[config.activation_fn] self.activation_fn = ACT2FN[config.activation_fn]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment