• Eugen Hotaj's avatar
    Extend auto shard capabilities to work around torch.fx edge cases. (#817) · 7bdf50a3
    Eugen Hotaj authored
    auto_shard.py currently uses torch.fx to create a symbolic DAG of
    operations and linearizes that DAG into an nn.Sequential so it can later
    be used for model offloading. This works in most cases but runs into
    issues for certain eager mode features, such as dynamic conditionals,
    shape-dependent computation, etc.
    
    This PR extends auto_shard.py to first run a preprocessing step which wraps
    any nn.Module which cannot be traced through. It adds a test for dynamic
    conditionals and updates existing failing test code.
    
    There are some immediate extensions to this approach which are marked as
    TODO in the code.
    7bdf50a3
test_auto_shard.py 5.11 KB