Unverified Commit 0963b288 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Fix zombie process when querying TE install path (#121)



* Remove zombie process from querying TE install path
Co-authored-by: default avatarNaman Goyal <naman@fb.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix FA version checking
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix unused import error
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix lint warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarNaman Goyal <naman@fb.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d90bf212
...@@ -3,25 +3,23 @@ ...@@ -3,25 +3,23 @@
# See LICENSE for license information. # See LICENSE for license information.
"""FW agnostic user-end APIs""" """FW agnostic user-end APIs"""
import ctypes
import os
import platform
import subprocess
def get_te_path(): def get_te_path():
"""Find TE path using pip""" """Find Transformer Engine install path using pip"""
import os command = ["pip", "show", "transformer_engine"]
result = subprocess.run(command, capture_output=True, check=True, text=True)
te_info = ( result = result.stdout.replace("\n", ":").split(":")
os.popen("pip show transformer_engine").read().replace("\n", ":").split(":") return result[result.index("Location")+1].strip()
)
return te_info[te_info.index("Location") + 1].strip()
def _load_library(): def _load_library():
"""Load TE .so""" """Load shared library with Transformer Engine C extensions"""
import os
import ctypes
import platform
system = platform.system() system = platform.system()
if system == "Linux": if system == "Linux":
...@@ -31,7 +29,7 @@ def _load_library(): ...@@ -31,7 +29,7 @@ def _load_library():
elif system == "Windows": elif system == "Windows":
extension = "dll" extension = "dll"
else: else:
raise "Unsupported operating system " + system + "." raise RuntimeError(f"Unsupported operating system ({system})")
lib_name = "libtransformer_engine." + extension lib_name = "libtransformer_engine." + extension
dll_path = get_te_path() dll_path = get_te_path()
dll_path = os.path.join(dll_path, lib_name) dll_path = os.path.join(dll_path, lib_name)
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
"""Transformer.""" """Transformer."""
import os import os
import re
import math import math
import warnings import warnings
from importlib.metadata import version
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Callable, Optional, Tuple, Union from typing import Any, Callable, Optional, Tuple, Union
...@@ -42,7 +42,7 @@ from transformer_engine.pytorch.distributed import ( ...@@ -42,7 +42,7 @@ from transformer_engine.pytorch.distributed import (
checkpoint, checkpoint,
) )
_flash_attn_version = re.search("Version: (.*)", os.popen("pip show flash_attn").read()).group(1) _flash_attn_version = version("flash-attn")
warnings.filterwarnings("module", category=DeprecationWarning, module="transformer") warnings.filterwarnings("module", category=DeprecationWarning, module="transformer")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment