Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
5b6190b2
Commit
5b6190b2
authored
Aug 27, 2025
by
yuguo
Browse files
[DCU] fix compile
parent
87e3e56e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
86 additions
and
23 deletions
+86
-23
transformer_engine/common/__init__.py
transformer_engine/common/__init__.py
+14
-13
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+68
-1
transformer_engine/common/include/transformer_engine/gemm.h
transformer_engine/common/include/transformer_engine/gemm.h
+1
-1
transformer_engine/common/multi_tensor/compute_scale.cu
transformer_engine/common/multi_tensor/compute_scale.cu
+1
-2
transformer_engine/common/util/cast_gated_kernels.cuh
transformer_engine/common/util/cast_gated_kernels.cuh
+1
-0
transformer_engine/common/util/cast_kernels.cuh
transformer_engine/common/util/cast_kernels.cuh
+1
-6
No files found.
transformer_engine/common/__init__.py
View file @
5b6190b2
...
@@ -17,7 +17,7 @@ import subprocess
...
@@ -17,7 +17,7 @@ import subprocess
import
sys
import
sys
import
sysconfig
import
sysconfig
from
typing
import
Optional
from
typing
import
Optional
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -245,18 +245,19 @@ def _load_cudnn():
...
@@ -245,18 +245,19 @@ def _load_cudnn():
found
,
handle
=
_load_nvidia_cuda_library
(
"cudnn"
)
found
,
handle
=
_load_nvidia_cuda_library
(
"cudnn"
)
if
found
:
if
found
:
return
handle
return
handle
# Attempt to locate libcudnn via ldconfig
if
not
IS_HIP_EXTENSION
:
libs
=
subprocess
.
check_output
(
# Attempt to locate libcudnn via ldconfig
f
"ldconfig -p | grep 'libcudnn
{
_get_sys_extension
()
}
'"
,
shell
=
True
libs
=
subprocess
.
check_output
(
)
f
"ldconfig -p | grep 'libcudnn
{
_get_sys_extension
()
}
'"
,
shell
=
True
libs
=
libs
.
decode
(
"utf-8"
).
split
(
"
\n
"
)
)
sos
=
[]
libs
=
libs
.
decode
(
"utf-8"
).
split
(
"
\n
"
)
for
lib
in
libs
:
sos
=
[]
if
"libcudnn"
in
lib
and
"=>"
in
lib
:
for
lib
in
libs
:
sos
.
append
(
lib
.
split
(
">"
)[
1
].
strip
())
if
"libcudnn"
in
lib
and
"=>"
in
lib
:
if
sos
:
sos
.
append
(
lib
.
split
(
">"
)[
1
].
strip
())
return
ctypes
.
CDLL
(
sos
[
0
],
mode
=
ctypes
.
RTLD_GLOBAL
)
if
sos
:
return
ctypes
.
CDLL
(
sos
[
0
],
mode
=
ctypes
.
RTLD_GLOBAL
)
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return
ctypes
.
CDLL
(
f
"libcudnn
{
_get_sys_extension
()
}
"
,
mode
=
ctypes
.
RTLD_GLOBAL
)
return
ctypes
.
CDLL
(
f
"libcudnn
{
_get_sys_extension
()
}
"
,
mode
=
ctypes
.
RTLD_GLOBAL
)
...
...
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
5b6190b2
...
@@ -757,7 +757,7 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
...
@@ -757,7 +757,7 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
void
nvte_cublas_gemm_scaled
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
void
nvte_cublas_gemm_scaled
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
float
alpha
,
float
beta
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
float
alpha
,
float
beta
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
)
{
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
,
bool
nvte_use_rocblas
,
int
compute_stream_offset
)
{
NVTE_API_CALL
(
nvte_cublas_gemm_scaled
);
NVTE_API_CALL
(
nvte_cublas_gemm_scaled
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
const
Tensor
*
inputA
=
convertNVTETensorCheck
(
A
);
const
Tensor
*
inputA
=
convertNVTETensorCheck
(
A
);
...
@@ -767,9 +767,76 @@ void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor
...
@@ -767,9 +767,76 @@ void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor
Tensor
*
outputGelu
=
convertNVTETensor
(
pre_gelu_out
);
Tensor
*
outputGelu
=
convertNVTETensor
(
pre_gelu_out
);
Tensor
*
wspace
=
convertNVTETensor
(
workspace
);
Tensor
*
wspace
=
convertNVTETensor
(
workspace
);
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK
(
alpha
==
1.0
f
,
"alpha must be 1.0 for hip"
);
NVTE_CHECK
(
beta
==
1.0
f
||
beta
==
0.0
f
,
"beta must be 1.0 or 0.0 for hip"
);
bool
accumulate
=
false
;
if
(
alpha
==
1.0
f
and
beta
==
1.0
f
)
{
accumulate
=
true
;
}
const
size_t
A0
=
inputA
->
flat_first_dim
();
const
size_t
A1
=
inputA
->
flat_last_dim
();
const
size_t
B0
=
inputB
->
flat_first_dim
();
const
size_t
B1
=
inputB
->
flat_last_dim
();
const
int
m
=
transa
?
A0
:
A1
;
const
int
k
=
transa
?
A1
:
A0
;
const
int
n
=
transb
?
B1
:
B0
;
int
lda
,
ldb
,
ldd
;
if
(
transa
&&
!
transb
)
{
// TN
lda
=
k
;
ldb
=
k
;
ldd
=
m
;
}
else
if
(
!
transa
&&
!
transb
)
{
// NN
lda
=
m
;
ldb
=
k
;
ldd
=
m
;
}
else
if
(
!
transa
&&
transb
)
{
// NT
lda
=
m
;
ldb
=
n
;
ldd
=
m
;
}
else
{
// TT
NVTE_ERROR
(
"TT layout not allowed."
);
}
const
bool
use_int8
=
is_int8_dtype
(
inputA
->
data
.
dtype
)
||
is_int8_dtype
(
inputB
->
data
.
dtype
);
const
char
*
NVTE_FORCE_ROCM_GEMM
=
std
::
getenv
(
"NVTE_FORCE_ROCM_GEMM"
);
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
is_fp8_dtype
(
inputB
->
data
.
dtype
);
const
char
*
NVTE_INT8_SIM_FP8_TENSORWISE
=
std
::
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE"
);
if
(
NVTE_INT8_SIM_FP8_TENSORWISE
!=
nullptr
&&
NVTE_INT8_SIM_FP8_TENSORWISE
[
0
]
==
'1'
&&
use_int8
&&
use_split_accumulator
)
nvte_use_hipblaslt
=
1
;
if
((
biasTensor
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
->
data
.
dptr
!=
nullptr
)
||
(
use_fp8
)
||
(
NVTE_FORCE_ROCM_GEMM
!=
nullptr
&&
NVTE_FORCE_ROCM_GEMM
[
0
]
==
'1'
)
||
(
nvte_use_hipblaslt
)
||
(
nvte_use_rocblas
))
{
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
transa
,
transb
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
,
nvte_use_hipblaslt
,
nvte_use_rocblas
,
compute_stream_offset
);
}
else
{
hipblas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
}
#else
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
alpha
,
beta
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
alpha
,
beta
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
#endif
}
}
void
nvte_cublas_atomic_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
void
nvte_cublas_atomic_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
...
...
transformer_engine/common/include/transformer_engine/gemm.h
View file @
5b6190b2
...
@@ -72,7 +72,7 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
...
@@ -72,7 +72,7 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
void
nvte_cublas_gemm_scaled
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
void
nvte_cublas_gemm_scaled
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
float
alpha
,
float
beta
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
float
alpha
,
float
beta
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
);
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
,
int
compute_stream_offset
=
0
);
/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters.
/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters.
*
*
...
...
transformer_engine/common/multi_tensor/compute_scale.cu
View file @
5b6190b2
...
@@ -58,8 +58,7 @@ struct ComputeScaleAndScaleInvFunctor {
...
@@ -58,8 +58,7 @@ struct ComputeScaleAndScaleInvFunctor {
void
multi_tensor_compute_scale_and_scale_inv_cuda
(
int
chunk_size
,
Tensor
noop_flag
,
void
multi_tensor_compute_scale_and_scale_inv_cuda
(
int
chunk_size
,
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
Tensor
*>>
tensor_lists
,
std
::
vector
<
std
::
vector
<
Tensor
*>>
tensor_lists
,
float
max_fp8
,
bool
force_pow_2_scales
,
float
max_fp8
,
bool
force_pow_2_scales
,
float
epsilon
,
const
int
device_id
,
float
epsilon
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
multi_tensor_apply
<
BLOCK_SIZE
,
3
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
multi_tensor_apply
<
BLOCK_SIZE
,
3
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
ComputeScaleAndScaleInvFunctor
(),
stream
,
max_fp8
,
ComputeScaleAndScaleInvFunctor
(),
stream
,
max_fp8
,
force_pow_2_scales
,
epsilon
);
force_pow_2_scales
,
epsilon
);
...
...
transformer_engine/common/util/cast_gated_kernels.cuh
View file @
5b6190b2
...
@@ -264,6 +264,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -264,6 +264,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
}
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
#endif
namespace
mxfp8_kernel
{
namespace
mxfp8_kernel
{
...
...
transformer_engine/common/util/cast_kernels.cuh
View file @
5b6190b2
...
@@ -44,7 +44,7 @@ constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128
...
@@ -44,7 +44,7 @@ constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr
size_t
THREADS_PER_BANK
=
TOTAL_BANKS_WIDTH
/
SCALE_DIM_X
;
// 4 = 128 / 32
constexpr
size_t
THREADS_PER_BANK
=
TOTAL_BANKS_WIDTH
/
SCALE_DIM_X
;
// 4 = 128 / 32
#ifndef __HIP_PLATFORM_AMD__
template
<
bool
IS_DBIAS
,
bool
IS_DACT
,
bool
IS_ACT
,
typename
ParamOP
,
template
<
bool
IS_DBIAS
,
bool
IS_DACT
,
bool
IS_ACT
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
typename
IType
,
typename
OType
,
bool
ROWWISE_SCALING
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
typename
IType
,
typename
OType
,
bool
ROWWISE_SCALING
,
bool
COLWISE_SCALING
,
size_t
CHUNK_DIM_Y
,
size_t
CHUNK_DIM_X
,
size_t
THREADS_PER_CHUNK
>
bool
COLWISE_SCALING
,
size_t
CHUNK_DIM_Y
,
size_t
CHUNK_DIM_X
,
size_t
THREADS_PER_CHUNK
>
...
@@ -205,11 +205,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -205,11 +205,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Wait for the data to have arrived
// Wait for the data to have arrived
ptx
::
mbarrier_wait_parity
(
&
mbar
[
stage
],
parity
);
ptx
::
mbarrier_wait_parity
(
&
mbar
[
stage
],
parity
);
// Trigger the next kernel, so its TMA load can be overlapped with the current kernel
if
(
stage
==
STAGES
-
1
)
{
cudaTriggerProgrammaticLaunchCompletion
();
}
float
thread_amax
=
0.0
f
;
float
thread_amax
=
0.0
f
;
if
constexpr
(
COLWISE_SCALING
)
{
if
constexpr
(
COLWISE_SCALING
)
{
const
size_t
shmem_offset_base_colwise
=
buff
*
BUFF_DIM
+
tid_X_colwise
;
const
size_t
shmem_offset_base_colwise
=
buff
*
BUFF_DIM
+
tid_X_colwise
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment