Commit 3d5f8cb7 authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.10' into release_v2.10

parents ad5fbfb5 bdf3d931
...@@ -145,7 +145,10 @@ def get_build_ext( ...@@ -145,7 +145,10 @@ def get_build_ext(
# For editable/inplace builds this is not a concern as # For editable/inplace builds this is not a concern as
# the SOs will be in a local directory anyway. # the SOs will be in a local directory anyway.
if not self.inplace: if not self.inplace:
target_dir = install_dir / "transformer_engine" / lib_dir if not bool(int(os.getenv("TEFL_HYGON_BACKEND", "0"))):
target_dir = install_dir / "transformer_engine" / lib_dir
else:
target_dir = install_dir / "transformer_engine_fl_hygon" / lib_dir
target_dir.mkdir(exist_ok=True, parents=True) target_dir.mkdir(exist_ok=True, parents=True)
for ext in Path(self.build_lib).glob("*.so"): for ext in Path(self.build_lib).glob("*.so"):
......
...@@ -155,7 +155,7 @@ def setup_pytorch_extension( ...@@ -155,7 +155,7 @@ def setup_pytorch_extension(
if rocm_build(): if rocm_build():
return CUDAExtension( return CUDAExtension(
name="transformer_engine_torch", name="transformer_engine_torch" if not bool(int(os.getenv("TEFL_HYGON_BACKEND", "0"))) else "transformer_engine_fl_torch_hygon",
sources=[str(src) for src in sources], sources=[str(src) for src in sources],
include_dirs=[str(inc) for inc in include_dirs], include_dirs=[str(inc) for inc in include_dirs],
extra_compile_args={ extra_compile_args={
......
...@@ -50,6 +50,17 @@ if rocm_build(): ...@@ -50,6 +50,17 @@ if rocm_build():
else: else:
archs = cuda_archs() archs = cuda_archs()
if bool(int(os.getenv("TEFL_HYGON_BACKEND", "0"))):
common_dir = current_file_path / "transformer_engine" / "common"
common_copy = current_file_path / "transformer_engine_fl_hygon" / "common"
if common_copy.exists():
shutil.rmtree(common_copy)
shutil.copytree(common_dir, common_copy)
csrc_dir = current_file_path / "transformer_engine" / "pytorch" / "csrc"
csrc_copy = current_file_path / "transformer_engine_fl_hygon" / "pytorch" / "csrc"
if csrc_copy.exists():
shutil.rmtree(csrc_copy)
shutil.copytree(csrc_dir, csrc_copy)
class TimedBdist(bdist_wheel): class TimedBdist(bdist_wheel):
"""Helper class to measure build time""" """Helper class to measure build time"""
...@@ -114,9 +125,13 @@ def setup_common_extension() -> CMakeExtension: ...@@ -114,9 +125,13 @@ def setup_common_extension() -> CMakeExtension:
if os.getenv("NVTE_USE_ROCBLAS") is not None: if os.getenv("NVTE_USE_ROCBLAS") is not None:
cmake_flags.append("-DUSE_ROCBLAS=ON") cmake_flags.append("-DUSE_ROCBLAS=ON")
if not bool(int(os.getenv("TEFL_HYGON_BACKEND", "0"))):
cmake_path = root_path / Path("transformer_engine/common")
else:
cmake_path = root_path / Path("transformer_engine_fl_hygon/common")
return CMakeExtension( return CMakeExtension(
name="transformer_engine", name="transformer_engine" if not bool(int(os.getenv("TEFL_HYGON_BACKEND", "0"))) else "transformer_engine_fl_hygon",
cmake_path=root_path / Path("transformer_engine/common"), cmake_path=cmake_path,
cmake_flags=cmake_flags, cmake_flags=cmake_flags,
) )
...@@ -239,13 +254,22 @@ if __name__ == "__main__": ...@@ -239,13 +254,22 @@ if __name__ == "__main__":
if "pytorch" in frameworks: if "pytorch" in frameworks:
from build_tools.pytorch import setup_pytorch_extension from build_tools.pytorch import setup_pytorch_extension
ext_modules.append( if not bool(int(os.getenv("TEFL_HYGON_BACKEND", "0"))):
setup_pytorch_extension( ext_modules.append(
"transformer_engine/pytorch/csrc", setup_pytorch_extension(
current_file_path / "transformer_engine" / "pytorch" / "csrc", "transformer_engine/pytorch/csrc",
current_file_path / "transformer_engine", current_file_path / "transformer_engine" / "pytorch" / "csrc",
current_file_path / "transformer_engine",
)
)
else:
ext_modules.append(
setup_pytorch_extension(
"transformer_engine_fl_hygon/pytorch/csrc",
current_file_path / "transformer_engine_fl_hygon" / "pytorch" / "csrc",
current_file_path / "transformer_engine_fl_hygon",
)
) )
)
if "jax" in frameworks: if "jax" in frameworks:
from build_tools.jax import setup_jax_extension from build_tools.jax import setup_jax_extension
...@@ -257,27 +281,52 @@ if __name__ == "__main__": ...@@ -257,27 +281,52 @@ if __name__ == "__main__":
) )
) )
# Configure package if not bool(int(os.getenv("TEFL_HYGON_BACKEND", "0"))):
setuptools.setup( # Configure package
name="transformer_engine", setuptools.setup(
version=__version__, name="transformer_engine",
packages=setuptools.find_packages( version=__version__,
include=[ packages=setuptools.find_packages(
"transformer_engine", include=[
"transformer_engine.*", "transformer_engine",
"transformer_engine/build_tools", "transformer_engine.*",
], "transformer_engine/build_tools",
), ],
extras_require=extras_require, ),
description="Transformer acceleration library", extras_require=extras_require,
long_description=long_description, description="Transformer acceleration library",
long_description_content_type="text/x-rst", long_description=long_description,
ext_modules=ext_modules, long_description_content_type="text/x-rst",
cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, ext_modules=ext_modules,
python_requires=f">={min_python_version_str()}", cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
classifiers=["Programming Language :: Python :: 3"], python_requires=f">={min_python_version_str()}",
install_requires=install_requires, classifiers=["Programming Language :: Python :: 3"],
license_files=("LICENSE",), install_requires=install_requires,
include_package_data=include_package_data, license_files=("LICENSE",),
package_data=package_data, include_package_data=include_package_data,
) package_data=package_data,
)
else:
# Configure package of hygon backend for TransformerEngine-FL
setuptools.setup(
name="transformer_engine_fl_hygon",
version=__version__,
packages=setuptools.find_packages(
include=[
"transformer_engine_fl_hygon",
"transformer_engine_fl_hygon.*",
],
),
extras_require=extras_require,
description="Transformer acceleration library for TransformerEngine-FL",
long_description=long_description,
long_description_content_type="text/x-rst",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
python_requires=f">={min_python_version_str()}",
classifiers=["Programming Language :: Python :: 3"],
install_requires=install_requires,
license_files=("LICENSE",),
include_package_data=include_package_data,
package_data=package_data,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment