• Ibraheem Moosa's avatar
    Optimize Token Classification models for TPU (#13096) · eae7a96b
    Ibraheem Moosa authored
    * Optimize Token Classification models for TPU
    
    As per the XLA document XLA cannot handle masked indexing well. So token classification
    models for BERT and others use an implementation based on `torch.where`. This implementation
    works well on TPU. 
    
    ALBERT token classification model uses the masked indexing which causes performance issues
    on TPU. This PR fixes this issue by following the BERT implementation.
    
    * Same fix for ELECTRA
    
    * Same fix for LayoutLM
    eae7a96b
modeling_electra.py 60.9 KB