Commit c520cba3 authored by yuguo's avatar yuguo
Browse files

[DCU] Preliminary adaptation

parent 5b6ef054
...@@ -7,9 +7,11 @@ ...@@ -7,9 +7,11 @@
#ifndef TRANSFORMER_ENGINE_COMMON_NORM_COMMON_H_ #ifndef TRANSFORMER_ENGINE_COMMON_NORM_COMMON_H_
#define TRANSFORMER_ENGINE_COMMON_NORM_COMMON_H_ #define TRANSFORMER_ENGINE_COMMON_NORM_COMMON_H_
#ifndef __HIP_PLATFORM_AMD__
#include <cudnn.h> #include <cudnn.h>
#include <cudnn_frontend.h> #include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h> #include <cudnn_frontend_utils.h>
#endif
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <functional> #include <functional>
...@@ -282,6 +284,7 @@ class CudnnNormalizationPlan : public NormalizationPlanBase { ...@@ -282,6 +284,7 @@ class CudnnNormalizationPlan : public NormalizationPlanBase {
const NVTE_Norm_Type _norm_type; const NVTE_Norm_Type _norm_type;
std::unique_ptr<char[]> _scalar_dptr; std::unique_ptr<char[]> _scalar_dptr;
std::unique_ptr<float> _one_dptr = std::make_unique<float>(1.0f); std::unique_ptr<float> _one_dptr = std::make_unique<float>(1.0f);
#ifndef __HIP_PLATFORM_AMD__
// FWD // FWD
std::shared_ptr<fe::graph::Tensor_attributes> _x, _gamma_zero, _scalar_offset, _gamma, _beta, std::shared_ptr<fe::graph::Tensor_attributes> _x, _gamma_zero, _scalar_offset, _gamma, _beta,
_eps, _mean, _rsigma, _z, _z_scale, _one_for_div, _z_scale_inv, _amax, _z_fp8; _eps, _mean, _rsigma, _z, _z_scale, _one_for_div, _z_scale_inv, _amax, _z_fp8;
...@@ -294,6 +297,7 @@ class CudnnNormalizationPlan : public NormalizationPlanBase { ...@@ -294,6 +297,7 @@ class CudnnNormalizationPlan : public NormalizationPlanBase {
fe::graph::Graph _graph; fe::graph::Graph _graph;
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> _variant_pack; std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> _variant_pack;
cudnnHandle_t _handle; cudnnHandle_t _handle;
#endif
}; };
class NormalizationPlanRegistry { class NormalizationPlanRegistry {
...@@ -322,9 +326,15 @@ using byte = uint8_t; ...@@ -322,9 +326,15 @@ using byte = uint8_t;
using int32 = int32_t; using int32 = int32_t;
using fp32 = float; using fp32 = float;
using fp16 = half; using fp16 = half;
#ifndef __HIP_PLATFORM_AMD__
using bf16 = nv_bfloat16; using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3; using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2; using fp8e5m2 = __nv_fp8_e5m2;
#else
using bf16 = __hip_bfloat16;
using fp8e4m3 = hip_f8<hip_f8_type::fp8>;
using fp8e5m2 = hip_f8<hip_f8_type::bf8>;
#endif
template <typename T> template <typename T>
struct TypeToDType; struct TypeToDType;
......
...@@ -57,7 +57,14 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -57,7 +57,14 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
NVTE_Norm_Backend norm_backend; NVTE_Norm_Backend norm_backend;
bool is_aligned = true; bool is_aligned = true;
#ifdef USE_ROCM
NVTE_CHECK(
!is_block_scaling(z->scaling_mode),
"Cudnn backend is need by block scaling mode for normalization! Not surpported in rocm yet.");
bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode);
#else
bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode); bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode);
#endif
if (cudnn_backend) { if (cudnn_backend) {
// TODO: add check for GPU ARCH // TODO: add check for GPU ARCH
......
...@@ -38,10 +38,13 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params, ...@@ -38,10 +38,13 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
return; return;
} }
#ifndef __HIP_PLATFORM_AMD__
if (Kernel_traits::SMEM_BYTES >= 48 * 1024) { if (Kernel_traits::SMEM_BYTES >= 48 * 1024) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES)); Kernel_traits::SMEM_BYTES));
} }
#endif
auto stream = launch_params.stream; auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col; auto ctas_per_col = launch_params.params.ctas_per_col;
auto ctas_per_row = launch_params.params.ctas_per_row; auto ctas_per_row = launch_params.params.ctas_per_row;
......
...@@ -34,10 +34,13 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params, ...@@ -34,10 +34,13 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
return; return;
} }
#ifndef __HIP_PLATFORM_AMD__
if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) { if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES_FWD)); Kernel_traits::SMEM_BYTES_FWD));
} }
#endif
auto stream = launch_params.stream; auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col; auto ctas_per_col = launch_params.params.ctas_per_col;
auto ctas_per_row = launch_params.params.ctas_per_row; auto ctas_per_row = launch_params.params.ctas_per_row;
......
...@@ -47,7 +47,14 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -47,7 +47,14 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
NVTE_Norm_Backend norm_backend; NVTE_Norm_Backend norm_backend;
bool is_aligned = true; bool is_aligned = true;
#ifdef USE_ROCM
NVTE_CHECK(
!is_block_scaling(z->scaling_mode),
"Cudnn backend is need by block scaling mode for normalization! Not surpported in rocm yet.");
bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode);
#else
bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode); bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode);
#endif
bool training = bool training =
is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr; is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr;
......
...@@ -37,10 +37,13 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params, ...@@ -37,10 +37,13 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
return; return;
} }
#ifndef __HIP_PLATFORM_AMD__
if (Kernel_traits::SMEM_BYTES >= 48 * 1024) { if (Kernel_traits::SMEM_BYTES >= 48 * 1024) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES)); Kernel_traits::SMEM_BYTES));
} }
#endif
auto stream = launch_params.stream; auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col; auto ctas_per_col = launch_params.params.ctas_per_col;
auto ctas_per_row = launch_params.params.ctas_per_row; auto ctas_per_row = launch_params.params.ctas_per_row;
......
...@@ -35,10 +35,13 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params, ...@@ -35,10 +35,13 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
return; return;
} }
#ifndef __HIP_PLATFORM_AMD__
if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) { if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES_FWD)); Kernel_traits::SMEM_BYTES_FWD));
} }
#endif
auto stream = launch_params.stream; auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col; auto ctas_per_col = launch_params.params.ctas_per_col;
auto ctas_per_row = launch_params.params.ctas_per_row; auto ctas_per_row = launch_params.params.ctas_per_row;
......
...@@ -8,6 +8,11 @@ ...@@ -8,6 +8,11 @@
#include "../common.h" #include "../common.h"
#ifdef __HIP_PLATFORM_AMD__
using __nv_fp8_e4m3 = hip_f8<hip_f8_type::fp8>;
using __nv_fp8_e5m2 = hip_f8<hip_f8_type::bf8>;
#endif
static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id_map, static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id_map,
const int num_rows, const int topK, const int num_rows, const int topK,
const int num_out_tokens) { const int num_out_tokens) {
......
...@@ -5,7 +5,9 @@ ...@@ -5,7 +5,9 @@
************************************************************************/ ************************************************************************/
#include <cuda.h> #include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h> #include <cudaTypedefs.h>
#endif
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <transformer_engine/cast.h> #include <transformer_engine/cast.h>
......
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
#define TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_ #define TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_
#include <cuda.h> #include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h> #include <cudaTypedefs.h>
#endif
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <transformer_engine/activation.h> #include <transformer_engine/activation.h>
#include <transformer_engine/cast.h> #include <transformer_engine/cast.h>
...@@ -723,6 +725,10 @@ template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP ...@@ -723,6 +725,10 @@ template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output,
cudaStream_t stream) { cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
static_assert(false,
"Cast_fp8_gated is not surpported in rocm yet.");
#else
if (output->has_data()) { if (output->has_data()) {
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated.");
} }
...@@ -796,12 +802,17 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu ...@@ -796,12 +802,17 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows,
cols);); // NOLINT(*) cols);); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
#endif
} }
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output,
cudaStream_t stream) { cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
static_assert(false,
"Cast_mxfp8_gated is not surpported in rocm yet.");
#else
const bool USE_ROWWISE_SCALING = output->has_data(); const bool USE_ROWWISE_SCALING = output->has_data();
const bool USE_COLWISE_SCALING = output->has_columnwise_data(); const bool USE_COLWISE_SCALING = output->has_columnwise_data();
...@@ -919,6 +930,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -919,6 +930,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
); // NOLINT(*) ); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
#endif
} }
template <typename ParamOP, float (*ActOP)(float, const ParamOP &)> template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
......
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
#define TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ #define TRANSFORMER_ENGINE_CAST_KERNELS_CUH_
#include <cuda.h> #include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h> #include <cudaTypedefs.h>
#endif
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <transformer_engine/cast.h> #include <transformer_engine/cast.h>
...@@ -853,6 +855,10 @@ static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream ...@@ -853,6 +855,10 @@ static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream
template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, const ParamOP &)> template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, Tensor *dbias, void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, Tensor *dbias,
Tensor *workspace, cudaStream_t stream) { Tensor *workspace, cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
static_assert(false,
"Cast_fp8_2D is not surpported in rocm yet.");
#else
checkCuDriverContext(stream); checkCuDriverContext(stream);
const size_t rows = input.flat_first_dim(); const size_t rows = input.flat_first_dim();
...@@ -916,6 +922,7 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T ...@@ -916,6 +922,7 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T
reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream);
}); // NOLINT(*) }); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
#endif
} }
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP, template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
...@@ -923,6 +930,10 @@ template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP, ...@@ -923,6 +930,10 @@ template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
void mxfp8_quantize(const Tensor &input, const Tensor *act_input, void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
const Tensor *noop, // TODO (ksivamani) const Tensor *noop, // TODO (ksivamani)
Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
static_assert(false,
"Mxfp8_quantize is not surpported in rocm yet.");
#else
bool use_rowwise_scaling = output->has_data(); bool use_rowwise_scaling = output->has_data();
bool use_colwise_scaling = output->has_columnwise_data(); bool use_colwise_scaling = output->has_columnwise_data();
checkCuDriverContext(stream); checkCuDriverContext(stream);
...@@ -1027,6 +1038,7 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -1027,6 +1038,7 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
); // NOLINT(*) ); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
#endif
} }
namespace detail { namespace detail {
......
...@@ -15,10 +15,17 @@ namespace cuda_driver { ...@@ -15,10 +15,17 @@ namespace cuda_driver {
void *get_symbol(const char *symbol) { void *get_symbol(const char *symbol) {
void *entry_point; void *entry_point;
#ifdef USE_ROCM
hipDriverProcAddressQueryResult driver_result;
NVTE_CHECK_CUDA(hipGetProcAddress(symbol, &entry_point, HIP_VERSION_MAJOR*100+HIP_VERSION_MINOR, 0, &driver_result));
NVTE_CHECK(driver_result == HIP_GET_PROC_ADDRESS_SUCCESS,
"Could not find CUDA driver entry point for ", symbol);
#else
cudaDriverEntryPointQueryResult driver_result; cudaDriverEntryPointQueryResult driver_result;
NVTE_CHECK_CUDA(cudaGetDriverEntryPoint(symbol, &entry_point, cudaEnableDefault, &driver_result)); NVTE_CHECK_CUDA(cudaGetDriverEntryPoint(symbol, &entry_point, cudaEnableDefault, &driver_result));
NVTE_CHECK(driver_result == cudaDriverEntryPointSuccess, NVTE_CHECK(driver_result == cudaDriverEntryPointSuccess,
"Could not find CUDA driver entry point for ", symbol); "Could not find CUDA driver entry point for ", symbol);
#endif
return entry_point; return entry_point;
} }
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
namespace transformer_engine { namespace transformer_engine {
namespace cuda_nvml { namespace cuda_nvml {
#ifndef __HIP_PLATFORM_AMD__
/*! \brief Lazily-initialized shared library for CUDA NVML */ /*! \brief Lazily-initialized shared library for CUDA NVML */
Library &cuda_nvml_lib() { Library &cuda_nvml_lib() {
...@@ -20,7 +21,7 @@ Library &cuda_nvml_lib() { ...@@ -20,7 +21,7 @@ Library &cuda_nvml_lib() {
} }
void *get_symbol(const char *symbol) { return cuda_nvml_lib().get_symbol(symbol); } void *get_symbol(const char *symbol) { return cuda_nvml_lib().get_symbol(symbol); }
#endif
} // namespace cuda_nvml } // namespace cuda_nvml
} // namespace transformer_engine } // namespace transformer_engine
...@@ -7,7 +7,9 @@ ...@@ -7,7 +7,9 @@
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_ #ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_ #define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_
#ifndef __HIP_PLATFORM_AMD__
#include <nvml.h> #include <nvml.h>
#endif
#include <string> #include <string>
...@@ -17,6 +19,7 @@ ...@@ -17,6 +19,7 @@
namespace transformer_engine { namespace transformer_engine {
namespace cuda_nvml { namespace cuda_nvml {
#ifndef __HIP_PLATFORM_AMD__
/*! \brief Get pointer corresponding to symbol in CUDA NVML library */ /*! \brief Get pointer corresponding to symbol in CUDA NVML library */
void *get_symbol(const char *symbol); void *get_symbol(const char *symbol);
...@@ -45,11 +48,13 @@ inline const char *get_nvml_error_string(nvmlReturn_t rc) { ...@@ -45,11 +48,13 @@ inline const char *get_nvml_error_string(nvmlReturn_t rc) {
FuncT *func = reinterpret_cast<FuncT *>(get_symbol("nvmlErrorString")); FuncT *func = reinterpret_cast<FuncT *>(get_symbol("nvmlErrorString"));
return (*func)(rc); return (*func)(rc);
} }
#endif
} // namespace cuda_nvml } // namespace cuda_nvml
} // namespace transformer_engine } // namespace transformer_engine
#ifndef __HIP_PLATFORM_AMD__
#define NVTE_CHECK_CUDA_NVML(expr) \ #define NVTE_CHECK_CUDA_NVML(expr) \
do { \ do { \
const nvmlReturn_t status_NVTE_CHECK_CUDA_NVML = (expr); \ const nvmlReturn_t status_NVTE_CHECK_CUDA_NVML = (expr); \
...@@ -65,5 +70,6 @@ inline const char *get_nvml_error_string(nvmlReturn_t rc) { ...@@ -65,5 +70,6 @@ inline const char *get_nvml_error_string(nvmlReturn_t rc) {
do { \ do { \
NVTE_CHECK_CUDA_NVML(::transformer_engine::cuda_nvml::call(#symbol VA_ARGS(__VA_ARGS__))); \ NVTE_CHECK_CUDA_NVML(::transformer_engine::cuda_nvml::call(#symbol VA_ARGS(__VA_ARGS__))); \
} while (false) } while (false)
#endif
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_ #endif // TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_
...@@ -18,12 +18,14 @@ namespace transformer_engine { ...@@ -18,12 +18,14 @@ namespace transformer_engine {
namespace cuda { namespace cuda {
#ifndef __HIP_PLATFORM_AMD__
namespace { namespace {
// String with build-time CUDA include path // String with build-time CUDA include path
#include "string_path_cuda_include.h" #include "string_path_cuda_include.h"
} // namespace } // namespace
#endif // __HIP_PLATFORM_AMD__
int num_devices() { int num_devices() {
auto query_num_devices = []() -> int { auto query_num_devices = []() -> int {
...@@ -81,6 +83,24 @@ int sm_count(int device_id) { ...@@ -81,6 +83,24 @@ int sm_count(int device_id) {
return cache[device_id]; return cache[device_id];
} }
#ifdef __HIP_PLATFORM_AMD__
const std::string &sm_arch_name(int device_id) {
static std::vector<std::string> cache(num_devices(), "");
static std::vector<std::once_flag> flags(num_devices());
if (device_id < 0) {
device_id = current_device();
}
NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid HIP device ID");
auto init = [&] () {
cudaDeviceProp prop;
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, device_id));
cache[device_id] = prop.gcnArchName;
};
std::call_once(flags[device_id], init);
return cache[device_id];
}
#endif // __HIP_PLATFORM_AMD__
void stream_priority_range(int *low_priority, int *high_priority, int device_id) { void stream_priority_range(int *low_priority, int *high_priority, int device_id) {
static std::vector<std::pair<int, int>> cache(num_devices()); static std::vector<std::pair<int, int>> cache(num_devices());
static std::vector<std::once_flag> flags(num_devices()); static std::vector<std::once_flag> flags(num_devices());
...@@ -126,6 +146,7 @@ bool supports_multicast(int device_id) { ...@@ -126,6 +146,7 @@ bool supports_multicast(int device_id) {
#endif #endif
} }
#ifndef __HIP_PLATFORM_AMD__
const std::string &include_directory(bool required) { const std::string &include_directory(bool required) {
static std::string path; static std::string path;
...@@ -190,6 +211,7 @@ const std::string &include_directory(bool required) { ...@@ -190,6 +211,7 @@ const std::string &include_directory(bool required) {
// Return cached path // Return cached path
return path; return path;
} }
#endif // __HIP_PLATFORM_AMD__
} // namespace cuda } // namespace cuda
......
...@@ -30,6 +30,16 @@ int current_device(); ...@@ -30,6 +30,16 @@ int current_device();
*/ */
int sm_arch(int device_id = -1); int sm_arch(int device_id = -1);
#ifdef __HIP_PLATFORM_AMD__
/* \brief Compute capability of device
*
* \param[in] device_id HIP device (default is current device)
*
* \return GPU arch name and compute capabilities string.
*/
const std::string &sm_arch_name(int device_id = -1);
#endif
/* \brief Number of multiprocessors on a device /* \brief Number of multiprocessors on a device
* *
* \param[in] device_id CUDA device (default is current device) * \param[in] device_id CUDA device (default is current device)
...@@ -56,6 +66,7 @@ void stream_priority_range(int *low_priority, int *high_priority, int device_id ...@@ -56,6 +66,7 @@ void stream_priority_range(int *low_priority, int *high_priority, int device_id
*/ */
bool supports_multicast(int device_id = -1); bool supports_multicast(int device_id = -1);
#ifndef __HIP_PLATFORM_AMD__
/* \brief Path to CUDA Toolkit headers /* \brief Path to CUDA Toolkit headers
* *
* The path can be configured by setting NVTE_CUDA_INCLUDE_DIR in the * The path can be configured by setting NVTE_CUDA_INCLUDE_DIR in the
...@@ -66,6 +77,7 @@ bool supports_multicast(int device_id = -1); ...@@ -66,6 +77,7 @@ bool supports_multicast(int device_id = -1);
* \return Path to include directory, or an empty string if not found * \return Path to include directory, or an empty string if not found
*/ */
const std::string &include_directory(bool required = false); const std::string &include_directory(bool required = false);
#endif
} // namespace cuda } // namespace cuda
......
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
#define TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_ #define TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_
#include <cuda.h> #include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h> #include <cudaTypedefs.h>
#endif
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <transformer_engine/cast.h> #include <transformer_engine/cast.h>
...@@ -250,6 +252,10 @@ static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t str ...@@ -250,6 +252,10 @@ static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t str
} }
static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
static_assert(false,
"Mxfp8_dequantize is not surpported in rocm yet.");
#else
bool use_rowwise_scaling = input.has_data(); bool use_rowwise_scaling = input.has_data();
bool use_colwise_scaling = input.has_columnwise_data(); bool use_colwise_scaling = input.has_columnwise_data();
checkCuDriverContext(stream); checkCuDriverContext(stream);
...@@ -332,6 +338,7 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s ...@@ -332,6 +338,7 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
); // NOLINT(*) ); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
} }
#endif
} // namespace dequantization } // namespace dequantization
namespace detail { namespace detail {
......
...@@ -7,9 +7,19 @@ ...@@ -7,9 +7,19 @@
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ #ifndef TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ #define TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_
#include <cublas_v2.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT
#include <hipblaslt/hipblaslt.h>
#endif
#ifdef USE_ROCBLAS
#define ROCBLAS_BETA_FEATURES_API
#include <rocblas/rocblas.h>
#endif
#else
#include <cublas_v2.h>
#include <cudnn.h> #include <cudnn.h>
#endif // __HIP_PLATFORM_AMD__
#include <nvrtc.h> #include <nvrtc.h>
#include <stdexcept> #include <stdexcept>
...@@ -39,6 +49,28 @@ ...@@ -39,6 +49,28 @@
} \ } \
} while (false) } while (false)
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT //hipblaslt
#define NVTE_CHECK_HIPBLASLT(expr) \
do { \
const hipblasStatus_t status_NVTE_CHECK_CUBLAS = (expr); \
if (status_NVTE_CHECK_CUBLAS != CUBLAS_STATUS_SUCCESS) { \
NVTE_ERROR("HIPBLASLT Error: ", \
std::to_string((int)status_NVTE_CHECK_CUBLAS)); \
} \
} while (false)
#endif
#ifdef USE_ROCBLAS //rocblas
#define NVTE_CHECK_ROCBLAS(expr) \
do { \
const rocblas_status status_NVTE_CHECK_CUBLAS = (expr); \
if (status_NVTE_CHECK_CUBLAS != rocblas_status_success) { \
NVTE_ERROR("ROCBLAS Error: " + \
std::string(rocblas_status_to_string(status_NVTE_CHECK_CUBLAS))); \
} \
} while (false)
#endif
#else //cublas
#define NVTE_CHECK_CUBLAS(expr) \ #define NVTE_CHECK_CUBLAS(expr) \
do { \ do { \
const cublasStatus_t status_NVTE_CHECK_CUBLAS = (expr); \ const cublasStatus_t status_NVTE_CHECK_CUBLAS = (expr); \
...@@ -46,6 +78,7 @@ ...@@ -46,6 +78,7 @@
NVTE_ERROR("cuBLAS Error: ", cublasGetStatusString(status_NVTE_CHECK_CUBLAS)); \ NVTE_ERROR("cuBLAS Error: ", cublasGetStatusString(status_NVTE_CHECK_CUBLAS)); \
} \ } \
} while (false) } while (false)
#endif
#define NVTE_CHECK_CUDNN(expr) \ #define NVTE_CHECK_CUDNN(expr) \
do { \ do { \
......
...@@ -25,6 +25,12 @@ namespace { ...@@ -25,6 +25,12 @@ namespace {
#include "string_code_util_math_h.h" #include "string_code_util_math_h.h"
#include "string_code_utils_cuh.h" #include "string_code_utils_cuh.h"
#ifdef USE_ROCM
#include "string_code_amd_detail_hip_float8_h.h"
#include "string_code_amd_detail_hip_f8_impl_h.h"
#endif // USE_ROCM
#ifndef USE_ROCM
/*! \brief Latest compute capability that NVRTC supports /*! \brief Latest compute capability that NVRTC supports
* *
* \return Compute capability as int. Last digit is minor revision, * \return Compute capability as int. Last digit is minor revision,
...@@ -42,6 +48,7 @@ inline int max_supported_sm_arch() { ...@@ -42,6 +48,7 @@ inline int max_supported_sm_arch() {
} }
return arch_; return arch_;
} }
#endif // USE_ROCM
} // namespace } // namespace
...@@ -66,6 +73,9 @@ Kernel::~Kernel() { ...@@ -66,6 +73,9 @@ Kernel::~Kernel() {
for (int device_id = 0; device_id < static_cast<int>(modules_.size()); ++device_id) { for (int device_id = 0; device_id < static_cast<int>(modules_.size()); ++device_id) {
// Unload CUDA modules if needed // Unload CUDA modules if needed
if (modules_[device_id] != null_module) { if (modules_[device_id] != null_module) {
#ifdef USE_ROCM
(void)cuda_driver::call("hipModuleUnload", modules_[device_id]);
#else
CUdevice device; CUdevice device;
CUcontext context; CUcontext context;
if (cuda_driver::call("cuDeviceGet", &device, device_id) != CUDA_SUCCESS) { if (cuda_driver::call("cuDeviceGet", &device, device_id) != CUDA_SUCCESS) {
...@@ -79,6 +89,7 @@ Kernel::~Kernel() { ...@@ -79,6 +89,7 @@ Kernel::~Kernel() {
} }
cuda_driver::call("cuModuleUnload", modules_[device_id]); cuda_driver::call("cuModuleUnload", modules_[device_id]);
cuda_driver::call("cuDevicePrimaryCtxRelease", device); cuda_driver::call("cuDevicePrimaryCtxRelease", device);
#endif // USE_ROCM
} }
} }
} }
...@@ -143,9 +154,11 @@ void KernelManager::compile(const std::string& kernel_label, const std::string& ...@@ -143,9 +154,11 @@ void KernelManager::compile(const std::string& kernel_label, const std::string&
// Choose whether to compile to PTX or cubin // Choose whether to compile to PTX or cubin
const int device_id = cuda::current_device(); const int device_id = cuda::current_device();
#ifndef USE_ROCM
const int sm_arch_ = cuda::sm_arch(device_id); const int sm_arch_ = cuda::sm_arch(device_id);
const int compile_sm_arch = std::min(sm_arch_, max_supported_sm_arch()); const int compile_sm_arch = std::min(sm_arch_, max_supported_sm_arch());
const bool compile_ptx = (CUDA_VERSION <= 11000) || (sm_arch_ != compile_sm_arch); const bool compile_ptx = (CUDA_VERSION <= 11000) || (sm_arch_ != compile_sm_arch);
#endif // USE_ROCM
// Compilation flags // Compilation flags
std::vector<std::string> opts = { std::vector<std::string> opts = {
...@@ -153,12 +166,15 @@ void KernelManager::compile(const std::string& kernel_label, const std::string& ...@@ -153,12 +166,15 @@ void KernelManager::compile(const std::string& kernel_label, const std::string&
"-G", "-G",
#endif #endif
"--std=c++17"}; "--std=c++17"};
#ifndef USE_ROCM
if (compile_ptx) { if (compile_ptx) {
opts.push_back(concat_strings("--gpu-architecture=compute_", compile_sm_arch)); opts.push_back(concat_strings("--gpu-architecture=compute_", compile_sm_arch));
} else { } else {
opts.push_back(concat_strings("--gpu-architecture=sm_", compile_sm_arch)); opts.push_back(concat_strings("--gpu-architecture=sm_", compile_sm_arch));
} }
opts.push_back(concat_strings("-I", cuda::include_directory(true))); opts.push_back(concat_strings("-I", cuda::include_directory(true)));
#endif //USE_ROCM
std::vector<const char*> opts_ptrs; std::vector<const char*> opts_ptrs;
for (const auto& opt : opts) { for (const auto& opt : opts) {
opts_ptrs.push_back(opt.c_str()); opts_ptrs.push_back(opt.c_str());
...@@ -166,9 +182,15 @@ void KernelManager::compile(const std::string& kernel_label, const std::string& ...@@ -166,9 +182,15 @@ void KernelManager::compile(const std::string& kernel_label, const std::string&
// Compile source // Compile source
nvrtcProgram program; nvrtcProgram program;
#ifdef USE_ROCM
constexpr int num_headers = 4;
const char* headers[num_headers] = {string_code_utils_cuh, string_code_util_math_h, string_code_amd_detail_hip_float8_h, string_code_amd_detail_hip_f8_impl_h};
const char* include_names[num_headers] = {"utils_hip.cuh", "util/math.h", "amd_detail/hip_float8.h", "amd_detail/hip_f8_impl.h"};
#else
constexpr int num_headers = 2; constexpr int num_headers = 2;
constexpr const char* headers[num_headers] = {string_code_utils_cuh, string_code_util_math_h}; constexpr const char* headers[num_headers] = {string_code_utils_cuh, string_code_util_math_h};
constexpr const char* include_names[num_headers] = {"utils.cuh", "util/math.h"}; constexpr const char* include_names[num_headers] = {"utils.cuh", "util/math.h"};
#endif // USE_ROCM
NVTE_CHECK_NVRTC(nvrtcCreateProgram(&program, code.c_str(), filename.c_str(), num_headers, NVTE_CHECK_NVRTC(nvrtcCreateProgram(&program, code.c_str(), filename.c_str(), num_headers,
headers, include_names)); headers, include_names));
NVTE_CHECK_NVRTC(nvrtcAddNameExpression(program, kernel_name.c_str())); NVTE_CHECK_NVRTC(nvrtcAddNameExpression(program, kernel_name.c_str()));
...@@ -193,6 +215,14 @@ void KernelManager::compile(const std::string& kernel_label, const std::string& ...@@ -193,6 +215,14 @@ void KernelManager::compile(const std::string& kernel_label, const std::string&
// Get compiled code // Get compiled code
std::string compiled_code; std::string compiled_code;
#ifdef USE_ROCM
{
size_t compiled_size;
NVTE_CHECK_NVRTC(hiprtcGetCodeSize(program, &compiled_size));
compiled_code.resize(compiled_size);
NVTE_CHECK_NVRTC(hiprtcGetCode(program, compiled_code.data()));
}
#else
if (compile_ptx) { if (compile_ptx) {
size_t compiled_size; size_t compiled_size;
NVTE_CHECK_NVRTC(nvrtcGetPTXSize(program, &compiled_size)); NVTE_CHECK_NVRTC(nvrtcGetPTXSize(program, &compiled_size));
...@@ -204,6 +234,7 @@ void KernelManager::compile(const std::string& kernel_label, const std::string& ...@@ -204,6 +234,7 @@ void KernelManager::compile(const std::string& kernel_label, const std::string&
compiled_code.resize(compiled_size); compiled_code.resize(compiled_size);
NVTE_CHECK_NVRTC(nvrtcGetCUBIN(program, compiled_code.data())); NVTE_CHECK_NVRTC(nvrtcGetCUBIN(program, compiled_code.data()));
} }
#endif //USE_ROCM
// Cache compiled code // Cache compiled code
const auto key = get_kernel_cache_key(kernel_label, device_id); const auto key = get_kernel_cache_key(kernel_label, device_id);
...@@ -228,7 +259,11 @@ bool KernelManager::is_compiled(const std::string& kernel_label, int device_id) ...@@ -228,7 +259,11 @@ bool KernelManager::is_compiled(const std::string& kernel_label, int device_id)
std::string KernelManager::get_kernel_cache_key(const std::string& kernel_label, std::string KernelManager::get_kernel_cache_key(const std::string& kernel_label,
int device_id) const { int device_id) const {
#ifdef USE_ROCM
return concat_strings(cuda::sm_arch_name(device_id), ",", kernel_label);
#else
return concat_strings("sm=", cuda::sm_arch(device_id), ",", kernel_label); return concat_strings("sm=", cuda::sm_arch(device_id), ",", kernel_label);
#endif
} }
} // namespace rtc } // namespace rtc
......
...@@ -7,10 +7,22 @@ ...@@ -7,10 +7,22 @@
#ifndef TRANSFORMER_ENGINE_COMMON_UTILS_CUH_ #ifndef TRANSFORMER_ENGINE_COMMON_UTILS_CUH_
#define TRANSFORMER_ENGINE_COMMON_UTILS_CUH_ #define TRANSFORMER_ENGINE_COMMON_UTILS_CUH_
#include <cuda_bf16.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#ifdef __HIP_PLATFORM_AMD__
#ifndef __HIPCC_RTC__
#include <cstdint>
#else
using namespace __hip_internal;
#endif
#endif
#include <cuda_bf16.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#ifdef __HIP_PLATFORM_AMD__
typedef uint16_t hip_bfloat16x2 __attribute__((ext_vector_type(2)));
#else
#if !defined(__CUDACC_RTC__) #if !defined(__CUDACC_RTC__)
#include <cstdint> #include <cstdint>
#else #else
...@@ -25,12 +37,14 @@ static_assert(sizeof(uint32_t) == 4); ...@@ -25,12 +37,14 @@ static_assert(sizeof(uint32_t) == 4);
static_assert(sizeof(uint64_t) == 8); static_assert(sizeof(uint64_t) == 8);
#endif #endif
#endif // __HIP_PLATFORM_AMD__
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
constexpr uint32_t THREADS_PER_WARP = 32; constexpr uint32_t THREADS_PER_WARP = 32;
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
#if !defined(USE_HIPBLASLT) && !defined(__HIPCC_RTC__)
inline __device__ float2 operator+(const float2 &a, const float2 &b) { // NOLINT(*) inline __device__ float2 operator+(const float2 &a, const float2 &b) { // NOLINT(*)
return {a.x + b.x, a.y + b.y}; return {a.x + b.x, a.y + b.y};
} }
...@@ -41,6 +55,7 @@ inline __device__ void operator+=(float2 &a, const float2 &b) { // NOLINT(*) ...@@ -41,6 +55,7 @@ inline __device__ void operator+=(float2 &a, const float2 &b) { // NOLINT(*)
a.x += b.x; a.x += b.x;
a.y += b.y; a.y += b.y;
} }
#endif
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -54,7 +69,11 @@ struct Sum { ...@@ -54,7 +69,11 @@ struct Sum {
template <typename T> template <typename T>
inline __device__ T warp_shuffle_xor(const T &x, uint32_t idx) { inline __device__ T warp_shuffle_xor(const T &x, uint32_t idx) {
#ifdef __HIP_PLATFORM_AMD__
return __shfl_xor(x, idx, THREADS_PER_WARP);
#else
return __shfl_xor_sync(static_cast<uint32_t>(-1), x, idx); return __shfl_xor_sync(static_cast<uint32_t>(-1), x, idx);
#endif
} }
template <> template <>
...@@ -64,7 +83,11 @@ inline __device__ float2 warp_shuffle_xor<float2>(const float2 &x, uint32_t idx) ...@@ -64,7 +83,11 @@ inline __device__ float2 warp_shuffle_xor<float2>(const float2 &x, uint32_t idx)
template <typename T> template <typename T>
inline __device__ T warp_shuffle_down(const T &x, uint32_t idx) { inline __device__ T warp_shuffle_down(const T &x, uint32_t idx) {
#ifdef __HIP_PLATFORM_AMD__
return __shfl_down(x, idx, THREADS_PER_WARP);
#else
return __shfl_down_sync(static_cast<uint32_t>(-1), x, idx); return __shfl_down_sync(static_cast<uint32_t>(-1), x, idx);
#endif
} }
template <> template <>
...@@ -154,10 +177,17 @@ struct TypeToVec2<half> { ...@@ -154,10 +177,17 @@ struct TypeToVec2<half> {
using Type = half2; using Type = half2;
}; };
#ifdef __HIP_PLATFORM_AMD__
template <>
struct TypeToVec2<__hip_bfloat16> {
using Type = hip_bfloat16x2;
};
#else
template <> template <>
struct TypeToVec2<nv_bfloat16> { struct TypeToVec2<nv_bfloat16> {
using Type = nv_bfloat162; using Type = nv_bfloat162;
}; };
#endif
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -222,6 +252,20 @@ struct Converter<float2, half2> { ...@@ -222,6 +252,20 @@ struct Converter<float2, half2> {
static inline __device__ half2 convert(const float2 &x) { return __float22half2_rn(x); } static inline __device__ half2 convert(const float2 &x) { return __float22half2_rn(x); }
}; };
#ifdef __HIP_PLATFORM_AMD__
template <>
struct Converter<float2, hip_bfloat16x2> {
static inline __device__ hip_bfloat16x2 convert(const float2 &x) {
union {
hip_bfloat16x2 raw;
hip_bfloat16 elt[2];
} tmp;
tmp.elt[0] = __hip_bfloat16(x.x);
tmp.elt[1] = __hip_bfloat16(x.y);
return tmp.raw;
}
};
#else
template <> template <>
struct Converter<float2, nv_bfloat162> { struct Converter<float2, nv_bfloat162> {
static inline __device__ nv_bfloat162 convert(const float2 &x) { static inline __device__ nv_bfloat162 convert(const float2 &x) {
...@@ -238,6 +282,7 @@ struct Converter<float2, nv_bfloat162> { ...@@ -238,6 +282,7 @@ struct Converter<float2, nv_bfloat162> {
#endif #endif
} }
}; };
#endif
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -266,6 +311,12 @@ struct Vec { ...@@ -266,6 +311,12 @@ struct Vec {
}; };
Alias_type data; Alias_type data;
#ifdef __HIP_PLATFORM_AMD__
__HOST_DEVICE__ Vec& operator=(const Vec& rhs) {
data.vec = rhs.data.vec;
return *this;
}
#endif
template <typename S> template <typename S>
inline __device__ void to(Vec<S, NUM_ELT> &other) { // NOLINT(*) inline __device__ void to(Vec<S, NUM_ELT> &other) { // NOLINT(*)
...@@ -346,12 +397,21 @@ struct InterCTASync { ...@@ -346,12 +397,21 @@ struct InterCTASync {
// BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0! // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!
} }
#ifdef __HIP_PLATFORM_AMD__
inline __device__ void spin_wait_(int *barrier, int step, int expected) {
__hip_atomic_fetch_add(barrier, step, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT);
for (int found = -1; found != expected; ) {
found = __hip_atomic_load(barrier, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_AGENT);
}
}
#else
inline __device__ void spin_wait_(int *barrier, int step, int expected) { inline __device__ void spin_wait_(int *barrier, int step, int expected) {
asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step)); asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step));
for (int found = -1; found != expected;) { for (int found = -1; found != expected;) {
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier)); asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier));
} }
} }
#endif
inline __device__ void sync() { inline __device__ void sync() {
// ALL THREADS MUST ENTER! // ALL THREADS MUST ENTER!
...@@ -634,8 +694,13 @@ inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a, ...@@ -634,8 +694,13 @@ inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a,
m2_a = m2_ab; m2_a = m2_ab;
} }
// Intra-warp broadcast (only lane 0 has valid stats). // Intra-warp broadcast (only lane 0 has valid stats).
#ifdef __HIP_PLATFORM_AMD__
m_a = __shfl(m_a, 0, THREADS_PER_WARP);
m2_a = __shfl(m2_a, 0, THREADS_PER_WARP);
#else
m_a = __shfl_sync(static_cast<uint32_t>(-1), m_a, 0); m_a = __shfl_sync(static_cast<uint32_t>(-1), m_a, 0);
m2_a = __shfl_sync(static_cast<uint32_t>(-1), m2_a, 0); m2_a = __shfl_sync(static_cast<uint32_t>(-1), m2_a, 0);
#endif
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -811,7 +876,11 @@ __device__ __forceinline__ float warp_reduce_max(const float m) { ...@@ -811,7 +876,11 @@ __device__ __forceinline__ float warp_reduce_max(const float m) {
float tmp = m; float tmp = m;
#pragma unroll #pragma unroll
for (int delta = num_elems / 2; delta > 0; delta /= 2) { for (int delta = num_elems / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__
const float other_m = __shfl_down(tmp, delta, THREADS_PER_WARP);
#else
const float other_m = __shfl_down_sync(0xFFFFFFFF, tmp, delta); const float other_m = __shfl_down_sync(0xFFFFFFFF, tmp, delta);
#endif
__builtin_assume(tmp >= 0); __builtin_assume(tmp >= 0);
__builtin_assume(other_m >= 0); __builtin_assume(other_m >= 0);
tmp = fmaxf(tmp, other_m); tmp = fmaxf(tmp, other_m);
...@@ -823,14 +892,22 @@ __forceinline__ __device__ float warp_reduce_max_broadcast(const float val) { ...@@ -823,14 +892,22 @@ __forceinline__ __device__ float warp_reduce_max_broadcast(const float val) {
float val_tmp = val; float val_tmp = val;
#pragma unroll #pragma unroll
for (int offset = THREADS_PER_WARP / 2; offset > 0; offset /= 2) { for (int offset = THREADS_PER_WARP / 2; offset > 0; offset /= 2) {
#ifdef __HIP_PLATFORM_AMD__
const float val_other = __shfl_down(val_tmp, offset, THREADS_PER_WARP);
#else
const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset); const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset);
#endif
__builtin_assume(val_tmp >= 0); __builtin_assume(val_tmp >= 0);
__builtin_assume(val_other >= 0); __builtin_assume(val_other >= 0);
val_tmp = fmaxf(val_tmp, val_other); val_tmp = fmaxf(val_tmp, val_other);
} }
// Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id // Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id
constexpr int subwarp_lane_zero = 0; constexpr int subwarp_lane_zero = 0;
#ifdef __HIP_PLATFORM_AMD__
val_tmp = __shfl(val_tmp, subwarp_lane_zero, THREADS_PER_WARP);
#else
val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero); val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero);
#endif
return val_tmp; return val_tmp;
} }
...@@ -864,14 +941,22 @@ __forceinline__ __device__ float subwarp_reduce_max_broadcast(const float val) { ...@@ -864,14 +941,22 @@ __forceinline__ __device__ float subwarp_reduce_max_broadcast(const float val) {
float val_tmp = val; float val_tmp = val;
#pragma unroll #pragma unroll
for (int offset = subwarp_width / 2; offset > 0; offset /= 2) { for (int offset = subwarp_width / 2; offset > 0; offset /= 2) {
#ifdef __HIP_PLATFORM_AMD__
const float val_other = __shfl_down(val_tmp, offset, subwarp_width);
#else
const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset, subwarp_width); const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset, subwarp_width);
#endif
__builtin_assume(val_tmp >= 0); __builtin_assume(val_tmp >= 0);
__builtin_assume(val_other >= 0); __builtin_assume(val_other >= 0);
val_tmp = fmaxf(val_tmp, val_other); val_tmp = fmaxf(val_tmp, val_other);
} }
// Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id // Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id
constexpr int subwarp_lane_zero = 0; constexpr int subwarp_lane_zero = 0;
#ifdef __HIP_PLATFORM_AMD__
val_tmp = __shfl(val_tmp, subwarp_lane_zero, subwarp_width);
#else
val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero, subwarp_width); val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero, subwarp_width);
#endif
return val_tmp; return val_tmp;
} }
...@@ -897,8 +982,13 @@ __device__ __forceinline__ void reciprocal<float>(float *value_inv, const float ...@@ -897,8 +982,13 @@ __device__ __forceinline__ void reciprocal<float>(float *value_inv, const float
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
#ifndef __HIP_PLATFORM_AMD__
using fp8e4m3 = __nv_fp8_e4m3; using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2; using fp8e5m2 = __nv_fp8_e5m2;
#else
using fp8e4m3 = hip_f8<hip_f8_type::fp8>;
using fp8e5m2 = hip_f8<hip_f8_type::bf8>;
#endif
using e8m0_t = uint8_t; using e8m0_t = uint8_t;
constexpr uint32_t FP32_MANTISSA_BITS = 23; constexpr uint32_t FP32_MANTISSA_BITS = 23;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment