-
Eugen Hotaj authored
auto_shard.py currently uses torch.fx to create a symbolic DAG of operations and linearizes that DAG into an nn.Sequential so it can later be used for model offloading. This works in most cases but runs into issues for certain eager mode features, such as dynamic conditionals, shape-dependent computation, etc. This PR extends auto_shard.py to first run a preprocessing step which wraps any nn.Module which cannot be traced through. It adds a test for dynamic conditionals and updates existing failing test code. There are some immediate extensions to this approach which are marked as TODO in the code.
7bdf50a3