Commit b9856c3a authored by zhangyh15's avatar zhangyh15
Browse files

打包时依赖可以带上torch版本

parent f01246b6
......@@ -138,6 +138,10 @@ def _find_rocm_home() -> Optional[str]:
return rocm_home
ROCM_HOME = _find_rocm_home()
pytorch_dep = 'torch'
if os.getenv('PYTORCH_VERSION'):
pytorch_dep += "==" + os.getenv('PYTORCH_VERSION')
setup(
name="flash_mla",
version=get_version(ROCM_HOME),
......@@ -145,4 +149,6 @@ setup(
ext_modules=ext_modules,
package_data={"flash_mla":["asm/*.co"]},
cmdclass={"build_ext": BuildExtension},
install_requires=[pytorch_dep],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment