Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
c520cba3
Commit
c520cba3
authored
Mar 20, 2025
by
yuguo
Browse files
[DCU] Preliminary adaptation
parent
5b6ef054
Changes
79
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3808 additions
and
109 deletions
+3808
-109
tests/pytorch/test_recipe.py
tests/pytorch/test_recipe.py
+2
-1
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+11
-0
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+285
-97
transformer_engine/common/__init__.py
transformer_engine/common/__init__.py
+5
-2
transformer_engine/common/amd_detail/hip_f8_impl.h
transformer_engine/common/amd_detail/hip_f8_impl.h
+276
-0
transformer_engine/common/amd_detail/hip_float8.h
transformer_engine/common/amd_detail/hip_float8.h
+458
-0
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
...mer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
+8
-0
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
...common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
+9
-0
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
...ngine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
+142
-0
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h
...engine/common/comm_gemm_overlap/userbuffers/userbuffers.h
+8
-0
transformer_engine/common/common.cu
transformer_engine/common/common.cu
+10
-0
transformer_engine/common/common.h
transformer_engine/common/common.h
+16
-1
transformer_engine/common/cudnn_utils.cpp
transformer_engine/common/cudnn_utils.cpp
+2
-0
transformer_engine/common/cudnn_utils.h
transformer_engine/common/cudnn_utils.h
+5
-1
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+301
-6
transformer_engine/common/gemm/hipblas_gemm.cu
transformer_engine/common/gemm/hipblas_gemm.cu
+193
-0
transformer_engine/common/gemm/hipblas_gemm.h
transformer_engine/common/gemm/hipblas_gemm.h
+149
-0
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+1863
-0
transformer_engine/common/include/transformer_engine/gemm.h
transformer_engine/common/include/transformer_engine/gemm.h
+22
-1
transformer_engine/common/normalization/common.cpp
transformer_engine/common/normalization/common.cpp
+43
-0
No files found.
tests/pytorch/test_recipe.py
View file @
c520cba3
...
...
@@ -6,6 +6,7 @@ from typing import Iterable, Optional
import
pytest
import
torch
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine.common.recipe
import
transformer_engine.pytorch
as
te
...
...
@@ -258,7 +259,7 @@ class TestFP8Recipe:
# Compute scale
max_val
=
{
"forward"
:
448.0
,
"forward"
:
448.0
if
not
IS_HIP_EXTENSION
else
240.0
,
"backward"
:
57344.0
,
}[
stage
]
ref_scale
=
(
max_val
/
ref_amax
)
/
(
2
**
margin
)
...
...
tests/pytorch/test_sanity.py
View file @
c520cba3
...
...
@@ -9,6 +9,7 @@ from contextlib import nullcontext
import
torch
import
pytest
import
os
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
transformer_engine.pytorch.fp8
import
(
fp8_autocast
,
...
...
@@ -51,6 +52,12 @@ def create_meta(scale_factor: float, size: int = 1):
meta
.
scale
=
torch
.
ones
(
size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
*
scale_factor
return
meta
if
IS_HIP_EXTENSION
:
from
functools
import
cache
@
cache
def
use_hipblaslt
()
->
bool
:
return
(
os
.
getenv
(
"NVTE_USE_HIPBLASLT"
)
is
not
None
or
os
.
getenv
(
"NVTE_USE_ROCBLAS"
)
is
None
)
def
custom_amax_to_scale
(
amax
:
torch
.
Tensor
,
...
...
@@ -982,6 +989,10 @@ def test_sanity_gradient_accumulation_fusion(
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
def
test_gpt_cuda_graph
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
zero_centered_gamma
,
normalization
):
if
IS_HIP_EXTENSION
:
if
not
use_hipblaslt
():
pytest
.
skip
(
"CUDA graph capture not supported with rocBLAS path"
)
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
...
...
transformer_engine/common/CMakeLists.txt
View file @
c520cba3
...
...
@@ -4,44 +4,100 @@
cmake_minimum_required
(
VERSION 3.21
)
option
(
USE_ROCM
"Use ROCm"
OFF
)
option
(
USE_HIPBLASLT
"Use HIPBLASLT"
ON
)
# Temp unsupport aottriton\ck backend and Use ROCBLAS
option
(
USE_ROCBLAS
"Use ROCBLAS"
OFF
)
if
(
NOT USE_ROCM
)
if
(((
EXISTS
"/opt/dtk/"
)
OR
(
EXISTS $ENV{ROCM_PATH}
))
AND
NOT
(
EXISTS
"/bin/nvcc"
))
message
(
"hcu detected."
)
set
(
USE_ROCM ON
)
endif
()
endif
()
if
(
USE_ROCM
)
add_compile_definitions
(
__HIP_CLANG_ONLY__=1
)
if
(
NOT USE_HIPBLASLT AND NOT USE_ROCBLAS
)
message
(
FATAL_ERROR
"Need specify at least one GEMM library to use: HIPBLASLT or ROCBLAS"
)
endif
()
unset
(
USE_CUDA
)
else
()
set
(
USE_CUDA TRUE
)
endif
()
# Language options
if
(
NOT DEFINED CMAKE_CUDA_ARCHITECTURES
)
if
(
USE_CUDA
)
if
(
NOT DEFINED CMAKE_CUDA_ARCHITECTURES
)
if
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8
)
set
(
CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120
)
else
()
set
(
CMAKE_CUDA_ARCHITECTURES 70 80 89 90
)
endif
()
endif
()
set
(
CMAKE_CXX_STANDARD 17
)
set
(
CMAKE_CUDA_STANDARD 17
)
set
(
CMAKE_CUDA_STANDARD_REQUIRED ON
)
if
(
CMAKE_BUILD_TYPE STREQUAL
"Debug"
)
endif
()
set
(
CMAKE_CXX_STANDARD 17
)
set
(
CMAKE_CUDA_STANDARD 17
)
set
(
CMAKE_CUDA_STANDARD_REQUIRED ON
)
if
(
CMAKE_BUILD_TYPE STREQUAL
"Debug"
)
set
(
CMAKE_CUDA_FLAGS_DEBUG
"
${
CMAKE_CUDA_FLAGS_DEBUG
}
-g -G"
)
endif
()
endif
()
# Hide non-necessary symbols in shared object.
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-Wl,--version-script=
${
CMAKE_CURRENT_SOURCE_DIR
}
/libtransformer_engine.version"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-Wl,--version-script=
${
CMAKE_CURRENT_SOURCE_DIR
}
/libtransformer_engine.version"
)
# Hide non-necessary symbols in shared object.
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-Wl,--version-script=
${
CMAKE_CURRENT_SOURCE_DIR
}
/libtransformer_engine.version"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-Wl,--version-script=
${
CMAKE_CURRENT_SOURCE_DIR
}
/libtransformer_engine.version"
)
# Transformer Engine library
project
(
transformer_engine LANGUAGES CUDA CXX
)
# Transformer Engine library
project
(
transformer_engine LANGUAGES CUDA CXX
)
# CUDA Toolkit
find_package
(
CUDAToolkit REQUIRED
)
if
(
CUDAToolkit_VERSION VERSION_LESS 12.0
)
# CUDA Toolkit
find_package
(
CUDAToolkit REQUIRED
)
if
(
CUDAToolkit_VERSION VERSION_LESS 12.0
)
message
(
FATAL_ERROR
"CUDA 12.0+ is required, but found CUDA
${
CUDAToolkit_VERSION
}
"
)
endif
()
endif
()
# cuDNN frontend API
set
(
CUDNN_FRONTEND_INCLUDE_DIR
# cuDNN frontend API
set
(
CUDNN_FRONTEND_INCLUDE_DIR
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../3rdparty/cudnn-frontend/include"
)
if
(
NOT EXISTS
"
${
CUDNN_FRONTEND_INCLUDE_DIR
}
"
)
if
(
NOT EXISTS
"
${
CUDNN_FRONTEND_INCLUDE_DIR
}
"
)
message
(
FATAL_ERROR
"Could not find cuDNN frontend API at
${
CUDNN_FRONTEND_INCLUDE_DIR
}
. "
"Try running 'git submodule update --init --recursive' "
"within the Transformer Engine source."
)
endif
()
include
(
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake
)
else
()
set
(
CMAKE_CXX_STANDARD 17
)
project
(
transformer_engine LANGUAGES HIP CXX
)
# Disable Asserts In Code (Can't use asserts on HIP stack.)
add_definitions
(
-DNDEBUG
)
add_definitions
(
-DUSE_ROCM
)
# Change clang++ to hipcc
SET
(
CMAKE_CXX_COMPILER
"
${
ROCM_PATH
}
/bin/hipcc"
)
if
(
NOT DEFINED ENV{NVTE_ROCM_ARCH}
)
SET
(
CMAKE_HIP_ARCHITECTURES gfx906;gfx926;gfx928;gfx936
)
else
()
SET
(
CMAKE_HIP_ARCHITECTURES $ENV{NVTE_ROCM_ARCH}
)
endif
()
# build error will be dup-ed parallel-jobs times
# set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -parallel-jobs=4")
if
(
CMAKE_BUILD_TYPE STREQUAL
"Debug"
)
set
(
CMAKE_HIP_FLAGS
"
${
CMAKE_HIP_FLAGS
}
-g"
)
endif
()
list
(
APPEND CMAKE_MODULE_PATH
"/opt/dtk"
)
endif
()
set
(
message_line
"-------------------------------------------------------------"
)
message
(
"
${
message_line
}
"
)
message
(
STATUS
"USE_ROCM
${
USE_ROCM
}
"
)
if
(
USE_ROCM
)
message
(
STATUS
"CMAKE_HIP_ARCHITECTURES:
${
CMAKE_HIP_ARCHITECTURES
}
"
)
message
(
STATUS
"USE_HIPBLASLT
${
USE_HIPBLASLT
}
USE_ROCBLAS
${
USE_ROCBLAS
}
"
)
endif
()
include
(
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake
)
# Python
find_package
(
Python COMPONENTS Interpreter Development.Module REQUIRED
)
...
...
@@ -49,7 +105,9 @@ find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
# Configure Transformer Engine library
include_directories
(
${
PROJECT_SOURCE_DIR
}
/..
)
set
(
transformer_engine_SOURCES
)
list
(
APPEND transformer_engine_SOURCES
if
(
USE_CUDA
)
list
(
APPEND transformer_engine_SOURCES
cudnn_utils.cpp
transformer_engine.cpp
common.cu
...
...
@@ -92,17 +150,114 @@ list(APPEND transformer_engine_SOURCES
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp
)
add_library
(
transformer_engine SHARED
${
transformer_engine_SOURCES
}
)
target_include_directories
(
transformer_engine PUBLIC
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/include"
)
add_library
(
transformer_engine SHARED
${
transformer_engine_SOURCES
}
)
else
()
list
(
APPEND transformer_engine_SOURCES
cudnn_utils.cpp
transformer_engine.cpp
common.cu
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
gemm/cublaslt_gemm.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_api.cpp
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
permutation/permutation.cu
util/cast.cu
util/padding.cu
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/rtc.cpp
swizzle/swizzle.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_rope/fused_rope.cu
recipe/current_scaling.cu
recipe/delayed_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp
)
# process source code files
message
(
"
${
message_line
}
"
)
message
(
STATUS
"CMAKE_CURRENT_SOURCE_DIR:
${
CMAKE_CURRENT_SOURCE_DIR
}
"
)
message
(
STATUS
"PROJECT_SOURCE_DIR:
${
PROJECT_SOURCE_DIR
}
"
)
set
(
TE
${
CMAKE_CURRENT_SOURCE_DIR
}
/../..
)
set
(
THIRDPARTY
${
TE
}
/3rdparty
)
list
(
APPEND CMAKE_MODULE_PATH
"
${
THIRDPARTY
}
/hipify_torch/cmake"
)
include
(
Hipify
)
message
(
STATUS
"CMAKE_MODULE_PATH:
${
CMAKE_MODULE_PATH
}
"
)
set
(
header_include_dir
${
CMAKE_CURRENT_SOURCE_DIR
}
/comm_gemm_overlap/userbuffers
${
CMAKE_CURRENT_SOURCE_DIR
}
/activation
${
CMAKE_CURRENT_SOURCE_DIR
}
/include
${
CMAKE_CURRENT_SOURCE_DIR
}
/transpose
${
CMAKE_CURRENT_SOURCE_DIR
}
/util
${
CMAKE_CURRENT_SOURCE_DIR
}
/normalization
${
CMAKE_CURRENT_SOURCE_DIR
}
/normalization/rmsnorm
${
CMAKE_CURRENT_SOURCE_DIR
}
/normalization/layernorm
${
CMAKE_CURRENT_SOURCE_DIR
}
)
message
(
STATUS
"HIPIFY CUDA_SOURCE_DIR:
${
CMAKE_CURRENT_SOURCE_DIR
}
"
)
message
(
STATUS
"HIPIFY HEADER_INCLUDE_DIR:
${
header_include_dir
}
"
)
hipify
(
CUDA_SOURCE_DIR
${
CMAKE_CURRENT_SOURCE_DIR
}
HEADER_INCLUDE_DIR
${
header_include_dir
}
IGNORES
"*/amd_detail/*"
IGNORES
"*/fused_attn/*"
CUSTOM_MAP_FILE
"
${
TE
}
/hipify_custom_map.json"
)
get_hipified_list
(
"
${
transformer_engine_SOURCES
}
"
te_hip_sources
)
message
(
"
${
message_line
}
"
)
message
(
STATUS
"nvte hipified sources:
${
te_hip_sources
}
"
)
add_library
(
transformer_engine SHARED
${
te_hip_sources
}
)
endif
()
# Configure dependencies
target_link_libraries
(
transformer_engine PUBLIC
if
(
USE_CUDA
)
target_include_directories
(
transformer_engine PUBLIC
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/include"
)
# Configure dependencies
target_link_libraries
(
transformer_engine PUBLIC
CUDA::cublas
CUDA::cudart
)
target_include_directories
(
transformer_engine PRIVATE
target_include_directories
(
transformer_engine PRIVATE
${
CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES
}
)
target_include_directories
(
transformer_engine PRIVATE
"
${
CUDNN_FRONTEND_INCLUDE_DIR
}
"
)
target_include_directories
(
transformer_engine PRIVATE
"
${
CUDNN_FRONTEND_INCLUDE_DIR
}
"
)
else
()
# Aotriton is currently unsupported
set
(
AotritonAndCk_fused_attn
"unsupported"
)
find_package
(
hip
)
list
(
APPEND transformer_engine_LINKER_LIBS hip::host hip::device roctx64
)
if
(
USE_HIPBLASLT
)
find_package
(
hipblaslt
)
find_package
(
hipblas REQUIRED PATHS
${
ROCM_PATH
}
)
target_compile_definitions
(
transformer_engine PUBLIC USE_HIPBLASLT
)
list
(
APPEND transformer_engine_LINKER_LIBS roc::hipblaslt hipblas
)
endif
()
if
(
USE_ROCBLAS
)
find_package
(
rocblas
)
target_compile_definitions
(
transformer_engine PUBLIC USE_ROCBLAS
)
list
(
APPEND transformer_engine_LINKER_LIBS roc::rocblas
)
endif
()
target_link_libraries
(
transformer_engine PUBLIC
${
transformer_engine_LINKER_LIBS
}
)
endif
()
# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI
option
(
NVTE_UB_WITH_MPI
"Bootstrap Userbuffers with MPI"
OFF
)
...
...
@@ -113,8 +268,10 @@ if (NVTE_UB_WITH_MPI)
target_compile_definitions
(
transformer_engine PUBLIC NVTE_UB_WITH_MPI
)
endif
()
# Hack to enable dynamic loading in cuDNN frontend
target_compile_definitions
(
transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING
)
if
(
USE_CUDA
)
# Hack to enable dynamic loading in cuDNN frontend
target_compile_definitions
(
transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING
)
endif
()
# Helper functions to make header files with C++ strings
function
(
make_string_header STRING STRING_NAME
)
...
...
@@ -130,17 +287,34 @@ function(make_string_header_from_file file_ STRING_NAME)
endfunction
()
# Header files with C++ strings
list
(
GET CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES 0 cuda_include_path
)
make_string_header
(
"
${
cuda_include_path
}
"
if
(
USE_CUDA
)
list
(
GET CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES 0 cuda_include_path
)
make_string_header
(
"
${
cuda_include_path
}
"
string_path_cuda_include
)
make_string_header_from_file
(
transpose/rtc/cast_transpose_fusion.cu
make_string_header_from_file
(
transpose/rtc/cast_transpose_fusion.cu
string_code_transpose_rtc_cast_transpose_fusion_cu
)
make_string_header_from_file
(
transpose/rtc/cast_transpose.cu
make_string_header_from_file
(
transpose/rtc/cast_transpose.cu
string_code_transpose_rtc_cast_transpose_cu
)
make_string_header_from_file
(
transpose/rtc/transpose.cu
make_string_header_from_file
(
transpose/rtc/transpose.cu
string_code_transpose_rtc_transpose_cu
)
make_string_header_from_file
(
utils.cuh
make_string_header_from_file
(
utils.cuh
string_code_utils_cuh
)
else
()
make_string_header_from_file
(
utils_hip.cuh
string_code_utils_cuh
)
make_string_header_from_file
(
transpose/rtc/cast_transpose_fusion.hip
string_code_transpose_rtc_cast_transpose_fusion_cu
)
make_string_header_from_file
(
transpose/rtc/cast_transpose.hip
string_code_transpose_rtc_cast_transpose_cu
)
make_string_header_from_file
(
transpose/rtc/transpose.hip
string_code_transpose_rtc_transpose_cu
)
make_string_header_from_file
(
amd_detail/hip_float8.h
string_code_amd_detail_hip_float8_h
)
make_string_header_from_file
(
amd_detail/hip_f8_impl.h
string_code_amd_detail_hip_f8_impl_h
)
endif
()
make_string_header_from_file
(
util/math.h
string_code_util_math_h
)
target_include_directories
(
transformer_engine PRIVATE
...
...
@@ -160,8 +334,22 @@ if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
PROPERTIES
COMPILE_OPTIONS
"--use_fast_math"
)
endif
()
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
--expt-relaxed-constexpr"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-O3"
)
if
(
USE_CUDA
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
--expt-relaxed-constexpr"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-O3"
)
else
()
set
(
CMAKE_HIP_FLAGS
"
${
CMAKE_HIP_FLAGS
}
-O3"
)
set
(
HIP_HCC_FLAGS
"
${
CMAKE_HIP_FLAGS
}
-mavx2 -mf16c -mfma -std=c++17"
)
# Ask hcc to generate device code during compilation so we can use
# host linker to link.
set
(
HIP_HCC_FLAGS
"
${
HIP_HCC_FLAGS
}
-fno-gpu-rdc -Wno-defaulted-function-deleted"
)
foreach
(
rocm_arch
${
CMAKE_HIP_ARCHITECTURES
}
)
# if CMAKE_CXX_FLAGS has --offload-arch set already, better to rm first
set
(
HIP_HCC_FLAGS
"
${
HIP_HCC_FLAGS
}
--offload-arch=
${
rocm_arch
}
"
)
endforeach
()
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
${
HIP_HCC_FLAGS
}
"
)
endif
()
# Number of parallel build jobs
if
(
ENV{MAX_JOBS}
)
...
...
transformer_engine/common/__init__.py
View file @
c520cba3
...
...
@@ -124,6 +124,9 @@ def _load_nvrtc():
if
"NVTE_PROJECT_BUILDING"
not
in
os
.
environ
or
bool
(
int
(
os
.
getenv
(
"NVTE_RELEASE_BUILD"
,
"0"
))):
try
:
_CUDNN_LIB_CTYPES
=
_load_cudnn
()
_NVRTC_LIB_CTYPES
=
_load_nvrtc
()
except
OSError
:
pass
_TE_LIB_CTYPES
=
_load_library
()
transformer_engine/common/amd_detail/hip_f8_impl.h
0 → 100644
View file @
c520cba3
/*************************************************************************
* Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
*
* License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/
#include <hip/hip_runtime.h>
namespace
hip_f8_impl
{
HIP_HOST_DEVICE
inline
int
clz
(
uint32_t
x
)
{
#ifdef __HIP_DEVICE_COMPILE__
return
__clz
(
x
);
#else
return
__builtin_clz
(
x
);
#endif
}
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
,
bool
clip
>
HIP_HOST_DEVICE
uint8_t
cast_to_f8
(
T
_x
,
bool
stoch
,
uint32_t
rng
)
{
constexpr
bool
is_half
=
std
::
is_same
<
T
,
__half
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
static_assert
(
wm
+
we
==
7
,
"wm+we==7"
);
static_assert
(
is_half
||
is_float
,
"Only half and float can be cast to f8"
);
const
int
mfmt
=
(
sizeof
(
T
)
==
4
)
?
23
:
10
;
uint32_t
x
;
if
(
sizeof
(
T
)
==
4
)
x
=
reinterpret_cast
<
uint32_t
&>
(
_x
);
else
x
=
reinterpret_cast
<
uint16_t
&>
(
_x
);
uint32_t
y
,
head
,
mantissa
;
int
exponent
,
bias
;
uint32_t
sign
;
if
(
sizeof
(
T
)
==
4
)
{
head
=
x
&
0xFF800000
;
mantissa
=
x
&
0x7FFFFF
;
exponent
=
(
head
>>
23
)
&
0xFF
;
sign
=
head
>>
31
;
bias
=
127
;
}
else
{
head
=
x
&
0xFC00
;
mantissa
=
x
&
0x3FF
;
exponent
=
(
head
>>
10
)
&
0x1F
;
sign
=
head
>>
15
;
bias
=
15
;
}
uint32_t
signed_inf
=
(
sign
<<
7
)
+
(((
1
<<
we
)
-
1
)
<<
wm
);
// Deal with inf and NaNs
if
(
negative_zero_nan
)
{
if
(
sizeof
(
T
)
==
4
)
{
if
((
x
&
0x7F800000
)
==
0x7F800000
)
return
0x80
;
}
else
{
//if(__hisinf(x) || __hisnan(x))
if
((
x
&
0x7C00
)
==
0x7C00
)
return
0x80
;
}
}
else
{
if
(
sizeof
(
T
)
==
4
)
{
if
((
x
&
0x7F800000
)
==
0x7F800000
)
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
}
else
{
if
((
x
&
0x7C00
)
==
0x7C00
)
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
}
}
if
(
x
==
0
)
return
0
;
// First need to check if it is normal or denorm as there is a difference of implict 1
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
// exponent and mantissa again
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits
const
int
f8_bias
=
(
1
<<
(
we
-
1
)
)
-
1
+
(
negative_zero_nan
?
1
:
0
);
const
int
f8_denormal_act_exponent
=
1
-
f8_bias
;
//actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// f8_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int
act_exponent
,
f8_exponent
,
exponent_diff
;
if
(
exponent
==
0
)
{
// fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16 here.
In this case, f8 is usually in denormal. But there could be exceptions.
fp16 denormal has exponent bias 15 while bf8 with NANOO has exponent bias 16.
It means that there are some numbers in fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15.
fp16 numbers where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
In this case, the fp16 mantissa should be shift left by 1 */
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
f8_denormal_act_exponent
-
act_exponent
;
// actual exponent is exponent-bias+1 as it is denormal
}
else
{
// fp32/fp16 is normal with implicit 1
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
f8_denormal_act_exponent
)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff
=
f8_denormal_act_exponent
-
act_exponent
;
}
else
{
//both fp32/fp16 and f8 are in normal range
exponent_diff
=
0
;
// exponent_diff=0 does not mean there is no difference for this case,
//act_exponent could be larger. Just that it does not need shift mantissa
}
mantissa
+=
(
1
<<
mfmt
);
//Add the implicit 1 into mantissa
}
bool
midpoint
;
if
(
exponent_diff
<=
wm
+
1
)
// The determinination of midpoint only makes sense when wm+1 could compensate the difference in exponent.
// Why wm+1 instead of wm? It is because in additional to the wm bits to be left as f8 mantissa, there is
// also the implicit 1 (There is not always implicit 1 but it does not matter).
midpoint
=
(
mantissa
&
(
(
1
<<
(
mfmt
-
wm
+
exponent_diff
))
-
1
))
==
(
1
<<
(
mfmt
-
wm
+
exponent_diff
-
1
)
);
else
midpoint
=
false
;
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we shift right
as shift right could rip off some residual part and make something not midpoint look like midpoint.
For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than midpoint,
but after shift right by 4 bits, it would look like midpoint.
*/
if
(
exponent_diff
>
0
)
// Clip the exponent_diff as if the right shift when exponent_diff > 31 is undefined behavior
mantissa
>>=
(
exponent_diff
>
31
?
31
:
exponent_diff
);
else
if
(
exponent_diff
==
-
1
)
mantissa
<<=
-
exponent_diff
;
bool
implicit_one
=
mantissa
&
(
1
<<
mfmt
);
//if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
f8_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
f8_bias
-
(
implicit_one
?
0
:
1
);
//Now we have the exponent and mantissa adjusted
uint32_t
drop_mask
=
(
1
<<
(
mfmt
-
wm
))
-
1
;
//bool midpoint = (mantissa & drop_mask) == ( 1 << (mfmt-wm-1) );
bool
odd
=
mantissa
&
(
1
<<
(
mfmt
-
wm
));
// if the least significant bit that is not truncated is 1
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
)
)
&
drop_mask
;
//Now we deal with overflow
if
(
f8_exponent
==
0
)
{
if
((
1
<<
mfmt
)
&
mantissa
)
{
f8_exponent
=
1
;
//denormal overflow to become normal, promote exponent
//mantissa &= (1<<mfmt) -1 ; //No need to make 1 implicit now as it will be addressed later
}
}
else
{
if
((
1
<<
(
mfmt
+
1
))
&
mantissa
)
{
mantissa
>>=
1
;
f8_exponent
++
;
//mantissa &= (1<<mfmt) -1 ; // No need to make 1 implicit now as it will be addressed later
}
}
mantissa
>>=
(
mfmt
-
wm
);
// above range: quantize to maximum possible float of the same sign
const
int
max_exp
=
(
1
<<
we
)
-
(
negative_zero_nan
?
1
:
2
);
if
(
f8_exponent
>
max_exp
)
{
if
(
clip
)
{
mantissa
=
(
1
<<
wm
)
-
1
;
f8_exponent
=
max_exp
;
}
else
{
return
signed_inf
;
}
}
if
(
f8_exponent
==
0
&&
mantissa
==
0
)
return
negative_zero_nan
?
0
:
(
sign
<<
7
);
mantissa
&=
(
1
<<
wm
)
-
1
;
return
(
sign
<<
7
)
|
(
f8_exponent
<<
wm
)
|
mantissa
;
}
/* RTC does not have std::conditional so implement it here*/
template
<
bool
B
,
class
T
,
class
F
>
struct
conditional
{
typedef
T
type
;
};
template
<
class
T
,
class
F
>
struct
conditional
<
false
,
T
,
F
>
{
typedef
F
type
;
};
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
>
HIP_HOST_DEVICE
T
cast_from_f8
(
uint8_t
x
)
{
constexpr
bool
is_half
=
std
::
is_same
<
T
,
__half
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
constexpr
bool
is_bf16
=
std
::
is_same
<
T
,
hip_bfloat16
>::
value
;
static_assert
(
is_half
||
is_float
,
"only half and float are supported"
);
constexpr
int
weo
=
is_half
?
5
:
8
;
constexpr
int
wmo
=
is_half
?
10
:
(
is_float
?
23
:
7
);
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
if
(
is_half
)
{
const
uint16_t
ihInf
=
0x7C00
;
const
uint16_t
ihNegInf
=
0xFC00
;
const
uint16_t
ihNaN
=
0x7C01
;
const
uint16_t
ihNeg0
=
0x8000
;
fInf
=
reinterpret_cast
<
const
__half
&>
(
ihInf
);
fNegInf
=
reinterpret_cast
<
const
__half
&>
(
ihNegInf
);
fNaN
=
reinterpret_cast
<
const
__half
&>
(
ihNaN
);
fNeg0
=
reinterpret_cast
<
const
__half
&>
(
ihNeg0
);
}
else
if
(
is_float
)
{
const
uint32_t
ifInf
=
0x7F800000
;
const
uint32_t
ifNegInf
=
0xFF800000
;
const
uint32_t
ifNaN
=
0x7F800001
;
const
uint32_t
ifNeg0
=
0x80000000
;
fInf
=
reinterpret_cast
<
const
float
&>
(
ifInf
);
fNegInf
=
reinterpret_cast
<
const
float
&>
(
ifNegInf
);
fNaN
=
reinterpret_cast
<
const
float
&>
(
ifNaN
);
fNeg0
=
reinterpret_cast
<
const
float
&>
(
ifNeg0
);
}
if
(
x
==
0
)
return
0
;
uint32_t
sign
=
x
>>
7
;
uint32_t
mantissa
=
x
&
((
1
<<
wm
)
-
1
);
int
exponent
=
(
x
&
0x7F
)
>>
wm
;
if
(
negative_zero_nan
)
{
if
(
x
==
0x80
)
return
fNaN
;
}
else
{
if
(
x
==
0x80
)
return
fNeg0
;
if
(
exponent
==
((
1
<<
we
)
-
1
))
return
(
mantissa
==
0
)
?
(
sign
?
fNegInf
:
fInf
)
:
fNaN
;
}
typename
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
retval
;
if
(
we
==
5
&&
is_half
&&
!
negative_zero_nan
)
{
retval
=
x
<<
8
;
return
reinterpret_cast
<
const
T
&>
(
retval
);
}
const
int
exp_low_cutoff
=
(
1
<<
(
weo
-
1
))
-
(
1
<<
(
we
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
//subnormal input
if
(
exponent
==
0
)
{
//guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int
sh
=
1
+
clz
(
mantissa
)
-
(
32
-
wm
);
mantissa
<<=
sh
;
exponent
+=
1
-
sh
;
/*
exponent++;
while(mantissa<(1<<wm)) {
mantissa <<= 1;
exponent--;
}
*/
mantissa
&=
((
1
<<
wm
)
-
1
);
}
exponent
+=
exp_low_cutoff
-
1
;
mantissa
<<=
wmo
-
wm
;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if
(
exponent
<=
0
)
{
mantissa
|=
1
<<
wmo
;
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
}
if
(
sizeof
(
T
)
==
2
)
retval
=
(
sign
<<
15
)
|
(
exponent
<<
10
)
|
mantissa
;
else
retval
=
(
sign
<<
31
)
|
(
exponent
<<
23
)
|
mantissa
;
return
reinterpret_cast
<
const
T
&>
(
retval
);
}
}
// namespace hip_f8_impl
\ No newline at end of file
transformer_engine/common/amd_detail/hip_float8.h
0 → 100644
View file @
c520cba3
/*************************************************************************
* Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
*
* License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/
#pragma once
// FP8 header version 0.3, 2021/05/11
#include <hip/hip_runtime.h>
#define HIP_HOST_DEVICE __host__ __device__
#define HIP_DEVICE __device__
#define HIP_HOST __host__
#define E5M2_AMAX 57344.0
#define E4M3_AMAX 240.0
namespace
hip_f8_impl
{
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
,
bool
clip
>
HIP_HOST_DEVICE
uint8_t
cast_to_f8
(
T
_x
,
bool
stoch
=
false
,
uint32_t
rng
=
0
);
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
>
HIP_HOST_DEVICE
T
cast_from_f8
(
uint8_t
x
);
}
// namespace hip_f8_impl
#include "hip_f8_impl.h"
enum
class
hip_f8_type
{
bf8
=
0
,
// 1:5:2
fp8
=
1
// 1:4:3
};
enum
class
hip_f8_rounding_mode
{
standard
,
stochastic
};
// bias mode bit implementation
//
// For MI100 simulation purpose, we keep a copy of it on the host and device
// (MI300 HW implementation will be different)
//
// The bias mode should only be accessed via its get/set routines.
// The set routine sets both copies to the same value, keeping them in sync
// The get routine will return the device copy for device functions and
// the host copy for host functions
//
// "bias mode optimial"
// => "bias mode bit" = 1
// => bias = 16 for 152, 8 for 143
// => NAN/INF are represented as negative_zero
//
// "bias mode ieee"
// => "bias mode bit" = 0
// => bias = 15 for 152, 7 for 143
// => NAN/INF are represented as per IEEE conventions
#ifndef __HIPCC_RTC__
static
bool
hip_f8_bias_mode_bit_host
=
true
;
static
inline
__host__
bool
get_hip_f8_bias_mode
()
{
return
hip_f8_bias_mode_bit_host
;
}
#endif // __HIPCC_RTC__
#ifdef __HIPCC__
static
__device__
bool
hip_f8_bias_mode_bit_device
=
true
;
static
inline
__device__
bool
get_hip_f8_bias_mode
()
{
return
hip_f8_bias_mode_bit_device
;
}
#ifndef __HIPCC_RTC__
static
__global__
void
set_hip_f8_bias_mode_bit
(
bool
v
)
{
hip_f8_bias_mode_bit_device
=
v
;
}
static
void
set_hip_f8_bias_mode_ieee
()
{
hipLaunchKernelGGL
(
set_hip_f8_bias_mode_bit
,
dim3
(
1
),
dim3
(
1
),
0
,
0
,
false
);
hip_f8_bias_mode_bit_host
=
false
;
}
static
void
set_hip_f8_bias_mode_optimal
()
{
hipLaunchKernelGGL
(
set_hip_f8_bias_mode_bit
,
dim3
(
1
),
dim3
(
1
),
0
,
0
,
true
);
hip_f8_bias_mode_bit_host
=
true
;
}
#endif // __HIPCC_RTC__
#endif // __HIPCC__
template
<
hip_f8_type
T
>
struct
hip_f8
{
uint8_t
data
;
// default constructor
HIP_HOST_DEVICE
hip_f8
()
=
default
;
// constructor from bits
explicit
HIP_HOST_DEVICE
hip_f8
(
uint8_t
v
)
{
data
=
v
;
}
// constructor from float
#ifdef __gfx942__
explicit
HIP_DEVICE
hip_f8
(
float
v
,
hip_f8_rounding_mode
rm
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
)
{
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
}
val
;
uint32_t
ival
=
0
;
val
.
fval
=
v
;
if
(
T
==
hip_f8_type
::
bf8
)
{
// bf8
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
// propagate NAN/INF, no clipping
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
E5M2_AMAX
,
-
E5M2_AMAX
);
if
(
rm
==
hip_f8_rounding_mode
::
standard
)
{
// RNE rounding
ival
=
__builtin_amdgcn_cvt_pk_bf8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
data
=
val
.
i8val
[
0
];
}
else
{
//stochastic rounding
ival
=
__builtin_amdgcn_cvt_sr_bf8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
data
=
val
.
i8val
[
0
];
// little endian
}
}
else
{
// fp8
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
/// propagate NAN/INF, no clipping
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
E4M3_AMAX
,
-
E4M3_AMAX
);
if
(
rm
==
hip_f8_rounding_mode
::
standard
)
{
// RNE rounding
ival
=
__builtin_amdgcn_cvt_pk_fp8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
data
=
val
.
i8val
[
0
];
}
else
{
//stochastic rounding
ival
=
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
data
=
val
.
i8val
[
0
];
// little endian
}
}
}
#ifndef __HIPCC_RTC__
explicit
HIP_HOST
//Code host still uses SW simulated conversion on gfx942
hip_f8
(
float
v
,
hip_f8_rounding_mode
rm
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
)
{
if
(
T
==
hip_f8_type
::
bf8
)
{
if
(
get_hip_f8_bias_mode
())
{
data
=
hip_f8_impl
::
cast_to_f8
<
2
,
5
,
float
,
true
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
hip_f8_rounding_mode
::
stochastic
),
rng
);
}
else
{
data
=
hip_f8_impl
::
cast_to_f8
<
2
,
5
,
float
,
false
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
hip_f8_rounding_mode
::
stochastic
),
rng
);
}
}
else
/* fp8*/
{
if
(
get_hip_f8_bias_mode
())
{
data
=
hip_f8_impl
::
cast_to_f8
<
3
,
4
,
float
,
true
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
hip_f8_rounding_mode
::
stochastic
),
rng
);
}
else
{
data
=
hip_f8_impl
::
cast_to_f8
<
3
,
4
,
float
,
false
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
hip_f8_rounding_mode
::
stochastic
),
rng
);
}
}
}
#endif
#else // #ifndef __gfx942__
explicit
HIP_HOST_DEVICE
// On architectures other than gfx942, both host and device still use SW simulated conversion
hip_f8
(
float
v
,
hip_f8_rounding_mode
rm
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
)
{
if
(
T
==
hip_f8_type
::
bf8
)
{
if
(
get_hip_f8_bias_mode
())
{
data
=
hip_f8_impl
::
cast_to_f8
<
2
,
5
,
float
,
true
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
hip_f8_rounding_mode
::
stochastic
),
rng
);
}
else
{
data
=
hip_f8_impl
::
cast_to_f8
<
2
,
5
,
float
,
false
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
hip_f8_rounding_mode
::
stochastic
),
rng
);
}
}
else
/* fp8*/
{
if
(
get_hip_f8_bias_mode
())
{
data
=
hip_f8_impl
::
cast_to_f8
<
3
,
4
,
float
,
true
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
hip_f8_rounding_mode
::
stochastic
),
rng
);
}
else
{
data
=
hip_f8_impl
::
cast_to_f8
<
3
,
4
,
float
,
false
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
hip_f8_rounding_mode
::
stochastic
),
rng
);
}
}
}
#endif // #ifdef __gfx942__
// constructor from half
#ifdef __gfx942__
explicit
HIP_DEVICE
hip_f8
(
half
v
,
hip_f8_rounding_mode
rm
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
)
:
hip_f8
((
float
)
v
,
rm
,
rng
)
{
}
#ifndef __HIPCC_RTC__
explicit
HIP_HOST
//Code host still uses SW simulated conversion on gfx942
hip_f8
(
half
v
,
hip_f8_rounding_mode
rm
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
)
{
if
(
T
==
hip_f8_type
::
bf8
)
{
if
(
get_hip_f8_bias_mode
())
{
data
=
hip_f8_impl
::
cast_to_f8
<
2
,
5
,
half
,
true
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
hip_f8_rounding_mode
::
stochastic
),
rng
);
}
else
{
data
=
hip_f8_impl
::
cast_to_f8
<
2
,
5
,
half
,
false
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
hip_f8_rounding_mode
::
stochastic
),
rng
);
}
}
else
/* fp8*/
{
if
(
get_hip_f8_bias_mode
())
{
data
=
hip_f8_impl
::
cast_to_f8
<
3
,
4
,
half
,
true
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
hip_f8_rounding_mode
::
stochastic
),
rng
);
}
else
{
data
=
hip_f8_impl
::
cast_to_f8
<
3
,
4
,
half
,
false
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
hip_f8_rounding_mode
::
stochastic
),
rng
);
}
}
}
#endif
#else // #ifndef __gfx942__
explicit
HIP_HOST_DEVICE
// On architectures other than gfx942, both host and device still use SW simulated conversion
hip_f8
(
half
v
,
hip_f8_rounding_mode
rm
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
)
{
if
(
T
==
hip_f8_type
::
bf8
)
{
if
(
get_hip_f8_bias_mode
())
{
data
=
hip_f8_impl
::
cast_to_f8
<
2
,
5
,
half
,
true
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
hip_f8_rounding_mode
::
stochastic
),
rng
);
}
else
{
data
=
hip_f8_impl
::
cast_to_f8
<
2
,
5
,
half
,
false
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
hip_f8_rounding_mode
::
stochastic
),
rng
);
}
}
else
/* fp8*/
{
if
(
get_hip_f8_bias_mode
())
{
data
=
hip_f8_impl
::
cast_to_f8
<
3
,
4
,
half
,
true
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
hip_f8_rounding_mode
::
stochastic
),
rng
);
}
else
{
data
=
hip_f8_impl
::
cast_to_f8
<
3
,
4
,
half
,
false
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
hip_f8_rounding_mode
::
stochastic
),
rng
);
}
}
}
#endif // #ifdef __gfx942__
// constructor from hip_bfloat16
explicit
HIP_HOST_DEVICE
hip_f8
(
hip_bfloat16
v
,
hip_f8_rounding_mode
r
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
);
// convert to float
#ifdef __gfx942__
HIP_DEVICE
operator
float
()
const
{
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// dependent of endian
}
val
;
// assign 8bit data in position [7:0]
val
.
i32val
=
0
;
val
.
i8val
[
3
]
=
data
;
// little endian
// upcast
if
(
T
==
hip_f8_type
::
bf8
)
val
.
fval
=
__builtin_amdgcn_cvt_f32_bf8
(
val
.
i32val
,
3
);
// 0 pos
else
// fp8
val
.
fval
=
__builtin_amdgcn_cvt_f32_fp8
(
val
.
i32val
,
3
);
// 0 pos
return
val
.
fval
;
}
#ifndef __HIPCC_RTC__
explicit
inline
HIP_HOST
//Code host still uses SW simulated conversion on gfx942
operator
float
()
const
{
if
(
T
==
hip_f8_type
::
bf8
)
{
if
(
get_hip_f8_bias_mode
())
{
return
hip_f8_impl
::
cast_from_f8
<
2
,
5
,
float
,
true
/*negative_zero_nan*/
>
(
data
);
}
else
{
return
hip_f8_impl
::
cast_from_f8
<
2
,
5
,
float
,
false
/*negative_zero_nan*/
>
(
data
);
}
}
else
/* fp8*/
{
if
(
get_hip_f8_bias_mode
())
{
return
hip_f8_impl
::
cast_from_f8
<
3
,
4
,
float
,
true
/*negative_zero_nan*/
>
(
data
);
}
else
{
return
hip_f8_impl
::
cast_from_f8
<
3
,
4
,
float
,
false
/*negative_zero_nan*/
>
(
data
);
}
}
}
#endif
#else // #ifdef __gfx942__
explicit
inline
HIP_HOST_DEVICE
// On architectures other than gfx942, both host and device still use SW simulated conversion
operator
float
()
const
{
if
(
T
==
hip_f8_type
::
bf8
)
{
if
(
get_hip_f8_bias_mode
())
{
return
hip_f8_impl
::
cast_from_f8
<
2
,
5
,
float
,
true
/*negative_zero_nan*/
>
(
data
);
}
else
{
return
hip_f8_impl
::
cast_from_f8
<
2
,
5
,
float
,
false
/*negative_zero_nan*/
>
(
data
);
}
}
else
/* fp8*/
{
if
(
get_hip_f8_bias_mode
())
{
return
hip_f8_impl
::
cast_from_f8
<
3
,
4
,
float
,
true
/*negative_zero_nan*/
>
(
data
);
}
else
{
return
hip_f8_impl
::
cast_from_f8
<
3
,
4
,
float
,
false
/*negative_zero_nan*/
>
(
data
);
}
}
}
#endif // #ifdef __gfx942__
// convert to half
#ifdef __gfx942__
explicit
HIP_DEVICE
inline
operator
half
()
const
{
return
__half
(
float
(
*
this
));
}
#ifndef __HIPCC_RTC__
explicit
inline
HIP_HOST
//Code host still uses SW simulated conversion on gfx942
operator
half
()
const
{
if
(
T
==
hip_f8_type
::
bf8
)
{
if
(
get_hip_f8_bias_mode
())
{
return
hip_f8_impl
::
cast_from_f8
<
2
,
5
,
half
,
true
/*negative_zero_nan*/
>
(
data
);
}
else
{
return
hip_f8_impl
::
cast_from_f8
<
2
,
5
,
half
,
false
/*negative_zero_nan*/
>
(
data
);
}
}
else
/* fp8*/
{
if
(
get_hip_f8_bias_mode
())
{
return
hip_f8_impl
::
cast_from_f8
<
3
,
4
,
half
,
true
/*negative_zero_nan*/
>
(
data
);
}
else
{
return
hip_f8_impl
::
cast_from_f8
<
3
,
4
,
half
,
false
/*negative_zero_nan*/
>
(
data
);
}
}
}
#endif
#else // #ifndef __gfx942__
explicit
inline
HIP_HOST_DEVICE
// On architectures other than gfx942, both host and device still use SW simulated conversion
operator
half
()
const
{
if
(
T
==
hip_f8_type
::
bf8
)
{
if
(
get_hip_f8_bias_mode
())
{
return
hip_f8_impl
::
cast_from_f8
<
2
,
5
,
half
,
true
/*negative_zero_nan*/
>
(
data
);
}
else
{
return
hip_f8_impl
::
cast_from_f8
<
2
,
5
,
half
,
false
/*negative_zero_nan*/
>
(
data
);
}
}
else
/* fp8*/
{
if
(
get_hip_f8_bias_mode
())
{
return
hip_f8_impl
::
cast_from_f8
<
3
,
4
,
half
,
true
/*negative_zero_nan*/
>
(
data
);
}
else
{
return
hip_f8_impl
::
cast_from_f8
<
3
,
4
,
half
,
false
/*negative_zero_nan*/
>
(
data
);
}
}
}
#endif // #ifdef __gfx942__
// convert to hip_bfloat16
explicit
inline
HIP_HOST_DEVICE
operator
hip_bfloat16
()
const
;
// check for zero
inline
HIP_HOST_DEVICE
bool
is_zero
()
const
{
if
(
get_hip_f8_bias_mode
())
{
return
data
==
0x00
;
}
else
{
return
(
data
==
0x00
)
||
(
data
==
0x80
);
}
}
// check for nan
inline
HIP_HOST_DEVICE
bool
is_nan
()
const
{
if
(
get_hip_f8_bias_mode
())
{
return
data
==
0x80
;
}
else
{
if
(
T
==
hip_f8_type
::
bf8
)
{
return
(
data
==
0x7d
)
||
(
data
==
0x7e
)
||
(
data
==
0x7f
)
||
(
data
==
0xfd
)
||
(
data
==
0xfe
)
||
(
data
==
0xff
);
}
else
{
return
(
data
==
0x79
)
||
(
data
==
0x7a
)
||
(
data
==
0x7b
)
||
(
data
==
0x7c
)
||
(
data
==
0x7d
)
||
(
data
==
0x7e
)
||
(
data
==
0x7f
)
||
(
data
==
0xf9
)
||
(
data
==
0xfa
)
||
(
data
==
0xfb
)
||
(
data
==
0xfc
)
||
(
data
==
0xfd
)
||
(
data
==
0xfe
)
||
(
data
==
0xff
);
}
}
}
// check for inf
inline
HIP_HOST_DEVICE
bool
is_inf
()
const
{
if
(
get_hip_f8_bias_mode
())
{
return
data
==
0x80
;
}
else
{
if
(
T
==
hip_f8_type
::
bf8
)
{
return
(
data
==
0x7c
)
||
(
data
==
0xfc
);
}
else
{
return
(
data
==
0x78
)
||
(
data
==
0xf8
);
}
}
}
};
#ifdef __HIPCC__
template
<
hip_f8_type
T
>
struct
hip_f8x4
{
// define some convenience types
typedef
float
float32x2
__attribute__
((
ext_vector_type
(
2
)));
typedef
float
float32x4
__attribute__
((
ext_vector_type
(
4
)));
typedef
_Float16
halfx2
__attribute__
((
ext_vector_type
(
2
)));
typedef
_Float16
halfx4
__attribute__
((
ext_vector_type
(
4
)));
typedef
uint16_t
hip_bfloat16x2
__attribute__
((
ext_vector_type
(
2
)));
typedef
uint16_t
hip_bfloat16x4
__attribute__
((
ext_vector_type
(
4
)));
uint32_t
data
;
// default constructor
HIP_HOST_DEVICE
hip_f8x4
()
=
default
;
// constructor from bits
HIP_HOST_DEVICE
hip_f8x4
(
uint32_t
v
);
// constructor from float
HIP_HOST_DEVICE
hip_f8x4
(
float
v0
,
float
v1
=
0
,
float
v2
=
0
,
float
v3
=
0
,
hip_f8_rounding_mode
rm
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
);
HIP_HOST_DEVICE
hip_f8x4
(
float32x2
v
,
hip_f8_rounding_mode
rm
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
);
HIP_HOST_DEVICE
hip_f8x4
(
float32x4
v
,
hip_f8_rounding_mode
rm
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
);
// constructor from half
HIP_HOST_DEVICE
hip_f8x4
(
half
v0
,
half
v1
=
0
,
half
v2
=
0
,
half
v3
=
0
,
hip_f8_rounding_mode
rm
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
);
HIP_HOST_DEVICE
hip_f8x4
(
halfx2
v
,
hip_f8_rounding_mode
rm
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
);
HIP_HOST_DEVICE
hip_f8x4
(
halfx4
v
,
hip_f8_rounding_mode
rm
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
);
// constructor from hip_bfloat16
HIP_HOST_DEVICE
hip_f8x4
(
hip_bfloat16
v0
,
hip_bfloat16
v1
=
hip_bfloat16
(
0.0
f
),
hip_bfloat16
v2
=
hip_bfloat16
(
0.0
f
),
hip_bfloat16
v3
=
hip_bfloat16
(
0.0
f
),
hip_f8_rounding_mode
rm
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
);
HIP_HOST_DEVICE
hip_f8x4
(
hip_bfloat16x2
v
,
hip_f8_rounding_mode
rm
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
);
HIP_HOST_DEVICE
hip_f8x4
(
hip_bfloat16x4
v
,
hip_f8_rounding_mode
rm
=
hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
);
// convert to float32x4
inline
HIP_HOST_DEVICE
operator
float32x4
()
const
;
// convert to halfx4
inline
HIP_HOST_DEVICE
operator
halfx4
()
const
;
// convert to hip_bfloat16x4
inline
HIP_HOST_DEVICE
operator
hip_bfloat16x4
()
const
;
};
template
<
hip_f8_type
T
>
struct
hip_f8x8
{
// define some convenience types
typedef
hip_f8x4
<
T
>
f8x8
__attribute__
((
ext_vector_type
(
2
)));
f8x8
data
;
// default constructor
HIP_HOST_DEVICE
hip_f8x8
()
=
default
;
// do we need to define other constructors or any conversion routines here?
};
// If we do not end up needing either any constructors or conversion routines for the above type, then
// we can simplify the above type to the following
#if USE_SIMPLER_HIP_F8x8
template
<
hip_f8_type
T
>
using
hip_f8x8
=
hip_f8x4
<
T
>
__attribute__
((
ext_vector_type
(
2
)));
#endif
typedef
float
hip_float32x4
__attribute__
((
ext_vector_type
(
4
)));
typedef
float
hip_float32x16
__attribute__
((
ext_vector_type
(
16
)));
// these are device-specific and we don't expect them to exist unless we're compiling with hip-clang for gfx942.
template
<
hip_f8_type
T_A
,
hip_f8_type
T_B
>
__device__
hip_float32x4
mfma_f32_16x16x32
(
hip_f8x8
<
T_A
>
a
,
hip_f8x8
<
T_B
>
b
,
hip_float32x4
c
);
template
<
hip_f8_type
T_A
,
hip_f8_type
T_B
>
__device__
hip_float32x16
mfma_f32_32x32x16
(
hip_f8x8
<
T_A
>
a
,
hip_f8x8
<
T_B
>
b
,
hip_float32x16
c
);
#endif //__HIPCC__
\ No newline at end of file
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
View file @
c520cba3
...
...
@@ -107,8 +107,12 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
This is needed only for Hopper, which uses persistent CTA execution.
*/
int
max_connection
=
transformer_engine
::
getenv
<
int
>
(
"CUDA_DEVICE_MAX_CONNECTIONS"
,
8
);
#ifdef USE_ROCM
int
runtime_version
=
6
;
#else
int
runtime_version
=
0
;
cudaRuntimeGetVersion
(
&
runtime_version
);
#endif
cudaDeviceProp
deviceProp
;
cudaGetDeviceProperties
(
&
deviceProp
,
0
);
if
(
runtime_version
>=
12030
&&
deviceProp
.
major
==
9
&&
max_connection
>
1
)
{
...
...
@@ -277,7 +281,11 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
assert
(
rs_output
.
size
(
0
)
==
_ubuf
.
size
(
0
)
/
_tp_size
);
assert
(
rs_output
.
element_size
()
==
2
);
char
*
rs_output_ptr
=
reinterpret_cast
<
char
*>
(
rs_output
.
dptr
());
#ifdef USE_ROCM
reducescatter2_userbuff_fp8
<
hip_f8
<
hip_f8_type
::
bf8
>>
(
rs_output_ptr
,
_ubuf
.
scale_inv
(),
_ub_reg
,
0
,
#else
reducescatter2_userbuff_fp8
<
__nv_fp8_e5m2
>
(
rs_output_ptr
,
_ubuf
.
scale_inv
(),
_ub_reg
,
0
,
#endif
comm_elements
,
_ub_comm
,
_stream_comm
,
(
cudaEvent_t
)
_comm_launch_event
);
}
else
{
...
...
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
View file @
c520cba3
...
...
@@ -361,7 +361,11 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks,
NVTE_CHECK_CUDA
(
cudaMalloc
(
&
(
*
comm
)
->
flags
,
2
*
GPU_PAGE_SIZE
));
NVTE_CHECK_CUDA
(
cudaMemset
((
*
comm
)
->
flags
,
0
,
2
*
GPU_PAGE_SIZE
));
(
*
comm
)
->
flags
=
#ifdef USE_ROCM
reinterpret_cast
<
int
*>
((
reinterpret_cast
<
uintptr_t
>
((
*
comm
)
->
flags
)
+
GPU_PAGE_SIZE
-
1
)
&
GPU_PAGE_MASK
#else
reinterpret_cast
<
int
*>
(((
CUdeviceptr
)(
*
comm
)
->
flags
+
GPU_PAGE_SIZE
-
1
)
&
GPU_PAGE_MASK
);
#endif
using
namespace
std
;
...
...
@@ -624,7 +628,12 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
}
else
{
#endif
if
(
alloc
)
{
#ifdef USE_ROCM
// Ref to RCCL
NVTE_CHECK_CUDA
(
hipExtMallocWithFlags
(
gpubuff
,
bytes
,
hipDeviceMallocFinegrained
));
#else
NVTE_CHECK_CUDA
(
cudaMalloc
(
gpubuff
,
bytes
));
#endif
NVTE_CHECK_CUDA
(
cudaMemset
(
*
gpubuff
,
0
,
bytes
));
}
...
...
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
View file @
c520cba3
...
...
@@ -9,8 +9,10 @@
#include <cuda_runtime.h>
#if __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#define half_dtype nv_bfloat16
#else
#include <cuda_fp16.h>
#define half_dtype half
#endif
...
...
@@ -24,6 +26,18 @@
#define MAX_THREADS 1024
#ifdef __HIP_PLATFORM_AMD__
#define ATOMIC_CONSUMER(chunk) \
if (counters) { \
if (threadIdx.x == 0 && blockIdx.x == 0) { \
while (0 != (atomicCAS(((unsigned int *)counters) + chunk, 0, 0))) { \
} \
((unsigned int *)counters)[chunk] = 1; \
__threadfence_system(); \
} \
if (blockIdx.x == 0) __syncthreads(); \
}
#else
#define ATOMIC_CONSUMER(chunk) \
if (counters) { \
if (threadIdx.x == 0 && blockIdx.x == 0) { \
...
...
@@ -34,6 +48,7 @@
} \
if (blockIdx.x == 0) __syncthreads(); \
}
#endif
#define ATOMIC_PRODUCER(chunk) \
if (counters) { \
...
...
@@ -1025,7 +1040,13 @@ __global__ void __launch_bounds__(MAX_THREADS)
// reset counter for next producer.
((
unsigned
int
*
)
counters
)[
0
]
=
1
;
#ifdef __HIP_PLATFORM_AMD__
__threadfence_system
();
// __threadfence();
// __syncthreads()
#else
asm
volatile
(
"fence.sc.gpu;
\n
"
);
#endif
}
}
__syncthreads
();
...
...
@@ -1116,7 +1137,13 @@ __global__ void __launch_bounds__(MAX_THREADS)
// reset counter for next producer.
((
unsigned
int
*
)
counters
)[
chunk_i
]
=
1
;
#ifdef __HIP_PLATFORM_AMD__
__threadfence_system
();
// __threadfence();
// __syncthreads()
#else
asm
volatile
(
"fence.sc.gpu;
\n
"
);
#endif
}
}
__syncthreads
();
...
...
@@ -1357,6 +1384,33 @@ __global__ void __launch_bounds__(MAX_THREADS)
}
}
// fp16 inplace allgather kernel (Volta,Hopper)
#ifdef __HIP_PLATFORM_AMD__
#define SETUP_LAUNCH_CONFIG(sms, threads, stream) \
cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \
cudaLaunchAttribute attribute_ub[2]; \
attribute_ub[1].id = cudaLaunchAttributeClusterDimension; \
attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \
attribute_ub[1].val.clusterDim.y = 1; \
attribute_ub[1].val.clusterDim.z = 1; \
attribute_ub[0].id = cudaLaunchAttributeCooperative; \
cfg.attrs = attribute_ub; \
cfg.numAttrs = 1; // dtk unsupport hipLaunchAttributeClusterDimension
#define ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event)
#define NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH 2
#define SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, threads, stream, comm_launch_event) \
cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \
cudaLaunchAttribute attribute_ub[NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH] = {}; \
ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) \
attribute_ub[1].id = cudaLaunchAttributeClusterDimension; \
attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \
attribute_ub[1].val.clusterDim.y = 1; \
attribute_ub[1].val.clusterDim.z = 1; \
attribute_ub[0].id = cudaLaunchAttributeCooperative; \
cfg.attrs = attribute_ub; \
cfg.numAttrs = 1; // dtk unsupport hipLaunchAttributeClusterDimension
#else
#define SETUP_LAUNCH_CONFIG(sms, threads, stream) \
cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \
cudaLaunchAttribute attribute_ub[2]; \
...
...
@@ -1389,6 +1443,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
attribute_ub[0].id = cudaLaunchAttributeCooperative; \
cfg.attrs = attribute_ub; \
cfg.numAttrs = NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH;
#endif
#define callranks_ag(x) \
if (ar_nvsize == x) { \
...
...
@@ -1932,6 +1987,53 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const
}
}
#ifdef __HIP_PLATFORM_AMD__
template
void
reducescatter2_userbuff_stridedoutput_fp8
<
hip_f8
<
hip_f8_type
::
bf8
>
>
(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
const
int
colelements
,
const
int
strideelements
,
communicator
*
comm
,
cudaStream_t
stream
,
cudaEvent_t
comm_launch_event
);
template
void
reducescatter2_userbuff_stridedoutput_fp8
<
hip_f8
<
hip_f8_type
::
fp8
>
>
(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
const
int
colelements
,
const
int
strideelements
,
communicator
*
comm
,
cudaStream_t
stream
,
cudaEvent_t
comm_launch_event
);
template
<
typename
fp8type
>
void
reducescatter2_userbuff_fp8
(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
elements
,
communicator
*
comm
,
cudaStream_t
stream
,
cudaEvent_t
comm_launch_event
)
{
reducescatter2_userbuff_stridedoutput_fp8
<
fp8type
>
(
output
,
scale
,
handler
,
offset
,
elements
,
1
,
0
,
comm
,
stream
,
comm_launch_event
);
}
template
void
reducescatter2_userbuff_fp8
<
hip_f8
<
hip_f8_type
::
bf8
>
>
(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
elements
,
communicator
*
comm
,
cudaStream_t
stream
,
cudaEvent_t
comm_launch_event
);
template
void
reducescatter2_userbuff_fp8
<
hip_f8
<
hip_f8_type
::
fp8
>
>
(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
elements
,
communicator
*
comm
,
cudaStream_t
stream
,
cudaEvent_t
comm_launch_event
);
template
void
reducescatter2_userbuff_strided_atomic_fp8
<
hip_f8
<
hip_f8_type
::
fp8
>
>
(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
const
int
colelements
,
const
int
strideelements_out
,
const
int
strideelements_in
,
const
int
numchunks
,
void
*
counters
,
communicator
*
comm
,
cudaStream_t
stream
);
template
void
reducescatter2_userbuff_strided_atomic_fp8
<
hip_f8
<
hip_f8_type
::
bf8
>
>
(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
const
int
colelements
,
const
int
strideelements_out
,
const
int
strideelements_in
,
const
int
numchunks
,
void
*
counters
,
communicator
*
comm
,
cudaStream_t
stream
);
template
void
reducescatter2_userbuff_strided_multiatomic_fp8
<
hip_f8
<
hip_f8_type
::
fp8
>
>
(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
const
int
colelements
,
const
int
strideelements_out
,
const
int
strideelements_in
,
const
int
numchunks
,
void
*
counters
,
communicator
*
comm
,
cudaStream_t
stream
);
template
void
reducescatter2_userbuff_strided_multiatomic_fp8
<
hip_f8
<
hip_f8_type
::
bf8
>
>
(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
const
int
colelements
,
const
int
strideelements_out
,
const
int
strideelements_in
,
const
int
numchunks
,
void
*
counters
,
communicator
*
comm
,
cudaStream_t
stream
);
#else
template
void
reducescatter2_userbuff_stridedoutput_fp8
<
__nv_fp8_e5m2
>(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
const
int
colelements
,
const
int
strideelements
,
communicator
*
comm
,
cudaStream_t
stream
,
...
...
@@ -1977,6 +2079,7 @@ template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e5m2>(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
const
int
colelements
,
const
int
strideelements_out
,
const
int
strideelements_in
,
const
int
numchunks
,
void
*
counters
,
communicator
*
comm
,
cudaStream_t
stream
);
#endif
__global__
void
kuserbuffers_pullsend
(
int
myrank
,
int
peer
,
int
*
send_id
,
int
*
flagptr
)
{
atomicAdd_system
(
flagptr
,
1
);
...
...
@@ -2196,7 +2299,13 @@ __global__ void __launch_bounds__(MAX_THREADS)
// Decrement atomic val to signal current output tile finish
if
(
counters
)
{
((
unsigned
int
*
)
counters
)[
0
]
=
0
;
#ifdef __HIP_PLATFORM_AMD__
__threadfence_system
();
// __threadfence();
// __syncthreads()
#else
asm
volatile
(
"fence.sc.gpu;
\n
"
);
#endif
}
}
}
...
...
@@ -2267,7 +2376,13 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat
// Decrement atomic val to signal current output tile finish
if
(
counters
)
{
((
unsigned
int
*
)
counters
)[
recv_chunk_id
/*chunk_i+1*/
]
=
0
;
#ifdef __HIP_PLATFORM_AMD__
__threadfence_system
();
// __threadfence();
// __syncthreads()
#else
asm
volatile
(
"fence.sc.gpu;
\n
"
);
#endif
}
}
...
...
@@ -2545,7 +2660,13 @@ static __global__ void producer_kernel(void *atomic_ptr, int chunk_i) {
// COMM kernel need to explicitely flash gmem.
// GEMM kernel already executed, and can not see gmem
// change without COMM kernel explicitely make change
#ifdef __HIP_PLATFORM_AMD__
__threadfence_system
();
// __threadfence();
// __syncthreads()
#else
asm
volatile
(
"fence.sc.gpu;
\n
"
);
#endif
}
// consumer
...
...
@@ -2555,7 +2676,13 @@ static __global__ void consumer_kernel(void *atomic_ptr, int chunk_i) {
while
(
0
!=
(
atomicCAS
((
unsigned
int
*
)
atomic_ptr
+
chunk_i
,
0
,
0
)))
{
}
((
unsigned
int
*
)
atomic_ptr
)[
chunk_i
]
=
1
;
#ifdef __HIP_PLATFORM_AMD__
__threadfence_system
();
// __threadfence();
// __syncthreads()
#else
asm
volatile
(
"fence.sc.gpu;
\n
"
);
#endif
}
}
...
...
@@ -2567,7 +2694,13 @@ static __global__ void consumer_batch_kernel(void *atomic_ptr, int first_chunk_i
while
(
0
!=
(
atomicCAS
((
unsigned
int
*
)
atomic_ptr
+
i
,
0
,
0
)))
{
}
((
unsigned
int
*
)
atomic_ptr
)[
i
]
=
1
;
#ifdef __HIP_PLATFORM_AMD__
__threadfence_system
();
// __threadfence();
// __syncthreads()
#else
asm
volatile
(
"fence.sc.gpu;
\n
"
);
#endif
}
}
}
...
...
@@ -2661,12 +2794,21 @@ void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_in
num_aligned_elements_per_input
,
tot_input_size
);
}
#ifdef __HIP_PLATFORM_AMD__
template
void
reduce_fp8_in_bf16_out
<
hip_f8
<
hip_f8_type
::
fp8
>
>
(
void
*
inputs
,
void
*
output
,
float
*
scale
,
int
num_inputs
,
int
input_size
,
cudaStream_t
stream
);
template
void
reduce_fp8_in_bf16_out
<
hip_f8
<
hip_f8_type
::
bf8
>
>
(
void
*
inputs
,
void
*
output
,
float
*
scale
,
int
num_inputs
,
int
input_size
,
cudaStream_t
stream
);
#else
template
void
reduce_fp8_in_bf16_out
<
__nv_fp8_e4m3
>(
void
*
inputs
,
void
*
output
,
float
*
scale
,
int
num_inputs
,
int
input_size
,
cudaStream_t
stream
);
template
void
reduce_fp8_in_bf16_out
<
__nv_fp8_e5m2
>(
void
*
inputs
,
void
*
output
,
float
*
scale
,
int
num_inputs
,
int
input_size
,
cudaStream_t
stream
);
#endif
template
<
int
nvec
>
__global__
void
__launch_bounds__
(
MAX_THREADS
/
4
)
...
...
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h
View file @
c520cba3
...
...
@@ -105,14 +105,22 @@ struct communicator {
int
memflags
[
NVTE_MAX_REGIONS
];
// UC,MC, user/lib allocated
#ifdef __HIP_PLATFORM_AMD__
hipMemGenericAllocationHandle_t
*
uchandles
[
NVTE_MAX_REGIONS
];
#else
CUmemGenericAllocationHandle
*
uchandles
[
NVTE_MAX_REGIONS
];
#endif
void
*
ucbase_ptr
[
NVTE_MAX_REGIONS
];
// only for cuMem allocated memory
size_t
mem_size
[
NVTE_MAX_REGIONS
];
bool
mem_dealloc
[
NVTE_MAX_REGIONS
];
void
*
mc_ptr
[
NVTE_MAX_REGIONS
];
void
*
mc_baseptr
;
#ifdef __HIP_PLATFORM_AMD__
hipMemGenericAllocationHandle_t
mc_handle
;
#else
CUmemGenericAllocationHandle
mc_handle
;
#endif
size_t
mc_offset
,
mc_maxsize
;
int
use_mc
;
// 1: use MC if available, 0: override not to use MC
...
...
transformer_engine/common/common.cu
View file @
c520cba3
...
...
@@ -36,6 +36,9 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) {
}
void
checkCuDriverContext
(
CUstream
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
return
;
#else
CUcontext
ctx
;
const
CUresult
driver_status
=
cuda_driver
::
call
(
"cuStreamGetCtx"
,
stream
,
&
ctx
);
switch
(
driver_status
)
{
...
...
@@ -54,8 +57,10 @@ void checkCuDriverContext(CUstream stream) {
cuda_driver
::
call
(
"cuGetErrorString"
,
driver_status
,
&
desc_NVTE_CHECK_CUDA_DRIVER
);
NVTE_ERROR
(
"CUDA Error: "
,
desc_NVTE_CHECK_CUDA_DRIVER
);
}
#endif
}
#ifndef __HIP_PLATFORM_AMD__
CUtensorMapDataType
get_CUtensorMapDataType
(
DType
dtype
)
{
static
const
std
::
unordered_map
<
DType
,
CUtensorMapDataType
>
dtypeMapping
=
{
{
DType
::
kByte
,
CUtensorMapDataType
::
CU_TENSOR_MAP_DATA_TYPE_UINT8
},
...
...
@@ -127,11 +132,16 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
// Any element that is outside of bounds will be set to zero by the TMA transfer.
CUtensorMapFloatOOBfill
::
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
));
}
#endif
bool
is_supported_by_CC_100
()
{
#ifdef __HIP_PLATFORM_AMD__
return
false
;
#else
int
deviceComputeCapability
=
cuda
::
sm_arch
(
cuda
::
current_device
());
return
deviceComputeCapability
>=
100
;
#endif
}
}
// namespace transformer_engine
transformer_engine/common/common.h
View file @
c520cba3
...
...
@@ -6,8 +6,9 @@
#ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_
#define TRANSFORMER_ENGINE_COMMON_COMMON_H_
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
...
...
@@ -217,9 +218,15 @@ using int32 = int32_t;
using
int64
=
int64_t
;
using
fp32
=
float
;
using
fp16
=
half
;
#ifndef __HIP_PLATFORM_AMD__
using
bf16
=
nv_bfloat16
;
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
#else
using
bf16
=
hip_bfloat16
;
using
fp8e4m3
=
hip_f8
<
hip_f8_type
::
fp8
>
;
using
fp8e5m2
=
hip_f8
<
hip_f8_type
::
bf8
>
;
#endif
#if CUDA_VERSION >= 12080
using
fp8e8m0
=
__nv_fp8_e8m0
;
#endif
...
...
@@ -239,9 +246,15 @@ TRANSFORMER_ENGINE_TYPE_NAME(int32_t)
TRANSFORMER_ENGINE_TYPE_NAME
(
int64_t
)
TRANSFORMER_ENGINE_TYPE_NAME
(
float
)
TRANSFORMER_ENGINE_TYPE_NAME
(
half
)
#ifdef __HIP_PLATFORM_AMD__
TRANSFORMER_ENGINE_TYPE_NAME
(
hip_bfloat16
)
TRANSFORMER_ENGINE_TYPE_NAME
(
hip_f8
<
hip_f8_type
::
fp8
>
)
TRANSFORMER_ENGINE_TYPE_NAME
(
hip_f8
<
hip_f8_type
::
bf8
>
)
#else
TRANSFORMER_ENGINE_TYPE_NAME
(
nv_bfloat16
)
TRANSFORMER_ENGINE_TYPE_NAME
(
__nv_fp8_e4m3
)
TRANSFORMER_ENGINE_TYPE_NAME
(
__nv_fp8_e5m2
)
#endif
#if CUDA_VERSION >= 12080
TRANSFORMER_ENGINE_TYPE_NAME
(
__nv_fp8_e8m0
)
#endif
...
...
@@ -510,6 +523,7 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream);
void
checkCuDriverContext
(
CUstream
stream
);
#ifndef __HIP_PLATFORM_AMD__
CUtensorMapDataType
get_CUtensorMapDataType
(
DType
dtype
);
// Set up parameters to create TMA descriptor.
...
...
@@ -517,6 +531,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
const
uint64_t
globalY
,
const
uint64_t
globalX
,
const
uint32_t
shmemY
,
const
uint32_t
shmemX
,
const
uint32_t
stride_elems
,
const
uint32_t
offset_elems
,
const
size_t
type_size
);
#endif
bool
is_supported_by_CC_100
();
...
...
transformer_engine/common/cudnn_utils.cpp
View file @
c520cba3
...
...
@@ -11,6 +11,7 @@
namespace
transformer_engine
{
#ifndef USE_ROCM
// get cuDNN data type
cudnnDataType_t
get_cudnn_dtype
(
const
transformer_engine
::
DType
t
)
{
using
namespace
transformer_engine
;
...
...
@@ -60,6 +61,7 @@ cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t)
void
nvte_cudnn_handle_init
()
{
auto
handle
=
cudnnExecutionPlanManager
::
Instance
().
GetCudnnHandle
();
}
#endif
}
// namespace transformer_engine
...
...
transformer_engine/common/cudnn_utils.h
View file @
c520cba3
...
...
@@ -7,9 +7,11 @@
#ifndef TRANSFORMER_ENGINE_CUDNN_UTILS_H_
#define TRANSFORMER_ENGINE_CUDNN_UTILS_H_
#ifndef __HIP_PLATFORM_AMD__
#include <cudnn.h>
#include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h>
#endif
#include <cstdint>
#include <mutex>
...
...
@@ -18,6 +20,7 @@
namespace
transformer_engine
{
#ifndef __HIP_PLATFORM_AMD__
cudnnDataType_t
get_cudnn_dtype
(
const
transformer_engine
::
DType
t
);
cudnn_frontend
::
DataType_t
get_cudnn_fe_dtype
(
const
transformer_engine
::
DType
t
);
...
...
@@ -40,6 +43,7 @@ class cudnnExecutionPlanManager {
private:
cudnnHandle_t
handle_
=
nullptr
;
};
#endif
}
// namespace transformer_engine
...
...
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
c520cba3
...
...
@@ -4,9 +4,15 @@
* See LICENSE for license information.
************************************************************************/
#ifndef __HIP_PLATFORM_AMD__
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda.h>
#else
#include <iostream>
#include "hipblas_gemm.h"
#include "rocm_gemm.hip"
#endif // #ifndef __HIP_PLATFORM_AMD__
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>
...
...
@@ -17,6 +23,7 @@
#include "../util/logging.h"
#include "common/util/cuda_runtime.h"
#ifndef __HIP_PLATFORM_AMD__
namespace
{
cudaDataType_t
get_cuda_dtype
(
const
transformer_engine
::
DType
t
)
{
...
...
@@ -137,9 +144,19 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
}
}
// namespace
#endif // __HIP_PLATFORM_AMD__
namespace
transformer_engine
{
#ifdef __HIP_PLATFORM_AMD__
//Forward declaration. The implementation is in rocm_gemm.cu
void
cublas_gemm
(
const
Tensor
*
inputA
,
const
Tensor
*
inputB
,
Tensor
*
outputD
,
const
Tensor
*
inputBias
,
Tensor
*
outputPreGelu
,
int
m
,
int
n
,
int
k
,
int
lda
,
int
ldb
,
int
ldd
,
bool
transa
,
bool
transb
,
bool
grad
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
hipStream_t
stream
);
#else // Use cublasLt
void
cublas_gemm
(
const
Tensor
*
inputA
,
const
Tensor
*
inputB
,
Tensor
*
outputD
,
const
Tensor
*
inputBias
,
Tensor
*
outputPreGelu
,
int
m
,
int
n
,
int
k
,
int
lda
,
int
ldb
,
int
ldd
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
bool
grad
,
...
...
@@ -424,6 +441,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS
(
cublasLtMatrixLayoutDestroy
(
Adesc
));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescDestroy
(
operationDesc
));
}
#endif // __HIP_PLATFORM_AMD__
static
std
::
once_flag
init_flag
;
static
cudaStream_t
compute_streams
[
num_streams
];
...
...
@@ -437,6 +455,19 @@ static void init_streams_and_events() {
}
}
// Add for batchgemm
static
std
::
once_flag
init_flag_batchgemm
;
static
cudaStream_t
compute_streams_batchgemm
[
num_batchgemm_streams
];
static
cudaEvent_t
cublas_event_batchgemm
[
num_batchgemm_streams
];
// Warning: only call once per device!
static
void
init_streams_and_events_batchgemm
()
{
for
(
int
i
=
0
;
i
<
num_batchgemm_streams
;
i
++
)
{
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
compute_streams_batchgemm
[
i
],
cudaStreamNonBlocking
,
-
1
));
NVTE_CHECK_CUDA
(
cudaEventCreate
(
&
cublas_event_batchgemm
[
i
]));
}
}
}
// namespace transformer_engine
void
nvte_cublas_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
...
...
@@ -477,10 +508,57 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
NVTE_ERROR
(
"TT layout not allowed."
);
}
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT
const
char
*
NVTE_BLASLT_BLAS
=
std
::
getenv
(
"NVTE_FORCE_BLASLT"
);
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
is_fp8_dtype
(
inputB
->
data
.
dtype
);
if
((
biasTensor
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
->
data
.
dptr
!=
nullptr
)
||
(
use_fp8
)
||
(
NVTE_BLASLT_BLAS
!=
nullptr
&&
NVTE_BLASLT_BLAS
[
0
]
==
'1'
)){
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
#else
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
#endif //USE_HIPBLASLT
#else
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
grad
,
#endif //__HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
transa
,
transb
,
#else
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
#endif //__HIP_PLATFORM_AMD__
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT
}
else
{
hipblas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
}
#endif //USE_HIPBLASLT
#endif //__HIP_PLATFORM_AMD__
}
void
nvte_cublas_atomic_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
...
...
@@ -491,10 +569,12 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_cublas_atomic_gemm
);
#ifndef __HIP_PLATFORM_AMD__
int
cudart_version
;
NVTE_CHECK_CUDA
(
cudaRuntimeGetVersion
(
&
cudart_version
));
NVTE_CHECK
(
cudart_version
>=
12020
,
"Cuda version 12.2 is required for atomic gemm."
);
NVTE_CHECK
(
cublasLtGetVersion
()
>=
120205
,
"Cublas version 12.2.5 is required for atomic gemm."
);
#endif
using
namespace
transformer_engine
;
const
Tensor
*
inputA
=
reinterpret_cast
<
const
Tensor
*>
(
A
);
...
...
@@ -529,10 +609,103 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
NVTE_ERROR
(
"TT layout not allowed."
);
}
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT
const
char
*
NVTE_BLASLT_BLAS
=
std
::
getenv
(
"NVTE_FORCE_BLASLT"
);
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
is_fp8_dtype
(
inputB
->
data
.
dtype
);
if
((
biasTensor
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
->
data
.
dptr
!=
nullptr
)
||
(
use_fp8
)
||
(
NVTE_BLASLT_BLAS
!=
nullptr
&&
NVTE_BLASLT_BLAS
[
0
]
==
'1'
)){
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
#else
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
#endif //USE_HIPBLASLT
#else
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
grad
,
#endif //__HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
transa
,
transb
,
#else
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
#endif //__HIP_PLATFORM_AMD__
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
);
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT
}
else
{
hipblas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
);
}
#endif //USE_HIPBLASLT
#endif //__HIP_PLATFORM_AMD__
}
void
nvte_cublaslt_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_cublaslt_gemm
);
using
namespace
transformer_engine
;
const
Tensor
*
inputA
=
reinterpret_cast
<
const
Tensor
*>
(
A
);
const
Tensor
*
inputB
=
reinterpret_cast
<
const
Tensor
*>
(
B
);
Tensor
*
outputD
=
reinterpret_cast
<
Tensor
*>
(
D
);
const
Tensor
*
biasTensor
=
reinterpret_cast
<
const
Tensor
*>
(
bias
);
Tensor
*
outputGelu
=
reinterpret_cast
<
Tensor
*>
(
pre_gelu_out
);
Tensor
*
wspace
=
reinterpret_cast
<
Tensor
*>
(
workspace
);
const
int
m
=
transa
?
inputA
->
data
.
shape
[
0
]
:
inputA
->
data
.
shape
[
1
];
const
int
k
=
transa
?
inputA
->
data
.
shape
[
1
]
:
inputA
->
data
.
shape
[
0
];
const
int
n
=
transb
?
inputB
->
data
.
shape
[
1
]
:
inputB
->
data
.
shape
[
0
];
int
lda
,
ldb
,
ldd
;
if
(
transa
&&
!
transb
)
{
// TN
lda
=
k
;
ldb
=
k
;
ldd
=
m
;
}
else
if
(
!
transa
&&
!
transb
)
{
// NN
lda
=
m
;
ldb
=
k
;
ldd
=
m
;
}
else
if
(
!
transa
&&
transb
)
{
// NT
lda
=
m
;
ldb
=
n
;
ldd
=
m
;
}
else
{
// TT
NVTE_ERROR
(
"TT layout not allowed."
);
}
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
#ifdef __HIP_PLATFORM_AMD__
transa
,
transb
,
#else
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
#endif //__HIP_PLATFORM_AMD__
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
}
void
nvte_multi_stream_cublas_gemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
...
...
@@ -552,12 +725,30 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
for
(
int
s
=
0
;
s
<
num_stream_used
;
s
++
)
{
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
compute_streams
[
s
],
cublas_event
[
0
]));
}
const
char
*
NVTE_BLAS_MULSTREAM
=
std
::
getenv
(
"NVTE_FORCE_BLAS_MULSTREAM"
);
const
char
*
NVTE_BLASLT_BLAS
=
std
::
getenv
(
"NVTE_FORCE_BLASLT"
);
bool
NVTE_FORCE_BLASLT_MULSTREAM
;
if
(
NVTE_BLAS_MULSTREAM
==
nullptr
){
NVTE_FORCE_BLASLT_MULSTREAM
=
true
;
}
elif
((
NVTE_BLASLT_BLAS
!=
nullptr
&&
NVTE_BLASLT_BLAS
[
0
]
==
'1'
)
&&
(
NVTE_BLAS_MULSTREAM
!=
nullptr
&&
NVTE_BLAS_MULSTREAM
[
0
]
==
'1'
)){
NVTE_ERROR
(
"NVTE_FORCE_BLAS_MULSTREAM and NVTE_FORCE_BLASLT can't be set at the same time."
);
}
else
{
NVTE_FORCE_BLASLT_MULSTREAM
=
false
;
}
if
(
NVTE_FORCE_BLASLT_MULSTREAM
){
for
(
int
i
=
0
;
i
<
num_gemms
;
i
++
)
{
nvte_cublaslt_gemm
(
A
[
i
],
B
[
i
],
D
[
i
],
bias
[
i
],
pre_gelu_out
[
i
],
transa
,
transb
,
grad
,
workspace
[
i
%
num_streams
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
compute_streams
[
i
%
num_streams
]);
}
}
else
{
for
(
int
i
=
0
;
i
<
num_gemms
;
i
++
)
{
nvte_cublas_gemm
(
A
[
i
],
B
[
i
],
D
[
i
],
bias
[
i
],
pre_gelu_out
[
i
],
transa
,
transb
,
grad
,
workspace
[
i
%
num_streams
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
compute_streams
[
i
%
num_streams
]);
}
}
// record events on compute streams
for
(
int
s
=
0
;
s
<
num_stream_used
;
s
++
)
{
...
...
@@ -568,3 +759,107 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream
,
cublas_event
[
s
]));
}
}
#ifdef __HIP_PLATFORM_AMD__
void
nvte_multi_stream_cublas_batchgemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
const
NVTETensor
*
bias
,
NVTETensor
*
pre_gelu_out
,
const
int
num_gemms
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
*
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_stream_cublas_batchgemm
);
using
namespace
transformer_engine
;
static_assert
(
num_gemms
%
num_batchgemm_streams
==
0
,
"Need num_gemms mod num_batchgemm_streams == 0."
);
static
int
batch_count
=
num_gemms
/
num_batchgemm_streams
;
// Inits streams and events (once, globally)
std
::
call_once
(
init_flag_batchgemm
,
init_streams_and_events_batchgemm
);
int
num_stream_used
=
num_batchgemm_streams
;
// wait for current stream to finish
NVTE_CHECK_CUDA
(
cudaEventRecord
(
cublas_event_batchgemm
[
0
],
stream
));
for
(
int
s
=
0
;
s
<
num_stream_used
;
s
++
)
{
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
compute_streams_batchgemm
[
s
],
cublas_event_batchgemm
[
0
]));
}
for
(
int
i
=
0
;
i
<
num_stream_used
;
i
++
)
{
nvte_cublas_batchgemm
(
A
[
i
],
B
[
i
],
D
[
i
],
bias
[
i
],
pre_gelu_out
[
i
],
transa
,
transb
,
grad
,
workspace
[
i
%
num_batchgemm_streams
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
batch_count
,
compute_streams_batchgemm
[
i
%
num_batchgemm_streams
]);
}
// record events on compute streams
for
(
int
s
=
0
;
s
<
num_stream_used
;
s
++
)
{
NVTE_CHECK_CUDA
(
cudaEventRecord
(
cublas_event_batchgemm
[
s
],
compute_streams_batchgemm
[
s
]));
}
// wait for all compute streams to finish
for
(
int
s
=
0
;
s
<
num_stream_used
;
s
++
)
{
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream
,
cublas_event_batchgemm
[
s
]));
}
}
// add for batchgemm
void
nvte_cublas_batchgemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_cublas_batchgemm
);
using
namespace
transformer_engine
;
const
Tensor
*
inputA
=
reinterpret_cast
<
const
Tensor
*>
(
A
);
const
Tensor
*
inputB
=
reinterpret_cast
<
const
Tensor
*>
(
B
);
Tensor
*
outputD
=
reinterpret_cast
<
Tensor
*>
(
D
);
const
Tensor
*
biasTensor
=
reinterpret_cast
<
const
Tensor
*>
(
bias
);
Tensor
*
outputGelu
=
reinterpret_cast
<
Tensor
*>
(
pre_gelu_out
);
Tensor
*
wspace
=
reinterpret_cast
<
Tensor
*>
(
workspace
);
int
m
,
n
,
k
;
if
(
!
transa
&&
transb
)
{
// for NT
m
=
transa
?
inputA
->
data
.
shape
[
0
]
/
batch_count
:
inputA
->
data
.
shape
[
1
];
k
=
transa
?
inputA
->
data
.
shape
[
1
]
:
inputA
->
data
.
shape
[
0
]
/
batch_count
;
n
=
transb
?
inputB
->
data
.
shape
[
1
]
:
inputB
->
data
.
shape
[
0
]
/
batch_count
;
}
else
if
(
transa
&&
!
transb
){
// for TN
m
=
transa
?
inputA
->
data
.
shape
[
0
]
/
batch_count
:
inputA
->
data
.
shape
[
1
];
k
=
transa
?
inputA
->
data
.
shape
[
1
]
:
inputA
->
data
.
shape
[
0
]
/
batch_count
;
n
=
transb
?
inputB
->
data
.
shape
[
1
]
:
inputB
->
data
.
shape
[
0
]
/
batch_count
;
}
else
if
(
!
transa
&&
!
transb
){
// for NN
m
=
transa
?
inputA
->
data
.
shape
[
0
]
/
batch_count
:
inputA
->
data
.
shape
[
1
];
k
=
transa
?
inputA
->
data
.
shape
[
1
]
:
inputA
->
data
.
shape
[
0
]
/
batch_count
;
n
=
transb
?
inputB
->
data
.
shape
[
1
]
:
inputB
->
data
.
shape
[
0
]
/
batch_count
;
}
int
lda
,
ldb
,
ldd
;
if
(
transa
&&
!
transb
)
{
// TN
lda
=
k
;
ldb
=
k
;
ldd
=
m
;
}
else
if
(
!
transa
&&
!
transb
)
{
// NN
lda
=
m
;
ldb
=
k
;
ldd
=
m
;
}
else
if
(
!
transa
&&
transb
)
{
// NT
lda
=
m
;
ldb
=
n
;
ldd
=
m
;
}
else
{
// TT
NVTE_ERROR
(
"TT layout not allowed."
);
}
hipblas_batchgemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
batch_count
,
stream
);
}
#endif
\ No newline at end of file
transformer_engine/common/gemm/hipblas_gemm.cu
0 → 100644
View file @
c520cba3
/*************************************************************************
* Copyright (c) 2022-2024, S3000 qianyj. All rights reserved.
************************************************************************/
#include <hip/hip_runtime.h>
#include "hipblas_gemm.h"
#include "../common_hip.h"
#include "../util/logging.h"
namespace
{
hipblasDatatype_t
get_hip_dtype
(
const
transformer_engine
::
DType
t
)
{
using
namespace
transformer_engine
;
switch
(
t
)
{
case
DType
::
kFloat16
:
return
HIPBLAS_R_16F
;
case
DType
::
kFloat32
:
return
HIPBLAS_R_32F
;
case
DType
::
kBFloat16
:
return
HIPBLAS_R_16B
;
default:
NVTE_ERROR
(
"Invalid type"
);
}
}
}
// namespace
// Define a static handle manager
static
HipblasHandleManager
handleManager
;
namespace
transformer_engine
{
void
hipblas_gemm
(
const
Tensor
*
inputA
,
const
Tensor
*
inputB
,
Tensor
*
outputD
,
const
Tensor
*
inputBias
,
Tensor
*
outputPreGelu
,
int
m
,
int
n
,
int
k
,
int
lda
,
int
ldb
,
int
ldd
,
hipblasOperation_t
transa
,
hipblasOperation_t
transb
,
bool
grad
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
hipStream_t
stream
)
{
// Use static handles
int
device_id
;
hipGetDevice
(
&
device_id
);
hipblasHandle_t
handle
=
handleManager
.
get
(
device_id
);
void
*
A
=
inputA
->
data
.
dptr
;
// void *A_scale_inverse = inputA->scale_inv.dptr;
void
*
B
=
inputB
->
data
.
dptr
;
// void *B_scale_inverse = inputB->scale_inv.dptr;
void
*
C
=
outputD
->
data
.
dptr
;
void
*
D
=
outputD
->
data
.
dptr
;
// Select the calculation accuracy
hipblasDatatype_t
A_type
=
get_hip_dtype
(
inputA
->
data
.
dtype
);
hipblasDatatype_t
B_type
=
get_hip_dtype
(
inputB
->
data
.
dtype
);
hipblasDatatype_t
D_type
=
get_hip_dtype
(
outputD
->
data
.
dtype
);
hipblasDatatype_t
computeType
=
HIPBLAS_R_32F
;
// default acc is float32
// setting computetype
// if (/* condition for mixed precision */) {
// computeType = HIPBLAS_R_16F; //
// }
// hipblasComputeType_t gemm_compute_type = HIPBLAS_COMPUTE_32F;
// const char *env_tf32 = std::getenv("NVTE_BLASLT_TF32");
// if (env_tf32 != nullptr && env_tf32[0] == '1') {
// if (A_type == HIPBLAS_R_32F && B_type == HIPBLAS_R_32F && D_type == HIPBLAS_R_32F) {
// gemm_compute_type = HIPBLAS_COMPUTE_32F_FAST_TF32;
// }
float
one
=
1.0
f
;
float
zero
=
0.0
f
;
float
beta
=
accumulate
?
one
:
zero
;
hipblasSetStream
(
handle
,
stream
);
// execute multiply
hipblasStatus_t
status
=
hipblasGemmEx
(
handle
,
transa
,
// transa
transb
,
// transb
m
,
n
,
k
,
static_cast
<
const
void
*>
(
&
one
),
A
,
A_type
,
lda
,
B
,
B_type
,
ldb
,
static_cast
<
const
void
*>
(
&
beta
),
D
,
D_type
,
ldd
,
computeType
,
HIPBLAS_GEMM_DEFAULT
);
if
(
status
!=
HIPBLAS_STATUS_SUCCESS
)
{
NVTE_ERROR
(
"hipblasGemmEx execution failed"
);
}
}
void
hipblas_batchgemm
(
const
Tensor
*
inputA
,
const
Tensor
*
inputB
,
Tensor
*
outputD
,
const
Tensor
*
inputBias
,
Tensor
*
outputPreGelu
,
int
m
,
int
n
,
int
k
,
int
lda
,
int
ldb
,
int
ldd
,
hipblasOperation_t
transa
,
hipblasOperation_t
transb
,
bool
grad
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
int
batch_count
,
hipStream_t
stream
)
{
// Use static handles
int
device_id
;
hipGetDevice
(
&
device_id
);
hipblasHandle_t
handle
=
handleManager
.
get
(
device_id
);
void
*
A
=
inputA
->
data
.
dptr
;
// void *A_scale_inverse = inputA->scale_inv.dptr;
void
*
B
=
inputB
->
data
.
dptr
;
// void *B_scale_inverse = inputB->scale_inv.dptr;
void
*
C
=
outputD
->
data
.
dptr
;
void
*
D
=
outputD
->
data
.
dptr
;
// Select the calculation accuracy
hipblasDatatype_t
A_type
=
get_hip_dtype
(
inputA
->
data
.
dtype
);
hipblasDatatype_t
B_type
=
get_hip_dtype
(
inputB
->
data
.
dtype
);
hipblasDatatype_t
D_type
=
get_hip_dtype
(
outputD
->
data
.
dtype
);
hipblasDatatype_t
computeType
=
HIPBLAS_R_32F
;
// default acc is float32
float
one
=
1.0
f
;
float
zero
=
0.0
f
;
float
beta
=
accumulate
?
one
:
zero
;
hipblasSetStream
(
handle
,
stream
);
// execute multiply
// calculate stride
const
long
long
int
strideA
=
m
*
k
;
const
long
long
int
strideB
=
k
*
n
;
const
long
long
int
strideD
=
m
*
n
;
hipblasStatus_t
status
=
hipblasGemmStridedBatchedEx
(
handle
,
transa
,
// transa
transb
,
// transb
m
,
n
,
k
,
static_cast
<
const
void
*>
(
&
one
),
A
,
A_type
,
lda
,
strideA
,
B
,
B_type
,
ldb
,
strideB
,
static_cast
<
const
void
*>
(
&
beta
),
D
,
D_type
,
ldd
,
strideD
,
batch_count
,
computeType
,
HIPBLAS_GEMM_DEFAULT
);
if
(
status
!=
HIPBLAS_STATUS_SUCCESS
)
{
NVTE_ERROR
(
"hipblasGemmEx execution failed"
);
}
}
}
// namespace transformer_engine
\ No newline at end of file
transformer_engine/common/gemm/hipblas_gemm.h
0 → 100644
View file @
c520cba3
/*************************************************************************
* Copyright (c) 2022-2024, S3000 qianyj. All rights reserved.
************************************************************************/
/*! \file hipblas_gemmn.h
* \brief Functions for blas instead blaslt in pure gemm
*/
#ifndef TRANSFORMER_ENGINE_COMMON_HIPBLAS_GEMM_H_
#define TRANSFORMER_ENGINE_COMMON_HIPBLAS_GEMM_H_
#include <hip/hip_runtime.h>
#ifdef USE_HIPBLASLT
#include <hipblas/hipblas.h>
#include <mutex>
#else
#include <rocblas/rocblas.h>
#endif
#include <stdexcept>
#include "../common_hip.h"
#include <iostream>
#ifdef USE_HIPBLASLT
class
HipblasHandleManager
{
public:
HipblasHandleManager
()
{}
~
HipblasHandleManager
()
{
// Release all handles when the manager is destroyed
for
(
auto
&
device_pair
:
handles_map_
)
{
hipblasDestroy
(
device_pair
.
second
);
// Only one handle per device
}
}
// Get a handle for the given device (creates if necessary)
hipblasHandle_t
get
(
int
device_id
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
// Check if the handle for this device exists
auto
device_it
=
handles_map_
.
find
(
device_id
);
if
(
device_it
!=
handles_map_
.
end
())
{
return
device_it
->
second
;
}
// Create a new handle for this device if it doesn't exist
hipblasHandle_t
handle
;
hipblasStatus_t
status
=
hipblasCreate
(
&
handle
);
if
(
status
!=
HIPBLAS_STATUS_SUCCESS
)
{
throw
std
::
runtime_error
(
"Failed to create HIPBLAS handle"
);
}
// Store the handle in the map for this device
handles_map_
[
device_id
]
=
handle
;
return
handle
;
}
private:
std
::
unordered_map
<
int
,
hipblasHandle_t
>
handles_map_
;
// Map from device_id to hipblasHandle
std
::
mutex
mutex_
;
};
namespace
transformer_engine
{
void
hipblas_gemm
(
const
Tensor
*
inputA
,
const
Tensor
*
inputB
,
Tensor
*
outputD
,
const
Tensor
*
inputBias
,
Tensor
*
outputPreGelu
,
int
m
,
int
n
,
int
k
,
int
lda
,
int
ldb
,
int
ldd
,
hipblasOperation_t
transa
,
hipblasOperation_t
transb
,
bool
grad
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
hipStream_t
stream
);
void
hipblas_batchgemm
(
const
Tensor
*
inputA
,
const
Tensor
*
inputB
,
Tensor
*
outputD
,
const
Tensor
*
inputBias
,
Tensor
*
outputPreGelu
,
int
m
,
int
n
,
int
k
,
int
lda
,
int
ldb
,
int
ldd
,
hipblasOperation_t
transa
,
hipblasOperation_t
transb
,
bool
grad
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
int
batch_count
,
hipStream_t
stream
);
}
#else
class
HipblasHandleManager
{
public:
HipblasHandleManager
()
:
handle_
(
nullptr
)
{}
~
HipblasHandleManager
()
{
// Release the handle in the destructor to ensure cleanup when it's no longer needed
if
(
handle_
!=
nullptr
)
{
rocblas_destroy_handle
(
handle_
);
}
}
// Get a handle to make sure it's valid every time
rocblas_handle
get
()
{
if
(
handle_
==
nullptr
)
{
createHandle
();
}
// Check whether the handle is created successfully
assert
(
handle_
!=
nullptr
&&
"hipblasHandle should not be null after creation"
);
return
handle_
;
}
private:
rocblas_handle
handle_
;
//
void
createHandle
()
{
// A private method that creates a handle
rocblas_status
status
=
rocblas_create_handle
(
&
handle_
);
if
(
status
!=
rocblas_status_success
)
{
// If initialization fails, an exception is thrown
throw
std
::
runtime_error
(
"Failed to create HIPBLAS handle"
);
}
}
// Copy construct and assignment operations are prohibited
HipblasHandleManager
(
const
HipblasHandleManager
&
)
=
delete
;
HipblasHandleManager
&
operator
=
(
const
HipblasHandleManager
&
)
=
delete
;
};
#endif // #ifdef USE_HIPBLASLT
#endif // TRANSFORMER_ENGINE_COMMON_HIPBLAS_GEMM_H_
\ No newline at end of file
transformer_engine/common/gemm/rocm_gemm.cu
0 → 100644
View file @
c520cba3
This diff is collapsed.
Click to expand it.
transformer_engine/common/include/transformer_engine/gemm.h
View file @
c520cba3
...
...
@@ -109,6 +109,21 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor* A, const NVTETensor* B, NVT
NVTETensor
*
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
);
#ifdef __HIP_PLATFORM_AMD__
void
nvte_multi_stream_cublas_batchgemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
const
NVTETensor
*
bias
,
NVTETensor
*
pre_gelu_out
,
const
int
num_gemms
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
*
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
);
void
nvte_cublas_batchgemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
);
#endif
#ifdef __cplusplus
}
// extern "C"
#endif
...
...
@@ -116,8 +131,14 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor* A, const NVTETensor* B, NVT
/*! \namespace transformer_engine
*/
namespace
transformer_engine
{
#ifdef __HIP_PLATFORM_AMD__
// In dcu, 2 stream is more better
constexpr
int
num_streams
=
2
;
// Add for batchgemm stream
constexpr
int
num_batchgemm_streams
=
1
;
#else
constexpr
int
num_streams
=
4
;
#endif
}
// namespace transformer_engine
...
...
transformer_engine/common/normalization/common.cpp
View file @
c520cba3
...
...
@@ -39,10 +39,12 @@ Compute always in FP32
namespace
transformer_engine
{
namespace
normalization
{
#ifndef __HIP_PLATFORM_AMD__
cudnn_frontend
::
NormFwdPhase_t
get_cudnn_forward_phase
(
const
bool
training
)
{
return
training
?
cudnn_frontend
::
NormFwdPhase_t
::
TRAINING
:
cudnn_frontend
::
NormFwdPhase_t
::
INFERENCE
;
}
#endif
TupleKeyType
get_key
(
NVTE_Norm_Backend
NormBackend
,
NVTE_Norm_Type
NormType
,
NVTE_Norm_Stage
NormStage
,
DType
wtype
,
DType
itype
,
DType
otype
,
DType
ctype
,
...
...
@@ -195,6 +197,10 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
_training
(
training
),
_norm_stage
(
NormStage
),
_norm_type
(
NormType
)
{
#ifdef USE_ROCM
static_assert
(
false
,
"Cudnn backend is not surpported in rocm for normalization yet."
);
#else
static_assert
(
CUDNN_FRONTEND_VERSION
>=
10601
,
"CUDNN_FRONTEND_VERSION should be at least 1.6.1!"
);
...
...
@@ -378,9 +384,14 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
}
// Build the graph
this
->
_build
();
#endif
}
void
CudnnNormalizationPlan
::
_build
()
{
#ifdef USE_ROCM
static_assert
(
false
,
"Cudnn backend is not surpported in rocm for normalization yet."
);
#else
NVTE_CHECK
(
_graph
.
validate
().
is_good
());
NVTE_CHECK
(
_graph
.
build_operation_graph
(
_handle
).
is_good
());
NVTE_CHECK
(
_graph
...
...
@@ -390,15 +401,25 @@ void CudnnNormalizationPlan::_build() {
NVTE_CHECK
(
_graph
.
check_support
(
_handle
).
is_good
());
NVTE_CHECK
(
_graph
.
build_plans
(
_handle
,
cudnn_frontend
::
BuildPlanPolicy_t
::
HEURISTICS_CHOICE
).
is_good
());
#endif
}
std
::
vector
<
size_t
>
CudnnNormalizationPlan
::
getWorkspaceShape
()
const
{
#ifdef USE_ROCM
static_assert
(
false
,
"Cudnn backend is not surpported in rocm for normalization yet."
);
#else
return
{
static_cast
<
size_t
>
(
_graph
.
get_workspace_size
())};
#endif
}
void
CudnnNormalizationPlan
::
execute
(
Tensor
*
z
,
void
*
x_dptr
,
void
*
gamma_dptr
,
void
*
beta_dptr
,
void
*
mean_dptr
,
void
*
eps_dptr
,
void
*
rsigma_dptr
,
void
*
workspace_dptr
,
cudaStream_t
stream
)
{
#ifdef USE_ROCM
static_assert
(
false
,
"Cudnn backend is not surpported in rocm for normalization yet."
);
#else
// Binding data pointers to graph tensors
_variant_pack
=
{{
_x
,
x_dptr
},
{
_eps
,
eps_dptr
}};
...
...
@@ -433,12 +454,17 @@ void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr,
// Execute the computation
NVTE_CHECK_CUDNN
(
cudnnSetStream
(
_handle
,
stream
));
NVTE_CHECK
(
_graph
.
execute
(
_handle
,
_variant_pack
,
workspace_dptr
).
is_good
());
#endif
}
void
CudnnNormalizationPlan
::
execute
(
void
*
x_dptr
,
void
*
gamma_dptr
,
void
*
mean_dptr
,
void
*
rsigma_dptr
,
void
*
dx_dptr
,
void
*
dz_dptr
,
void
*
dbeta_dptr
,
void
*
dgamma_dptr
,
void
*
workspace_dptr
,
cudaStream_t
stream
)
{
#ifdef USE_ROCM
static_assert
(
false
,
"Cudnn backend is not surpported in rocm for normalization yet."
);
#else
// Binding data pointers to graph tensors
_variant_pack
=
{
{
_x
,
x_dptr
},
{
_rsigma
,
rsigma_dptr
},
{
_dz
,
dz_dptr
},
{
_dgamma
,
dgamma_dptr
},
{
_dx
,
dx_dptr
}};
...
...
@@ -455,6 +481,7 @@ void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_
// Execute the computation
NVTE_CHECK_CUDNN
(
cudnnSetStream
(
_handle
,
stream
));
NVTE_CHECK
(
_graph
.
execute
(
_handle
,
_variant_pack
,
workspace_dptr
).
is_good
());
#endif
}
NormalizationPlanBase
*
NormalizationPlanRegistry
::
getNormalizationPlan
(
...
...
@@ -491,13 +518,21 @@ NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan(
}
bool
&
_cudnn_norm_fwd_flag
()
{
#ifdef USE_ROCM
return
false
;
#else
static
bool
flag
=
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
);
return
flag
;
#endif
}
bool
&
_cudnn_norm_bwd_flag
()
{
#ifdef USE_ROCM
return
false
;
#else
static
bool
flag
=
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_BWD_USE_CUDNN"
);
return
flag
;
#endif
}
bool
use_cudnn_norm_fwd
()
{
return
_cudnn_norm_fwd_flag
();
}
...
...
@@ -508,10 +543,18 @@ bool use_cudnn_norm_bwd() { return _cudnn_norm_bwd_flag(); }
void
nvte_enable_cudnn_norm_fwd
(
bool
enable
)
{
NVTE_API_CALL
(
nvte_enable_cudnn_norm_fwd
);
#ifdef USE_ROCM
transformer_engine
::
normalization
::
_cudnn_norm_bwd_flag
()
=
false
;
#else
transformer_engine
::
normalization
::
_cudnn_norm_fwd_flag
()
=
enable
;
#endif
}
void
nvte_enable_cudnn_norm_bwd
(
bool
enable
)
{
NVTE_API_CALL
(
nvte_enable_cudnn_norm_bwd
);
#ifdef USE_ROCM
transformer_engine
::
normalization
::
_cudnn_norm_bwd_flag
()
=
false
;
#else
transformer_engine
::
normalization
::
_cudnn_norm_bwd_flag
()
=
enable
;
#endif
}
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment