Commit 0a5016b1 authored by wenjh's avatar wenjh
Browse files

Merge nv release_v2.9


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 063ef88d 70f53666
...@@ -145,15 +145,25 @@ if __name__ == "__main__": ...@@ -145,15 +145,25 @@ if __name__ == "__main__":
) )
] ]
# Setup version and requirements.
# Having the framework extension depend on the core lib allows
# us to detect CUDA version dynamically during compilation and
# choose the correct wheel for te core lib.
__version__ = te_version()
cuda_major_version = parse(torch.version.cuda).major
assert cuda_major_version in (12, 13), f"Unsupported cuda version {torch.version.cuda}."
te_core = f"transformer_engine_cu{cuda_major_version}=={__version__}"
install_requires = install_requirements() + [te_core]
# Configure package # Configure package
setuptools.setup( setuptools.setup(
name=PACKAGE_NAME, name=PACKAGE_NAME,
version=te_version(), version=__version__,
description="Transformer acceleration library - Torch Lib", description="Transformer acceleration library - Torch Lib",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": CachedWheelsCommand}, cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": CachedWheelsCommand},
python_requires=f">={min_python_version_str()}", python_requires=f">={min_python_version_str()}",
install_requires=install_requirements(), install_requires=install_requires,
tests_require=test_requirements(), tests_require=test_requirements(),
) )
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment