Extend auto shard capabilities to work around torch.fx edge cases. (#817)
auto_shard.py currently uses torch.fx to create a symbolic DAG of operations and linearizes that DAG into an nn.Sequential so it can later be used for model offloading. This works in most cases but runs into issues for certain eager mode features, such as dynamic conditionals, shape-dependent computation, etc. This PR extends auto_shard.py to first run a preprocessing step which wraps any nn.Module which cannot be traced through. It adds a test for dynamic conditionals and updates existing failing test code. There are some immediate extensions to this approach which are marked as TODO in the code.
Showing
Please register or sign in to comment