Commit c520cba3 authored by yuguo's avatar yuguo
Browse files

[DCU] Preliminary adaptation

parent 5b6ef054
......@@ -7,9 +7,11 @@
#ifndef TRANSFORMER_ENGINE_COMMON_NORM_COMMON_H_
#define TRANSFORMER_ENGINE_COMMON_NORM_COMMON_H_
#ifndef __HIP_PLATFORM_AMD__
#include <cudnn.h>
#include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h>
#endif
#include <transformer_engine/transformer_engine.h>
#include <functional>
......@@ -282,6 +284,7 @@ class CudnnNormalizationPlan : public NormalizationPlanBase {
const NVTE_Norm_Type _norm_type;
std::unique_ptr<char[]> _scalar_dptr;
std::unique_ptr<float> _one_dptr = std::make_unique<float>(1.0f);
#ifndef __HIP_PLATFORM_AMD__
// FWD
std::shared_ptr<fe::graph::Tensor_attributes> _x, _gamma_zero, _scalar_offset, _gamma, _beta,
_eps, _mean, _rsigma, _z, _z_scale, _one_for_div, _z_scale_inv, _amax, _z_fp8;
......@@ -294,6 +297,7 @@ class CudnnNormalizationPlan : public NormalizationPlanBase {
fe::graph::Graph _graph;
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> _variant_pack;
cudnnHandle_t _handle;
#endif
};
class NormalizationPlanRegistry {
......@@ -322,9 +326,15 @@ using byte = uint8_t;
using int32 = int32_t;
using fp32 = float;
using fp16 = half;
#ifndef __HIP_PLATFORM_AMD__
using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
#else
using bf16 = __hip_bfloat16;
using fp8e4m3 = hip_f8<hip_f8_type::fp8>;
using fp8e5m2 = hip_f8<hip_f8_type::bf8>;
#endif
template <typename T>
struct TypeToDType;
......
......@@ -57,7 +57,14 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
NVTE_Norm_Backend norm_backend;
bool is_aligned = true;
#ifdef USE_ROCM
NVTE_CHECK(
!is_block_scaling(z->scaling_mode),
"Cudnn backend is need by block scaling mode for normalization! Not surpported in rocm yet.");
bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode);
#else
bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode);
#endif
if (cudnn_backend) {
// TODO: add check for GPU ARCH
......
......@@ -38,10 +38,13 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
return;
}
#ifndef __HIP_PLATFORM_AMD__
if (Kernel_traits::SMEM_BYTES >= 48 * 1024) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES));
}
#endif
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
auto ctas_per_row = launch_params.params.ctas_per_row;
......
......@@ -34,10 +34,13 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
return;
}
#ifndef __HIP_PLATFORM_AMD__
if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES_FWD));
}
#endif
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
auto ctas_per_row = launch_params.params.ctas_per_row;
......
......@@ -47,7 +47,14 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
NVTE_Norm_Backend norm_backend;
bool is_aligned = true;
#ifdef USE_ROCM
NVTE_CHECK(
!is_block_scaling(z->scaling_mode),
"Cudnn backend is need by block scaling mode for normalization! Not surpported in rocm yet.");
bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode);
#else
bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode);
#endif
bool training =
is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr;
......
......@@ -37,10 +37,13 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
return;
}
#ifndef __HIP_PLATFORM_AMD__
if (Kernel_traits::SMEM_BYTES >= 48 * 1024) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES));
}
#endif
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
auto ctas_per_row = launch_params.params.ctas_per_row;
......
......@@ -35,10 +35,13 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
return;
}
#ifndef __HIP_PLATFORM_AMD__
if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES_FWD));
}
#endif
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
auto ctas_per_row = launch_params.params.ctas_per_row;
......
......@@ -8,6 +8,11 @@
#include "../common.h"
#ifdef __HIP_PLATFORM_AMD__
using __nv_fp8_e4m3 = hip_f8<hip_f8_type::fp8>;
using __nv_fp8_e5m2 = hip_f8<hip_f8_type::bf8>;
#endif
static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id_map,
const int num_rows, const int topK,
const int num_out_tokens) {
......
......@@ -5,7 +5,9 @@
************************************************************************/
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_runtime.h>
#include <transformer_engine/cast.h>
......
......@@ -12,7 +12,9 @@
#define TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_runtime.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
......@@ -723,6 +725,10 @@ template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP
float (*DActOP)(float, const ParamOP &)>
void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output,
cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
static_assert(false,
"Cast_fp8_gated is not surpported in rocm yet.");
#else
if (output->has_data()) {
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated.");
}
......@@ -796,12 +802,17 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows,
cols);); // NOLINT(*)
); // NOLINT(*)
#endif
}
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output,
cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
static_assert(false,
"Cast_mxfp8_gated is not surpported in rocm yet.");
#else
const bool USE_ROWWISE_SCALING = output->has_data();
const bool USE_COLWISE_SCALING = output->has_columnwise_data();
......@@ -919,6 +930,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
#endif
}
template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
......
......@@ -12,7 +12,9 @@
#define TRANSFORMER_ENGINE_CAST_KERNELS_CUH_
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_runtime.h>
#include <transformer_engine/cast.h>
......@@ -853,6 +855,10 @@ static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream
template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, Tensor *dbias,
Tensor *workspace, cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
static_assert(false,
"Cast_fp8_2D is not surpported in rocm yet.");
#else
checkCuDriverContext(stream);
const size_t rows = input.flat_first_dim();
......@@ -916,6 +922,7 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T
reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream);
}); // NOLINT(*)
); // NOLINT(*)
#endif
}
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
......@@ -923,6 +930,10 @@ template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
const Tensor *noop, // TODO (ksivamani)
Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
static_assert(false,
"Mxfp8_quantize is not surpported in rocm yet.");
#else
bool use_rowwise_scaling = output->has_data();
bool use_colwise_scaling = output->has_columnwise_data();
checkCuDriverContext(stream);
......@@ -1027,6 +1038,7 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
#endif
}
namespace detail {
......
......@@ -15,10 +15,17 @@ namespace cuda_driver {
void *get_symbol(const char *symbol) {
void *entry_point;
#ifdef USE_ROCM
hipDriverProcAddressQueryResult driver_result;
NVTE_CHECK_CUDA(hipGetProcAddress(symbol, &entry_point, HIP_VERSION_MAJOR*100+HIP_VERSION_MINOR, 0, &driver_result));
NVTE_CHECK(driver_result == HIP_GET_PROC_ADDRESS_SUCCESS,
"Could not find CUDA driver entry point for ", symbol);
#else
cudaDriverEntryPointQueryResult driver_result;
NVTE_CHECK_CUDA(cudaGetDriverEntryPoint(symbol, &entry_point, cudaEnableDefault, &driver_result));
NVTE_CHECK(driver_result == cudaDriverEntryPointSuccess,
"Could not find CUDA driver entry point for ", symbol);
#endif
return entry_point;
}
......
......@@ -11,6 +11,7 @@
namespace transformer_engine {
namespace cuda_nvml {
#ifndef __HIP_PLATFORM_AMD__
/*! \brief Lazily-initialized shared library for CUDA NVML */
Library &cuda_nvml_lib() {
......@@ -20,7 +21,7 @@ Library &cuda_nvml_lib() {
}
void *get_symbol(const char *symbol) { return cuda_nvml_lib().get_symbol(symbol); }
#endif
} // namespace cuda_nvml
} // namespace transformer_engine
......@@ -7,7 +7,9 @@
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_
#ifndef __HIP_PLATFORM_AMD__
#include <nvml.h>
#endif
#include <string>
......@@ -17,6 +19,7 @@
namespace transformer_engine {
namespace cuda_nvml {
#ifndef __HIP_PLATFORM_AMD__
/*! \brief Get pointer corresponding to symbol in CUDA NVML library */
void *get_symbol(const char *symbol);
......@@ -45,11 +48,13 @@ inline const char *get_nvml_error_string(nvmlReturn_t rc) {
FuncT *func = reinterpret_cast<FuncT *>(get_symbol("nvmlErrorString"));
return (*func)(rc);
}
#endif
} // namespace cuda_nvml
} // namespace transformer_engine
#ifndef __HIP_PLATFORM_AMD__
#define NVTE_CHECK_CUDA_NVML(expr) \
do { \
const nvmlReturn_t status_NVTE_CHECK_CUDA_NVML = (expr); \
......@@ -65,5 +70,6 @@ inline const char *get_nvml_error_string(nvmlReturn_t rc) {
do { \
NVTE_CHECK_CUDA_NVML(::transformer_engine::cuda_nvml::call(#symbol VA_ARGS(__VA_ARGS__))); \
} while (false)
#endif
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_
......@@ -18,12 +18,14 @@ namespace transformer_engine {
namespace cuda {
#ifndef __HIP_PLATFORM_AMD__
namespace {
// String with build-time CUDA include path
#include "string_path_cuda_include.h"
} // namespace
#endif // __HIP_PLATFORM_AMD__
int num_devices() {
auto query_num_devices = []() -> int {
......@@ -81,6 +83,24 @@ int sm_count(int device_id) {
return cache[device_id];
}
#ifdef __HIP_PLATFORM_AMD__
const std::string &sm_arch_name(int device_id) {
static std::vector<std::string> cache(num_devices(), "");
static std::vector<std::once_flag> flags(num_devices());
if (device_id < 0) {
device_id = current_device();
}
NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid HIP device ID");
auto init = [&] () {
cudaDeviceProp prop;
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, device_id));
cache[device_id] = prop.gcnArchName;
};
std::call_once(flags[device_id], init);
return cache[device_id];
}
#endif // __HIP_PLATFORM_AMD__
void stream_priority_range(int *low_priority, int *high_priority, int device_id) {
static std::vector<std::pair<int, int>> cache(num_devices());
static std::vector<std::once_flag> flags(num_devices());
......@@ -126,6 +146,7 @@ bool supports_multicast(int device_id) {
#endif
}
#ifndef __HIP_PLATFORM_AMD__
const std::string &include_directory(bool required) {
static std::string path;
......@@ -190,6 +211,7 @@ const std::string &include_directory(bool required) {
// Return cached path
return path;
}
#endif // __HIP_PLATFORM_AMD__
} // namespace cuda
......
......@@ -30,6 +30,16 @@ int current_device();
*/
int sm_arch(int device_id = -1);
#ifdef __HIP_PLATFORM_AMD__
/* \brief Compute capability of device
*
* \param[in] device_id HIP device (default is current device)
*
* \return GPU arch name and compute capabilities string.
*/
const std::string &sm_arch_name(int device_id = -1);
#endif
/* \brief Number of multiprocessors on a device
*
* \param[in] device_id CUDA device (default is current device)
......@@ -56,6 +66,7 @@ void stream_priority_range(int *low_priority, int *high_priority, int device_id
*/
bool supports_multicast(int device_id = -1);
#ifndef __HIP_PLATFORM_AMD__
/* \brief Path to CUDA Toolkit headers
*
* The path can be configured by setting NVTE_CUDA_INCLUDE_DIR in the
......@@ -66,6 +77,7 @@ bool supports_multicast(int device_id = -1);
* \return Path to include directory, or an empty string if not found
*/
const std::string &include_directory(bool required = false);
#endif
} // namespace cuda
......
......@@ -12,7 +12,9 @@
#define TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_runtime.h>
#include <transformer_engine/cast.h>
......@@ -250,6 +252,10 @@ static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t str
}
static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
static_assert(false,
"Mxfp8_dequantize is not surpported in rocm yet.");
#else
bool use_rowwise_scaling = input.has_data();
bool use_colwise_scaling = input.has_columnwise_data();
checkCuDriverContext(stream);
......@@ -332,6 +338,7 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
); // NOLINT(*)
); // NOLINT(*)
}
#endif
} // namespace dequantization
namespace detail {
......
......@@ -7,9 +7,19 @@
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT
#include <hipblaslt/hipblaslt.h>
#endif
#ifdef USE_ROCBLAS
#define ROCBLAS_BETA_FEATURES_API
#include <rocblas/rocblas.h>
#endif
#else
#include <cublas_v2.h>
#include <cudnn.h>
#endif // __HIP_PLATFORM_AMD__
#include <nvrtc.h>
#include <stdexcept>
......@@ -39,6 +49,28 @@
} \
} while (false)
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT //hipblaslt
#define NVTE_CHECK_HIPBLASLT(expr) \
do { \
const hipblasStatus_t status_NVTE_CHECK_CUBLAS = (expr); \
if (status_NVTE_CHECK_CUBLAS != CUBLAS_STATUS_SUCCESS) { \
NVTE_ERROR("HIPBLASLT Error: ", \
std::to_string((int)status_NVTE_CHECK_CUBLAS)); \
} \
} while (false)
#endif
#ifdef USE_ROCBLAS //rocblas
#define NVTE_CHECK_ROCBLAS(expr) \
do { \
const rocblas_status status_NVTE_CHECK_CUBLAS = (expr); \
if (status_NVTE_CHECK_CUBLAS != rocblas_status_success) { \
NVTE_ERROR("ROCBLAS Error: " + \
std::string(rocblas_status_to_string(status_NVTE_CHECK_CUBLAS))); \
} \
} while (false)
#endif
#else //cublas
#define NVTE_CHECK_CUBLAS(expr) \
do { \
const cublasStatus_t status_NVTE_CHECK_CUBLAS = (expr); \
......@@ -46,6 +78,7 @@
NVTE_ERROR("cuBLAS Error: ", cublasGetStatusString(status_NVTE_CHECK_CUBLAS)); \
} \
} while (false)
#endif
#define NVTE_CHECK_CUDNN(expr) \
do { \
......
......@@ -25,6 +25,12 @@ namespace {
#include "string_code_util_math_h.h"
#include "string_code_utils_cuh.h"
#ifdef USE_ROCM
#include "string_code_amd_detail_hip_float8_h.h"
#include "string_code_amd_detail_hip_f8_impl_h.h"
#endif // USE_ROCM
#ifndef USE_ROCM
/*! \brief Latest compute capability that NVRTC supports
*
* \return Compute capability as int. Last digit is minor revision,
......@@ -42,6 +48,7 @@ inline int max_supported_sm_arch() {
}
return arch_;
}
#endif // USE_ROCM
} // namespace
......@@ -66,6 +73,9 @@ Kernel::~Kernel() {
for (int device_id = 0; device_id < static_cast<int>(modules_.size()); ++device_id) {
// Unload CUDA modules if needed
if (modules_[device_id] != null_module) {
#ifdef USE_ROCM
(void)cuda_driver::call("hipModuleUnload", modules_[device_id]);
#else
CUdevice device;
CUcontext context;
if (cuda_driver::call("cuDeviceGet", &device, device_id) != CUDA_SUCCESS) {
......@@ -79,6 +89,7 @@ Kernel::~Kernel() {
}
cuda_driver::call("cuModuleUnload", modules_[device_id]);
cuda_driver::call("cuDevicePrimaryCtxRelease", device);
#endif // USE_ROCM
}
}
}
......@@ -143,9 +154,11 @@ void KernelManager::compile(const std::string& kernel_label, const std::string&
// Choose whether to compile to PTX or cubin
const int device_id = cuda::current_device();
#ifndef USE_ROCM
const int sm_arch_ = cuda::sm_arch(device_id);
const int compile_sm_arch = std::min(sm_arch_, max_supported_sm_arch());
const bool compile_ptx = (CUDA_VERSION <= 11000) || (sm_arch_ != compile_sm_arch);
#endif // USE_ROCM
// Compilation flags
std::vector<std::string> opts = {
......@@ -153,12 +166,15 @@ void KernelManager::compile(const std::string& kernel_label, const std::string&
"-G",
#endif
"--std=c++17"};
#ifndef USE_ROCM
if (compile_ptx) {
opts.push_back(concat_strings("--gpu-architecture=compute_", compile_sm_arch));
} else {
opts.push_back(concat_strings("--gpu-architecture=sm_", compile_sm_arch));
}
opts.push_back(concat_strings("-I", cuda::include_directory(true)));
#endif //USE_ROCM
std::vector<const char*> opts_ptrs;
for (const auto& opt : opts) {
opts_ptrs.push_back(opt.c_str());
......@@ -166,9 +182,15 @@ void KernelManager::compile(const std::string& kernel_label, const std::string&
// Compile source
nvrtcProgram program;
#ifdef USE_ROCM
constexpr int num_headers = 4;
const char* headers[num_headers] = {string_code_utils_cuh, string_code_util_math_h, string_code_amd_detail_hip_float8_h, string_code_amd_detail_hip_f8_impl_h};
const char* include_names[num_headers] = {"utils_hip.cuh", "util/math.h", "amd_detail/hip_float8.h", "amd_detail/hip_f8_impl.h"};
#else
constexpr int num_headers = 2;
constexpr const char* headers[num_headers] = {string_code_utils_cuh, string_code_util_math_h};
constexpr const char* include_names[num_headers] = {"utils.cuh", "util/math.h"};
#endif // USE_ROCM
NVTE_CHECK_NVRTC(nvrtcCreateProgram(&program, code.c_str(), filename.c_str(), num_headers,
headers, include_names));
NVTE_CHECK_NVRTC(nvrtcAddNameExpression(program, kernel_name.c_str()));
......@@ -193,6 +215,14 @@ void KernelManager::compile(const std::string& kernel_label, const std::string&
// Get compiled code
std::string compiled_code;
#ifdef USE_ROCM
{
size_t compiled_size;
NVTE_CHECK_NVRTC(hiprtcGetCodeSize(program, &compiled_size));
compiled_code.resize(compiled_size);
NVTE_CHECK_NVRTC(hiprtcGetCode(program, compiled_code.data()));
}
#else
if (compile_ptx) {
size_t compiled_size;
NVTE_CHECK_NVRTC(nvrtcGetPTXSize(program, &compiled_size));
......@@ -204,6 +234,7 @@ void KernelManager::compile(const std::string& kernel_label, const std::string&
compiled_code.resize(compiled_size);
NVTE_CHECK_NVRTC(nvrtcGetCUBIN(program, compiled_code.data()));
}
#endif //USE_ROCM
// Cache compiled code
const auto key = get_kernel_cache_key(kernel_label, device_id);
......@@ -228,7 +259,11 @@ bool KernelManager::is_compiled(const std::string& kernel_label, int device_id)
std::string KernelManager::get_kernel_cache_key(const std::string& kernel_label,
int device_id) const {
#ifdef USE_ROCM
return concat_strings(cuda::sm_arch_name(device_id), ",", kernel_label);
#else
return concat_strings("sm=", cuda::sm_arch(device_id), ",", kernel_label);
#endif
}
} // namespace rtc
......
......@@ -7,10 +7,22 @@
#ifndef TRANSFORMER_ENGINE_COMMON_UTILS_CUH_
#define TRANSFORMER_ENGINE_COMMON_UTILS_CUH_
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#ifdef __HIP_PLATFORM_AMD__
#ifndef __HIPCC_RTC__
#include <cstdint>
#else
using namespace __hip_internal;
#endif
#endif
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#ifdef __HIP_PLATFORM_AMD__
typedef uint16_t hip_bfloat16x2 __attribute__((ext_vector_type(2)));
#else
#if !defined(__CUDACC_RTC__)
#include <cstdint>
#else
......@@ -25,12 +37,14 @@ static_assert(sizeof(uint32_t) == 4);
static_assert(sizeof(uint64_t) == 8);
#endif
#endif // __HIP_PLATFORM_AMD__
////////////////////////////////////////////////////////////////////////////////////////////////////
constexpr uint32_t THREADS_PER_WARP = 32;
////////////////////////////////////////////////////////////////////////////////////////////////////
#if !defined(USE_HIPBLASLT) && !defined(__HIPCC_RTC__)
inline __device__ float2 operator+(const float2 &a, const float2 &b) { // NOLINT(*)
return {a.x + b.x, a.y + b.y};
}
......@@ -41,6 +55,7 @@ inline __device__ void operator+=(float2 &a, const float2 &b) { // NOLINT(*)
a.x += b.x;
a.y += b.y;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -54,7 +69,11 @@ struct Sum {
template <typename T>
inline __device__ T warp_shuffle_xor(const T &x, uint32_t idx) {
#ifdef __HIP_PLATFORM_AMD__
return __shfl_xor(x, idx, THREADS_PER_WARP);
#else
return __shfl_xor_sync(static_cast<uint32_t>(-1), x, idx);
#endif
}
template <>
......@@ -64,7 +83,11 @@ inline __device__ float2 warp_shuffle_xor<float2>(const float2 &x, uint32_t idx)
template <typename T>
inline __device__ T warp_shuffle_down(const T &x, uint32_t idx) {
#ifdef __HIP_PLATFORM_AMD__
return __shfl_down(x, idx, THREADS_PER_WARP);
#else
return __shfl_down_sync(static_cast<uint32_t>(-1), x, idx);
#endif
}
template <>
......@@ -154,10 +177,17 @@ struct TypeToVec2<half> {
using Type = half2;
};
#ifdef __HIP_PLATFORM_AMD__
template <>
struct TypeToVec2<__hip_bfloat16> {
using Type = hip_bfloat16x2;
};
#else
template <>
struct TypeToVec2<nv_bfloat16> {
using Type = nv_bfloat162;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -222,6 +252,20 @@ struct Converter<float2, half2> {
static inline __device__ half2 convert(const float2 &x) { return __float22half2_rn(x); }
};
#ifdef __HIP_PLATFORM_AMD__
template <>
struct Converter<float2, hip_bfloat16x2> {
static inline __device__ hip_bfloat16x2 convert(const float2 &x) {
union {
hip_bfloat16x2 raw;
hip_bfloat16 elt[2];
} tmp;
tmp.elt[0] = __hip_bfloat16(x.x);
tmp.elt[1] = __hip_bfloat16(x.y);
return tmp.raw;
}
};
#else
template <>
struct Converter<float2, nv_bfloat162> {
static inline __device__ nv_bfloat162 convert(const float2 &x) {
......@@ -238,6 +282,7 @@ struct Converter<float2, nv_bfloat162> {
#endif
}
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -266,6 +311,12 @@ struct Vec {
};
Alias_type data;
#ifdef __HIP_PLATFORM_AMD__
__HOST_DEVICE__ Vec& operator=(const Vec& rhs) {
data.vec = rhs.data.vec;
return *this;
}
#endif
template <typename S>
inline __device__ void to(Vec<S, NUM_ELT> &other) { // NOLINT(*)
......@@ -346,12 +397,21 @@ struct InterCTASync {
// BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!
}
#ifdef __HIP_PLATFORM_AMD__
inline __device__ void spin_wait_(int *barrier, int step, int expected) {
__hip_atomic_fetch_add(barrier, step, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT);
for (int found = -1; found != expected; ) {
found = __hip_atomic_load(barrier, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_AGENT);
}
}
#else
inline __device__ void spin_wait_(int *barrier, int step, int expected) {
asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step));
for (int found = -1; found != expected;) {
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier));
}
}
#endif
inline __device__ void sync() {
// ALL THREADS MUST ENTER!
......@@ -634,8 +694,13 @@ inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a,
m2_a = m2_ab;
}
// Intra-warp broadcast (only lane 0 has valid stats).
#ifdef __HIP_PLATFORM_AMD__
m_a = __shfl(m_a, 0, THREADS_PER_WARP);
m2_a = __shfl(m2_a, 0, THREADS_PER_WARP);
#else
m_a = __shfl_sync(static_cast<uint32_t>(-1), m_a, 0);
m2_a = __shfl_sync(static_cast<uint32_t>(-1), m2_a, 0);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -811,7 +876,11 @@ __device__ __forceinline__ float warp_reduce_max(const float m) {
float tmp = m;
#pragma unroll
for (int delta = num_elems / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__
const float other_m = __shfl_down(tmp, delta, THREADS_PER_WARP);
#else
const float other_m = __shfl_down_sync(0xFFFFFFFF, tmp, delta);
#endif
__builtin_assume(tmp >= 0);
__builtin_assume(other_m >= 0);
tmp = fmaxf(tmp, other_m);
......@@ -823,14 +892,22 @@ __forceinline__ __device__ float warp_reduce_max_broadcast(const float val) {
float val_tmp = val;
#pragma unroll
for (int offset = THREADS_PER_WARP / 2; offset > 0; offset /= 2) {
#ifdef __HIP_PLATFORM_AMD__
const float val_other = __shfl_down(val_tmp, offset, THREADS_PER_WARP);
#else
const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset);
#endif
__builtin_assume(val_tmp >= 0);
__builtin_assume(val_other >= 0);
val_tmp = fmaxf(val_tmp, val_other);
}
// Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id
constexpr int subwarp_lane_zero = 0;
#ifdef __HIP_PLATFORM_AMD__
val_tmp = __shfl(val_tmp, subwarp_lane_zero, THREADS_PER_WARP);
#else
val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero);
#endif
return val_tmp;
}
......@@ -864,14 +941,22 @@ __forceinline__ __device__ float subwarp_reduce_max_broadcast(const float val) {
float val_tmp = val;
#pragma unroll
for (int offset = subwarp_width / 2; offset > 0; offset /= 2) {
#ifdef __HIP_PLATFORM_AMD__
const float val_other = __shfl_down(val_tmp, offset, subwarp_width);
#else
const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset, subwarp_width);
#endif
__builtin_assume(val_tmp >= 0);
__builtin_assume(val_other >= 0);
val_tmp = fmaxf(val_tmp, val_other);
}
// Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id
constexpr int subwarp_lane_zero = 0;
#ifdef __HIP_PLATFORM_AMD__
val_tmp = __shfl(val_tmp, subwarp_lane_zero, subwarp_width);
#else
val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero, subwarp_width);
#endif
return val_tmp;
}
......@@ -897,8 +982,13 @@ __device__ __forceinline__ void reciprocal<float>(float *value_inv, const float
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifndef __HIP_PLATFORM_AMD__
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
#else
using fp8e4m3 = hip_f8<hip_f8_type::fp8>;
using fp8e5m2 = hip_f8<hip_f8_type::bf8>;
#endif
using e8m0_t = uint8_t;
constexpr uint32_t FP32_MANTISSA_BITS = 23;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment