Commit 3e806729 authored by wenjh's avatar wenjh
Browse files

Fix swizzle, swap_first_dims and RMSNorm issues


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 08f06b7a
...@@ -43,9 +43,16 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, ...@@ -43,9 +43,16 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
<< "in fused norm backward+add"; << "in fused norm backward+add";
} }
#ifdef __HIP_PLATFORM_AMD__
if (use_cudnn) {
GTEST_SKIP() << "cuDNN normalizations not supported on pre-Hopper GPUs yet!";
return;
}
#else
if (getDeviceComputeCapability() < hopperComputeCapability && use_cudnn) { if (getDeviceComputeCapability() < hopperComputeCapability && use_cudnn) {
GTEST_SKIP() << "cuDNN normalizations not supported on pre-Hopper GPUs yet!"; GTEST_SKIP() << "cuDNN normalizations not supported on pre-Hopper GPUs yet!";
} }
#endif
using WeightType = InputType; using WeightType = InputType;
DType itype = TypeInfo<InputType>::dtype; DType itype = TypeInfo<InputType>::dtype;
......
...@@ -536,7 +536,7 @@ else() ...@@ -536,7 +536,7 @@ else()
string_code_transpose_rtc_cast_transpose_cu) string_code_transpose_rtc_cast_transpose_cu)
make_string_header_from_file(transpose/rtc/transpose.hip make_string_header_from_file(transpose/rtc/transpose.hip
string_code_transpose_rtc_transpose_cu) string_code_transpose_rtc_transpose_cu)
make_string_header_from_file(transpose/rtc/swap_first_dims.cu make_string_header_from_file(transpose/rtc/swap_first_dims.hip
string_code_transpose_rtc_swap_first_dims_cu) string_code_transpose_rtc_swap_first_dims_cu)
endif() endif()
......
...@@ -21,6 +21,15 @@ namespace { ...@@ -21,6 +21,15 @@ namespace {
constexpr int MXFP8_BLOCK_SIZE = 32; constexpr int MXFP8_BLOCK_SIZE = 32;
constexpr int NVFP4_BLOCK_SIZE = 16; constexpr int NVFP4_BLOCK_SIZE = 16;
#ifdef __HIP_PLATFORM_AMD__
constexpr int TB_DIM = 32;
constexpr int NEW_SF_TILE_DIM_K = 16;
constexpr int N_SF_PER_TD_PER_TILE = 4;
// output is in ~K-major interleaved blocks
constexpr int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4;
constexpr int NEW_SF_TILE_DIM_M_I32 = 32;
#else
constexpr __device__ __host__ int TB_DIM = 32; constexpr __device__ __host__ int TB_DIM = 32;
constexpr __device__ __host__ int NEW_SF_TILE_DIM_K = 16; constexpr __device__ __host__ int NEW_SF_TILE_DIM_K = 16;
constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4; constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4;
...@@ -28,6 +37,7 @@ constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4; ...@@ -28,6 +37,7 @@ constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4;
// output is in ~K-major interleaved blocks // output is in ~K-major interleaved blocks
constexpr __device__ __host__ int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4; constexpr __device__ __host__ int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4;
constexpr __device__ __host__ int NEW_SF_TILE_DIM_M_I32 = 32; constexpr __device__ __host__ int NEW_SF_TILE_DIM_M_I32 = 32;
#endif
template <typename LType> template <typename LType>
__device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) { __device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment