Unverified Commit 0da9683c authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

use legacy pytorch fuser

parent a6ed694a
...@@ -383,19 +383,8 @@ def fp32_to_bf16_sr(t, o): ...@@ -383,19 +383,8 @@ def fp32_to_bf16_sr(t, o):
def set_jit_fusion_options(): def set_jit_fusion_options():
"""Set PyTorch JIT layer fusion options.""" """Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels # flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split('.')[0]) # legacy pytorch fuser
TORCH_MINOR = int(torch.__version__.split('.')[1]) torch._C._jit_set_profiling_mode(False)
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10): torch._C._jit_set_profiling_executor(False)
# nvfuser torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_set_profiling_executor(True) torch._C._jit_override_can_fuse_on_gpu(True)
torch._C._jit_set_profiling_mode(True)
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._debug_set_autodiff_subgraph_inlining(False)
else:
# legacy pytorch fuser
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment