Unverified Commit d6bf0387 authored by Songqing Zhang's avatar Songqing Zhang Committed by GitHub
Browse files

[examples] fix Graphormer as key in state_dict has changed (#6806)


Co-authored-by: default avatarRhett Ying <85214957+Rhett-Ying@users.noreply.github.com>
parent 898af658
...@@ -24,7 +24,7 @@ How to run ...@@ -24,7 +24,7 @@ How to run
---------- ----------
```bash ```bash
accelerate launch --multi_gpu --mixed_precision=fp16 train.py accelerate launch --multi_gpu --mixed_precision=fp16 main.py
``` ```
> **_NOTE:_** The script will automatically download weights pre-trained on PCQM4Mv2. To reproduce the same result, set the total batch size to 64. > **_NOTE:_** The script will automatically download weights pre-trained on PCQM4Mv2. To reproduce the same result, set the total batch size to 64.
......
...@@ -47,7 +47,7 @@ class Graphormer(nn.Module): ...@@ -47,7 +47,7 @@ class Graphormer(nn.Module):
self.spatial_encoder = SpatialEncoder( self.spatial_encoder = SpatialEncoder(
max_dist=num_spatial, num_heads=num_attention_heads max_dist=num_spatial, num_heads=num_attention_heads
) )
self.graph_token_virtual_dist = nn.Embedding(1, num_attention_heads) self.graph_token_virtual_distance = nn.Embedding(1, num_attention_heads)
self.emb_layer_norm = nn.LayerNorm(self.embedding_dim) self.emb_layer_norm = nn.LayerNorm(self.embedding_dim)
...@@ -112,7 +112,9 @@ class Graphormer(nn.Module): ...@@ -112,7 +112,9 @@ class Graphormer(nn.Module):
attn_bias[:, 1:, 1:, :] = path_encoding + spatial_encoding attn_bias[:, 1:, 1:, :] = path_encoding + spatial_encoding
# spatial encoding of the virtual node # spatial encoding of the virtual node
t = self.graph_token_virtual_dist.weight.reshape(1, 1, self.num_heads) t = self.graph_token_virtual_distance.weight.reshape(
1, 1, self.num_heads
)
# Since the virtual node comes first, the spatial encodings between it # Since the virtual node comes first, the spatial encodings between it
# and other nodes will fill the 1st row and 1st column (omit num_graphs # and other nodes will fill the 1st row and 1st column (omit num_graphs
# and num_heads dimensions) of attn_bias matrix by broadcasting. # and num_heads dimensions) of attn_bias matrix by broadcasting.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment