Unverified Commit bc44e947 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Update `Graphormer` and fix its `torchscript` test failures (#21380)



* fix
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 19d67bfe
......@@ -798,9 +798,11 @@ class GraphormerModel(GraphormerPreTrainedModel):
attn_edge_type,
perturb=None,
masked_tokens=None,
return_dict: Optional[bool] = True,
return_dict: Optional[bool] = None,
**unused
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
inner_states, graph_rep = self.graph_encoder(
input_nodes, input_edges, attn_bias, in_degree, out_degree, spatial_pos, attn_edge_type, perturb=perturb
)
......@@ -819,7 +821,7 @@ class GraphormerModel(GraphormerPreTrainedModel):
input_nodes = torch.nn.functional.linear(input_nodes, self.graph_encoder.embed_tokens.weight)
if not return_dict:
return (input_nodes, inner_states)
return tuple(x for x in [input_nodes, inner_states] if x is not None)
return BaseModelOutputWithNoAttention(last_hidden_state=input_nodes, hidden_states=inner_states)
def max_nodes(self):
......@@ -860,9 +862,11 @@ class GraphormerForGraphClassification(GraphormerPreTrainedModel):
spatial_pos,
attn_edge_type,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = True,
return_dict: Optional[bool] = None,
**unused,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_outputs = self.encoder(
input_nodes,
input_edges,
......@@ -871,12 +875,14 @@ class GraphormerForGraphClassification(GraphormerPreTrainedModel):
out_degree,
spatial_pos,
attn_edge_type,
return_dict=True,
)
outputs, hidden_states = encoder_outputs["last_hidden_state"], encoder_outputs["hidden_states"]
head_outputs = self.classifier(outputs)
logits = head_outputs[:, 0, :].contiguous()
loss = None
if labels is not None:
mask = ~torch.isnan(labels)
......@@ -891,5 +897,5 @@ class GraphormerForGraphClassification(GraphormerPreTrainedModel):
loss = loss_fct(logits[mask], labels[mask])
if not return_dict:
return (loss, logits, hidden_states)
return tuple(x for x in [loss, logits, hidden_states] if x is not None)
return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=hidden_states, attentions=None)
......@@ -17,13 +17,15 @@
import copy
import inspect
import os
import tempfile
import unittest
from transformers import GraphormerConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor
if is_torch_available():
......@@ -255,6 +257,92 @@ class GraphormerModelTest(ModelTesterMixin, unittest.TestCase):
self.model_tester = GraphormerModelTester(self)
self.config_tester = ConfigTester(self, config_class=GraphormerConfig, has_text_modality=False)
# overwrite from common as `Graphormer` requires more input arguments
def _create_and_check_torchscript(self, config, inputs_dict):
if not self.test_torchscript:
return
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.torchscript = True
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
try:
required_keys = (
"input_nodes",
"input_edges",
"attn_bias",
"in_degree",
"out_degree",
"spatial_pos",
"attn_edge_type",
)
required_inputs = tuple(inputs[k] for k in required_keys)
model(*required_inputs)
traced_model = torch.jit.trace(model, required_inputs)
except RuntimeError:
self.fail("Couldn't trace module.")
with tempfile.TemporaryDirectory() as tmp_dir_name:
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
try:
torch.jit.save(traced_model, pt_file_name)
except Exception:
self.fail("Couldn't save module.")
try:
loaded_model = torch.jit.load(pt_file_name)
except Exception:
self.fail("Couldn't load module.")
model.to(torch_device)
model.eval()
loaded_model.to(torch_device)
loaded_model.eval()
model_state_dict = model.state_dict()
loaded_model_state_dict = loaded_model.state_dict()
non_persistent_buffers = {}
for key in loaded_model_state_dict.keys():
if key not in model_state_dict.keys():
non_persistent_buffers[key] = loaded_model_state_dict[key]
loaded_model_state_dict = {
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
}
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
model_buffers = list(model.buffers())
for non_persistent_buffer in non_persistent_buffers.values():
found_buffer = False
for i, model_buffer in enumerate(model_buffers):
if torch.equal(non_persistent_buffer, model_buffer):
found_buffer = True
break
self.assertTrue(found_buffer)
model_buffers.pop(i)
models_equal = True
for layer_name, p1 in model_state_dict.items():
if layer_name in loaded_model_state_dict:
p2 = loaded_model_state_dict[layer_name]
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
self.assertTrue(models_equal)
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
# (Even with this call, there are still memory leak by ~0.04MB)
self.clear_torch_jit_class_registry()
def test_config(self):
self.config_tester.run_common_tests()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment