• JB (Don)'s avatar
    [`BERT`] Add support for sdpa (#28802) · dfa7b580
    JB (Don) authored
    * Adding SDPA support for BERT
    
    * Using the proper input name for testing model input in inference()
    
    * Adding documentation for SDPA in BERT model page
    
    * Use the stable link for the documentation
    
    * Adding a gate to only call .contiguous() for torch < 2.2.0
    
    * Additions and fixes to the documentation
    
    * Minor updates to documentation
    
    * Adding extra requirements needed for the contiguous() bug
    
    * Adding "Adapted from" in plcae of the "Copied from"
    
    * Add benchmark speedup tables to the documentation
    
    * Minor fixes to the documentation
    
    * Use ClapText as a replacemenet for Bert in the Copied-From
    
    * Some more fixes for the fix-copies references
    
    * Overriding the test_eager_matches_sdpa_generate in bert tests to not load with low_cpu_mem_usage
    
    [test all]
    
    * Undo changes to separate test
    
    * Refactored SDPA self attention code for KV projections
    
    * Change use_sdpa to attn_implementation
    
    * Fix test_sdpa_can_dispatch_on_flash by preparing input (required for MultipleChoice models)
    dfa7b580
test_modeling_bert.py 29.7 KB