Unverified Commit c2393cad authored by dewa's avatar dewa Committed by GitHub
Browse files

Added type hints for `Graphormer` pytorch version (#23073)

* Added type hints for `Graphormer` pytorch version

added type hints for graphormers pytorch , checked formating issues .

* made the code less bloated
parent ee3be053
......@@ -14,16 +14,18 @@
# limitations under the License.
""" PyTorch Graphormer model."""
import math
from typing import Optional, Tuple, Union
from typing import Iterable, Iterator, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithNoAttention, SequenceClassifierOutput
from ...modeling_outputs import (
BaseModelOutputWithNoAttention,
SequenceClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_graphormer import GraphormerConfig
......@@ -42,7 +44,7 @@ GRAPHORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
def quant_noise(module, p, block_size):
def quant_noise(module: nn.Module, p: float, block_size: int):
"""
From:
https://github.com/facebookresearch/fairseq/blob/dd0079bde7f678b0cd0715cbd0ae68d661b7226d/fairseq/modules/quant_noise.py
......@@ -161,11 +163,11 @@ class LayerDropModuleList(nn.ModuleList):
modules (iterable, optional): an iterable of modules to add
"""
def __init__(self, p, modules=None):
def __init__(self, p: float, modules: Optional[Iterable[nn.Module]] = None):
super().__init__(modules)
self.p = p
def __iter__(self):
def __iter__(self) -> Iterator[nn.Module]:
dropout_probs = torch.empty(len(self)).uniform_()
for i, m in enumerate(super().__iter__()):
if not self.training or (dropout_probs[i] > self.p):
......@@ -177,7 +179,7 @@ class GraphormerGraphNodeFeature(nn.Module):
Compute node features for each node in the graph.
"""
def __init__(self, config):
def __init__(self, config: GraphormerConfig):
super().__init__()
self.num_heads = config.num_attention_heads
self.num_atoms = config.num_atoms
......@@ -192,7 +194,12 @@ class GraphormerGraphNodeFeature(nn.Module):
self.graph_token = nn.Embedding(1, config.hidden_size)
def forward(self, input_nodes, in_degree, out_degree):
def forward(
self,
input_nodes: torch.LongTensor,
in_degree: torch.LongTensor,
out_degree: torch.LongTensor,
) -> torch.Tensor:
n_graph, n_node = input_nodes.size()[:2]
node_feature = ( # node feature + graph token
......@@ -213,7 +220,7 @@ class GraphormerGraphAttnBias(nn.Module):
Compute attention bias for each head.
"""
def __init__(self, config):
def __init__(self, config: GraphormerConfig):
super().__init__()
self.num_heads = config.num_attention_heads
self.multi_hop_max_dist = config.multi_hop_max_dist
......@@ -233,7 +240,14 @@ class GraphormerGraphAttnBias(nn.Module):
self.graph_token_virtual_distance = nn.Embedding(1, config.num_attention_heads)
def forward(self, input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type):
def forward(
self,
input_nodes: torch.LongTensor,
attn_bias: torch.Tensor,
spatial_pos: torch.LongTensor,
input_edges: torch.LongTensor,
attn_edge_type: torch.LongTensor,
) -> torch.Tensor:
n_graph, n_node = input_nodes.size()[:2]
graph_attn_bias = attn_bias.clone()
graph_attn_bias = graph_attn_bias.unsqueeze(1).repeat(
......@@ -289,7 +303,7 @@ class GraphormerMultiheadAttention(nn.Module):
See "Attention Is All You Need" for more details.
"""
def __init__(self, config):
def __init__(self, config: GraphormerConfig):
super().__init__()
self.embedding_dim = config.embedding_dim
self.kdim = config.kdim if config.kdim is not None else config.embedding_dim
......@@ -352,7 +366,7 @@ class GraphormerMultiheadAttention(nn.Module):
def forward(
self,
query,
query: torch.LongTensor,
key: Optional[torch.Tensor],
value: Optional[torch.Tensor],
attn_bias: Optional[torch.Tensor],
......@@ -458,7 +472,7 @@ class GraphormerMultiheadAttention(nn.Module):
raise AssertionError("The attention generated do not match the expected dimensions.")
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embedding_dim)
attn = self.out_proj(attn)
attn: torch.Tensor = self.out_proj(attn)
attn_weights = None
if need_weights:
......@@ -469,12 +483,12 @@ class GraphormerMultiheadAttention(nn.Module):
return attn, attn_weights
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
def apply_sparse_mask(self, attn_weights: torch.Tensor, tgt_len: int, src_len: int, bsz: int) -> torch.Tensor:
return attn_weights
class GraphormerGraphEncoderLayer(nn.Module):
def __init__(self, config) -> None:
def __init__(self, config: GraphormerConfig) -> None:
super().__init__()
# Initialize parameters
......@@ -512,7 +526,9 @@ class GraphormerGraphEncoderLayer(nn.Module):
# layer norm associated with the position wise feed-forward NN
self.final_layer_norm = nn.LayerNorm(self.embedding_dim)
def build_fc(self, input_dim, output_dim, q_noise, qn_block_size):
def build_fc(
self, input_dim: int, output_dim: int, q_noise: float, qn_block_size: int
) -> Union[nn.Module, nn.Linear, nn.Embedding, nn.Conv2d]:
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
def forward(
......@@ -521,7 +537,7 @@ class GraphormerGraphEncoderLayer(nn.Module):
self_attn_bias: Optional[torch.Tensor] = None,
self_attn_mask: Optional[torch.Tensor] = None,
self_attn_padding_mask: Optional[torch.Tensor] = None,
):
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
nn.LayerNorm is applied either before or after the self-attention/ffn modules similar to the original
Transformer implementation.
......@@ -559,7 +575,7 @@ class GraphormerGraphEncoderLayer(nn.Module):
class GraphormerGraphEncoder(nn.Module):
def __init__(self, config):
def __init__(self, config: GraphormerConfig):
super().__init__()
self.dropout_module = torch.nn.Dropout(p=config.dropout, inplace=False)
......@@ -608,18 +624,18 @@ class GraphormerGraphEncoder(nn.Module):
def forward(
self,
input_nodes,
input_edges,
attn_bias,
in_degree,
out_degree,
spatial_pos,
attn_edge_type,
input_nodes: torch.LongTensor,
input_edges: torch.LongTensor,
attn_bias: torch.Tensor,
in_degree: torch.LongTensor,
out_degree: torch.LongTensor,
spatial_pos: torch.LongTensor,
attn_edge_type: torch.LongTensor,
perturb=None,
last_state_only: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.torch.Tensor, torch.Tensor]:
) -> Tuple[Union[torch.Tensor, List[torch.LongTensor]], torch.Tensor]:
# compute padding mask. This is needed for multi-head attention
data_x = input_nodes
n_graph, n_node = data_x.size()[:2]
......@@ -676,14 +692,14 @@ class GraphormerGraphEncoder(nn.Module):
class GraphormerDecoderHead(nn.Module):
def __init__(self, embedding_dim, num_classes):
def __init__(self, embedding_dim: int, num_classes: int):
super().__init__()
"""num_classes should be 1 for regression, or the number of classes for classification"""
self.lm_output_learned_bias = nn.Parameter(torch.zeros(1))
self.classifier = nn.Linear(embedding_dim, num_classes, bias=False)
self.num_classes = num_classes
def forward(self, input_nodes, **unused):
def forward(self, input_nodes: torch.Tensor, **unused) -> torch.Tensor:
input_nodes = self.classifier(input_nodes)
input_nodes = input_nodes + self.lm_output_learned_bias
return input_nodes
......@@ -702,12 +718,12 @@ class GraphormerPreTrainedModel(PreTrainedModel):
main_input_name_nodes = "input_nodes"
main_input_name_edges = "input_edges"
def normal_(self, data):
def normal_(self, data: torch.Tensor):
# with FSDP, module params will be on CUDA, so we cast them back to CPU
# so that the RNG is consistent with and without FSDP
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
def init_graphormer_params(self, module):
def init_graphormer_params(self, module: Union[nn.Linear, nn.Embedding, GraphormerMultiheadAttention]):
"""
Initialize the weights specific to the Graphormer Model.
"""
......@@ -724,7 +740,12 @@ class GraphormerPreTrainedModel(PreTrainedModel):
self.normal_(module.k_proj.weight.data)
self.normal_(module.v_proj.weight.data)
def _init_weights(self, module):
def _init_weights(
self,
module: Union[
nn.Linear, nn.Conv2d, nn.Embedding, nn.LayerNorm, GraphormerMultiheadAttention, GraphormerGraphEncoder
],
):
"""
Initialize the weights
"""
......@@ -766,7 +787,7 @@ class GraphormerModel(GraphormerPreTrainedModel):
this model with a downstream model of your choice, following the example in GraphormerForGraphClassification.
"""
def __init__(self, config):
def __init__(self, config: GraphormerConfig):
super().__init__(config)
self.max_nodes = config.max_nodes
......@@ -789,18 +810,18 @@ class GraphormerModel(GraphormerPreTrainedModel):
def forward(
self,
input_nodes,
input_edges,
attn_bias,
in_degree,
out_degree,
spatial_pos,
attn_edge_type,
input_nodes: torch.LongTensor,
input_edges: torch.LongTensor,
attn_bias: torch.Tensor,
in_degree: torch.LongTensor,
out_degree: torch.LongTensor,
spatial_pos: torch.LongTensor,
attn_edge_type: torch.LongTensor,
perturb=None,
masked_tokens=None,
return_dict: Optional[bool] = None,
**unused,
):
) -> Union[Tuple[torch.LongTensor], BaseModelOutputWithNoAttention]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
inner_states, graph_rep = self.graph_encoder(
......@@ -841,7 +862,7 @@ class GraphormerForGraphClassification(GraphormerPreTrainedModel):
of integer labels for each graph.
"""
def __init__(self, config):
def __init__(self, config: GraphormerConfig):
super().__init__(config)
self.encoder = GraphormerModel(config)
self.embedding_dim = config.embedding_dim
......@@ -854,13 +875,13 @@ class GraphormerForGraphClassification(GraphormerPreTrainedModel):
def forward(
self,
input_nodes,
input_edges,
attn_bias,
in_degree,
out_degree,
spatial_pos,
attn_edge_type,
input_nodes: torch.LongTensor,
input_edges: torch.LongTensor,
attn_bias: torch.Tensor,
in_degree: torch.LongTensor,
out_degree: torch.LongTensor,
spatial_pos: torch.LongTensor,
attn_edge_type: torch.LongTensor,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
**unused,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment