Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
11e445c3
Commit
11e445c3
authored
Feb 06, 2026
by
zhanghj2
Browse files
加入版本信息
parent
b1ba831f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
92 additions
and
4 deletions
+92
-4
get_version.py
get_version.py
+69
-0
setup.py
setup.py
+23
-4
No files found.
get_version.py
0 → 100644
View file @
11e445c3
import
os
,
re
import
ast
import
subprocess
from
pathlib
import
Path
import
torch
ROOT_DIR
=
Path
(
__file__
).
parent
.
resolve
()
def
_run_cmd
(
cmd
,
shell
=
False
):
try
:
return
subprocess
.
check_output
(
cmd
,
cwd
=
ROOT_DIR
,
stderr
=
subprocess
.
DEVNULL
,
shell
=
shell
).
decode
(
"ascii"
).
strip
()
except
Exception
:
return
None
def
get_package_version
():
with
open
(
Path
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)))
/
"flash_mla"
/
"__init__.py"
,
"r"
)
as
f
:
version_match
=
re
.
search
(
r
"^__version__\s*=\s*(.*)$"
,
f
.
read
(),
re
.
MULTILINE
)
public_version
=
ast
.
literal_eval
(
version_match
.
group
(
1
))
local_version
=
os
.
environ
.
get
(
"FLASH_ATTN_LOCAL_VERSION"
)
if
local_version
:
return
f
"
{
public_version
}
+
{
local_version
}
"
else
:
return
str
(
public_version
)
def
_make_version_file
(
version
,
sha
,
abi
,
dtk
,
torch_version
,
branch
):
sha
=
"Unknown"
if
sha
is
None
else
sha
torch_version
=
'.'
.
join
(
torch_version
.
split
(
'.'
)[:
2
])
# hcu_version = f"{version}+das1.1git{sha}.abi{abi}.dtk{dtk}.torch{torch_version}"
hcu_version
=
f
"
{
version
}
+das.opt
{
os
.
environ
[
'FLASH_ATTN_OPT'
]
}
.dtk
{
dtk
}
"
version_path
=
ROOT_DIR
/
"flash_mla"
/
"version.py"
with
open
(
version_path
,
"w"
)
as
f
:
f
.
write
(
f
"version = '
{
version
}
'
\n
"
)
f
.
write
(
f
"git_hash = '
{
sha
}
'
\n
"
)
f
.
write
(
f
"git_branch = '
{
branch
}
'
\n
"
)
f
.
write
(
f
"abi = 'abi
{
abi
}
'
\n
"
)
f
.
write
(
f
"dtk = '
{
dtk
}
'
\n
"
)
f
.
write
(
f
"torch_version = '
{
torch_version
}
'
\n
"
)
f
.
write
(
f
"hcu_version = '
{
hcu_version
}
'
\n
"
)
return
hcu_version
def
_get_pytorch_version
():
if
"PYTORCH_VERSION"
in
os
.
environ
:
return
f
"
{
os
.
environ
[
'PYTORCH_VERSION'
]
}
"
return
torch
.
__version__
def
get_version
(
ROCM_HOME
):
sha
=
_run_cmd
([
"git"
,
"rev-parse"
,
"HEAD"
])
if
sha
is
not
None
:
sha
=
sha
[:
7
]
branch
=
_run_cmd
([
"git"
,
"rev-parse"
,
"--abbrev-ref"
,
"HEAD"
])
tag
=
_run_cmd
([
"git"
,
"describe"
,
"--tags"
,
"--exact-match"
,
"@"
])
print
(
"-- Git branch:"
,
branch
)
print
(
"-- Git SHA:"
,
sha
)
print
(
"-- Git tag:"
,
tag
)
torch_version
=
_get_pytorch_version
()
print
(
"-- PyTorch:"
,
torch_version
)
version
=
get_package_version
()
print
(
"-- Building version"
,
version
)
abi
=
_run_cmd
([
"echo '#include <string>' | gcc -x c++ -E -dM - | fgrep _GLIBCXX_USE_CXX11_ABI | awk '{print $3}'"
],
shell
=
True
)
print
(
"-- _GLIBCXX_USE_CXX11_ABI:"
,
abi
)
dtk
=
_run_cmd
([
"cat"
,
os
.
path
.
join
(
ROCM_HOME
,
'.info/rocm_version'
)])
dtk
=
''
.
join
(
dtk
.
replace
(
' '
,
''
).
replace
(
'-'
,
''
).
replace
(
'V'
,
''
).
split
(
'.'
))
print
(
"-- DTK:"
,
dtk
)
return
_make_version_file
(
version
,
sha
,
abi
,
dtk
,
torch_version
,
branch
)
setup.py
View file @
11e445c3
...
@@ -2,9 +2,11 @@ import os
...
@@ -2,9 +2,11 @@ import os
from
pathlib
import
Path
from
pathlib
import
Path
from
datetime
import
datetime
from
datetime
import
datetime
import
subprocess
import
subprocess
from
typing
import
Optional
from
get_version
import
get_version
from
setuptools
import
setup
,
find_packages
from
setuptools
import
setup
,
find_packages
import
torch
from
torch.utils.cpp_extension
import
(
from
torch.utils.cpp_extension
import
(
BuildExtension
,
BuildExtension
,
CUDAExtension
,
CUDAExtension
,
...
@@ -100,11 +102,28 @@ ext_modules.append(
...
@@ -100,11 +102,28 @@ ext_modules.append(
)
)
)
)
def
_find_rocm_home
()
->
Optional
[
str
]:
rocm_home
=
os
.
environ
.
get
(
'ROCM_HOME'
)
or
os
.
environ
.
get
(
'ROCM_PATH'
)
if
rocm_home
is
None
:
try
:
pipe_hipcc
=
subprocess
.
Popen
(
[
"which hipcc | xargs readlink -f"
],
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
shell
=
True
)
hipcc
,
_
=
pipe_hipcc
.
communicate
()
rocm_home
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
hipcc
.
decode
(
*
()).
rstrip
(
'
\r\n
'
)))
if
os
.
path
.
basename
(
rocm_home
)
==
'hip'
:
rocm_home
=
os
.
path
.
dirname
(
rocm_home
)
except
Exception
:
rocm_home
=
'/opt/rocm'
if
not
os
.
path
.
exists
(
rocm_home
):
rocm_home
=
None
if
rocm_home
and
torch
.
version
.
hip
is
None
:
print
(
f
"No ROCm runtime is found, using ROCM_HOME='
{
rocm_home
}
'"
)
return
rocm_home
ROCM_HOME
=
_find_rocm_home
()
setup
(
setup
(
name
=
"flash_mla"
,
name
=
"flash_mla"
,
version
=
"1.0.0"
,
version
=
get_version
(
ROCM_HOME
)
,
packages
=
find_packages
(
include
=
[
'flash_mla'
]),
packages
=
find_packages
(
include
=
[
'flash_mla'
]),
ext_modules
=
ext_modules
,
ext_modules
=
ext_modules
,
cmdclass
=
{
"build_ext"
:
BuildExtension
},
cmdclass
=
{
"build_ext"
:
BuildExtension
},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment