Commit 08be824c authored by wenjh's avatar wenjh
Browse files

Rename package for tefl


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent bdf3d931
...@@ -148,7 +148,7 @@ def get_build_ext( ...@@ -148,7 +148,7 @@ def get_build_ext(
if not bool(int(os.getenv("TEFL_HYGON_BACKEND", "0"))): if not bool(int(os.getenv("TEFL_HYGON_BACKEND", "0"))):
target_dir = install_dir / "transformer_engine" / lib_dir target_dir = install_dir / "transformer_engine" / lib_dir
else: else:
target_dir = install_dir / "transformer_engine_fl_hygon" / lib_dir target_dir = install_dir / "transformer_engine_hygon" / lib_dir
target_dir.mkdir(exist_ok=True, parents=True) target_dir.mkdir(exist_ok=True, parents=True)
for ext in Path(self.build_lib).glob("*.so"): for ext in Path(self.build_lib).glob("*.so"):
......
...@@ -155,7 +155,7 @@ def setup_pytorch_extension( ...@@ -155,7 +155,7 @@ def setup_pytorch_extension(
if rocm_build(): if rocm_build():
return CUDAExtension( return CUDAExtension(
name="transformer_engine_torch" if not bool(int(os.getenv("TEFL_HYGON_BACKEND", "0"))) else "transformer_engine_fl_torch_hygon", name="transformer_engine_torch" if not bool(int(os.getenv("TEFL_HYGON_BACKEND", "0"))) else "transformer_engine_torch_hygon",
sources=[str(src) for src in sources], sources=[str(src) for src in sources],
include_dirs=[str(inc) for inc in include_dirs], include_dirs=[str(inc) for inc in include_dirs],
extra_compile_args={ extra_compile_args={
......
...@@ -52,12 +52,12 @@ else: ...@@ -52,12 +52,12 @@ else:
if bool(int(os.getenv("TEFL_HYGON_BACKEND", "0"))): if bool(int(os.getenv("TEFL_HYGON_BACKEND", "0"))):
common_dir = current_file_path / "transformer_engine" / "common" common_dir = current_file_path / "transformer_engine" / "common"
common_copy = current_file_path / "transformer_engine_fl_hygon" / "common" common_copy = current_file_path / "transformer_engine_hygon" / "common"
if common_copy.exists(): if common_copy.exists():
shutil.rmtree(common_copy) shutil.rmtree(common_copy)
shutil.copytree(common_dir, common_copy) shutil.copytree(common_dir, common_copy)
csrc_dir = current_file_path / "transformer_engine" / "pytorch" / "csrc" csrc_dir = current_file_path / "transformer_engine" / "pytorch" / "csrc"
csrc_copy = current_file_path / "transformer_engine_fl_hygon" / "pytorch" / "csrc" csrc_copy = current_file_path / "transformer_engine_hygon" / "pytorch" / "csrc"
if csrc_copy.exists(): if csrc_copy.exists():
shutil.rmtree(csrc_copy) shutil.rmtree(csrc_copy)
shutil.copytree(csrc_dir, csrc_copy) shutil.copytree(csrc_dir, csrc_copy)
...@@ -128,9 +128,9 @@ def setup_common_extension() -> CMakeExtension: ...@@ -128,9 +128,9 @@ def setup_common_extension() -> CMakeExtension:
if not bool(int(os.getenv("TEFL_HYGON_BACKEND", "0"))): if not bool(int(os.getenv("TEFL_HYGON_BACKEND", "0"))):
cmake_path = root_path / Path("transformer_engine/common") cmake_path = root_path / Path("transformer_engine/common")
else: else:
cmake_path = root_path / Path("transformer_engine_fl_hygon/common") cmake_path = root_path / Path("transformer_engine_hygon/common")
return CMakeExtension( return CMakeExtension(
name="transformer_engine" if not bool(int(os.getenv("TEFL_HYGON_BACKEND", "0"))) else "transformer_engine_fl_hygon", name="transformer_engine" if not bool(int(os.getenv("TEFL_HYGON_BACKEND", "0"))) else "transformer_engine_hygon",
cmake_path=cmake_path, cmake_path=cmake_path,
cmake_flags=cmake_flags, cmake_flags=cmake_flags,
) )
...@@ -265,9 +265,9 @@ if __name__ == "__main__": ...@@ -265,9 +265,9 @@ if __name__ == "__main__":
else: else:
ext_modules.append( ext_modules.append(
setup_pytorch_extension( setup_pytorch_extension(
"transformer_engine_fl_hygon/pytorch/csrc", "transformer_engine_hygon/pytorch/csrc",
current_file_path / "transformer_engine_fl_hygon" / "pytorch" / "csrc", current_file_path / "transformer_engine_hygon" / "pytorch" / "csrc",
current_file_path / "transformer_engine_fl_hygon", current_file_path / "transformer_engine_hygon",
) )
) )
if "jax" in frameworks: if "jax" in frameworks:
...@@ -309,12 +309,12 @@ if __name__ == "__main__": ...@@ -309,12 +309,12 @@ if __name__ == "__main__":
else: else:
# Configure package of hygon backend for TransformerEngine-FL # Configure package of hygon backend for TransformerEngine-FL
setuptools.setup( setuptools.setup(
name="transformer_engine_fl_hygon", name="transformer_engine_hygon",
version=__version__, version=__version__,
packages=setuptools.find_packages( packages=setuptools.find_packages(
include=[ include=[
"transformer_engine_fl_hygon", "transformer_engine_hygon",
"transformer_engine_fl_hygon.*", "transformer_engine_hygon.*",
], ],
), ),
extras_require=extras_require, extras_require=extras_require,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment