".github/vscode:/vscode.git/clone" did not exist on "e8ef4c0820ff6457f32c17e1470fe47976b35e21"
Commit 3e806729 authored by wenjh's avatar wenjh
Browse files

Fix swizzle, swap_first_dims and RMSNorm issues


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 08f06b7a
......@@ -42,10 +42,17 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
GTEST_SKIP() << "zero_centered_gamma_in_weight_dtype not currently supported "
<< "in fused norm backward+add";
}
#ifdef __HIP_PLATFORM_AMD__
if (use_cudnn) {
GTEST_SKIP() << "cuDNN normalizations not supported on pre-Hopper GPUs yet!";
return;
}
#else
if (getDeviceComputeCapability() < hopperComputeCapability && use_cudnn) {
GTEST_SKIP() << "cuDNN normalizations not supported on pre-Hopper GPUs yet!";
}
#endif
using WeightType = InputType;
DType itype = TypeInfo<InputType>::dtype;
......
......@@ -536,7 +536,7 @@ else()
string_code_transpose_rtc_cast_transpose_cu)
make_string_header_from_file(transpose/rtc/transpose.hip
string_code_transpose_rtc_transpose_cu)
make_string_header_from_file(transpose/rtc/swap_first_dims.cu
make_string_header_from_file(transpose/rtc/swap_first_dims.hip
string_code_transpose_rtc_swap_first_dims_cu)
endif()
......
......@@ -21,6 +21,15 @@ namespace {
constexpr int MXFP8_BLOCK_SIZE = 32;
constexpr int NVFP4_BLOCK_SIZE = 16;
#ifdef __HIP_PLATFORM_AMD__
constexpr int TB_DIM = 32;
constexpr int NEW_SF_TILE_DIM_K = 16;
constexpr int N_SF_PER_TD_PER_TILE = 4;
// output is in ~K-major interleaved blocks
constexpr int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4;
constexpr int NEW_SF_TILE_DIM_M_I32 = 32;
#else
constexpr __device__ __host__ int TB_DIM = 32;
constexpr __device__ __host__ int NEW_SF_TILE_DIM_K = 16;
constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4;
......@@ -28,6 +37,7 @@ constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4;
// output is in ~K-major interleaved blocks
constexpr __device__ __host__ int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4;
constexpr __device__ __host__ int NEW_SF_TILE_DIM_M_I32 = 32;
#endif
template <typename LType>
__device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment