"tests/test_zero/test_init_context.py" did not exist on "21dc54e019d4636a5024e8e41e2a69567cac37dc"
Unverified Commit e859380b authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[fx] support module with bias addition (#1780)

* [autoparallel] refactor tracer to fix bias addition issue

* [fx] support module with bias addition

* create bias_addition_module

* refactor file structure

* polish code

* fix unit test
parent f3f19a5c
from colossalai.fx import ColoTracer
import torch
from torch.fx import GraphModule, Tracer
from colossalai.fx import ColoTracer
def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwargs_transform=False):
data = data_gen()
......@@ -24,8 +25,9 @@ def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwa
fx_out = gm(**data)
if isinstance(fx_out, tuple):
for non_fx, fx in zip(non_fx_out, fx_out):
assert torch.allclose(non_fx,
fx), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
assert torch.allclose(
non_fx, fx, atol=1e-5), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
else:
assert torch.allclose(
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
fx_out, non_fx_out,
atol=1e-5), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment