Unverified Commit 0963b288 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Fix zombie process when querying TE install path (#121)



* Remove zombie process from querying TE install path
Co-authored-by: default avatarNaman Goyal <naman@fb.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix FA version checking
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix unused import error
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix lint warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarNaman Goyal <naman@fb.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d90bf212
......@@ -3,25 +3,23 @@
# See LICENSE for license information.
"""FW agnostic user-end APIs"""
import ctypes
import os
import platform
import subprocess
def get_te_path():
"""Find TE path using pip"""
"""Find Transformer Engine install path using pip"""
import os
te_info = (
os.popen("pip show transformer_engine").read().replace("\n", ":").split(":")
)
return te_info[te_info.index("Location") + 1].strip()
command = ["pip", "show", "transformer_engine"]
result = subprocess.run(command, capture_output=True, check=True, text=True)
result = result.stdout.replace("\n", ":").split(":")
return result[result.index("Location")+1].strip()
def _load_library():
"""Load TE .so"""
import os
import ctypes
import platform
"""Load shared library with Transformer Engine C extensions"""
system = platform.system()
if system == "Linux":
......@@ -31,7 +29,7 @@ def _load_library():
elif system == "Windows":
extension = "dll"
else:
raise "Unsupported operating system " + system + "."
raise RuntimeError(f"Unsupported operating system ({system})")
lib_name = "libtransformer_engine." + extension
dll_path = get_te_path()
dll_path = os.path.join(dll_path, lib_name)
......
......@@ -4,9 +4,9 @@
"""Transformer."""
import os
import re
import math
import warnings
from importlib.metadata import version
from contextlib import nullcontext
from typing import Any, Callable, Optional, Tuple, Union
......@@ -42,7 +42,7 @@ from transformer_engine.pytorch.distributed import (
checkpoint,
)
_flash_attn_version = re.search("Version: (.*)", os.popen("pip show flash_attn").read()).group(1)
_flash_attn_version = version("flash-attn")
warnings.filterwarnings("module", category=DeprecationWarning, module="transformer")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment