/* * Copyright (C) 2024-2025, The vLLM team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include "dispatch_utils.h" #include #include #include #include using __nv_bfloat16 = __hip_bfloat16; using __nv_bfloat162 = __hip_bfloat162; namespace vllm { template struct __align__(16) vec8_t { scalar_t x, y, z, w, u, v, s, t; __device__ vec8_t() : x(0), y(0), z(0), w(0), u(0), v(0), s(0), t(0) {} __device__ vec8_t(scalar_t x, scalar_t y, scalar_t z, scalar_t w, scalar_t u, scalar_t v, scalar_t s, scalar_t t) : x(x), y(y), z(z), w(w), u(u), v(v), s(s), t(t) {} __device__ vec8_t operator*(const vec8_t &other) const { return vec8_t(x * other.x, y * other.y, z * other.z, w * other.w, u * other.u, v * other.v, s * other.s, t * other.t); } __device__ vec8_t operator*(const float &scale) const { return vec8_t(x * scale, y * scale, z * scale, w * scale, u * scale, v * scale, s * scale, t * scale); } __device__ vec8_t operator+(const vec8_t &other) const { return vec8_t(x + other.x, y + other.y, z + other.z, w + other.w, u + other.u, v + other.v, s + other.s, t + other.t); } __device__ void operator+=(const vec8_t &other) { x += other.x; y += other.y; z += other.z; w += other.w; u += other.u; v += other.v; s += other.s; t += other.t; } __device__ scalar_t sum() const { return x + y + z + w + u + v + s + t; } }; // TODO(woosuk): Further optimize this kernel. template __global__ void rms_norm_kernel( scalar_t *__restrict__ out, // [..., hidden_size] const scalar_t *__restrict__ input, // [..., hidden_size] const scalar_t *__restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; vec8_t v8_variance = {0, 0, 0, 0, 0, 0, 0, 0}; vec8_t *vectorized_out = reinterpret_cast *>(out); vec8_t const *vectorized_in = reinterpret_cast const *>(input); vec8_t const *vectorized_weight = reinterpret_cast const *>(weight); const int vec_hidden_size = hidden_size >> 3; // Compute variance. Be careful, hidden_size should multiple of 4. for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { vec8_t x = vectorized_in[blockIdx.x * vec_hidden_size + idx]; v8_variance += x * x; } float v8_variance_sum = v8_variance.sum(); using BlockReduce = hipcub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; float variance = BlockReduce(reduceStore).Reduce(v8_variance_sum, hipcub::Sum{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } __syncthreads(); for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { vec8_t v8_in = vectorized_in[blockIdx.x * vec_hidden_size + idx]; vec8_t v8_w = vectorized_weight[idx]; vectorized_out[blockIdx.x * vec_hidden_size + idx] = v8_in * s_variance * v8_w; } } // template // __global__ void scaled_rms_norm_kernel( // c10::Float8_e4m3fnuz* __restrict__ out, // [..., hidden_size] // const scalar_t* __restrict__ input, // [..., hidden_size] // const scalar_t* __restrict__ weight, // [hidden_size] // const float scale, const float epsilon, const int num_tokens, // const int hidden_size) { // __shared__ float s_variance; // float variance = 0.0f; // for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { // const float x = (float)input[blockIdx.x * hidden_size + idx]; // variance += x * x; // } // using BlockReduce = hipcub::BlockReduce; // __shared__ typename BlockReduce::TempStorage reduceStore; // variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x); // if (threadIdx.x == 0) { // s_variance = rsqrtf(variance / hidden_size + epsilon); // } // __syncthreads(); // for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { // float x = (float)input[blockIdx.x * hidden_size + idx]; // float r = (x * s_variance) * weight[idx] * scale; // out[blockIdx.x * hidden_size + idx] = c10::Float8_e4m3fnuz( // hip_fp8(r).data, c10::Float8_e4m3fnuz::from_bits()); // } // } /* Converter structs for the conversion from torch types to HIP/CUDA types, and the associated type conversions within HIP/CUDA. These helpers need to be implemented for now because the relevant type conversion operators/constructors are not consistently implemented by HIP/CUDA, so a generic conversion via type casts cannot be implemented. Each struct should have the member static constexpr bool `exists`: If false, the optimized kernel is not used for the corresponding torch type. If true, the struct should be fully defined as shown in the examples below. */ template struct _typeConvert { static constexpr bool exists = false; }; #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) // CUDA < 12.0 runs into issues with packed type conversion template <> struct _typeConvert { static constexpr bool exists = true; using hip_type = __half; using packed_hip_type = __half2; __device__ static inline float convert(hip_type x) { return __half2float(x); } __device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); } __device__ static inline hip_type convert(float x) { return __float2half_rn(x); } __device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); } }; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // CUDA_ARCH < 800 does not have BF16 support // TODO: Add in ROCm support once public headers handle bf16 maturely template <> struct _typeConvert { static constexpr bool exists = true; using hip_type = __nv_bfloat16; using packed_hip_type = __nv_bfloat162; __device__ static inline float convert(hip_type x) { return __bfloat162float(x); } __device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); } __device__ static inline hip_type convert(float x) { return __float2bfloat16(x); } __device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); } }; #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= // 12000)) /* Vector POD struct to generate vectorized and packed FP16/BF16 ops for appropriate specializations of fused_add_rms_norm_kernel. Only functions that are necessary in that kernel are implemented. Alignment to 16 bytes is required to use 128-bit global memory ops. */ template struct alignas(16) _f16Vec { /* Not theoretically necessary that width is a power of 2 but should almost always be the case for optimization purposes */ static_assert(width > 0 && (width & (width - 1)) == 0, "Width is not a positive power of 2!"); using Converter = _typeConvert; using T1 = typename Converter::hip_type; using T2 = typename Converter::packed_hip_type; T1 data[width]; __device__ _f16Vec &operator+=(const _f16Vec &other) { if constexpr (width % 2 == 0) { #pragma unroll for (int i = 0; i < width; i += 2) { T2 temp{data[i], data[i + 1]}; temp += T2{other.data[i], other.data[i + 1]}; data[i] = temp.x; data[i + 1] = temp.y; } } else { #pragma unroll for (int i = 0; i < width; ++i) data[i] += other.data[i]; } return *this; } __device__ _f16Vec &operator*=(const _f16Vec &other) { if constexpr (width % 2 == 0) { #pragma unroll for (int i = 0; i < width; i += 2) { T2 temp{data[i], data[i + 1]}; temp *= T2{other.data[i], other.data[i + 1]}; data[i] = temp.x; data[i + 1] = temp.y; } } else { #pragma unroll for (int i = 0; i < width; ++i) data[i] *= other.data[i]; } return *this; } __device__ _f16Vec &operator*=(const float scale) { if constexpr (width % 2 == 0) { #pragma unroll for (int i = 0; i < width; i += 2) { float2 temp_f = Converter::convert(T2{data[i], data[i + 1]}); temp_f.x *= scale; temp_f.y *= scale; T2 temp = Converter::convert(temp_f); data[i] = temp.x; data[i + 1] = temp.y; } } else { #pragma unroll for (int i = 0; i < width; ++i) { float temp = Converter::convert(data[i]) * scale; data[i] = Converter::convert(temp); } } return *this; } __device__ float sum_squares() const { float result = 0.0f; if constexpr (width % 2 == 0) { #pragma unroll for (int i = 0; i < width; i += 2) { float2 z = Converter::convert(T2{data[i], data[i + 1]}); result += z.x * z.x + z.y * z.y; } } else { #pragma unroll for (int i = 0; i < width; ++i) { float x = Converter::convert(data[i]); result += x * x; } } return result; } }; /* Function specialization in the case of FP16/BF16 tensors. Additional optimizations we can make in this case are packed and vectorized operations, which help with the memory latency bottleneck. */ template __global__ std::enable_if_t<(width > 0) && _typeConvert::exists> fused_add_rms_norm_kernel( scalar_t *__restrict__ input, // [..., hidden_size] scalar_t *__restrict__ residual, // [..., hidden_size] const scalar_t *__restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { // Sanity checks on our vector struct and type-punned pointer arithmetic static_assert(std::is_pod_v<_f16Vec>); static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); const int vec_hidden_size = hidden_size / width; __shared__ float s_variance; float variance = 0.0f; /* These and the argument pointers are all declared `restrict` as they are not aliased in practice. Argument pointers should not be dereferenced in this kernel as that would be undefined behavior */ auto *__restrict__ input_v = reinterpret_cast<_f16Vec *>(input); auto *__restrict__ residual_v = reinterpret_cast<_f16Vec *>(residual); auto *__restrict__ weight_v = reinterpret_cast *>(weight); for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { int id = blockIdx.x * vec_hidden_size + idx; _f16Vec temp = input_v[id]; temp += residual_v[id]; variance += temp.sum_squares(); residual_v[id] = temp; } using BlockReduce = hipcub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } __syncthreads(); for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { int id = blockIdx.x * vec_hidden_size + idx; _f16Vec temp = residual_v[id]; temp *= s_variance; temp *= weight_v[idx]; input_v[id] = temp; } } /* Generic fused_add_rms_norm_kernel The width field is not used here but necessary for other specializations. */ template __global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> fused_add_rms_norm_kernel( scalar_t *__restrict__ input, // [..., hidden_size] scalar_t *__restrict__ residual, // [..., hidden_size] const scalar_t *__restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { scalar_t z = input[blockIdx.x * hidden_size + idx]; z += residual[blockIdx.x * hidden_size + idx]; float x = (float)z; variance += x * x; residual[blockIdx.x * hidden_size + idx] = z; } using BlockReduce = hipcub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } __syncthreads(); for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { float x = (float)residual[blockIdx.x * hidden_size + idx]; input[blockIdx.x * hidden_size + idx] = ((scalar_t)(x * s_variance)) * weight[idx]; } } template __global__ std::enable_if_t<(width > 0) && _typeConvert::exists> fused_add_rms_norm_out_kernel( scalar_t *__restrict__ out, // [..., hidden_size] const scalar_t *__restrict__ input, // [..., hidden_size] const scalar_t *__restrict__ residual_in, // [..., hidden_size] scalar_t *__restrict__ residual_out, // [..., hidden_size] const scalar_t *__restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { static_assert(std::is_pod_v<_f16Vec>); static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); const int vec_hidden_size = hidden_size / width; __shared__ float s_variance; float variance = 0.0f; auto *__restrict__ out_v = reinterpret_cast<_f16Vec *>(out); auto *__restrict__ residual_out_v = reinterpret_cast<_f16Vec *>(residual_out); auto *__restrict__ input_v = reinterpret_cast *>(input); auto *__restrict__ residual_in_v = reinterpret_cast *>(residual_in); auto *__restrict__ weight_v = reinterpret_cast *>(weight); for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { int id = blockIdx.x * vec_hidden_size + idx; _f16Vec temp = input_v[id]; temp += residual_in_v[id]; variance += temp.sum_squares(); residual_out_v[id] = temp; } using BlockReduce = hipcub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } __syncthreads(); for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { int id = blockIdx.x * vec_hidden_size + idx; _f16Vec temp = residual_out_v[id]; temp *= s_variance; temp *= weight_v[idx]; out_v[id] = temp; } } template __global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> fused_add_rms_norm_out_kernel( scalar_t *__restrict__ out, // [..., hidden_size] const scalar_t *__restrict__ input, // [..., hidden_size] const scalar_t *__restrict__ residual_in, // [..., hidden_size] scalar_t *__restrict__ residual_out, // [..., hidden_size] const scalar_t *__restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { scalar_t z = input[blockIdx.x * hidden_size + idx]; z += residual_in[blockIdx.x * hidden_size + idx]; float x = (float)z; variance += x * x; residual_out[blockIdx.x * hidden_size + idx] = z; } using BlockReduce = hipcub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } __syncthreads(); for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { float x = (float)residual_out[blockIdx.x * hidden_size + idx]; out[blockIdx.x * hidden_size + idx] = ((scalar_t)(x * s_variance)) * weight[idx]; } } /* Function specialization in the case of FP16/BF16 tensors. Additional optimizations we can make in this case are packed and vectorized operations, which help with the memory latency bottleneck. */ // template <> // struct Vec { // using Type = uint2; // }; // template <> // struct Vec { // using Type = uint4; // }; // template <> // struct Vec { // using Type = bf16_8_t; // }; // template // __global__ std::enable_if_t<(width > 0) && _typeConvert::exists> // scaled_fused_add_rms_norm_kernel( // c10::Float8_e4m3fnuz* __restrict__ out, // [..., hidden_size] // scalar_t* __restrict__ input, // [..., hidden_size] // scalar_t* __restrict__ residual, // [..., hidden_size] // const scalar_t* __restrict__ weight, // [hidden_size] // const float epsilon, const float scale, const int num_tokens, // const int hidden_size) { // using in_v_t = typename Vec::Type; // using out_v_t = typename Vec::Type; // // Sanity checks on our vector struct and type-punned pointer arithmetic // static_assert(std::is_pod_v<_f16Vec>); // static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); // const int vec_hidden_size = hidden_size / width; // __shared__ float s_variance; // float variance = 0.0f; // /* These and the argument pointers are all declared `restrict` as they are // not aliased in practice. Argument pointers should not be dereferenced // in this kernel as that would be undefined behavior */ // auto* __restrict__ out_v = reinterpret_cast(out); // auto* __restrict__ input_v = // reinterpret_cast<_f16Vec*>(input); // auto* __restrict__ residual_v = // reinterpret_cast<_f16Vec*>(residual); // auto* __restrict__ weight_v = // reinterpret_cast*>(weight); // for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { // int id = blockIdx.x * vec_hidden_size + idx; // _f16Vec temp = input_v[id]; // temp += residual_v[id]; // variance += temp.sum_squares(); // residual_v[id] = temp; // } // using BlockReduce = hipcub::BlockReduce; // __shared__ typename BlockReduce::TempStorage reduceStore; // variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x); // if (threadIdx.x == 0) { // s_variance = rsqrtf(variance / hidden_size + epsilon); // } // __syncthreads(); // for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { // int id = blockIdx.x * vec_hidden_size + idx; // _f16Vec temp = residual_v[id]; // temp *= s_variance; // temp *= weight_v[idx]; // out_v_t temp_quant = fp8::scaled_vec_conversion( // *reinterpret_cast(&temp), scale); // out_v[id] = temp_quant; // } // } /* Generic scaled_fused_add_rms_norm_kernel The width field is not used here but necessary for other specializations. */ // template // __global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> // scaled_fused_add_rms_norm_kernel( // c10::Float8_e4m3fnuz* __restrict__ out, // [..., hidden_size] // scalar_t* __restrict__ input, // [..., hidden_size] // scalar_t* __restrict__ residual, // [..., hidden_size] // const scalar_t* __restrict__ weight, // [hidden_size] // const float epsilon, const float scale, const int num_tokens, // const int hidden_size) { // __shared__ float s_variance; // float variance = 0.0f; // for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { // scalar_t z = input[blockIdx.x * hidden_size + idx]; // z += residual[blockIdx.x * hidden_size + idx]; // float x = (float)z; // variance += x * x; // residual[blockIdx.x * hidden_size + idx] = z; // } // using BlockReduce = hipcub::BlockReduce; // __shared__ typename BlockReduce::TempStorage reduceStore; // variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x); // if (threadIdx.x == 0) { // s_variance = rsqrtf(variance / hidden_size + epsilon); // } // __syncthreads(); // for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { // float x = (float)residual[blockIdx.x * hidden_size + idx]; // float r = (x * s_variance) * (float)weight[idx] / scale; // out[blockIdx.x * hidden_size + idx] = c10::Float8_e4m3fnuz( // hip_fp8(r).data, c10::Float8_e4m3fnuz::from_bits()); // } // } } // namespace vllm void rms_norm(torch::Tensor &out, // [..., hidden_size] torch::Tensor &input, // [..., hidden_size] torch::Tensor &weight, // [hidden_size] double epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); const hipStream_t stream = at::hip::getCurrentHIPStream(); VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { vllm::rms_norm_kernel<<>>( out.data_ptr(), input.data_ptr(), weight.data_ptr(), epsilon, num_tokens, hidden_size); }); } // void scaled_rms_norm(torch::Tensor& out, // [..., hidden_size] // torch::Tensor& input, // [..., hidden_size] // torch::Tensor& weight, // [hidden_size] // torch::Tensor& scale, double epsilon) { // int hidden_size = input.size(-1); // int num_tokens = input.numel() / hidden_size; // dim3 grid(num_tokens); // dim3 block(std::min(hidden_size, 1024)); // const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); // const hipStream_t stream = at::hip::getCurrentHIPStream(); // VLLM_DISPATCH_FLOATING_TYPES( // input.scalar_type(), "scaled_rms_norm_kernel", [&] { // vllm::scaled_rms_norm_kernel<<>>( // out.data_ptr(), input.data_ptr(), // weight.data_ptr(), 1.0 / (*scale.data_ptr()), // epsilon, num_tokens, hidden_size); // }); // } #define LAUNCH_FUSED_ADD_RMS_NORM(width) \ VLLM_DISPATCH_FLOATING_TYPES( \ input.scalar_type(), "fused_add_rms_norm_kernel", [&] { vllm::fused_add_rms_norm_kernel \ <<>>(input.data_ptr(), \ residual.data_ptr(), \ weight.data_ptr(), epsilon, \ num_tokens, hidden_size); }); #define LAUNCH_FUSED_ADD_RMS_NORM_OUT(width) \ VLLM_DISPATCH_FLOATING_TYPES( \ input.scalar_type(), "fused_add_rms_norm_out_kernel", [&] { vllm::fused_add_rms_norm_out_kernel \ <<>>(out.data_ptr(), \ input.data_ptr(), \ residual_in.data_ptr(), \ residual_out.data_ptr(),\ weight.data_ptr(), \ epsilon, num_tokens, hidden_size); }); void fused_add_rms_norm(torch::Tensor &input, // [..., hidden_size] torch::Tensor &residual, // [..., hidden_size] torch::Tensor &weight, // [hidden_size] double epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); /* This kernel is memory-latency bound in many scenarios. When num_tokens is large, a smaller block size allows for increased block occupancy on CUs and better latency hiding on global mem ops. */ const int max_block_size = (num_tokens < 256) ? 1024 : 256; dim3 block(std::min(hidden_size, max_block_size)); const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); const hipStream_t stream = at::hip::getCurrentHIPStream(); /*If the tensor types are FP16/BF16, try to use the optimized kernel with packed + vectorized ops. Max optimization is achieved with a width-8 vector of FP16/BF16s since we can load at most 128 bits at once in a global memory op. However, this requires each tensor's data to be aligned to 16 bytes. */ auto inp_ptr = reinterpret_cast(input.data_ptr()); auto res_ptr = reinterpret_cast(residual.data_ptr()); auto wt_ptr = reinterpret_cast(weight.data_ptr()); bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; if (ptrs_are_aligned && hidden_size % 8 == 0) { LAUNCH_FUSED_ADD_RMS_NORM(8); } else { LAUNCH_FUSED_ADD_RMS_NORM(0); } } void fused_add_rms_norm_out(torch::Tensor &out, // [..., hidden_size] torch::Tensor &input, // [..., hidden_size] torch::Tensor &residual_in, // [..., hidden_size] torch::Tensor &residual_out, // [..., hidden_size] torch::Tensor &weight, // [hidden_size] double epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); const int max_block_size = (num_tokens < 256) ? 1024 : 256; dim3 block(std::min(hidden_size, max_block_size)); const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); const hipStream_t stream = at::hip::getCurrentHIPStream(); auto out_ptr = reinterpret_cast(out.data_ptr()); auto inp_ptr = reinterpret_cast(input.data_ptr()); auto res_in_ptr = reinterpret_cast(residual_in.data_ptr()); auto res_out_ptr = reinterpret_cast(residual_out.data_ptr()); auto wt_ptr = reinterpret_cast(weight.data_ptr()); bool ptrs_are_aligned = out_ptr % 16 == 0 && inp_ptr % 16 == 0 && res_in_ptr % 16 == 0 && res_out_ptr % 16 == 0 && wt_ptr % 16 == 0; if (ptrs_are_aligned && hidden_size % 8 == 0) { LAUNCH_FUSED_ADD_RMS_NORM_OUT(8); } else { LAUNCH_FUSED_ADD_RMS_NORM_OUT(0); } } // #define LAUNCH_SCALED_FUSED_ADD_RMS_NORM(width) \ // VLLM_DISPATCH_FLOATING_TYPES( \ // input.scalar_type(), "scaled_fused_add_rms_norm_kernel", [&] { \ // vllm::scaled_fused_add_rms_norm_kernel \ // <<>>( \ // out.data_ptr(), \ // input.data_ptr(), residual.data_ptr(), \ // weight.data_ptr(), epsilon, \ // *scale.data_ptr(), num_tokens, hidden_size); \ // }); // void scaled_fused_add_rms_norm(torch::Tensor& out, // [..., hidden_size] // torch::Tensor& input, // [..., hidden_size] // torch::Tensor& residual, // [..., hidden_size] // torch::Tensor& weight, // [hidden_size] // torch::Tensor& scale, double epsilon) { // int hidden_size = input.size(-1); // int num_tokens = input.numel() / hidden_size; // dim3 grid(num_tokens); // /* This kernel is memory-latency bound in many scenarios. // When num_tokens is large, a smaller block size allows // for increased block occupancy on CUs and better latency // hiding on global mem ops. */ // const int max_block_size = (num_tokens < 256) ? 1024 : 256; // dim3 block(std::min(hidden_size, max_block_size)); // const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); // const hipStream_t stream = at::hip::getCurrentHIPStream(); // /*If the tensor types are FP16/BF16, try to use the optimized kernel // with packed + vectorized ops. // Max optimization is achieved with a width-8 vector of FP16/BF16s // since we can load at most 128 bits at once in a global memory op. // However, this requires each tensor's data to be aligned to 16 // bytes. // */ // auto inp_ptr = reinterpret_cast(input.data_ptr()); // auto res_ptr = reinterpret_cast(residual.data_ptr()); // auto wt_ptr = reinterpret_cast(weight.data_ptr()); // bool ptrs_are_aligned = // inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; // if (ptrs_are_aligned && hidden_size % 8 == 0) { // LAUNCH_SCALED_FUSED_ADD_RMS_NORM(8); // } else { // LAUNCH_SCALED_FUSED_ADD_RMS_NORM(0); // } // }