"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e93763d420910003656abe2f951ce5bd91ac2f08"
Unverified Commit c2393cad authored by dewa's avatar dewa Committed by GitHub
Browse files

Added type hints for `Graphormer` pytorch version (#23073)

* Added type hints for `Graphormer` pytorch version

added type hints for graphormers pytorch , checked formating issues .

* made the code less bloated
parent ee3be053
...@@ -14,16 +14,18 @@ ...@@ -14,16 +14,18 @@
# limitations under the License. # limitations under the License.
""" PyTorch Graphormer model.""" """ PyTorch Graphormer model."""
import math import math
from typing import Optional, Tuple, Union from typing import Iterable, Iterator, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithNoAttention, SequenceClassifierOutput from ...modeling_outputs import (
BaseModelOutputWithNoAttention,
SequenceClassifierOutput,
)
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import logging from ...utils import logging
from .configuration_graphormer import GraphormerConfig from .configuration_graphormer import GraphormerConfig
...@@ -42,7 +44,7 @@ GRAPHORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -42,7 +44,7 @@ GRAPHORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
def quant_noise(module, p, block_size): def quant_noise(module: nn.Module, p: float, block_size: int):
""" """
From: From:
https://github.com/facebookresearch/fairseq/blob/dd0079bde7f678b0cd0715cbd0ae68d661b7226d/fairseq/modules/quant_noise.py https://github.com/facebookresearch/fairseq/blob/dd0079bde7f678b0cd0715cbd0ae68d661b7226d/fairseq/modules/quant_noise.py
...@@ -161,11 +163,11 @@ class LayerDropModuleList(nn.ModuleList): ...@@ -161,11 +163,11 @@ class LayerDropModuleList(nn.ModuleList):
modules (iterable, optional): an iterable of modules to add modules (iterable, optional): an iterable of modules to add
""" """
def __init__(self, p, modules=None): def __init__(self, p: float, modules: Optional[Iterable[nn.Module]] = None):
super().__init__(modules) super().__init__(modules)
self.p = p self.p = p
def __iter__(self): def __iter__(self) -> Iterator[nn.Module]:
dropout_probs = torch.empty(len(self)).uniform_() dropout_probs = torch.empty(len(self)).uniform_()
for i, m in enumerate(super().__iter__()): for i, m in enumerate(super().__iter__()):
if not self.training or (dropout_probs[i] > self.p): if not self.training or (dropout_probs[i] > self.p):
...@@ -177,7 +179,7 @@ class GraphormerGraphNodeFeature(nn.Module): ...@@ -177,7 +179,7 @@ class GraphormerGraphNodeFeature(nn.Module):
Compute node features for each node in the graph. Compute node features for each node in the graph.
""" """
def __init__(self, config): def __init__(self, config: GraphormerConfig):
super().__init__() super().__init__()
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.num_atoms = config.num_atoms self.num_atoms = config.num_atoms
...@@ -192,7 +194,12 @@ class GraphormerGraphNodeFeature(nn.Module): ...@@ -192,7 +194,12 @@ class GraphormerGraphNodeFeature(nn.Module):
self.graph_token = nn.Embedding(1, config.hidden_size) self.graph_token = nn.Embedding(1, config.hidden_size)
def forward(self, input_nodes, in_degree, out_degree): def forward(
self,
input_nodes: torch.LongTensor,
in_degree: torch.LongTensor,
out_degree: torch.LongTensor,
) -> torch.Tensor:
n_graph, n_node = input_nodes.size()[:2] n_graph, n_node = input_nodes.size()[:2]
node_feature = ( # node feature + graph token node_feature = ( # node feature + graph token
...@@ -213,7 +220,7 @@ class GraphormerGraphAttnBias(nn.Module): ...@@ -213,7 +220,7 @@ class GraphormerGraphAttnBias(nn.Module):
Compute attention bias for each head. Compute attention bias for each head.
""" """
def __init__(self, config): def __init__(self, config: GraphormerConfig):
super().__init__() super().__init__()
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.multi_hop_max_dist = config.multi_hop_max_dist self.multi_hop_max_dist = config.multi_hop_max_dist
...@@ -233,7 +240,14 @@ class GraphormerGraphAttnBias(nn.Module): ...@@ -233,7 +240,14 @@ class GraphormerGraphAttnBias(nn.Module):
self.graph_token_virtual_distance = nn.Embedding(1, config.num_attention_heads) self.graph_token_virtual_distance = nn.Embedding(1, config.num_attention_heads)
def forward(self, input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type): def forward(
self,
input_nodes: torch.LongTensor,
attn_bias: torch.Tensor,
spatial_pos: torch.LongTensor,
input_edges: torch.LongTensor,
attn_edge_type: torch.LongTensor,
) -> torch.Tensor:
n_graph, n_node = input_nodes.size()[:2] n_graph, n_node = input_nodes.size()[:2]
graph_attn_bias = attn_bias.clone() graph_attn_bias = attn_bias.clone()
graph_attn_bias = graph_attn_bias.unsqueeze(1).repeat( graph_attn_bias = graph_attn_bias.unsqueeze(1).repeat(
...@@ -289,7 +303,7 @@ class GraphormerMultiheadAttention(nn.Module): ...@@ -289,7 +303,7 @@ class GraphormerMultiheadAttention(nn.Module):
See "Attention Is All You Need" for more details. See "Attention Is All You Need" for more details.
""" """
def __init__(self, config): def __init__(self, config: GraphormerConfig):
super().__init__() super().__init__()
self.embedding_dim = config.embedding_dim self.embedding_dim = config.embedding_dim
self.kdim = config.kdim if config.kdim is not None else config.embedding_dim self.kdim = config.kdim if config.kdim is not None else config.embedding_dim
...@@ -352,7 +366,7 @@ class GraphormerMultiheadAttention(nn.Module): ...@@ -352,7 +366,7 @@ class GraphormerMultiheadAttention(nn.Module):
def forward( def forward(
self, self,
query, query: torch.LongTensor,
key: Optional[torch.Tensor], key: Optional[torch.Tensor],
value: Optional[torch.Tensor], value: Optional[torch.Tensor],
attn_bias: Optional[torch.Tensor], attn_bias: Optional[torch.Tensor],
...@@ -458,7 +472,7 @@ class GraphormerMultiheadAttention(nn.Module): ...@@ -458,7 +472,7 @@ class GraphormerMultiheadAttention(nn.Module):
raise AssertionError("The attention generated do not match the expected dimensions.") raise AssertionError("The attention generated do not match the expected dimensions.")
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embedding_dim) attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embedding_dim)
attn = self.out_proj(attn) attn: torch.Tensor = self.out_proj(attn)
attn_weights = None attn_weights = None
if need_weights: if need_weights:
...@@ -469,12 +483,12 @@ class GraphormerMultiheadAttention(nn.Module): ...@@ -469,12 +483,12 @@ class GraphormerMultiheadAttention(nn.Module):
return attn, attn_weights return attn, attn_weights
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): def apply_sparse_mask(self, attn_weights: torch.Tensor, tgt_len: int, src_len: int, bsz: int) -> torch.Tensor:
return attn_weights return attn_weights
class GraphormerGraphEncoderLayer(nn.Module): class GraphormerGraphEncoderLayer(nn.Module):
def __init__(self, config) -> None: def __init__(self, config: GraphormerConfig) -> None:
super().__init__() super().__init__()
# Initialize parameters # Initialize parameters
...@@ -512,7 +526,9 @@ class GraphormerGraphEncoderLayer(nn.Module): ...@@ -512,7 +526,9 @@ class GraphormerGraphEncoderLayer(nn.Module):
# layer norm associated with the position wise feed-forward NN # layer norm associated with the position wise feed-forward NN
self.final_layer_norm = nn.LayerNorm(self.embedding_dim) self.final_layer_norm = nn.LayerNorm(self.embedding_dim)
def build_fc(self, input_dim, output_dim, q_noise, qn_block_size): def build_fc(
self, input_dim: int, output_dim: int, q_noise: float, qn_block_size: int
) -> Union[nn.Module, nn.Linear, nn.Embedding, nn.Conv2d]:
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
def forward( def forward(
...@@ -521,7 +537,7 @@ class GraphormerGraphEncoderLayer(nn.Module): ...@@ -521,7 +537,7 @@ class GraphormerGraphEncoderLayer(nn.Module):
self_attn_bias: Optional[torch.Tensor] = None, self_attn_bias: Optional[torch.Tensor] = None,
self_attn_mask: Optional[torch.Tensor] = None, self_attn_mask: Optional[torch.Tensor] = None,
self_attn_padding_mask: Optional[torch.Tensor] = None, self_attn_padding_mask: Optional[torch.Tensor] = None,
): ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
""" """
nn.LayerNorm is applied either before or after the self-attention/ffn modules similar to the original nn.LayerNorm is applied either before or after the self-attention/ffn modules similar to the original
Transformer implementation. Transformer implementation.
...@@ -559,7 +575,7 @@ class GraphormerGraphEncoderLayer(nn.Module): ...@@ -559,7 +575,7 @@ class GraphormerGraphEncoderLayer(nn.Module):
class GraphormerGraphEncoder(nn.Module): class GraphormerGraphEncoder(nn.Module):
def __init__(self, config): def __init__(self, config: GraphormerConfig):
super().__init__() super().__init__()
self.dropout_module = torch.nn.Dropout(p=config.dropout, inplace=False) self.dropout_module = torch.nn.Dropout(p=config.dropout, inplace=False)
...@@ -608,18 +624,18 @@ class GraphormerGraphEncoder(nn.Module): ...@@ -608,18 +624,18 @@ class GraphormerGraphEncoder(nn.Module):
def forward( def forward(
self, self,
input_nodes, input_nodes: torch.LongTensor,
input_edges, input_edges: torch.LongTensor,
attn_bias, attn_bias: torch.Tensor,
in_degree, in_degree: torch.LongTensor,
out_degree, out_degree: torch.LongTensor,
spatial_pos, spatial_pos: torch.LongTensor,
attn_edge_type, attn_edge_type: torch.LongTensor,
perturb=None, perturb=None,
last_state_only: bool = False, last_state_only: bool = False,
token_embeddings: Optional[torch.Tensor] = None, token_embeddings: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.torch.Tensor, torch.Tensor]: ) -> Tuple[Union[torch.Tensor, List[torch.LongTensor]], torch.Tensor]:
# compute padding mask. This is needed for multi-head attention # compute padding mask. This is needed for multi-head attention
data_x = input_nodes data_x = input_nodes
n_graph, n_node = data_x.size()[:2] n_graph, n_node = data_x.size()[:2]
...@@ -676,14 +692,14 @@ class GraphormerGraphEncoder(nn.Module): ...@@ -676,14 +692,14 @@ class GraphormerGraphEncoder(nn.Module):
class GraphormerDecoderHead(nn.Module): class GraphormerDecoderHead(nn.Module):
def __init__(self, embedding_dim, num_classes): def __init__(self, embedding_dim: int, num_classes: int):
super().__init__() super().__init__()
"""num_classes should be 1 for regression, or the number of classes for classification""" """num_classes should be 1 for regression, or the number of classes for classification"""
self.lm_output_learned_bias = nn.Parameter(torch.zeros(1)) self.lm_output_learned_bias = nn.Parameter(torch.zeros(1))
self.classifier = nn.Linear(embedding_dim, num_classes, bias=False) self.classifier = nn.Linear(embedding_dim, num_classes, bias=False)
self.num_classes = num_classes self.num_classes = num_classes
def forward(self, input_nodes, **unused): def forward(self, input_nodes: torch.Tensor, **unused) -> torch.Tensor:
input_nodes = self.classifier(input_nodes) input_nodes = self.classifier(input_nodes)
input_nodes = input_nodes + self.lm_output_learned_bias input_nodes = input_nodes + self.lm_output_learned_bias
return input_nodes return input_nodes
...@@ -702,12 +718,12 @@ class GraphormerPreTrainedModel(PreTrainedModel): ...@@ -702,12 +718,12 @@ class GraphormerPreTrainedModel(PreTrainedModel):
main_input_name_nodes = "input_nodes" main_input_name_nodes = "input_nodes"
main_input_name_edges = "input_edges" main_input_name_edges = "input_edges"
def normal_(self, data): def normal_(self, data: torch.Tensor):
# with FSDP, module params will be on CUDA, so we cast them back to CPU # with FSDP, module params will be on CUDA, so we cast them back to CPU
# so that the RNG is consistent with and without FSDP # so that the RNG is consistent with and without FSDP
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
def init_graphormer_params(self, module): def init_graphormer_params(self, module: Union[nn.Linear, nn.Embedding, GraphormerMultiheadAttention]):
""" """
Initialize the weights specific to the Graphormer Model. Initialize the weights specific to the Graphormer Model.
""" """
...@@ -724,7 +740,12 @@ class GraphormerPreTrainedModel(PreTrainedModel): ...@@ -724,7 +740,12 @@ class GraphormerPreTrainedModel(PreTrainedModel):
self.normal_(module.k_proj.weight.data) self.normal_(module.k_proj.weight.data)
self.normal_(module.v_proj.weight.data) self.normal_(module.v_proj.weight.data)
def _init_weights(self, module): def _init_weights(
self,
module: Union[
nn.Linear, nn.Conv2d, nn.Embedding, nn.LayerNorm, GraphormerMultiheadAttention, GraphormerGraphEncoder
],
):
""" """
Initialize the weights Initialize the weights
""" """
...@@ -766,7 +787,7 @@ class GraphormerModel(GraphormerPreTrainedModel): ...@@ -766,7 +787,7 @@ class GraphormerModel(GraphormerPreTrainedModel):
this model with a downstream model of your choice, following the example in GraphormerForGraphClassification. this model with a downstream model of your choice, following the example in GraphormerForGraphClassification.
""" """
def __init__(self, config): def __init__(self, config: GraphormerConfig):
super().__init__(config) super().__init__(config)
self.max_nodes = config.max_nodes self.max_nodes = config.max_nodes
...@@ -789,18 +810,18 @@ class GraphormerModel(GraphormerPreTrainedModel): ...@@ -789,18 +810,18 @@ class GraphormerModel(GraphormerPreTrainedModel):
def forward( def forward(
self, self,
input_nodes, input_nodes: torch.LongTensor,
input_edges, input_edges: torch.LongTensor,
attn_bias, attn_bias: torch.Tensor,
in_degree, in_degree: torch.LongTensor,
out_degree, out_degree: torch.LongTensor,
spatial_pos, spatial_pos: torch.LongTensor,
attn_edge_type, attn_edge_type: torch.LongTensor,
perturb=None, perturb=None,
masked_tokens=None, masked_tokens=None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
**unused, **unused,
): ) -> Union[Tuple[torch.LongTensor], BaseModelOutputWithNoAttention]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
inner_states, graph_rep = self.graph_encoder( inner_states, graph_rep = self.graph_encoder(
...@@ -841,7 +862,7 @@ class GraphormerForGraphClassification(GraphormerPreTrainedModel): ...@@ -841,7 +862,7 @@ class GraphormerForGraphClassification(GraphormerPreTrainedModel):
of integer labels for each graph. of integer labels for each graph.
""" """
def __init__(self, config): def __init__(self, config: GraphormerConfig):
super().__init__(config) super().__init__(config)
self.encoder = GraphormerModel(config) self.encoder = GraphormerModel(config)
self.embedding_dim = config.embedding_dim self.embedding_dim = config.embedding_dim
...@@ -854,13 +875,13 @@ class GraphormerForGraphClassification(GraphormerPreTrainedModel): ...@@ -854,13 +875,13 @@ class GraphormerForGraphClassification(GraphormerPreTrainedModel):
def forward( def forward(
self, self,
input_nodes, input_nodes: torch.LongTensor,
input_edges, input_edges: torch.LongTensor,
attn_bias, attn_bias: torch.Tensor,
in_degree, in_degree: torch.LongTensor,
out_degree, out_degree: torch.LongTensor,
spatial_pos, spatial_pos: torch.LongTensor,
attn_edge_type, attn_edge_type: torch.LongTensor,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
**unused, **unused,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment