#include "hip/hip_runtime.h" #include "layernorm_kernels_impl.cuh" #include "dispatch_utils.h" void rms_norm(Tensor &out, // [..., hidden_size] Tensor &input, // [..., hidden_size] Tensor &weight, // [hidden_size] float epsilon, bool use_quant) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); const hipStream_t stream = getCurrentHIPStreamMasqueradingAsCUDA(); VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { if (use_quant) { hipLaunchKernelGGL(( vllm::rms_norm_kernel), dim3(grid), dim3(block), 0, stream, out.data_ptr(), input.data_ptr(), weight.data_ptr(), epsilon, num_tokens, hidden_size); } else { hipLaunchKernelGGL(( vllm::rms_norm_kernel), dim3(grid), dim3(block), 0, stream, out.data_ptr(), input.data_ptr(), weight.data_ptr(), epsilon, num_tokens, hidden_size); } }); } void layernorm_general(Tensor out, Tensor input, Tensor weight, Tensor bias, float epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 256)); block.x = 32 * ((block.x + 31) / 32); size_t size_shmem = input.scalar_size() * hidden_size; const hipStream_t stream = getCurrentHIPStreamMasqueradingAsCUDA(); VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "generalLayerNorm", [&] { using T = typename packed_as::type; hipLaunchKernelGGL(( vllm::generalLayerNorm), dim3(grid), dim3(block), size_shmem, stream, reinterpret_cast(input.data_ptr()), weight.valid() ? reinterpret_cast(weight.data_ptr()) : nullptr, bias.valid() ? reinterpret_cast(bias.data_ptr()) : nullptr, reinterpret_cast(out.data_ptr()), epsilon, num_tokens, hidden_size, nullptr, nullptr, nullptr, true); }); } void rms_norm_general(Tensor &out, // [..., hidden_size] Tensor &input, // [..., hidden_size] Tensor &weight, // [hidden_size] Tensor &scaling, // [tokens] or [1] float epsilon, bool use_per_token_quant) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); block.x = 32 * ((block.x + 31) / 32); const hipStream_t stream = getCurrentHIPStreamMasqueradingAsCUDA(); VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "generalLayerNorm", [&] { using T = scalar_t; if (use_per_token_quant) { // per-token hipLaunchKernelGGL(( vllm::generalLayerNorm) , dim3(grid), dim3(block), 0, stream, reinterpret_cast(input.data_ptr()), reinterpret_cast(weight.data_ptr()), nullptr, nullptr, epsilon, num_tokens, hidden_size, nullptr, scaling.data_ptr(), out.data_ptr(), false); // input, gamma, beta, normed_output, eps, tokens, hidden_dim, per_tensor_scale, per_token_scale // normed_output_quant, use_shmem // out.data_ptr(), input.data_ptr(), // weight.data_ptr(), epsilon, num_tokens, hidden_size); } else { // per-tensor hipLaunchKernelGGL(( vllm::generalLayerNorm) , dim3(grid), dim3(block), 0, stream, reinterpret_cast(input.data_ptr()), reinterpret_cast(weight.data_ptr()), nullptr, nullptr, epsilon, num_tokens, hidden_size, scaling.data_ptr(), nullptr, out.data_ptr(), false); } }); } void rms_norm_general_fuse_sum(Tensor &out, // [..., hidden_size] Tensor &input, // [..., hidden_size] Tensor &weight, // [hidden_size] Tensor &input_sum, // [tokens] or [1] Tensor &scaling, // [tokens] or [1] float epsilon, bool use_per_token_quant) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); block.x = 32 * ((block.x + 31) / 32); const hipStream_t stream = getCurrentHIPStreamMasqueradingAsCUDA(); VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "generalLayerNorm_fuse_sum", [&] { using T = scalar_t; if (use_per_token_quant) { // per-token hipLaunchKernelGGL(( vllm::generalLayerNorm_fuse_sum) , dim3(grid), dim3(block), 0, stream, reinterpret_cast(input.data_ptr()), reinterpret_cast(weight.data_ptr()), nullptr, nullptr, epsilon, num_tokens, hidden_size, input_sum.data_ptr(), nullptr, scaling.data_ptr(), out.data_ptr(), false); // input, gamma, beta, normed_output, eps, tokens, hidden_dim, per_tensor_scale, per_token_scale // normed_output_quant, use_shmem // out.data_ptr(), input.data_ptr(), // weight.data_ptr(), epsilon, num_tokens, hidden_size); } else { // per-tensor // Rasing error here // Not implemented per-tensor input_sum assert(false); hipLaunchKernelGGL(( vllm::generalLayerNorm_fuse_sum) , dim3(grid), dim3(block), 0, stream, reinterpret_cast(input.data_ptr()), reinterpret_cast(weight.data_ptr()), nullptr, nullptr, epsilon, num_tokens, hidden_size, nullptr, scaling.data_ptr(), nullptr, out.data_ptr(), false); } }); } void invoke_dequant_add_residual_rms_norm_quant(Tensor &out, // [..., hidden_size] Tensor &input, // [..., hidden_size] Tensor &residual, // [..., hidden_size] Tensor &gamma, // [hidden_size] half scale, float epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); const hipStream_t stream = getCurrentHIPStreamMasqueradingAsCUDA(); VLLM_DISPATCH_FLOATING_TYPES(residual.scalar_type(), "dequant_add_residual_rms_norm_quant_kernel", [&] { hipLaunchKernelGGL(( vllm::dequant_add_residual_rms_norm_quant_kernel) , dim3(grid), dim3(block), 0, stream, input.data_ptr(), residual.data_ptr(), out.data_ptr(), gamma.data_ptr(), epsilon, scale, num_tokens, hidden_size); }); } void invoke_dequant_add_residual_rms_norm_quant(Tensor &out, // [..., hidden_size] Tensor &input, // [..., hidden_size] Tensor &residual, // [..., hidden_size] Tensor &gamma, // [hidden_size] Tensor &scale, // [num_tokens] float epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); const hipStream_t stream = getCurrentHIPStreamMasqueradingAsCUDA(); VLLM_DISPATCH_FLOATING_TYPES(residual.scalar_type(), "dequant_add_residual_rms_norm_quant_kernel", [&] { hipLaunchKernelGGL(( vllm::dequant_add_residual_rms_norm_quant_kernel) , dim3(grid), dim3(block), 0, stream, input.data_ptr(), residual.data_ptr(), out.data_ptr(), gamma.data_ptr(), epsilon, scale.data_ptr(), num_tokens, hidden_size); }); }