Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
2b1428ff
Commit
2b1428ff
authored
Jun 18, 2025
by
yuguo
Browse files
[DCU] fix 2.5 compile issues
parent
b4a2489f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
11 additions
and
28 deletions
+11
-28
build_tools/pytorch.py
build_tools/pytorch.py
+4
-4
pyproject.toml
pyproject.toml
+0
-10
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+1
-0
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+3
-3
transformer_engine/common/include/transformer_engine/gemm.h
transformer_engine/common/include/transformer_engine/gemm.h
+1
-0
transformer_engine/common/util/multi_stream.cpp
transformer_engine/common/util/multi_stream.cpp
+2
-1
transformer_engine/pytorch/pyproject.toml
transformer_engine/pytorch/pyproject.toml
+0
-10
No files found.
build_tools/pytorch.py
View file @
2b1428ff
...
@@ -15,10 +15,10 @@ from typing import List
...
@@ -15,10 +15,10 @@ from typing import List
def
install_requirements
()
->
List
[
str
]:
def
install_requirements
()
->
List
[
str
]:
"""Install dependencies for TE/JAX extensions."""
"""Install dependencies for TE/JAX extensions."""
reqs
=
[
"torch>=2.1"
,
"einops"
]
reqs
=
[
"torch>=2.1"
,
"einops"
]
reqs
.
append
(
#
reqs.append(
"nvdlfw-inspect @"
#
"nvdlfw-inspect @"
" git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
#
" git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
)
#
)
return
reqs
return
reqs
...
...
pyproject.toml
deleted
100755 → 0
View file @
b4a2489f
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
[build-system]
requires
=
[
"setuptools>=61.0"
,
"cmake>=3.21"
,
"wheel"
,
"pybind11[global]"
,
"ninja"
,
"pip"
,
"torch>=2.1"
,
"jax[cuda12]"
,
"flax>=0.7.1"
]
# Use legacy backend to import local packages in setup.py
build-backend
=
"setuptools.build_meta:__legacy__"
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
2b1428ff
...
@@ -692,6 +692,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
...
@@ -692,6 +692,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
"Cuda version >=12.2 and <13.0 is required for atomic gemm."
);
"Cuda version >=12.2 and <13.0 is required for atomic gemm."
);
NVTE_CHECK
(
cublasLtGetVersion
()
>=
120205
&&
cublasLtGetVersion
()
<
130000
,
NVTE_CHECK
(
cublasLtGetVersion
()
>=
120205
&&
cublasLtGetVersion
()
<
130000
,
"Cublas version >=12.2.5 and <13.0 is required for atomic gemm."
);
"Cublas version >=12.2.5 and <13.0 is required for atomic gemm."
);
#endif
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
const
Tensor
*
inputA
=
convertNVTETensorCheck
(
A
);
const
Tensor
*
inputA
=
convertNVTETensorCheck
(
A
);
...
...
transformer_engine/common/gemm/rocm_gemm.cu
View file @
2b1428ff
...
@@ -1003,7 +1003,7 @@ static inline int getIntEnv(const char *name, int defval, int minval)
...
@@ -1003,7 +1003,7 @@ static inline int getIntEnv(const char *name, int defval, int minval)
*/
*/
static
void
init_hipblaslt_handles
(
hipblasLtHandle_t
*
hipblaslt_handles
)
{
static
void
init_hipblaslt_handles
(
hipblasLtHandle_t
*
hipblaslt_handles
)
{
NVTE_CHECK
(
hipblaslt_handles
!=
nullptr
);
NVTE_CHECK
(
hipblaslt_handles
!=
nullptr
);
for
(
int
i
=
0
;
i
<
num_streams
;
i
++
)
{
for
(
int
i
=
0
;
i
<
compute_
num_streams
;
i
++
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtCreate
(
&
hipblaslt_handles
[
i
]));
NVTE_CHECK_HIPBLASLT
(
hipblasLtCreate
(
&
hipblaslt_handles
[
i
]));
}
}
}
}
...
@@ -1842,13 +1842,13 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -1842,13 +1842,13 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
if
(
use_hipblaslt
||
!
use_rocblas
)
if
(
use_hipblaslt
||
!
use_rocblas
)
{
{
// Check compute_stream_offset valid.
// Check compute_stream_offset valid.
NVTE_CHECK
(
compute_stream_offset
>=
-
1
&&
compute_stream_offset
<
num_streams
);
NVTE_CHECK
(
compute_stream_offset
>=
-
1
&&
compute_stream_offset
<
compute_
num_streams
);
hipblasLtHandle_t
handle
=
nullptr
;
hipblasLtHandle_t
handle
=
nullptr
;
if
(
compute_stream_offset
!=
-
1
)
{
if
(
compute_stream_offset
!=
-
1
)
{
// Init hipblaslt handles (once, globally)
// Init hipblaslt handles (once, globally)
static
std
::
once_flag
init_flag
;
static
std
::
once_flag
init_flag
;
static
hipblasLtHandle_t
hipblaslt_handles
[
num_streams
];
static
hipblasLtHandle_t
hipblaslt_handles
[
compute_
num_streams
];
std
::
call_once
(
init_flag
,
init_hipblaslt_handles
,
hipblaslt_handles
);
std
::
call_once
(
init_flag
,
init_hipblaslt_handles
,
hipblaslt_handles
);
handle
=
hipblaslt_handles
[
compute_stream_offset
];
handle
=
hipblaslt_handles
[
compute_stream_offset
];
...
...
transformer_engine/common/include/transformer_engine/gemm.h
View file @
2b1428ff
...
@@ -132,6 +132,7 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
...
@@ -132,6 +132,7 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
*/
*/
namespace
transformer_engine
{
namespace
transformer_engine
{
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
constexpr
int
compute_num_streams
=
2
;
// Add for batchgemm stream
// Add for batchgemm stream
constexpr
int
num_batchgemm_streams
=
1
;
constexpr
int
num_batchgemm_streams
=
1
;
#endif
#endif
...
...
transformer_engine/common/util/multi_stream.cpp
View file @
2b1428ff
...
@@ -10,6 +10,7 @@
...
@@ -10,6 +10,7 @@
#include "multi_stream.h"
#include "multi_stream.h"
#include <transformer_engine/multi_stream.h>
#include <transformer_engine/multi_stream.h>
#include <transformer_engine/gemm.h>
#include <mutex>
#include <mutex>
#include <vector>
#include <vector>
...
@@ -51,7 +52,7 @@ cudaEvent_t get_compute_stream_event(int idx) {
...
@@ -51,7 +52,7 @@ cudaEvent_t get_compute_stream_event(int idx) {
int
get_num_compute_streams
()
{
int
get_num_compute_streams
()
{
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
static
constexpr
int
num_compute_streams
=
2
;
static
constexpr
int
num_compute_streams
=
compute_num_streams
;
#else
#else
static
constexpr
int
num_compute_streams
=
4
;
static
constexpr
int
num_compute_streams
=
4
;
#endif
#endif
...
...
transformer_engine/pytorch/pyproject.toml
deleted
100755 → 0
View file @
b4a2489f
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
[build-system]
requires
=
[
"setuptools>=61.0"
,
"pip"
,
"torch>=2.1"
]
# Use legacy backend to import local packages in setup.py
build-backend
=
"setuptools.build_meta:__legacy__"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment