• Neta Zmora's avatar
    ONNX export refactoring (#197) · 83911ddb
    Neta Zmora authored
    
    
    * ONNX export refactoring
    
    * Remove infer_ort (to enable more testing)
    * Add BF16 ORT tests for Q/DQ ops and GELU.
      * Use FP32 i/o instead of BF16 (because ORT doesn't support BF16 i/o) and add casts from FP32 to BF16 (this is only for subgraph inputs and outputs).
      * We'll need to add more BF16 testing.
    * GEMM:
      * Add cast after DQ to achieve better performance (matmul at sub-fp32 precisions).
      * Fold bias into Gemm operation (=> smaller graphs)
      * Wrap GEMM-GELU with FP32 (TE implements GELU in FP32)
    * Enable tests for cross attention (test_export_multihead_attention)
    * Reduce test thresholds for test_export_layernorm_mlp, test_export_layernorm_linear, test_export_layernorm
    Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>
    
    * Loosen MHA export validation thresholds for FP16
    Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>
    
    ---------
    Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>
    83911ddb
test_onnx_export.py 44.1 KB