Unverified Commit 71f56c77 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Model tests xformers fixes (#5679)

* fix model xformers test

* update
parent 6a89a6c9
......@@ -293,7 +293,16 @@ class ModelTesterMixin:
with torch.no_grad():
output_2 = model(**inputs_dict)[0]
model.set_attn_processor(XFormersAttnProcessor())
assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
with torch.no_grad():
output_3 = model(**inputs_dict)[0]
torch.use_deterministic_algorithms(True)
assert torch.allclose(output, output_2, atol=self.base_precision)
assert torch.allclose(output, output_3, atol=self.base_precision)
assert torch.allclose(output_2, output_3, atol=self.base_precision)
@require_torch_gpu
def test_set_attn_processor_for_determinism(self):
......@@ -315,11 +324,6 @@ class ModelTesterMixin:
with torch.no_grad():
output_2 = model(**inputs_dict)[0]
model.enable_xformers_memory_efficient_attention()
assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
with torch.no_grad():
model(**inputs_dict)[0]
model.set_attn_processor(AttnProcessor2_0())
assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
with torch.no_grad():
......@@ -330,18 +334,12 @@ class ModelTesterMixin:
with torch.no_grad():
output_5 = model(**inputs_dict)[0]
model.set_attn_processor(XFormersAttnProcessor())
assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
with torch.no_grad():
output_6 = model(**inputs_dict)[0]
torch.use_deterministic_algorithms(True)
# make sure that outputs match
assert torch.allclose(output_2, output_1, atol=self.base_precision)
assert torch.allclose(output_2, output_4, atol=self.base_precision)
assert torch.allclose(output_2, output_5, atol=self.base_precision)
assert torch.allclose(output_2, output_6, atol=self.base_precision)
def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment