Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
c520cba3
Commit
c520cba3
authored
Mar 20, 2025
by
yuguo
Browse files
[DCU] Preliminary adaptation
parent
5b6ef054
Changes
79
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
283 additions
and
3 deletions
+283
-3
transformer_engine/common/normalization/common.h
transformer_engine/common/normalization/common.h
+10
-0
transformer_engine/common/normalization/layernorm/ln_api.cpp
transformer_engine/common/normalization/layernorm/ln_api.cpp
+7
-0
transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
...common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
+3
-0
transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu
...gine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu
+3
-0
transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
...ormer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
+7
-0
transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
...mon/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
+3
-0
transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
...e/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
+3
-0
transformer_engine/common/permutation/permutation.cu
transformer_engine/common/permutation/permutation.cu
+5
-0
transformer_engine/common/util/cast.cu
transformer_engine/common/util/cast.cu
+2
-0
transformer_engine/common/util/cast_gated_kernels.cuh
transformer_engine/common/util/cast_gated_kernels.cuh
+12
-0
transformer_engine/common/util/cast_kernels.cuh
transformer_engine/common/util/cast_kernels.cuh
+12
-0
transformer_engine/common/util/cuda_driver.cpp
transformer_engine/common/util/cuda_driver.cpp
+7
-0
transformer_engine/common/util/cuda_nvml.cpp
transformer_engine/common/util/cuda_nvml.cpp
+2
-1
transformer_engine/common/util/cuda_nvml.h
transformer_engine/common/util/cuda_nvml.h
+6
-0
transformer_engine/common/util/cuda_runtime.cpp
transformer_engine/common/util/cuda_runtime.cpp
+22
-0
transformer_engine/common/util/cuda_runtime.h
transformer_engine/common/util/cuda_runtime.h
+12
-0
transformer_engine/common/util/dequantize_kernels.cuh
transformer_engine/common/util/dequantize_kernels.cuh
+7
-0
transformer_engine/common/util/logging.h
transformer_engine/common/util/logging.h
+34
-1
transformer_engine/common/util/rtc.cpp
transformer_engine/common/util/rtc.cpp
+35
-0
transformer_engine/common/utils.cuh
transformer_engine/common/utils.cuh
+91
-1
No files found.
transformer_engine/common/normalization/common.h
View file @
c520cba3
...
...
@@ -7,9 +7,11 @@
#ifndef TRANSFORMER_ENGINE_COMMON_NORM_COMMON_H_
#define TRANSFORMER_ENGINE_COMMON_NORM_COMMON_H_
#ifndef __HIP_PLATFORM_AMD__
#include <cudnn.h>
#include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h>
#endif
#include <transformer_engine/transformer_engine.h>
#include <functional>
...
...
@@ -282,6 +284,7 @@ class CudnnNormalizationPlan : public NormalizationPlanBase {
const
NVTE_Norm_Type
_norm_type
;
std
::
unique_ptr
<
char
[]
>
_scalar_dptr
;
std
::
unique_ptr
<
float
>
_one_dptr
=
std
::
make_unique
<
float
>
(
1.0
f
);
#ifndef __HIP_PLATFORM_AMD__
// FWD
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
_x
,
_gamma_zero
,
_scalar_offset
,
_gamma
,
_beta
,
_eps
,
_mean
,
_rsigma
,
_z
,
_z_scale
,
_one_for_div
,
_z_scale_inv
,
_amax
,
_z_fp8
;
...
...
@@ -294,6 +297,7 @@ class CudnnNormalizationPlan : public NormalizationPlanBase {
fe
::
graph
::
Graph
_graph
;
std
::
unordered_map
<
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
void
*>
_variant_pack
;
cudnnHandle_t
_handle
;
#endif
};
class
NormalizationPlanRegistry
{
...
...
@@ -322,9 +326,15 @@ using byte = uint8_t;
using
int32
=
int32_t
;
using
fp32
=
float
;
using
fp16
=
half
;
#ifndef __HIP_PLATFORM_AMD__
using
bf16
=
nv_bfloat16
;
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
#else
using
bf16
=
__hip_bfloat16
;
using
fp8e4m3
=
hip_f8
<
hip_f8_type
::
fp8
>
;
using
fp8e5m2
=
hip_f8
<
hip_f8_type
::
bf8
>
;
#endif
template
<
typename
T
>
struct
TypeToDType
;
...
...
transformer_engine/common/normalization/layernorm/ln_api.cpp
View file @
c520cba3
...
...
@@ -57,7 +57,14 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
NVTE_Norm_Backend
norm_backend
;
bool
is_aligned
=
true
;
#ifdef USE_ROCM
NVTE_CHECK
(
!
is_block_scaling
(
z
->
scaling_mode
),
"Cudnn backend is need by block scaling mode for normalization! Not surpported in rocm yet."
);
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_block_scaling
(
z
->
scaling_mode
);
#else
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_block_scaling
(
z
->
scaling_mode
);
#endif
if
(
cudnn_backend
)
{
// TODO: add check for GPU ARCH
...
...
transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
View file @
c520cba3
...
...
@@ -38,10 +38,13 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
return
;
}
#ifndef __HIP_PLATFORM_AMD__
if
(
Kernel_traits
::
SMEM_BYTES
>=
48
*
1024
)
{
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES
));
}
#endif
auto
stream
=
launch_params
.
stream
;
auto
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
;
auto
ctas_per_row
=
launch_params
.
params
.
ctas_per_row
;
...
...
transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu
View file @
c520cba3
...
...
@@ -34,10 +34,13 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
return
;
}
#ifndef __HIP_PLATFORM_AMD__
if
(
Kernel_traits
::
SMEM_BYTES_FWD
>=
48
*
1024
)
{
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES_FWD
));
}
#endif
auto
stream
=
launch_params
.
stream
;
auto
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
;
auto
ctas_per_row
=
launch_params
.
params
.
ctas_per_row
;
...
...
transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
View file @
c520cba3
...
...
@@ -47,7 +47,14 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
NVTE_Norm_Backend
norm_backend
;
bool
is_aligned
=
true
;
#ifdef USE_ROCM
NVTE_CHECK
(
!
is_block_scaling
(
z
->
scaling_mode
),
"Cudnn backend is need by block scaling mode for normalization! Not surpported in rocm yet."
);
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_block_scaling
(
z
->
scaling_mode
);
#else
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_block_scaling
(
z
->
scaling_mode
);
#endif
bool
training
=
is_delayed_tensor_scaling
(
z
->
scaling_mode
)
||
(
z
->
columnwise_data
).
dptr
!=
nullptr
;
...
...
transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
View file @
c520cba3
...
...
@@ -37,10 +37,13 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
return
;
}
#ifndef __HIP_PLATFORM_AMD__
if
(
Kernel_traits
::
SMEM_BYTES
>=
48
*
1024
)
{
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES
));
}
#endif
auto
stream
=
launch_params
.
stream
;
auto
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
;
auto
ctas_per_row
=
launch_params
.
params
.
ctas_per_row
;
...
...
transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
View file @
c520cba3
...
...
@@ -35,10 +35,13 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
return
;
}
#ifndef __HIP_PLATFORM_AMD__
if
(
Kernel_traits
::
SMEM_BYTES_FWD
>=
48
*
1024
)
{
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES_FWD
));
}
#endif
auto
stream
=
launch_params
.
stream
;
auto
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
;
auto
ctas_per_row
=
launch_params
.
params
.
ctas_per_row
;
...
...
transformer_engine/common/permutation/permutation.cu
View file @
c520cba3
...
...
@@ -8,6 +8,11 @@
#include "../common.h"
#ifdef __HIP_PLATFORM_AMD__
using
__nv_fp8_e4m3
=
hip_f8
<
hip_f8_type
::
fp8
>
;
using
__nv_fp8_e5m2
=
hip_f8
<
hip_f8_type
::
bf8
>
;
#endif
static
__global__
void
moe_permute_row_map
(
const
int
*
sorted_row_id
,
int
*
row_id_map
,
const
int
num_rows
,
const
int
topK
,
const
int
num_out_tokens
)
{
...
...
transformer_engine/common/util/cast.cu
View file @
c520cba3
...
...
@@ -5,7 +5,9 @@
************************************************************************/
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_runtime.h>
#include <transformer_engine/cast.h>
...
...
transformer_engine/common/util/cast_gated_kernels.cuh
View file @
c520cba3
...
...
@@ -12,7 +12,9 @@
#define TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_runtime.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
...
...
@@ -723,6 +725,10 @@ template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP
float
(
*
DActOP
)(
float
,
const
ParamOP
&
)>
void
cast_fp8_gated
(
const
Tensor
&
grad
,
const
Tensor
&
gated_input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
static_assert
(
false
,
"Cast_fp8_gated is not surpported in rocm yet."
);
#else
if
(
output
->
has_data
())
{
NVTE_CHECK
(
output
->
scale_inv
.
dptr
!=
nullptr
,
"Scaling tensor must be allocated."
);
}
...
...
@@ -796,12 +802,17 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
tensor_map_output_gate
,
amax_ptr
,
scale_inv_ptr
,
scale_ptr
,
rows
,
cols
););
// NOLINT(*)
);
// NOLINT(*)
#endif
}
template
<
bool
IS_DGATED
,
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
float
(
*
DActOP
)(
float
,
const
ParamOP
&
)>
void
cast_mxfp8_gated
(
const
Tensor
&
grad
,
const
Tensor
&
gated_input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
static_assert
(
false
,
"Cast_mxfp8_gated is not surpported in rocm yet."
);
#else
const
bool
USE_ROWWISE_SCALING
=
output
->
has_data
();
const
bool
USE_COLWISE_SCALING
=
output
->
has_columnwise_data
();
...
...
@@ -919,6 +930,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
);
// NOLINT(*)
);
// NOLINT(*)
);
// NOLINT(*)
#endif
}
template
<
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
)>
...
...
transformer_engine/common/util/cast_kernels.cuh
View file @
c520cba3
...
...
@@ -12,7 +12,9 @@
#define TRANSFORMER_ENGINE_CAST_KERNELS_CUH_
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_runtime.h>
#include <transformer_engine/cast.h>
...
...
@@ -853,6 +855,10 @@ static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream
template
<
bool
IS_DBIAS
,
bool
IS_DACT
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
void
cast_fp8_2D
(
const
Tensor
&
input
,
const
Tensor
*
act_input
,
Tensor
*
output
,
Tensor
*
dbias
,
Tensor
*
workspace
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
static_assert
(
false
,
"Cast_fp8_2D is not surpported in rocm yet."
);
#else
checkCuDriverContext
(
stream
);
const
size_t
rows
=
input
.
flat_first_dim
();
...
...
@@ -916,6 +922,7 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T
reduce_dbias
<
IType
>
(
workspace_ptr
,
dbias
,
dbias_rows
,
dbias_cols
,
stream
);
});
// NOLINT(*)
);
// NOLINT(*)
#endif
}
template
<
bool
IS_DBIAS
,
bool
IS_DACT
,
bool
IS_ACT
,
typename
ParamOP
,
...
...
@@ -923,6 +930,10 @@ template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
void
mxfp8_quantize
(
const
Tensor
&
input
,
const
Tensor
*
act_input
,
const
Tensor
*
noop
,
// TODO (ksivamani)
Tensor
*
output
,
Tensor
*
dbias
,
Tensor
*
workspace
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
static_assert
(
false
,
"Mxfp8_quantize is not surpported in rocm yet."
);
#else
bool
use_rowwise_scaling
=
output
->
has_data
();
bool
use_colwise_scaling
=
output
->
has_columnwise_data
();
checkCuDriverContext
(
stream
);
...
...
@@ -1027,6 +1038,7 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
);
// NOLINT(*)
);
// NOLINT(*)
);
// NOLINT(*)
#endif
}
namespace
detail
{
...
...
transformer_engine/common/util/cuda_driver.cpp
View file @
c520cba3
...
...
@@ -15,10 +15,17 @@ namespace cuda_driver {
void
*
get_symbol
(
const
char
*
symbol
)
{
void
*
entry_point
;
#ifdef USE_ROCM
hipDriverProcAddressQueryResult
driver_result
;
NVTE_CHECK_CUDA
(
hipGetProcAddress
(
symbol
,
&
entry_point
,
HIP_VERSION_MAJOR
*
100
+
HIP_VERSION_MINOR
,
0
,
&
driver_result
));
NVTE_CHECK
(
driver_result
==
HIP_GET_PROC_ADDRESS_SUCCESS
,
"Could not find CUDA driver entry point for "
,
symbol
);
#else
cudaDriverEntryPointQueryResult
driver_result
;
NVTE_CHECK_CUDA
(
cudaGetDriverEntryPoint
(
symbol
,
&
entry_point
,
cudaEnableDefault
,
&
driver_result
));
NVTE_CHECK
(
driver_result
==
cudaDriverEntryPointSuccess
,
"Could not find CUDA driver entry point for "
,
symbol
);
#endif
return
entry_point
;
}
...
...
transformer_engine/common/util/cuda_nvml.cpp
View file @
c520cba3
...
...
@@ -11,6 +11,7 @@
namespace
transformer_engine
{
namespace
cuda_nvml
{
#ifndef __HIP_PLATFORM_AMD__
/*! \brief Lazily-initialized shared library for CUDA NVML */
Library
&
cuda_nvml_lib
()
{
...
...
@@ -20,7 +21,7 @@ Library &cuda_nvml_lib() {
}
void
*
get_symbol
(
const
char
*
symbol
)
{
return
cuda_nvml_lib
().
get_symbol
(
symbol
);
}
#endif
}
// namespace cuda_nvml
}
// namespace transformer_engine
transformer_engine/common/util/cuda_nvml.h
View file @
c520cba3
...
...
@@ -7,7 +7,9 @@
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_
#ifndef __HIP_PLATFORM_AMD__
#include <nvml.h>
#endif
#include <string>
...
...
@@ -17,6 +19,7 @@
namespace
transformer_engine
{
namespace
cuda_nvml
{
#ifndef __HIP_PLATFORM_AMD__
/*! \brief Get pointer corresponding to symbol in CUDA NVML library */
void
*
get_symbol
(
const
char
*
symbol
);
...
...
@@ -45,11 +48,13 @@ inline const char *get_nvml_error_string(nvmlReturn_t rc) {
FuncT
*
func
=
reinterpret_cast
<
FuncT
*>
(
get_symbol
(
"nvmlErrorString"
));
return
(
*
func
)(
rc
);
}
#endif
}
// namespace cuda_nvml
}
// namespace transformer_engine
#ifndef __HIP_PLATFORM_AMD__
#define NVTE_CHECK_CUDA_NVML(expr) \
do { \
const nvmlReturn_t status_NVTE_CHECK_CUDA_NVML = (expr); \
...
...
@@ -65,5 +70,6 @@ inline const char *get_nvml_error_string(nvmlReturn_t rc) {
do { \
NVTE_CHECK_CUDA_NVML(::transformer_engine::cuda_nvml::call(#symbol VA_ARGS(__VA_ARGS__))); \
} while (false)
#endif
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_
transformer_engine/common/util/cuda_runtime.cpp
View file @
c520cba3
...
...
@@ -18,12 +18,14 @@ namespace transformer_engine {
namespace
cuda
{
#ifndef __HIP_PLATFORM_AMD__
namespace
{
// String with build-time CUDA include path
#include "string_path_cuda_include.h"
}
// namespace
#endif // __HIP_PLATFORM_AMD__
int
num_devices
()
{
auto
query_num_devices
=
[]()
->
int
{
...
...
@@ -81,6 +83,24 @@ int sm_count(int device_id) {
return
cache
[
device_id
];
}
#ifdef __HIP_PLATFORM_AMD__
const
std
::
string
&
sm_arch_name
(
int
device_id
)
{
static
std
::
vector
<
std
::
string
>
cache
(
num_devices
(),
""
);
static
std
::
vector
<
std
::
once_flag
>
flags
(
num_devices
());
if
(
device_id
<
0
)
{
device_id
=
current_device
();
}
NVTE_CHECK
(
0
<=
device_id
&&
device_id
<
num_devices
(),
"invalid HIP device ID"
);
auto
init
=
[
&
]
()
{
cudaDeviceProp
prop
;
NVTE_CHECK_CUDA
(
cudaGetDeviceProperties
(
&
prop
,
device_id
));
cache
[
device_id
]
=
prop
.
gcnArchName
;
};
std
::
call_once
(
flags
[
device_id
],
init
);
return
cache
[
device_id
];
}
#endif // __HIP_PLATFORM_AMD__
void
stream_priority_range
(
int
*
low_priority
,
int
*
high_priority
,
int
device_id
)
{
static
std
::
vector
<
std
::
pair
<
int
,
int
>>
cache
(
num_devices
());
static
std
::
vector
<
std
::
once_flag
>
flags
(
num_devices
());
...
...
@@ -126,6 +146,7 @@ bool supports_multicast(int device_id) {
#endif
}
#ifndef __HIP_PLATFORM_AMD__
const
std
::
string
&
include_directory
(
bool
required
)
{
static
std
::
string
path
;
...
...
@@ -190,6 +211,7 @@ const std::string &include_directory(bool required) {
// Return cached path
return
path
;
}
#endif // __HIP_PLATFORM_AMD__
}
// namespace cuda
...
...
transformer_engine/common/util/cuda_runtime.h
View file @
c520cba3
...
...
@@ -30,6 +30,16 @@ int current_device();
*/
int
sm_arch
(
int
device_id
=
-
1
);
#ifdef __HIP_PLATFORM_AMD__
/* \brief Compute capability of device
*
* \param[in] device_id HIP device (default is current device)
*
* \return GPU arch name and compute capabilities string.
*/
const
std
::
string
&
sm_arch_name
(
int
device_id
=
-
1
);
#endif
/* \brief Number of multiprocessors on a device
*
* \param[in] device_id CUDA device (default is current device)
...
...
@@ -56,6 +66,7 @@ void stream_priority_range(int *low_priority, int *high_priority, int device_id
*/
bool
supports_multicast
(
int
device_id
=
-
1
);
#ifndef __HIP_PLATFORM_AMD__
/* \brief Path to CUDA Toolkit headers
*
* The path can be configured by setting NVTE_CUDA_INCLUDE_DIR in the
...
...
@@ -66,6 +77,7 @@ bool supports_multicast(int device_id = -1);
* \return Path to include directory, or an empty string if not found
*/
const
std
::
string
&
include_directory
(
bool
required
=
false
);
#endif
}
// namespace cuda
...
...
transformer_engine/common/util/dequantize_kernels.cuh
View file @
c520cba3
...
...
@@ -12,7 +12,9 @@
#define TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_runtime.h>
#include <transformer_engine/cast.h>
...
...
@@ -250,6 +252,10 @@ static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t str
}
static
void
mxfp8_dequantize
(
const
Tensor
&
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
static_assert
(
false
,
"Mxfp8_dequantize is not surpported in rocm yet."
);
#else
bool
use_rowwise_scaling
=
input
.
has_data
();
bool
use_colwise_scaling
=
input
.
has_columnwise_data
();
checkCuDriverContext
(
stream
);
...
...
@@ -332,6 +338,7 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
);
// NOLINT(*)
);
// NOLINT(*)
}
#endif
}
// namespace dequantization
namespace
detail
{
...
...
transformer_engine/common/util/logging.h
View file @
c520cba3
...
...
@@ -7,9 +7,19 @@
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT
#include <hipblaslt/hipblaslt.h>
#endif
#ifdef USE_ROCBLAS
#define ROCBLAS_BETA_FEATURES_API
#include <rocblas/rocblas.h>
#endif
#else
#include <cublas_v2.h>
#include <cudnn.h>
#endif // __HIP_PLATFORM_AMD__
#include <nvrtc.h>
#include <stdexcept>
...
...
@@ -39,6 +49,28 @@
} \
} while (false)
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT //hipblaslt
#define NVTE_CHECK_HIPBLASLT(expr) \
do { \
const hipblasStatus_t status_NVTE_CHECK_CUBLAS = (expr); \
if (status_NVTE_CHECK_CUBLAS != CUBLAS_STATUS_SUCCESS) { \
NVTE_ERROR("HIPBLASLT Error: ", \
std::to_string((int)status_NVTE_CHECK_CUBLAS)); \
} \
} while (false)
#endif
#ifdef USE_ROCBLAS //rocblas
#define NVTE_CHECK_ROCBLAS(expr) \
do { \
const rocblas_status status_NVTE_CHECK_CUBLAS = (expr); \
if (status_NVTE_CHECK_CUBLAS != rocblas_status_success) { \
NVTE_ERROR("ROCBLAS Error: " + \
std::string(rocblas_status_to_string(status_NVTE_CHECK_CUBLAS))); \
} \
} while (false)
#endif
#else //cublas
#define NVTE_CHECK_CUBLAS(expr) \
do { \
const cublasStatus_t status_NVTE_CHECK_CUBLAS = (expr); \
...
...
@@ -46,6 +78,7 @@
NVTE_ERROR("cuBLAS Error: ", cublasGetStatusString(status_NVTE_CHECK_CUBLAS)); \
} \
} while (false)
#endif
#define NVTE_CHECK_CUDNN(expr) \
do { \
...
...
transformer_engine/common/util/rtc.cpp
View file @
c520cba3
...
...
@@ -25,6 +25,12 @@ namespace {
#include "string_code_util_math_h.h"
#include "string_code_utils_cuh.h"
#ifdef USE_ROCM
#include "string_code_amd_detail_hip_float8_h.h"
#include "string_code_amd_detail_hip_f8_impl_h.h"
#endif // USE_ROCM
#ifndef USE_ROCM
/*! \brief Latest compute capability that NVRTC supports
*
* \return Compute capability as int. Last digit is minor revision,
...
...
@@ -42,6 +48,7 @@ inline int max_supported_sm_arch() {
}
return
arch_
;
}
#endif // USE_ROCM
}
// namespace
...
...
@@ -66,6 +73,9 @@ Kernel::~Kernel() {
for
(
int
device_id
=
0
;
device_id
<
static_cast
<
int
>
(
modules_
.
size
());
++
device_id
)
{
// Unload CUDA modules if needed
if
(
modules_
[
device_id
]
!=
null_module
)
{
#ifdef USE_ROCM
(
void
)
cuda_driver
::
call
(
"hipModuleUnload"
,
modules_
[
device_id
]);
#else
CUdevice
device
;
CUcontext
context
;
if
(
cuda_driver
::
call
(
"cuDeviceGet"
,
&
device
,
device_id
)
!=
CUDA_SUCCESS
)
{
...
...
@@ -79,6 +89,7 @@ Kernel::~Kernel() {
}
cuda_driver
::
call
(
"cuModuleUnload"
,
modules_
[
device_id
]);
cuda_driver
::
call
(
"cuDevicePrimaryCtxRelease"
,
device
);
#endif // USE_ROCM
}
}
}
...
...
@@ -143,9 +154,11 @@ void KernelManager::compile(const std::string& kernel_label, const std::string&
// Choose whether to compile to PTX or cubin
const
int
device_id
=
cuda
::
current_device
();
#ifndef USE_ROCM
const
int
sm_arch_
=
cuda
::
sm_arch
(
device_id
);
const
int
compile_sm_arch
=
std
::
min
(
sm_arch_
,
max_supported_sm_arch
());
const
bool
compile_ptx
=
(
CUDA_VERSION
<=
11000
)
||
(
sm_arch_
!=
compile_sm_arch
);
#endif // USE_ROCM
// Compilation flags
std
::
vector
<
std
::
string
>
opts
=
{
...
...
@@ -153,12 +166,15 @@ void KernelManager::compile(const std::string& kernel_label, const std::string&
"-G"
,
#endif
"--std=c++17"
};
#ifndef USE_ROCM
if
(
compile_ptx
)
{
opts
.
push_back
(
concat_strings
(
"--gpu-architecture=compute_"
,
compile_sm_arch
));
}
else
{
opts
.
push_back
(
concat_strings
(
"--gpu-architecture=sm_"
,
compile_sm_arch
));
}
opts
.
push_back
(
concat_strings
(
"-I"
,
cuda
::
include_directory
(
true
)));
#endif //USE_ROCM
std
::
vector
<
const
char
*>
opts_ptrs
;
for
(
const
auto
&
opt
:
opts
)
{
opts_ptrs
.
push_back
(
opt
.
c_str
());
...
...
@@ -166,9 +182,15 @@ void KernelManager::compile(const std::string& kernel_label, const std::string&
// Compile source
nvrtcProgram
program
;
#ifdef USE_ROCM
constexpr
int
num_headers
=
4
;
const
char
*
headers
[
num_headers
]
=
{
string_code_utils_cuh
,
string_code_util_math_h
,
string_code_amd_detail_hip_float8_h
,
string_code_amd_detail_hip_f8_impl_h
};
const
char
*
include_names
[
num_headers
]
=
{
"utils_hip.cuh"
,
"util/math.h"
,
"amd_detail/hip_float8.h"
,
"amd_detail/hip_f8_impl.h"
};
#else
constexpr
int
num_headers
=
2
;
constexpr
const
char
*
headers
[
num_headers
]
=
{
string_code_utils_cuh
,
string_code_util_math_h
};
constexpr
const
char
*
include_names
[
num_headers
]
=
{
"utils.cuh"
,
"util/math.h"
};
#endif // USE_ROCM
NVTE_CHECK_NVRTC
(
nvrtcCreateProgram
(
&
program
,
code
.
c_str
(),
filename
.
c_str
(),
num_headers
,
headers
,
include_names
));
NVTE_CHECK_NVRTC
(
nvrtcAddNameExpression
(
program
,
kernel_name
.
c_str
()));
...
...
@@ -193,6 +215,14 @@ void KernelManager::compile(const std::string& kernel_label, const std::string&
// Get compiled code
std
::
string
compiled_code
;
#ifdef USE_ROCM
{
size_t
compiled_size
;
NVTE_CHECK_NVRTC
(
hiprtcGetCodeSize
(
program
,
&
compiled_size
));
compiled_code
.
resize
(
compiled_size
);
NVTE_CHECK_NVRTC
(
hiprtcGetCode
(
program
,
compiled_code
.
data
()));
}
#else
if
(
compile_ptx
)
{
size_t
compiled_size
;
NVTE_CHECK_NVRTC
(
nvrtcGetPTXSize
(
program
,
&
compiled_size
));
...
...
@@ -204,6 +234,7 @@ void KernelManager::compile(const std::string& kernel_label, const std::string&
compiled_code
.
resize
(
compiled_size
);
NVTE_CHECK_NVRTC
(
nvrtcGetCUBIN
(
program
,
compiled_code
.
data
()));
}
#endif //USE_ROCM
// Cache compiled code
const
auto
key
=
get_kernel_cache_key
(
kernel_label
,
device_id
);
...
...
@@ -228,7 +259,11 @@ bool KernelManager::is_compiled(const std::string& kernel_label, int device_id)
std
::
string
KernelManager
::
get_kernel_cache_key
(
const
std
::
string
&
kernel_label
,
int
device_id
)
const
{
#ifdef USE_ROCM
return
concat_strings
(
cuda
::
sm_arch_name
(
device_id
),
","
,
kernel_label
);
#else
return
concat_strings
(
"sm="
,
cuda
::
sm_arch
(
device_id
),
","
,
kernel_label
);
#endif
}
}
// namespace rtc
...
...
transformer_engine/common/utils.cuh
View file @
c520cba3
...
...
@@ -7,10 +7,22 @@
#ifndef TRANSFORMER_ENGINE_COMMON_UTILS_CUH_
#define TRANSFORMER_ENGINE_COMMON_UTILS_CUH_
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#ifdef __HIP_PLATFORM_AMD__
#ifndef __HIPCC_RTC__
#include <cstdint>
#else
using
namespace
__hip_internal
;
#endif
#endif
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#ifdef __HIP_PLATFORM_AMD__
typedef
uint16_t
hip_bfloat16x2
__attribute__
((
ext_vector_type
(
2
)));
#else
#if !defined(__CUDACC_RTC__)
#include <cstdint>
#else
...
...
@@ -25,12 +37,14 @@ static_assert(sizeof(uint32_t) == 4);
static_assert
(
sizeof
(
uint64_t
)
==
8
);
#endif
#endif // __HIP_PLATFORM_AMD__
////////////////////////////////////////////////////////////////////////////////////////////////////
constexpr
uint32_t
THREADS_PER_WARP
=
32
;
////////////////////////////////////////////////////////////////////////////////////////////////////
#if !defined(USE_HIPBLASLT) && !defined(__HIPCC_RTC__)
inline
__device__
float2
operator
+
(
const
float2
&
a
,
const
float2
&
b
)
{
// NOLINT(*)
return
{
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
};
}
...
...
@@ -41,6 +55,7 @@ inline __device__ void operator+=(float2 &a, const float2 &b) { // NOLINT(*)
a
.
x
+=
b
.
x
;
a
.
y
+=
b
.
y
;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
@@ -54,7 +69,11 @@ struct Sum {
template
<
typename
T
>
inline
__device__
T
warp_shuffle_xor
(
const
T
&
x
,
uint32_t
idx
)
{
#ifdef __HIP_PLATFORM_AMD__
return
__shfl_xor
(
x
,
idx
,
THREADS_PER_WARP
);
#else
return
__shfl_xor_sync
(
static_cast
<
uint32_t
>
(
-
1
),
x
,
idx
);
#endif
}
template
<
>
...
...
@@ -64,7 +83,11 @@ inline __device__ float2 warp_shuffle_xor<float2>(const float2 &x, uint32_t idx)
template
<
typename
T
>
inline
__device__
T
warp_shuffle_down
(
const
T
&
x
,
uint32_t
idx
)
{
#ifdef __HIP_PLATFORM_AMD__
return
__shfl_down
(
x
,
idx
,
THREADS_PER_WARP
);
#else
return
__shfl_down_sync
(
static_cast
<
uint32_t
>
(
-
1
),
x
,
idx
);
#endif
}
template
<
>
...
...
@@ -154,10 +177,17 @@ struct TypeToVec2<half> {
using
Type
=
half2
;
};
#ifdef __HIP_PLATFORM_AMD__
template
<
>
struct
TypeToVec2
<
__hip_bfloat16
>
{
using
Type
=
hip_bfloat16x2
;
};
#else
template
<
>
struct
TypeToVec2
<
nv_bfloat16
>
{
using
Type
=
nv_bfloat162
;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
@@ -222,6 +252,20 @@ struct Converter<float2, half2> {
static
inline
__device__
half2
convert
(
const
float2
&
x
)
{
return
__float22half2_rn
(
x
);
}
};
#ifdef __HIP_PLATFORM_AMD__
template
<
>
struct
Converter
<
float2
,
hip_bfloat16x2
>
{
static
inline
__device__
hip_bfloat16x2
convert
(
const
float2
&
x
)
{
union
{
hip_bfloat16x2
raw
;
hip_bfloat16
elt
[
2
];
}
tmp
;
tmp
.
elt
[
0
]
=
__hip_bfloat16
(
x
.
x
);
tmp
.
elt
[
1
]
=
__hip_bfloat16
(
x
.
y
);
return
tmp
.
raw
;
}
};
#else
template
<
>
struct
Converter
<
float2
,
nv_bfloat162
>
{
static
inline
__device__
nv_bfloat162
convert
(
const
float2
&
x
)
{
...
...
@@ -238,6 +282,7 @@ struct Converter<float2, nv_bfloat162> {
#endif
}
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
@@ -266,6 +311,12 @@ struct Vec {
};
Alias_type
data
;
#ifdef __HIP_PLATFORM_AMD__
__HOST_DEVICE__
Vec
&
operator
=
(
const
Vec
&
rhs
)
{
data
.
vec
=
rhs
.
data
.
vec
;
return
*
this
;
}
#endif
template
<
typename
S
>
inline
__device__
void
to
(
Vec
<
S
,
NUM_ELT
>
&
other
)
{
// NOLINT(*)
...
...
@@ -346,12 +397,21 @@ struct InterCTASync {
// BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!
}
#ifdef __HIP_PLATFORM_AMD__
inline
__device__
void
spin_wait_
(
int
*
barrier
,
int
step
,
int
expected
)
{
__hip_atomic_fetch_add
(
barrier
,
step
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_AGENT
);
for
(
int
found
=
-
1
;
found
!=
expected
;
)
{
found
=
__hip_atomic_load
(
barrier
,
__ATOMIC_ACQUIRE
,
__HIP_MEMORY_SCOPE_AGENT
);
}
}
#else
inline
__device__
void
spin_wait_
(
int
*
barrier
,
int
step
,
int
expected
)
{
asm
volatile
(
"red.release.gpu.global.add.s32 [%0], %1;"
::
"l"
(
barrier
),
"r"
(
step
));
for
(
int
found
=
-
1
;
found
!=
expected
;)
{
asm
volatile
(
"ld.global.acquire.gpu.b32 %0, [%1];"
:
"=r"
(
found
)
:
"l"
(
barrier
));
}
}
#endif
inline
__device__
void
sync
()
{
// ALL THREADS MUST ENTER!
...
...
@@ -634,8 +694,13 @@ inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a,
m2_a
=
m2_ab
;
}
// Intra-warp broadcast (only lane 0 has valid stats).
#ifdef __HIP_PLATFORM_AMD__
m_a
=
__shfl
(
m_a
,
0
,
THREADS_PER_WARP
);
m2_a
=
__shfl
(
m2_a
,
0
,
THREADS_PER_WARP
);
#else
m_a
=
__shfl_sync
(
static_cast
<
uint32_t
>
(
-
1
),
m_a
,
0
);
m2_a
=
__shfl_sync
(
static_cast
<
uint32_t
>
(
-
1
),
m2_a
,
0
);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
@@ -811,7 +876,11 @@ __device__ __forceinline__ float warp_reduce_max(const float m) {
float
tmp
=
m
;
#pragma unroll
for
(
int
delta
=
num_elems
/
2
;
delta
>
0
;
delta
/=
2
)
{
#ifdef __HIP_PLATFORM_AMD__
const
float
other_m
=
__shfl_down
(
tmp
,
delta
,
THREADS_PER_WARP
);
#else
const
float
other_m
=
__shfl_down_sync
(
0xFFFFFFFF
,
tmp
,
delta
);
#endif
__builtin_assume
(
tmp
>=
0
);
__builtin_assume
(
other_m
>=
0
);
tmp
=
fmaxf
(
tmp
,
other_m
);
...
...
@@ -823,14 +892,22 @@ __forceinline__ __device__ float warp_reduce_max_broadcast(const float val) {
float
val_tmp
=
val
;
#pragma unroll
for
(
int
offset
=
THREADS_PER_WARP
/
2
;
offset
>
0
;
offset
/=
2
)
{
#ifdef __HIP_PLATFORM_AMD__
const
float
val_other
=
__shfl_down
(
val_tmp
,
offset
,
THREADS_PER_WARP
);
#else
const
float
val_other
=
__shfl_down_sync
(
0xFFFFFFFF
,
val_tmp
,
offset
);
#endif
__builtin_assume
(
val_tmp
>=
0
);
__builtin_assume
(
val_other
>=
0
);
val_tmp
=
fmaxf
(
val_tmp
,
val_other
);
}
// Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id
constexpr
int
subwarp_lane_zero
=
0
;
#ifdef __HIP_PLATFORM_AMD__
val_tmp
=
__shfl
(
val_tmp
,
subwarp_lane_zero
,
THREADS_PER_WARP
);
#else
val_tmp
=
__shfl_sync
(
0xFFFFFFFF
,
val_tmp
,
subwarp_lane_zero
);
#endif
return
val_tmp
;
}
...
...
@@ -864,14 +941,22 @@ __forceinline__ __device__ float subwarp_reduce_max_broadcast(const float val) {
float
val_tmp
=
val
;
#pragma unroll
for
(
int
offset
=
subwarp_width
/
2
;
offset
>
0
;
offset
/=
2
)
{
#ifdef __HIP_PLATFORM_AMD__
const
float
val_other
=
__shfl_down
(
val_tmp
,
offset
,
subwarp_width
);
#else
const
float
val_other
=
__shfl_down_sync
(
0xFFFFFFFF
,
val_tmp
,
offset
,
subwarp_width
);
#endif
__builtin_assume
(
val_tmp
>=
0
);
__builtin_assume
(
val_other
>=
0
);
val_tmp
=
fmaxf
(
val_tmp
,
val_other
);
}
// Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id
constexpr
int
subwarp_lane_zero
=
0
;
#ifdef __HIP_PLATFORM_AMD__
val_tmp
=
__shfl
(
val_tmp
,
subwarp_lane_zero
,
subwarp_width
);
#else
val_tmp
=
__shfl_sync
(
0xFFFFFFFF
,
val_tmp
,
subwarp_lane_zero
,
subwarp_width
);
#endif
return
val_tmp
;
}
...
...
@@ -897,8 +982,13 @@ __device__ __forceinline__ void reciprocal<float>(float *value_inv, const float
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifndef __HIP_PLATFORM_AMD__
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
#else
using
fp8e4m3
=
hip_f8
<
hip_f8_type
::
fp8
>
;
using
fp8e5m2
=
hip_f8
<
hip_f8_type
::
bf8
>
;
#endif
using
e8m0_t
=
uint8_t
;
constexpr
uint32_t
FP32_MANTISSA_BITS
=
23
;
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment