Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
d07683f3
Commit
d07683f3
authored
Apr 10, 2025
by
zhanggezhong
Browse files
Update setup.py
parent
f8f6f259
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
1 deletion
+34
-1
setup.py
setup.py
+34
-1
No files found.
setup.py
View file @
d07683f3
...
@@ -23,7 +23,9 @@ import shutil
...
@@ -23,7 +23,9 @@ import shutil
import
http.client
import
http.client
import
urllib.request
import
urllib.request
import
urllib.error
import
urllib.error
import
importlib
from
pathlib
import
Path
from
pathlib
import
Path
from
packaging
import
version
from
packaging.version
import
parse
from
packaging.version
import
parse
import
torch.version
import
torch.version
from
wheel.bdist_wheel
import
bdist_wheel
as
_bdist_wheel
from
wheel.bdist_wheel
import
bdist_wheel
as
_bdist_wheel
...
@@ -328,7 +330,38 @@ class CMakeBuild(BuildExtension):
...
@@ -328,7 +330,38 @@ class CMakeBuild(BuildExtension):
[
"cmake"
,
"--build"
,
"."
,
"--verbose"
,
*
build_args
],
cwd
=
build_temp
,
check
=
True
[
"cmake"
,
"--build"
,
"."
,
"--verbose"
,
*
build_args
],
cwd
=
build_temp
,
check
=
True
)
)
USE_FASTPT_CUDA
=
os
.
getenv
(
'USE_FASTPT_CUDA'
,
'False'
).
lower
()
==
'true'
def
check_fastpt_version
():
try
:
# Try to import the fastpt module
fastpt
=
importlib
.
import_module
(
'fastpt'
)
# Get version number
fastpt_version
=
getattr
(
fastpt
,
'__version__'
,
None
)
if
fastpt_version
is
None
:
raise
ImportError
(
"fastpt module doesn't have __version__ attribute, cannot determine version"
)
print
(
f
"Detected fastpt installation, version:
{
fastpt_version
}
"
)
# Compare version numbers
if
version
.
parse
(
fastpt_version
)
>=
version
.
parse
(
'2.0.2'
):
print
(
"fastpt version ≥ 2.0.2"
)
return
True
else
:
print
(
f
"fastpt version
{
fastpt_version
}
< 2.0.2"
)
return
False
except
ImportError
as
e
:
print
(
f
"Error: fastpt not installed or import failed -
{
str
(
e
)
}
"
)
raise
try
:
if
check_fastpt_version
():
USE_FASTPT_CUDA
=
os
.
getenv
(
'USE_FASTPT_CUDA'
,
'0'
)
==
'1'
else
:
USE_FASTPT_CUDA
=
os
.
getenv
(
'USE_FASTPT_CUDA'
,
'False'
).
lower
()
==
'true'
except
Exception
as
e
:
print
(
f
"Program terminated:
{
str
(
e
)
}
"
)
if
CUDA_HOME
is
not
None
:
if
CUDA_HOME
is
not
None
:
extra_nvcc_flags
=
[
extra_nvcc_flags
=
[
'-O3'
,
'-O3'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment