Unverified Commit fb3e78ab authored by Bhuminjay Soni's avatar Bhuminjay Soni Committed by GitHub
Browse files

[Feature][CI]: compare `func` & `no_func` outputs in test_functionalization.py (#35481)


Signed-off-by: default avatarBhuminjay <bhuminjaysoni@gmail.com>
Signed-off-by: default avatarBhuminjay Soni <Soni5Happy@gmail.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent fd3bfe74
...@@ -309,12 +309,15 @@ def test_fix_functionalization( ...@@ -309,12 +309,15 @@ def test_fix_functionalization(
model = model_class() model = model_class()
inputs_func = model.example_inputs() inputs_func = model.example_inputs()
inputs_no_func = copy.deepcopy(inputs_func) inputs_no_func = copy.deepcopy(inputs_func)
model_func = model_class() model_func = copy.deepcopy(model)
model_no_func = copy.deepcopy(model_func) model_no_func = copy.deepcopy(model)
model_func = torch.compile(model_func, backend=backend_func) model_func = torch.compile(model_func, backend=backend_func)
model_no_func = torch.compile(model_no_func, backend=backend_no_func) model_no_func = torch.compile(model_no_func, backend=backend_no_func)
model_func(*inputs_func)
model_no_func(*inputs_no_func) # deepcopy inputs to prevent potential in place mutation
outputs_func = model_func(*copy.deepcopy(inputs_func))
outputs_no_func = model_no_func(*copy.deepcopy(inputs_no_func))
torch.testing.assert_close(outputs_func, outputs_no_func)
# check if the functionalization pass is applied # check if the functionalization pass is applied
for op in model.ops_in_model(do_fusion): for op in model.ops_in_model(do_fusion):
...@@ -332,8 +335,3 @@ def test_fix_functionalization( ...@@ -332,8 +335,3 @@ def test_fix_functionalization(
found[op] = True found[op] = True
assert all(found[op] for op in model.ops_in_model(do_fusion)) assert all(found[op] for op in model.ops_in_model(do_fusion))
assert all(not found.get(op) for op in model.ops_not_in_model()) assert all(not found.get(op) for op in model.ops_not_in_model())
# TODO (Rohan138): compare the outputs from model_func and model_no_func
# currently runs into errors while comparing `TestFusedAddRMSNorm`
# Linked issue: https://github.com/vllm-project/vllm/issues/34996
# torch.testing.assert_close(outputs_func, outputs_no_func)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment