"vscode:/vscode.git/clone" did not exist on "ea8489fce266d69f2fbe314c1385956b1a342e12"
Commit ab122dac authored by yuguo's avatar yuguo
Browse files

[DCU] compile pass

parent 4c6a5a27
...@@ -58,6 +58,7 @@ def setup_pytorch_extension( ...@@ -58,6 +58,7 @@ def setup_pytorch_extension(
"-U__HIP_NO_BFLOAT16_CONVERSIONS__", "-U__HIP_NO_BFLOAT16_CONVERSIONS__",
"-U__HIP_NO_BFLOAT162_OPERATORS__", "-U__HIP_NO_BFLOAT162_OPERATORS__",
"-U__HIP_NO_BFLOAT162_CONVERSIONS__", "-U__HIP_NO_BFLOAT162_CONVERSIONS__",
"-w",
] ]
else: else:
nvcc_flags = [ nvcc_flags = [
......
{ {
"custom_map" : { "custom_map" : {
"common/util/vectorized_pointwise.h" : "common/util/vectorized_pointwise_hip.h",
"common/common.h" : "common/common_hip.h",
"/userbuffers.h" : "/userbuffers_hip.h",
"/logging.h" : "/logging_hip.h",
"/system.h" : "/system_hip.h",
"<cuda_bf16.h>" : "<hip/hip_bf16.h>", "<cuda_bf16.h>" : "<hip/hip_bf16.h>",
"<cuda_fp8.h>" : "\"amd_detail/hip_float8.h\"", "<cuda_fp8.h>" : "\"amd_detail/hip_float8.h\"",
"CUfunc_cache" : "hipFuncCache_t", "CUfunc_cache" : "hipFuncCache_t",
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Installation script.""" """Installation script."""
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=0 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc pip3 install . -v
import os import os
import sys import sys
...@@ -43,7 +44,10 @@ elif "jax" in frameworks: ...@@ -43,7 +44,10 @@ elif "jax" in frameworks:
CMakeBuildExtension = get_build_ext(BuildExtension) CMakeBuildExtension = get_build_ext(BuildExtension)
archs = cuda_archs() if rocm_build():
archs = None
else:
archs = cuda_archs()
class TimedBdist(bdist_wheel): class TimedBdist(bdist_wheel):
......
...@@ -226,11 +226,9 @@ else() ...@@ -226,11 +226,9 @@ else()
add_library(transformer_engine SHARED ${te_hip_sources}) add_library(transformer_engine SHARED ${te_hip_sources})
endif() endif()
target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include")
# Configure dependencies # Configure dependencies
if (USE_CUDA) if (USE_CUDA)
target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")
# Configure dependencies # Configure dependencies
target_link_libraries(transformer_engine PUBLIC target_link_libraries(transformer_engine PUBLIC
CUDA::cublas CUDA::cublas
...@@ -239,6 +237,7 @@ if (USE_CUDA) ...@@ -239,6 +237,7 @@ if (USE_CUDA)
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
else() else()
target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}")
# Aotriton is currently unsupported # Aotriton is currently unsupported
set(AotritonAndCk_fused_attn "unsupported") set(AotritonAndCk_fused_attn "unsupported")
...@@ -343,7 +342,7 @@ else() ...@@ -343,7 +342,7 @@ else()
set(HIP_HCC_FLAGS "${CMAKE_HIP_FLAGS} -mavx2 -mf16c -mfma -std=c++17") set(HIP_HCC_FLAGS "${CMAKE_HIP_FLAGS} -mavx2 -mf16c -mfma -std=c++17")
# Ask hcc to generate device code during compilation so we can use # Ask hcc to generate device code during compilation so we can use
# host linker to link. # host linker to link.
set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -fno-gpu-rdc -Wno-defaulted-function-deleted") set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -fno-gpu-rdc -w")
foreach(rocm_arch ${CMAKE_HIP_ARCHITECTURES}) foreach(rocm_arch ${CMAKE_HIP_ARCHITECTURES})
# if CMAKE_CXX_FLAGS has --offload-arch set already, better to rm first # if CMAKE_CXX_FLAGS has --offload-arch set already, better to rm first
set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} --offload-arch=${rocm_arch}") set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} --offload-arch=${rocm_arch}")
......
...@@ -4,6 +4,9 @@ ...@@ -4,6 +4,9 @@
* License for AMD contributions = MIT. See LICENSE for more information * License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/ ************************************************************************/
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
namespace hip_f8_impl { namespace hip_f8_impl {
HIP_HOST_DEVICE inline int clz(uint32_t x) { HIP_HOST_DEVICE inline int clz(uint32_t x) {
...@@ -190,7 +193,7 @@ HIP_HOST_DEVICE ...@@ -190,7 +193,7 @@ HIP_HOST_DEVICE
T cast_from_f8(uint8_t x) { T cast_from_f8(uint8_t x) {
constexpr bool is_half = std::is_same<T,__half>::value; constexpr bool is_half = std::is_same<T,__half>::value;
constexpr bool is_float = std::is_same<T,float>::value; constexpr bool is_float = std::is_same<T,float>::value;
constexpr bool is_bf16 = std::is_same<T,hip_bfloat16>::value; constexpr bool is_bf16 = std::is_same<T,__hip_bfloat16>::value;
static_assert(is_half || is_float, "only half and float are supported"); static_assert(is_half || is_float, "only half and float are supported");
constexpr int weo = is_half ? 5 : 8; constexpr int weo = is_half ? 5 : 8;
......
...@@ -326,7 +326,7 @@ struct hip_f8 { ...@@ -326,7 +326,7 @@ struct hip_f8 {
#endif // #ifdef __gfx942__ #endif // #ifdef __gfx942__
// constructor from hip_bfloat16 // constructor from hip_bfloat16
explicit HIP_HOST_DEVICE hip_f8(hip_bfloat16 v, hip_f8_rounding_mode r=hip_f8_rounding_mode::standard, uint32_t rng=0); explicit HIP_HOST_DEVICE hip_f8(__hip_bfloat16 v, hip_f8_rounding_mode r=hip_f8_rounding_mode::standard, uint32_t rng=0);
// convert to float // convert to float
#ifdef __gfx942__ #ifdef __gfx942__
...@@ -430,7 +430,7 @@ struct hip_f8 { ...@@ -430,7 +430,7 @@ struct hip_f8 {
#endif // #ifdef __gfx942__ #endif // #ifdef __gfx942__
// convert to hip_bfloat16 // convert to hip_bfloat16
explicit inline HIP_HOST_DEVICE operator hip_bfloat16() const; explicit inline HIP_HOST_DEVICE operator __hip_bfloat16() const;
// check for zero // check for zero
inline HIP_HOST_DEVICE bool is_zero() const { inline HIP_HOST_DEVICE bool is_zero() const {
...@@ -504,7 +504,7 @@ struct hip_f8x4 { ...@@ -504,7 +504,7 @@ struct hip_f8x4 {
HIP_HOST_DEVICE hip_f8x4(halfx4 v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0); HIP_HOST_DEVICE hip_f8x4(halfx4 v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0);
// constructor from hip_bfloat16 // constructor from hip_bfloat16
HIP_HOST_DEVICE hip_f8x4(hip_bfloat16 v0, hip_bfloat16 v1=hip_bfloat16(0.0f), hip_bfloat16 v2=hip_bfloat16(0.0f), hip_bfloat16 v3=hip_bfloat16(0.0f), hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0); HIP_HOST_DEVICE hip_f8x4(__hip_bfloat16 v0, __hip_bfloat16 v1=__hip_bfloat16(0.0f), __hip_bfloat16 v2=__hip_bfloat16(0.0f), __hip_bfloat16 v3=__hip_bfloat16(0.0f), hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0);
HIP_HOST_DEVICE hip_f8x4(hip_bfloat16x2 v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0); HIP_HOST_DEVICE hip_f8x4(hip_bfloat16x2 v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0);
HIP_HOST_DEVICE hip_f8x4(hip_bfloat16x4 v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0); HIP_HOST_DEVICE hip_f8x4(hip_bfloat16x4 v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0);
......
...@@ -12,8 +12,13 @@ ...@@ -12,8 +12,13 @@
#include <numeric> #include <numeric>
#include "common/common.h" #include "common/common.h"
#ifdef USE_ROCM
#include "common/util/hip_driver.h"
#include "common/util/hip_runtime.h"
#else
#include "common/util/cuda_driver.h" #include "common/util/cuda_driver.h"
#include "common/util/cuda_runtime.h" #include "common/util/cuda_runtime.h"
#endif
#include "common/util/logging.h" #include "common/util/logging.h"
#include "common/util/system.h" #include "common/util/system.h"
#include "userbuffers/userbuffers.h" #include "userbuffers/userbuffers.h"
......
...@@ -19,9 +19,15 @@ ...@@ -19,9 +19,15 @@
#include <map> #include <map>
#include <utility> #include <utility>
#ifdef USE_ROCM
#include "common/util/hip_driver.h"
#include "common/util/hip_nvml.h"
#include "common/util/hip_runtime.h"
#else
#include "common/util/cuda_driver.h" #include "common/util/cuda_driver.h"
#include "common/util/cuda_nvml.h" #include "common/util/cuda_nvml.h"
#include "common/util/cuda_runtime.h" #include "common/util/cuda_runtime.h"
#endif
#include "common/util/logging.h" #include "common/util/logging.h"
#include "common/util/system.h" #include "common/util/system.h"
#include "ipcsocket.h" #include "ipcsocket.h"
...@@ -362,7 +368,7 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, ...@@ -362,7 +368,7 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks,
NVTE_CHECK_CUDA(cudaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE)); NVTE_CHECK_CUDA(cudaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE));
(*comm)->flags = (*comm)->flags =
#ifdef USE_ROCM #ifdef USE_ROCM
reinterpret_cast<int *>((reinterpret_cast<uintptr_t>((*comm)->flags) + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK reinterpret_cast<int *>((reinterpret_cast<uintptr_t>((*comm)->flags) + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK);
#else #else
reinterpret_cast<int *>(((CUdeviceptr)(*comm)->flags + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK); reinterpret_cast<int *>(((CUdeviceptr)(*comm)->flags + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK);
#endif #endif
......
...@@ -10,7 +10,11 @@ ...@@ -10,7 +10,11 @@
#include "./common.h" #include "./common.h"
#include "./utils.cuh" #include "./utils.cuh"
#ifdef __HIP_PLATFORM_AMD__
#include "common/util/hip_runtime.h"
#else
#include "common/util/cuda_runtime.h" #include "common/util/cuda_runtime.h"
#endif
#include "common/util/logging.h" #include "common/util/logging.h"
namespace transformer_engine { namespace transformer_engine {
......
...@@ -25,7 +25,11 @@ ...@@ -25,7 +25,11 @@
#include <vector> #include <vector>
#include "./nvtx.h" #include "./nvtx.h"
#ifdef __HIP_PLATFORM_AMD__
#include "./util/hip_driver.h"
#else
#include "./util/cuda_driver.h" #include "./util/cuda_driver.h"
#endif
#include "./util/logging.h" #include "./util/logging.h"
namespace transformer_engine { namespace transformer_engine {
...@@ -223,7 +227,7 @@ using bf16 = nv_bfloat16; ...@@ -223,7 +227,7 @@ using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3; using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2; using fp8e5m2 = __nv_fp8_e5m2;
#else #else
using bf16 = hip_bfloat16; using bf16 = __hip_bfloat16;
using fp8e4m3 = te_hip_fp8_e4m3; using fp8e4m3 = te_hip_fp8_e4m3;
using fp8e5m2 = te_hip_fp8_e5m2; using fp8e5m2 = te_hip_fp8_e5m2;
#endif #endif
...@@ -247,7 +251,7 @@ TRANSFORMER_ENGINE_TYPE_NAME(int64_t) ...@@ -247,7 +251,7 @@ TRANSFORMER_ENGINE_TYPE_NAME(int64_t)
TRANSFORMER_ENGINE_TYPE_NAME(float) TRANSFORMER_ENGINE_TYPE_NAME(float)
TRANSFORMER_ENGINE_TYPE_NAME(half) TRANSFORMER_ENGINE_TYPE_NAME(half)
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
TRANSFORMER_ENGINE_TYPE_NAME(hip_bfloat16) TRANSFORMER_ENGINE_TYPE_NAME(__hip_bfloat16)
TRANSFORMER_ENGINE_TYPE_NAME(te_hip_fp8_e4m3) TRANSFORMER_ENGINE_TYPE_NAME(te_hip_fp8_e4m3)
TRANSFORMER_ENGINE_TYPE_NAME(te_hip_fp8_e5m2) TRANSFORMER_ENGINE_TYPE_NAME(te_hip_fp8_e5m2)
#else #else
......
...@@ -22,7 +22,11 @@ ...@@ -22,7 +22,11 @@
#include "../common.h" #include "../common.h"
#include "../util/handle_manager.h" #include "../util/handle_manager.h"
#include "../util/logging.h" #include "../util/logging.h"
#ifdef __HIP_PLATFORM_AMD__
#include "common/util/hip_runtime.h"
#else
#include "common/util/cuda_runtime.h" #include "common/util/cuda_runtime.h"
#endif
#ifndef __HIP_PLATFORM_AMD__ #ifndef __HIP_PLATFORM_AMD__
namespace { namespace {
...@@ -738,7 +742,7 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT ...@@ -738,7 +742,7 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
if(NVTE_BLAS_MULSTREAM==nullptr){ if(NVTE_BLAS_MULSTREAM==nullptr){
NVTE_FORCE_BLASLT_MULSTREAM = true; NVTE_FORCE_BLASLT_MULSTREAM = true;
} elif((NVTE_BLASLT_BLAS != nullptr && NVTE_BLASLT_BLAS[0] == '1') && (NVTE_BLAS_MULSTREAM != nullptr && NVTE_BLAS_MULSTREAM[0] == '1')){ } else if((NVTE_BLASLT_BLAS != nullptr && NVTE_BLASLT_BLAS[0] == '1') && (NVTE_BLAS_MULSTREAM != nullptr && NVTE_BLAS_MULSTREAM[0] == '1')){
NVTE_ERROR("NVTE_FORCE_BLAS_MULSTREAM and NVTE_FORCE_BLASLT can't be set at the same time."); NVTE_ERROR("NVTE_FORCE_BLAS_MULSTREAM and NVTE_FORCE_BLASLT can't be set at the same time.");
} else{ } else{
NVTE_FORCE_BLASLT_MULSTREAM = false; NVTE_FORCE_BLASLT_MULSTREAM = false;
...@@ -776,8 +780,7 @@ void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B ...@@ -776,8 +780,7 @@ void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_stream_cublas_batchgemm); NVTE_API_CALL(nvte_multi_stream_cublas_batchgemm);
using namespace transformer_engine; using namespace transformer_engine;
static_assert(num_gemms % num_batchgemm_streams == 0, assert(num_gemms % num_batchgemm_streams == 0);
"Need num_gemms mod num_batchgemm_streams == 0.");
static int batch_count = num_gemms / num_batchgemm_streams; static int batch_count = num_gemms / num_batchgemm_streams;
// Inits streams and events (once, globally) // Inits streams and events (once, globally)
std::call_once(init_flag_batchgemm, init_streams_and_events_batchgemm); std::call_once(init_flag_batchgemm, init_streams_and_events_batchgemm);
......
...@@ -192,15 +192,15 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor ...@@ -192,15 +192,15 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
const size_t sm_count, const size_t sm_count,
const bool zero_centered_gamma, const bool zero_centered_gamma,
const NVTEScalingMode mode, bool training) const NVTEScalingMode mode, bool training)
#ifdef USE_ROCM
{ assert(false);
#else
: _fp8_out(is_fp8_dtype(otype)), : _fp8_out(is_fp8_dtype(otype)),
_zero_centered(zero_centered_gamma), _zero_centered(zero_centered_gamma),
_training(training), _training(training),
_norm_stage(NormStage), _norm_stage(NormStage),
_norm_type(NormType) { _norm_type(NormType) {
#ifdef USE_ROCM
static_assert(false,
"Cudnn backend is not surpported in rocm for normalization yet.");
#else
static_assert(CUDNN_FRONTEND_VERSION >= 10601, static_assert(CUDNN_FRONTEND_VERSION >= 10601,
"CUDNN_FRONTEND_VERSION should be at least 1.6.1!"); "CUDNN_FRONTEND_VERSION should be at least 1.6.1!");
...@@ -389,8 +389,7 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor ...@@ -389,8 +389,7 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
void CudnnNormalizationPlan::_build() { void CudnnNormalizationPlan::_build() {
#ifdef USE_ROCM #ifdef USE_ROCM
static_assert(false, assert(false);
"Cudnn backend is not surpported in rocm for normalization yet.");
#else #else
NVTE_CHECK(_graph.validate().is_good()); NVTE_CHECK(_graph.validate().is_good());
NVTE_CHECK(_graph.build_operation_graph(_handle).is_good()); NVTE_CHECK(_graph.build_operation_graph(_handle).is_good());
...@@ -406,8 +405,8 @@ void CudnnNormalizationPlan::_build() { ...@@ -406,8 +405,8 @@ void CudnnNormalizationPlan::_build() {
std::vector<size_t> CudnnNormalizationPlan::getWorkspaceShape() const { std::vector<size_t> CudnnNormalizationPlan::getWorkspaceShape() const {
#ifdef USE_ROCM #ifdef USE_ROCM
static_assert(false, assert(false);
"Cudnn backend is not surpported in rocm for normalization yet."); return {0};
#else #else
return {static_cast<size_t>(_graph.get_workspace_size())}; return {static_cast<size_t>(_graph.get_workspace_size())};
#endif #endif
...@@ -417,8 +416,7 @@ void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr, ...@@ -417,8 +416,7 @@ void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr,
void* mean_dptr, void* eps_dptr, void* rsigma_dptr, void* mean_dptr, void* eps_dptr, void* rsigma_dptr,
void* workspace_dptr, cudaStream_t stream) { void* workspace_dptr, cudaStream_t stream) {
#ifdef USE_ROCM #ifdef USE_ROCM
static_assert(false, assert(false);
"Cudnn backend is not surpported in rocm for normalization yet.");
#else #else
// Binding data pointers to graph tensors // Binding data pointers to graph tensors
_variant_pack = {{_x, x_dptr}, {_eps, eps_dptr}}; _variant_pack = {{_x, x_dptr}, {_eps, eps_dptr}};
...@@ -462,8 +460,7 @@ void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_ ...@@ -462,8 +460,7 @@ void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_
void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr,
cudaStream_t stream) { cudaStream_t stream) {
#ifdef USE_ROCM #ifdef USE_ROCM
static_assert(false, assert(false);
"Cudnn backend is not surpported in rocm for normalization yet.");
#else #else
// Binding data pointers to graph tensors // Binding data pointers to graph tensors
_variant_pack = { _variant_pack = {
...@@ -519,7 +516,8 @@ NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan( ...@@ -519,7 +516,8 @@ NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan(
bool& _cudnn_norm_fwd_flag() { bool& _cudnn_norm_fwd_flag() {
#ifdef USE_ROCM #ifdef USE_ROCM
return false; static bool flag = false;
return flag;
#else #else
static bool flag = transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN"); static bool flag = transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN");
return flag; return flag;
...@@ -528,7 +526,8 @@ bool& _cudnn_norm_fwd_flag() { ...@@ -528,7 +526,8 @@ bool& _cudnn_norm_fwd_flag() {
bool& _cudnn_norm_bwd_flag() { bool& _cudnn_norm_bwd_flag() {
#ifdef USE_ROCM #ifdef USE_ROCM
return false; static bool flag = false;
return flag;
#else #else
static bool flag = transformer_engine::getenv<bool>("NVTE_NORM_BWD_USE_CUDNN"); static bool flag = transformer_engine::getenv<bool>("NVTE_NORM_BWD_USE_CUDNN");
return flag; return flag;
...@@ -544,7 +543,8 @@ bool use_cudnn_norm_bwd() { return _cudnn_norm_bwd_flag(); } ...@@ -544,7 +543,8 @@ bool use_cudnn_norm_bwd() { return _cudnn_norm_bwd_flag(); }
void nvte_enable_cudnn_norm_fwd(bool enable) { void nvte_enable_cudnn_norm_fwd(bool enable) {
NVTE_API_CALL(nvte_enable_cudnn_norm_fwd); NVTE_API_CALL(nvte_enable_cudnn_norm_fwd);
#ifdef USE_ROCM #ifdef USE_ROCM
transformer_engine::normalization::_cudnn_norm_bwd_flag() = false; bool flag = false;
transformer_engine::normalization::_cudnn_norm_bwd_flag() = flag;
#else #else
transformer_engine::normalization::_cudnn_norm_fwd_flag() = enable; transformer_engine::normalization::_cudnn_norm_fwd_flag() = enable;
#endif #endif
...@@ -553,7 +553,8 @@ void nvte_enable_cudnn_norm_fwd(bool enable) { ...@@ -553,7 +553,8 @@ void nvte_enable_cudnn_norm_fwd(bool enable) {
void nvte_enable_cudnn_norm_bwd(bool enable) { void nvte_enable_cudnn_norm_bwd(bool enable) {
NVTE_API_CALL(nvte_enable_cudnn_norm_bwd); NVTE_API_CALL(nvte_enable_cudnn_norm_bwd);
#ifdef USE_ROCM #ifdef USE_ROCM
transformer_engine::normalization::_cudnn_norm_bwd_flag() = false; bool flag = false;
transformer_engine::normalization::_cudnn_norm_bwd_flag() = flag;
#else #else
transformer_engine::normalization::_cudnn_norm_bwd_flag() = enable; transformer_engine::normalization::_cudnn_norm_bwd_flag() = enable;
#endif #endif
......
...@@ -30,7 +30,9 @@ namespace transformer_engine { ...@@ -30,7 +30,9 @@ namespace transformer_engine {
namespace normalization { namespace normalization {
#ifndef __HIP_PLATFORM_AMD__
namespace fe = cudnn_frontend; namespace fe = cudnn_frontend;
#endif
template <typename KernelParamsType> template <typename KernelParamsType>
struct LaunchParams { struct LaunchParams {
...@@ -277,14 +279,14 @@ class CudnnNormalizationPlan : public NormalizationPlanBase { ...@@ -277,14 +279,14 @@ class CudnnNormalizationPlan : public NormalizationPlanBase {
private: private:
void _build() override; void _build() override;
#ifndef __HIP_PLATFORM_AMD__
const bool _zero_centered, _fp8_out; const bool _zero_centered, _fp8_out;
int _ndim_scale_block; int _ndim_scale_block;
const NVTE_Norm_Stage _norm_stage; const NVTE_Norm_Stage _norm_stage;
const NVTE_Norm_Type _norm_type; const NVTE_Norm_Type _norm_type;
std::unique_ptr<char[]> _scalar_dptr; std::unique_ptr<char[]> _scalar_dptr;
std::unique_ptr<float> _one_dptr = std::make_unique<float>(1.0f); std::unique_ptr<float> _one_dptr = std::make_unique<float>(1.0f);
#ifndef __HIP_PLATFORM_AMD__
// FWD // FWD
std::shared_ptr<fe::graph::Tensor_attributes> _x, _gamma_zero, _scalar_offset, _gamma, _beta, std::shared_ptr<fe::graph::Tensor_attributes> _x, _gamma_zero, _scalar_offset, _gamma, _beta,
_eps, _mean, _rsigma, _z, _z_scale, _one_for_div, _z_scale_inv, _amax, _z_fp8; _eps, _mean, _rsigma, _z, _z_scale, _one_for_div, _z_scale_inv, _amax, _z_fp8;
......
...@@ -43,6 +43,11 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params, ...@@ -43,6 +43,11 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES)); Kernel_traits::SMEM_BYTES));
} }
#else
if (Kernel_traits::SMEM_BYTES >= 48 * 1024) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute((const void *)kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES));
}
#endif #endif
auto stream = launch_params.stream; auto stream = launch_params.stream;
......
...@@ -39,6 +39,11 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params, ...@@ -39,6 +39,11 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES_FWD)); Kernel_traits::SMEM_BYTES_FWD));
} }
#else
if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute((const void *)kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES_FWD));
}
#endif #endif
auto stream = launch_params.stream; auto stream = launch_params.stream;
......
...@@ -42,6 +42,11 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params, ...@@ -42,6 +42,11 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES)); Kernel_traits::SMEM_BYTES));
} }
#else
if (Kernel_traits::SMEM_BYTES >= 48 * 1024) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute((const void *)kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES));
}
#endif #endif
auto stream = launch_params.stream; auto stream = launch_params.stream;
......
...@@ -40,6 +40,11 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params, ...@@ -40,6 +40,11 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES_FWD)); Kernel_traits::SMEM_BYTES_FWD));
} }
#else
if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute((const void *)kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES_FWD));
}
#endif #endif
auto stream = launch_params.stream; auto stream = launch_params.stream;
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
using __nv_fp8_e4m3 = hip_f8<hip_f8_type::fp8>; using __nv_fp8_e4m3 = hip_f8<hip_f8_type::fp8>;
using __nv_fp8_e5m2 = hip_f8<hip_f8_type::bf8>; using __nv_fp8_e5m2 = hip_f8<hip_f8_type::bf8>;
#define __ldlu(x) __ldg(x)
#endif #endif
static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id_map, static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id_map,
...@@ -214,7 +215,11 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac ...@@ -214,7 +215,11 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac
if (k == topK) break; if (k == topK) break;
// Warp-level reduction // Warp-level reduction
for (int mask = 16; mask > 0; mask /= 2) { for (int mask = 16; mask > 0; mask /= 2) {
#ifdef __HIP_PLATFORM_AMD__
accum[k] = accum[k] + __shfl_xor(accum[k], mask, 32);
#else
accum[k] = accum[k] + __shfl_xor_sync(0xffffffff, accum[k], mask, 32); accum[k] = accum[k] + __shfl_xor_sync(0xffffffff, accum[k], mask, 32);
#endif
} }
} }
......
...@@ -15,6 +15,10 @@ ...@@ -15,6 +15,10 @@
#include "../util/vectorized_pointwise.h" #include "../util/vectorized_pointwise.h"
#include "recipe_common.cuh" #include "recipe_common.cuh"
#ifdef __HIP_PLATFORM_AMD__
using __nv_bfloat16 = __hip_bfloat16;
#endif
namespace transformer_engine { namespace transformer_engine {
namespace { namespace {
......
...@@ -11,7 +11,11 @@ ...@@ -11,7 +11,11 @@
#include <string> #include <string>
#include "../common.h" #include "../common.h"
#ifdef __HIP_PLATFORM_AMD__
#include "../util/hip_runtime.h"
#else
#include "../util/cuda_runtime.h" #include "../util/cuda_runtime.h"
#endif
#include "../util/logging.h" #include "../util/logging.h"
namespace transformer_engine { namespace transformer_engine {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment