Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
ab122dac
"vscode:/vscode.git/clone" did not exist on "ea8489fce266d69f2fbe314c1385956b1a342e12"
Commit
ab122dac
authored
Mar 27, 2025
by
yuguo
Browse files
[DCU] compile pass
parent
4c6a5a27
Changes
40
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
103 additions
and
33 deletions
+103
-33
build_tools/pytorch.py
build_tools/pytorch.py
+1
-0
hipify_custom_map.json
hipify_custom_map.json
+5
-0
setup.py
setup.py
+5
-1
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+3
-4
transformer_engine/common/amd_detail/hip_f8_impl.h
transformer_engine/common/amd_detail/hip_f8_impl.h
+4
-1
transformer_engine/common/amd_detail/hip_float8.h
transformer_engine/common/amd_detail/hip_float8.h
+3
-3
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
...mer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
+5
-0
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
...common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
+7
-1
transformer_engine/common/common.cu
transformer_engine/common/common.cu
+4
-0
transformer_engine/common/common.h
transformer_engine/common/common.h
+6
-2
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+6
-3
transformer_engine/common/normalization/common.cpp
transformer_engine/common/normalization/common.cpp
+17
-16
transformer_engine/common/normalization/common.h
transformer_engine/common/normalization/common.h
+4
-2
transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
...common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
+5
-0
transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu
...gine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu
+5
-0
transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
...mon/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
+5
-0
transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
...e/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
+5
-0
transformer_engine/common/permutation/permutation.cu
transformer_engine/common/permutation/permutation.cu
+5
-0
transformer_engine/common/recipe/current_scaling.cu
transformer_engine/common/recipe/current_scaling.cu
+4
-0
transformer_engine/common/recipe/delayed_scaling.cu
transformer_engine/common/recipe/delayed_scaling.cu
+4
-0
No files found.
build_tools/pytorch.py
View file @
ab122dac
...
@@ -58,6 +58,7 @@ def setup_pytorch_extension(
...
@@ -58,6 +58,7 @@ def setup_pytorch_extension(
"-U__HIP_NO_BFLOAT16_CONVERSIONS__"
,
"-U__HIP_NO_BFLOAT16_CONVERSIONS__"
,
"-U__HIP_NO_BFLOAT162_OPERATORS__"
,
"-U__HIP_NO_BFLOAT162_OPERATORS__"
,
"-U__HIP_NO_BFLOAT162_CONVERSIONS__"
,
"-U__HIP_NO_BFLOAT162_CONVERSIONS__"
,
"-w"
,
]
]
else
:
else
:
nvcc_flags
=
[
nvcc_flags
=
[
...
...
hipify_custom_map.json
View file @
ab122dac
{
{
"custom_map"
:
{
"custom_map"
:
{
"common/util/vectorized_pointwise.h"
:
"common/util/vectorized_pointwise_hip.h"
,
"common/common.h"
:
"common/common_hip.h"
,
"/userbuffers.h"
:
"/userbuffers_hip.h"
,
"/logging.h"
:
"/logging_hip.h"
,
"/system.h"
:
"/system_hip.h"
,
"<cuda_bf16.h>"
:
"<hip/hip_bf16.h>"
,
"<cuda_bf16.h>"
:
"<hip/hip_bf16.h>"
,
"<cuda_fp8.h>"
:
"
\"
amd_detail/hip_float8.h
\"
"
,
"<cuda_fp8.h>"
:
"
\"
amd_detail/hip_float8.h
\"
"
,
"CUfunc_cache"
:
"hipFuncCache_t"
,
"CUfunc_cache"
:
"hipFuncCache_t"
,
...
...
setup.py
View file @
ab122dac
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
# See LICENSE for license information.
# See LICENSE for license information.
"""Installation script."""
"""Installation script."""
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=0 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc pip3 install . -v
import
os
import
os
import
sys
import
sys
...
@@ -43,7 +44,10 @@ elif "jax" in frameworks:
...
@@ -43,7 +44,10 @@ elif "jax" in frameworks:
CMakeBuildExtension
=
get_build_ext
(
BuildExtension
)
CMakeBuildExtension
=
get_build_ext
(
BuildExtension
)
archs
=
cuda_archs
()
if
rocm_build
():
archs
=
None
else
:
archs
=
cuda_archs
()
class
TimedBdist
(
bdist_wheel
):
class
TimedBdist
(
bdist_wheel
):
...
...
transformer_engine/common/CMakeLists.txt
View file @
ab122dac
...
@@ -226,11 +226,9 @@ else()
...
@@ -226,11 +226,9 @@ else()
add_library
(
transformer_engine SHARED
${
te_hip_sources
}
)
add_library
(
transformer_engine SHARED
${
te_hip_sources
}
)
endif
()
endif
()
target_include_directories
(
transformer_engine PUBLIC
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/include"
)
# Configure dependencies
# Configure dependencies
if
(
USE_CUDA
)
if
(
USE_CUDA
)
target_include_directories
(
transformer_engine PUBLIC
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/include"
)
# Configure dependencies
# Configure dependencies
target_link_libraries
(
transformer_engine PUBLIC
target_link_libraries
(
transformer_engine PUBLIC
CUDA::cublas
CUDA::cublas
...
@@ -239,6 +237,7 @@ if (USE_CUDA)
...
@@ -239,6 +237,7 @@ if (USE_CUDA)
${
CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES
}
)
${
CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES
}
)
target_include_directories
(
transformer_engine PRIVATE
"
${
CUDNN_FRONTEND_INCLUDE_DIR
}
"
)
target_include_directories
(
transformer_engine PRIVATE
"
${
CUDNN_FRONTEND_INCLUDE_DIR
}
"
)
else
()
else
()
target_include_directories
(
transformer_engine PUBLIC
"
${
CMAKE_CURRENT_SOURCE_DIR
}
"
)
# Aotriton is currently unsupported
# Aotriton is currently unsupported
set
(
AotritonAndCk_fused_attn
"unsupported"
)
set
(
AotritonAndCk_fused_attn
"unsupported"
)
...
@@ -343,7 +342,7 @@ else()
...
@@ -343,7 +342,7 @@ else()
set
(
HIP_HCC_FLAGS
"
${
CMAKE_HIP_FLAGS
}
-mavx2 -mf16c -mfma -std=c++17"
)
set
(
HIP_HCC_FLAGS
"
${
CMAKE_HIP_FLAGS
}
-mavx2 -mf16c -mfma -std=c++17"
)
# Ask hcc to generate device code during compilation so we can use
# Ask hcc to generate device code during compilation so we can use
# host linker to link.
# host linker to link.
set
(
HIP_HCC_FLAGS
"
${
HIP_HCC_FLAGS
}
-fno-gpu-rdc -
Wno-defaulted-function-deleted
"
)
set
(
HIP_HCC_FLAGS
"
${
HIP_HCC_FLAGS
}
-fno-gpu-rdc -
w
"
)
foreach
(
rocm_arch
${
CMAKE_HIP_ARCHITECTURES
}
)
foreach
(
rocm_arch
${
CMAKE_HIP_ARCHITECTURES
}
)
# if CMAKE_CXX_FLAGS has --offload-arch set already, better to rm first
# if CMAKE_CXX_FLAGS has --offload-arch set already, better to rm first
set
(
HIP_HCC_FLAGS
"
${
HIP_HCC_FLAGS
}
--offload-arch=
${
rocm_arch
}
"
)
set
(
HIP_HCC_FLAGS
"
${
HIP_HCC_FLAGS
}
--offload-arch=
${
rocm_arch
}
"
)
...
...
transformer_engine/common/amd_detail/hip_f8_impl.h
View file @
ab122dac
...
@@ -4,6 +4,9 @@
...
@@ -4,6 +4,9 @@
* License for AMD contributions = MIT. See LICENSE for more information
* License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/
************************************************************************/
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
namespace
hip_f8_impl
{
namespace
hip_f8_impl
{
HIP_HOST_DEVICE
inline
int
clz
(
uint32_t
x
)
{
HIP_HOST_DEVICE
inline
int
clz
(
uint32_t
x
)
{
...
@@ -190,7 +193,7 @@ HIP_HOST_DEVICE
...
@@ -190,7 +193,7 @@ HIP_HOST_DEVICE
T
cast_from_f8
(
uint8_t
x
)
{
T
cast_from_f8
(
uint8_t
x
)
{
constexpr
bool
is_half
=
std
::
is_same
<
T
,
__half
>::
value
;
constexpr
bool
is_half
=
std
::
is_same
<
T
,
__half
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
constexpr
bool
is_bf16
=
std
::
is_same
<
T
,
hip_bfloat16
>::
value
;
constexpr
bool
is_bf16
=
std
::
is_same
<
T
,
__
hip_bfloat16
>::
value
;
static_assert
(
is_half
||
is_float
,
"only half and float are supported"
);
static_assert
(
is_half
||
is_float
,
"only half and float are supported"
);
constexpr
int
weo
=
is_half
?
5
:
8
;
constexpr
int
weo
=
is_half
?
5
:
8
;
...
...
transformer_engine/common/amd_detail/hip_float8.h
View file @
ab122dac
...
@@ -326,7 +326,7 @@ struct hip_f8 {
...
@@ -326,7 +326,7 @@ struct hip_f8 {
#endif // #ifdef __gfx942__
#endif // #ifdef __gfx942__
// constructor from hip_bfloat16
// constructor from hip_bfloat16
explicit
HIP_HOST_DEVICE
hip_f8
(
hip_bfloat16
v
,
hip_f8_rounding_mode
r
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
);
explicit
HIP_HOST_DEVICE
hip_f8
(
__
hip_bfloat16
v
,
hip_f8_rounding_mode
r
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
);
// convert to float
// convert to float
#ifdef __gfx942__
#ifdef __gfx942__
...
@@ -430,7 +430,7 @@ struct hip_f8 {
...
@@ -430,7 +430,7 @@ struct hip_f8 {
#endif // #ifdef __gfx942__
#endif // #ifdef __gfx942__
// convert to hip_bfloat16
// convert to hip_bfloat16
explicit
inline
HIP_HOST_DEVICE
operator
hip_bfloat16
()
const
;
explicit
inline
HIP_HOST_DEVICE
operator
__
hip_bfloat16
()
const
;
// check for zero
// check for zero
inline
HIP_HOST_DEVICE
bool
is_zero
()
const
{
inline
HIP_HOST_DEVICE
bool
is_zero
()
const
{
...
@@ -504,7 +504,7 @@ struct hip_f8x4 {
...
@@ -504,7 +504,7 @@ struct hip_f8x4 {
HIP_HOST_DEVICE
hip_f8x4
(
halfx4
v
,
hip_f8_rounding_mode
rm
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
);
HIP_HOST_DEVICE
hip_f8x4
(
halfx4
v
,
hip_f8_rounding_mode
rm
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
);
// constructor from hip_bfloat16
// constructor from hip_bfloat16
HIP_HOST_DEVICE
hip_f8x4
(
hip_bfloat16
v0
,
hip_bfloat16
v1
=
hip_bfloat16
(
0.0
f
),
hip_bfloat16
v2
=
hip_bfloat16
(
0.0
f
),
hip_bfloat16
v3
=
hip_bfloat16
(
0.0
f
),
hip_f8_rounding_mode
rm
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
);
HIP_HOST_DEVICE
hip_f8x4
(
__
hip_bfloat16
v0
,
__
hip_bfloat16
v1
=
__
hip_bfloat16
(
0.0
f
),
__
hip_bfloat16
v2
=
__
hip_bfloat16
(
0.0
f
),
__
hip_bfloat16
v3
=
__
hip_bfloat16
(
0.0
f
),
hip_f8_rounding_mode
rm
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
);
HIP_HOST_DEVICE
hip_f8x4
(
hip_bfloat16x2
v
,
hip_f8_rounding_mode
rm
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
);
HIP_HOST_DEVICE
hip_f8x4
(
hip_bfloat16x2
v
,
hip_f8_rounding_mode
rm
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
);
HIP_HOST_DEVICE
hip_f8x4
(
hip_bfloat16x4
v
,
hip_f8_rounding_mode
rm
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
);
HIP_HOST_DEVICE
hip_f8x4
(
hip_bfloat16x4
v
,
hip_f8_rounding_mode
rm
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
);
...
...
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
View file @
ab122dac
...
@@ -12,8 +12,13 @@
...
@@ -12,8 +12,13 @@
#include <numeric>
#include <numeric>
#include "common/common.h"
#include "common/common.h"
#ifdef USE_ROCM
#include "common/util/hip_driver.h"
#include "common/util/hip_runtime.h"
#else
#include "common/util/cuda_driver.h"
#include "common/util/cuda_driver.h"
#include "common/util/cuda_runtime.h"
#include "common/util/cuda_runtime.h"
#endif
#include "common/util/logging.h"
#include "common/util/logging.h"
#include "common/util/system.h"
#include "common/util/system.h"
#include "userbuffers/userbuffers.h"
#include "userbuffers/userbuffers.h"
...
...
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
View file @
ab122dac
...
@@ -19,9 +19,15 @@
...
@@ -19,9 +19,15 @@
#include <map>
#include <map>
#include <utility>
#include <utility>
#ifdef USE_ROCM
#include "common/util/hip_driver.h"
#include "common/util/hip_nvml.h"
#include "common/util/hip_runtime.h"
#else
#include "common/util/cuda_driver.h"
#include "common/util/cuda_driver.h"
#include "common/util/cuda_nvml.h"
#include "common/util/cuda_nvml.h"
#include "common/util/cuda_runtime.h"
#include "common/util/cuda_runtime.h"
#endif
#include "common/util/logging.h"
#include "common/util/logging.h"
#include "common/util/system.h"
#include "common/util/system.h"
#include "ipcsocket.h"
#include "ipcsocket.h"
...
@@ -362,7 +368,7 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks,
...
@@ -362,7 +368,7 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks,
NVTE_CHECK_CUDA
(
cudaMemset
((
*
comm
)
->
flags
,
0
,
2
*
GPU_PAGE_SIZE
));
NVTE_CHECK_CUDA
(
cudaMemset
((
*
comm
)
->
flags
,
0
,
2
*
GPU_PAGE_SIZE
));
(
*
comm
)
->
flags
=
(
*
comm
)
->
flags
=
#ifdef USE_ROCM
#ifdef USE_ROCM
reinterpret_cast
<
int
*>
((
reinterpret_cast
<
uintptr_t
>
((
*
comm
)
->
flags
)
+
GPU_PAGE_SIZE
-
1
)
&
GPU_PAGE_MASK
reinterpret_cast
<
int
*>
((
reinterpret_cast
<
uintptr_t
>
((
*
comm
)
->
flags
)
+
GPU_PAGE_SIZE
-
1
)
&
GPU_PAGE_MASK
);
#else
#else
reinterpret_cast
<
int
*>
(((
CUdeviceptr
)(
*
comm
)
->
flags
+
GPU_PAGE_SIZE
-
1
)
&
GPU_PAGE_MASK
);
reinterpret_cast
<
int
*>
(((
CUdeviceptr
)(
*
comm
)
->
flags
+
GPU_PAGE_SIZE
-
1
)
&
GPU_PAGE_MASK
);
#endif
#endif
...
...
transformer_engine/common/common.cu
View file @
ab122dac
...
@@ -10,7 +10,11 @@
...
@@ -10,7 +10,11 @@
#include "./common.h"
#include "./common.h"
#include "./utils.cuh"
#include "./utils.cuh"
#ifdef __HIP_PLATFORM_AMD__
#include "common/util/hip_runtime.h"
#else
#include "common/util/cuda_runtime.h"
#include "common/util/cuda_runtime.h"
#endif
#include "common/util/logging.h"
#include "common/util/logging.h"
namespace
transformer_engine
{
namespace
transformer_engine
{
...
...
transformer_engine/common/common.h
View file @
ab122dac
...
@@ -25,7 +25,11 @@
...
@@ -25,7 +25,11 @@
#include <vector>
#include <vector>
#include "./nvtx.h"
#include "./nvtx.h"
#ifdef __HIP_PLATFORM_AMD__
#include "./util/hip_driver.h"
#else
#include "./util/cuda_driver.h"
#include "./util/cuda_driver.h"
#endif
#include "./util/logging.h"
#include "./util/logging.h"
namespace
transformer_engine
{
namespace
transformer_engine
{
...
@@ -223,7 +227,7 @@ using bf16 = nv_bfloat16;
...
@@ -223,7 +227,7 @@ using bf16 = nv_bfloat16;
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
#else
#else
using
bf16
=
hip_bfloat16
;
using
bf16
=
__
hip_bfloat16
;
using
fp8e4m3
=
te_hip_fp8_e4m3
;
using
fp8e4m3
=
te_hip_fp8_e4m3
;
using
fp8e5m2
=
te_hip_fp8_e5m2
;
using
fp8e5m2
=
te_hip_fp8_e5m2
;
#endif
#endif
...
@@ -247,7 +251,7 @@ TRANSFORMER_ENGINE_TYPE_NAME(int64_t)
...
@@ -247,7 +251,7 @@ TRANSFORMER_ENGINE_TYPE_NAME(int64_t)
TRANSFORMER_ENGINE_TYPE_NAME
(
float
)
TRANSFORMER_ENGINE_TYPE_NAME
(
float
)
TRANSFORMER_ENGINE_TYPE_NAME
(
half
)
TRANSFORMER_ENGINE_TYPE_NAME
(
half
)
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
TRANSFORMER_ENGINE_TYPE_NAME
(
hip_bfloat16
)
TRANSFORMER_ENGINE_TYPE_NAME
(
__
hip_bfloat16
)
TRANSFORMER_ENGINE_TYPE_NAME
(
te_hip_fp8_e4m3
)
TRANSFORMER_ENGINE_TYPE_NAME
(
te_hip_fp8_e4m3
)
TRANSFORMER_ENGINE_TYPE_NAME
(
te_hip_fp8_e5m2
)
TRANSFORMER_ENGINE_TYPE_NAME
(
te_hip_fp8_e5m2
)
#else
#else
...
...
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
ab122dac
...
@@ -22,7 +22,11 @@
...
@@ -22,7 +22,11 @@
#include "../common.h"
#include "../common.h"
#include "../util/handle_manager.h"
#include "../util/handle_manager.h"
#include "../util/logging.h"
#include "../util/logging.h"
#ifdef __HIP_PLATFORM_AMD__
#include "common/util/hip_runtime.h"
#else
#include "common/util/cuda_runtime.h"
#include "common/util/cuda_runtime.h"
#endif
#ifndef __HIP_PLATFORM_AMD__
#ifndef __HIP_PLATFORM_AMD__
namespace
{
namespace
{
...
@@ -738,7 +742,7 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
...
@@ -738,7 +742,7 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
if
(
NVTE_BLAS_MULSTREAM
==
nullptr
){
if
(
NVTE_BLAS_MULSTREAM
==
nullptr
){
NVTE_FORCE_BLASLT_MULSTREAM
=
true
;
NVTE_FORCE_BLASLT_MULSTREAM
=
true
;
}
elif
((
NVTE_BLASLT_BLAS
!=
nullptr
&&
NVTE_BLASLT_BLAS
[
0
]
==
'1'
)
&&
(
NVTE_BLAS_MULSTREAM
!=
nullptr
&&
NVTE_BLAS_MULSTREAM
[
0
]
==
'1'
)){
}
el
se
if
((
NVTE_BLASLT_BLAS
!=
nullptr
&&
NVTE_BLASLT_BLAS
[
0
]
==
'1'
)
&&
(
NVTE_BLAS_MULSTREAM
!=
nullptr
&&
NVTE_BLAS_MULSTREAM
[
0
]
==
'1'
)){
NVTE_ERROR
(
"NVTE_FORCE_BLAS_MULSTREAM and NVTE_FORCE_BLASLT can't be set at the same time."
);
NVTE_ERROR
(
"NVTE_FORCE_BLAS_MULSTREAM and NVTE_FORCE_BLASLT can't be set at the same time."
);
}
else
{
}
else
{
NVTE_FORCE_BLASLT_MULSTREAM
=
false
;
NVTE_FORCE_BLASLT_MULSTREAM
=
false
;
...
@@ -776,8 +780,7 @@ void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B
...
@@ -776,8 +780,7 @@ void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_stream_cublas_batchgemm
);
NVTE_API_CALL
(
nvte_multi_stream_cublas_batchgemm
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
static_assert
(
num_gemms
%
num_batchgemm_streams
==
0
,
assert
(
num_gemms
%
num_batchgemm_streams
==
0
);
"Need num_gemms mod num_batchgemm_streams == 0."
);
static
int
batch_count
=
num_gemms
/
num_batchgemm_streams
;
static
int
batch_count
=
num_gemms
/
num_batchgemm_streams
;
// Inits streams and events (once, globally)
// Inits streams and events (once, globally)
std
::
call_once
(
init_flag_batchgemm
,
init_streams_and_events_batchgemm
);
std
::
call_once
(
init_flag_batchgemm
,
init_streams_and_events_batchgemm
);
...
...
transformer_engine/common/normalization/common.cpp
View file @
ab122dac
...
@@ -192,15 +192,15 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
...
@@ -192,15 +192,15 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
const
size_t
sm_count
,
const
size_t
sm_count
,
const
bool
zero_centered_gamma
,
const
bool
zero_centered_gamma
,
const
NVTEScalingMode
mode
,
bool
training
)
const
NVTEScalingMode
mode
,
bool
training
)
#ifdef USE_ROCM
{
assert
(
false
);
#else
:
_fp8_out
(
is_fp8_dtype
(
otype
)),
:
_fp8_out
(
is_fp8_dtype
(
otype
)),
_zero_centered
(
zero_centered_gamma
),
_zero_centered
(
zero_centered_gamma
),
_training
(
training
),
_training
(
training
),
_norm_stage
(
NormStage
),
_norm_stage
(
NormStage
),
_norm_type
(
NormType
)
{
_norm_type
(
NormType
)
{
#ifdef USE_ROCM
static_assert
(
false
,
"Cudnn backend is not surpported in rocm for normalization yet."
);
#else
static_assert
(
CUDNN_FRONTEND_VERSION
>=
10601
,
static_assert
(
CUDNN_FRONTEND_VERSION
>=
10601
,
"CUDNN_FRONTEND_VERSION should be at least 1.6.1!"
);
"CUDNN_FRONTEND_VERSION should be at least 1.6.1!"
);
...
@@ -389,8 +389,7 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
...
@@ -389,8 +389,7 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
void
CudnnNormalizationPlan
::
_build
()
{
void
CudnnNormalizationPlan
::
_build
()
{
#ifdef USE_ROCM
#ifdef USE_ROCM
static_assert
(
false
,
assert
(
false
);
"Cudnn backend is not surpported in rocm for normalization yet."
);
#else
#else
NVTE_CHECK
(
_graph
.
validate
().
is_good
());
NVTE_CHECK
(
_graph
.
validate
().
is_good
());
NVTE_CHECK
(
_graph
.
build_operation_graph
(
_handle
).
is_good
());
NVTE_CHECK
(
_graph
.
build_operation_graph
(
_handle
).
is_good
());
...
@@ -406,8 +405,8 @@ void CudnnNormalizationPlan::_build() {
...
@@ -406,8 +405,8 @@ void CudnnNormalizationPlan::_build() {
std
::
vector
<
size_t
>
CudnnNormalizationPlan
::
getWorkspaceShape
()
const
{
std
::
vector
<
size_t
>
CudnnNormalizationPlan
::
getWorkspaceShape
()
const
{
#ifdef USE_ROCM
#ifdef USE_ROCM
static_
assert
(
false
,
assert
(
false
);
"Cudnn backend is not surpported in rocm for normalization yet."
)
;
return
{
0
}
;
#else
#else
return
{
static_cast
<
size_t
>
(
_graph
.
get_workspace_size
())};
return
{
static_cast
<
size_t
>
(
_graph
.
get_workspace_size
())};
#endif
#endif
...
@@ -417,8 +416,7 @@ void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr,
...
@@ -417,8 +416,7 @@ void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr,
void
*
mean_dptr
,
void
*
eps_dptr
,
void
*
rsigma_dptr
,
void
*
mean_dptr
,
void
*
eps_dptr
,
void
*
rsigma_dptr
,
void
*
workspace_dptr
,
cudaStream_t
stream
)
{
void
*
workspace_dptr
,
cudaStream_t
stream
)
{
#ifdef USE_ROCM
#ifdef USE_ROCM
static_assert
(
false
,
assert
(
false
);
"Cudnn backend is not surpported in rocm for normalization yet."
);
#else
#else
// Binding data pointers to graph tensors
// Binding data pointers to graph tensors
_variant_pack
=
{{
_x
,
x_dptr
},
{
_eps
,
eps_dptr
}};
_variant_pack
=
{{
_x
,
x_dptr
},
{
_eps
,
eps_dptr
}};
...
@@ -462,8 +460,7 @@ void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_
...
@@ -462,8 +460,7 @@ void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_
void
*
dbeta_dptr
,
void
*
dgamma_dptr
,
void
*
workspace_dptr
,
void
*
dbeta_dptr
,
void
*
dgamma_dptr
,
void
*
workspace_dptr
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
#ifdef USE_ROCM
#ifdef USE_ROCM
static_assert
(
false
,
assert
(
false
);
"Cudnn backend is not surpported in rocm for normalization yet."
);
#else
#else
// Binding data pointers to graph tensors
// Binding data pointers to graph tensors
_variant_pack
=
{
_variant_pack
=
{
...
@@ -519,7 +516,8 @@ NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan(
...
@@ -519,7 +516,8 @@ NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan(
bool
&
_cudnn_norm_fwd_flag
()
{
bool
&
_cudnn_norm_fwd_flag
()
{
#ifdef USE_ROCM
#ifdef USE_ROCM
return
false
;
static
bool
flag
=
false
;
return
flag
;
#else
#else
static
bool
flag
=
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
);
static
bool
flag
=
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
);
return
flag
;
return
flag
;
...
@@ -528,7 +526,8 @@ bool& _cudnn_norm_fwd_flag() {
...
@@ -528,7 +526,8 @@ bool& _cudnn_norm_fwd_flag() {
bool
&
_cudnn_norm_bwd_flag
()
{
bool
&
_cudnn_norm_bwd_flag
()
{
#ifdef USE_ROCM
#ifdef USE_ROCM
return
false
;
static
bool
flag
=
false
;
return
flag
;
#else
#else
static
bool
flag
=
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_BWD_USE_CUDNN"
);
static
bool
flag
=
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_BWD_USE_CUDNN"
);
return
flag
;
return
flag
;
...
@@ -544,7 +543,8 @@ bool use_cudnn_norm_bwd() { return _cudnn_norm_bwd_flag(); }
...
@@ -544,7 +543,8 @@ bool use_cudnn_norm_bwd() { return _cudnn_norm_bwd_flag(); }
void
nvte_enable_cudnn_norm_fwd
(
bool
enable
)
{
void
nvte_enable_cudnn_norm_fwd
(
bool
enable
)
{
NVTE_API_CALL
(
nvte_enable_cudnn_norm_fwd
);
NVTE_API_CALL
(
nvte_enable_cudnn_norm_fwd
);
#ifdef USE_ROCM
#ifdef USE_ROCM
transformer_engine
::
normalization
::
_cudnn_norm_bwd_flag
()
=
false
;
bool
flag
=
false
;
transformer_engine
::
normalization
::
_cudnn_norm_bwd_flag
()
=
flag
;
#else
#else
transformer_engine
::
normalization
::
_cudnn_norm_fwd_flag
()
=
enable
;
transformer_engine
::
normalization
::
_cudnn_norm_fwd_flag
()
=
enable
;
#endif
#endif
...
@@ -553,7 +553,8 @@ void nvte_enable_cudnn_norm_fwd(bool enable) {
...
@@ -553,7 +553,8 @@ void nvte_enable_cudnn_norm_fwd(bool enable) {
void
nvte_enable_cudnn_norm_bwd
(
bool
enable
)
{
void
nvte_enable_cudnn_norm_bwd
(
bool
enable
)
{
NVTE_API_CALL
(
nvte_enable_cudnn_norm_bwd
);
NVTE_API_CALL
(
nvte_enable_cudnn_norm_bwd
);
#ifdef USE_ROCM
#ifdef USE_ROCM
transformer_engine
::
normalization
::
_cudnn_norm_bwd_flag
()
=
false
;
bool
flag
=
false
;
transformer_engine
::
normalization
::
_cudnn_norm_bwd_flag
()
=
flag
;
#else
#else
transformer_engine
::
normalization
::
_cudnn_norm_bwd_flag
()
=
enable
;
transformer_engine
::
normalization
::
_cudnn_norm_bwd_flag
()
=
enable
;
#endif
#endif
...
...
transformer_engine/common/normalization/common.h
View file @
ab122dac
...
@@ -30,7 +30,9 @@ namespace transformer_engine {
...
@@ -30,7 +30,9 @@ namespace transformer_engine {
namespace
normalization
{
namespace
normalization
{
#ifndef __HIP_PLATFORM_AMD__
namespace
fe
=
cudnn_frontend
;
namespace
fe
=
cudnn_frontend
;
#endif
template
<
typename
KernelParamsType
>
template
<
typename
KernelParamsType
>
struct
LaunchParams
{
struct
LaunchParams
{
...
@@ -277,14 +279,14 @@ class CudnnNormalizationPlan : public NormalizationPlanBase {
...
@@ -277,14 +279,14 @@ class CudnnNormalizationPlan : public NormalizationPlanBase {
private:
private:
void
_build
()
override
;
void
_build
()
override
;
#ifndef __HIP_PLATFORM_AMD__
const
bool
_zero_centered
,
_fp8_out
;
const
bool
_zero_centered
,
_fp8_out
;
int
_ndim_scale_block
;
int
_ndim_scale_block
;
const
NVTE_Norm_Stage
_norm_stage
;
const
NVTE_Norm_Stage
_norm_stage
;
const
NVTE_Norm_Type
_norm_type
;
const
NVTE_Norm_Type
_norm_type
;
std
::
unique_ptr
<
char
[]
>
_scalar_dptr
;
std
::
unique_ptr
<
char
[]
>
_scalar_dptr
;
std
::
unique_ptr
<
float
>
_one_dptr
=
std
::
make_unique
<
float
>
(
1.0
f
);
std
::
unique_ptr
<
float
>
_one_dptr
=
std
::
make_unique
<
float
>
(
1.0
f
);
#ifndef __HIP_PLATFORM_AMD__
// FWD
// FWD
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
_x
,
_gamma_zero
,
_scalar_offset
,
_gamma
,
_beta
,
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
_x
,
_gamma_zero
,
_scalar_offset
,
_gamma
,
_beta
,
_eps
,
_mean
,
_rsigma
,
_z
,
_z_scale
,
_one_for_div
,
_z_scale_inv
,
_amax
,
_z_fp8
;
_eps
,
_mean
,
_rsigma
,
_z
,
_z_scale
,
_one_for_div
,
_z_scale_inv
,
_amax
,
_z_fp8
;
...
...
transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
View file @
ab122dac
...
@@ -43,6 +43,11 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
...
@@ -43,6 +43,11 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES
));
Kernel_traits
::
SMEM_BYTES
));
}
}
#else
if
(
Kernel_traits
::
SMEM_BYTES
>=
48
*
1024
)
{
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
((
const
void
*
)
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES
));
}
#endif
#endif
auto
stream
=
launch_params
.
stream
;
auto
stream
=
launch_params
.
stream
;
...
...
transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu
View file @
ab122dac
...
@@ -39,6 +39,11 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
...
@@ -39,6 +39,11 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES_FWD
));
Kernel_traits
::
SMEM_BYTES_FWD
));
}
}
#else
if
(
Kernel_traits
::
SMEM_BYTES_FWD
>=
48
*
1024
)
{
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
((
const
void
*
)
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES_FWD
));
}
#endif
#endif
auto
stream
=
launch_params
.
stream
;
auto
stream
=
launch_params
.
stream
;
...
...
transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
View file @
ab122dac
...
@@ -42,6 +42,11 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
...
@@ -42,6 +42,11 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES
));
Kernel_traits
::
SMEM_BYTES
));
}
}
#else
if
(
Kernel_traits
::
SMEM_BYTES
>=
48
*
1024
)
{
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
((
const
void
*
)
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES
));
}
#endif
#endif
auto
stream
=
launch_params
.
stream
;
auto
stream
=
launch_params
.
stream
;
...
...
transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
View file @
ab122dac
...
@@ -40,6 +40,11 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
...
@@ -40,6 +40,11 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES_FWD
));
Kernel_traits
::
SMEM_BYTES_FWD
));
}
}
#else
if
(
Kernel_traits
::
SMEM_BYTES_FWD
>=
48
*
1024
)
{
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
((
const
void
*
)
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES_FWD
));
}
#endif
#endif
auto
stream
=
launch_params
.
stream
;
auto
stream
=
launch_params
.
stream
;
...
...
transformer_engine/common/permutation/permutation.cu
View file @
ab122dac
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
using
__nv_fp8_e4m3
=
hip_f8
<
hip_f8_type
::
fp8
>
;
using
__nv_fp8_e4m3
=
hip_f8
<
hip_f8_type
::
fp8
>
;
using
__nv_fp8_e5m2
=
hip_f8
<
hip_f8_type
::
bf8
>
;
using
__nv_fp8_e5m2
=
hip_f8
<
hip_f8_type
::
bf8
>
;
#define __ldlu(x) __ldg(x)
#endif
#endif
static
__global__
void
moe_permute_row_map
(
const
int
*
sorted_row_id
,
int
*
row_id_map
,
static
__global__
void
moe_permute_row_map
(
const
int
*
sorted_row_id
,
int
*
row_id_map
,
...
@@ -214,7 +215,11 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac
...
@@ -214,7 +215,11 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac
if
(
k
==
topK
)
break
;
if
(
k
==
topK
)
break
;
// Warp-level reduction
// Warp-level reduction
for
(
int
mask
=
16
;
mask
>
0
;
mask
/=
2
)
{
for
(
int
mask
=
16
;
mask
>
0
;
mask
/=
2
)
{
#ifdef __HIP_PLATFORM_AMD__
accum
[
k
]
=
accum
[
k
]
+
__shfl_xor
(
accum
[
k
],
mask
,
32
);
#else
accum
[
k
]
=
accum
[
k
]
+
__shfl_xor_sync
(
0xffffffff
,
accum
[
k
],
mask
,
32
);
accum
[
k
]
=
accum
[
k
]
+
__shfl_xor_sync
(
0xffffffff
,
accum
[
k
],
mask
,
32
);
#endif
}
}
}
}
...
...
transformer_engine/common/recipe/current_scaling.cu
View file @
ab122dac
...
@@ -15,6 +15,10 @@
...
@@ -15,6 +15,10 @@
#include "../util/vectorized_pointwise.h"
#include "../util/vectorized_pointwise.h"
#include "recipe_common.cuh"
#include "recipe_common.cuh"
#ifdef __HIP_PLATFORM_AMD__
using
__nv_bfloat16
=
__hip_bfloat16
;
#endif
namespace
transformer_engine
{
namespace
transformer_engine
{
namespace
{
namespace
{
...
...
transformer_engine/common/recipe/delayed_scaling.cu
View file @
ab122dac
...
@@ -11,7 +11,11 @@
...
@@ -11,7 +11,11 @@
#include <string>
#include <string>
#include "../common.h"
#include "../common.h"
#ifdef __HIP_PLATFORM_AMD__
#include "../util/hip_runtime.h"
#else
#include "../util/cuda_runtime.h"
#include "../util/cuda_runtime.h"
#endif
#include "../util/logging.h"
#include "../util/logging.h"
namespace
transformer_engine
{
namespace
transformer_engine
{
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment