// SPDX-License-Identifier: MIT #include "py_itfs_common.h" #include #include #include #include "rmsnorm2d_fwd.hpp" void fused_add_rms_norm_out(torch::Tensor& out, torch::Tensor& input, torch::Tensor& residual_in, torch::Tensor& residual_out, torch::Tensor& weight, double epsilon); namespace { bool is_dense_row_major_2d(const torch::Tensor& tensor) { return tensor.dim() == 2 && tensor.stride(-1) == 1 && tensor.stride(0) == tensor.size(-1); } bool is_dense_last_dim(const torch::Tensor& tensor) { return tensor.dim() >= 1 && tensor.stride(-1) == 1 && tensor.is_contiguous(); } } // namespace void rmsnorm2d( torch::Tensor& out, // [m, n] torch::Tensor& input, // [m, n] torch::Tensor& weight, // [1, n] double epsilon) // 0: Use default RMSNorm; 1: Use T5-like implementation { auto dtype = input.dtype(); TORCH_CHECK(dtype == torch::kFloat16 || dtype == torch::kBFloat16, "ck rmsnorm2d only support fp16 and bf16 data type"); std::string dtype_str = torchDTypeToStr(dtype); int n = input.size(-1); int m = input.numel() / n; int stride = input.stride(0); int xr_stride = -1; int y_stride = out.stride(0); int yr_stride = -1; bool SaveRms = false; const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); const hipStream_t stream = at::hip::getCurrentHIPStream(); rmsnorm2d_fwd({dtype_str, // input precision dtype_str, // output precision dtype_str, // x-scale, used for [1*N] input smooth quant dtype_str, // y-scale, used for [M*1] output for next layer SaveRms, false, // save_unquant 0, // fused_add 0, // fused_quant }, {input.data_ptr(), nullptr, // p_x_residual nullptr, // p_x_scale weight.data_ptr(), out.data_ptr(), nullptr, // p_y_residual nullptr, // p_y_scale nullptr, // p_invRms nullptr, // p_y_unquant static_cast(epsilon), m, n, stride, xr_stride, y_stride, yr_stride}, {stream}); } torch::Tensor rmsnorm2d( torch::Tensor& input, // [m, n] torch::Tensor& weight, // [1, n] double epsilon) // 0: Use default RMSNorm; 1: Use T5-like implementation { torch::Tensor out = torch::empty_like(input); rmsnorm2d(out, input, weight, epsilon); return out; } void rmsnorm2d_with_add( torch::Tensor& out, // [m ,n] torch::Tensor& input, // [m ,n] torch::Tensor& residual_in, // [m ,n] torch::Tensor& residual_out, // [m ,n] torch::Tensor& weight, // [1 ,n] double epsilon) // 0: Use default RMSNorm; 1: Use T5-like implementation { auto dtype = input.dtype(); TORCH_CHECK(dtype == torch::kFloat16 || dtype == torch::kBFloat16, "ck rmsnorm2d only support fp16 and bf16 data type"); TORCH_CHECK(out.dtype() == dtype && residual_in.dtype() == dtype && residual_out.dtype() == dtype && weight.dtype() == dtype, "rmsnorm2d_with_add expects input/output/residual/weight to share the same dtype"); TORCH_CHECK(input.sizes() == out.sizes() && input.sizes() == residual_in.sizes() && input.sizes() == residual_out.sizes(), "rmsnorm2d_with_add expects input/out/residual tensors to have the same shape"); std::string dtype_str = torchDTypeToStr(input.dtype()); int n = input.size(-1); int m = input.numel() / n; int stride = input.stride(0); int xr_stride = residual_in.stride(0); int y_stride = out.stride(0); int yr_stride = residual_out.stride(0); bool SaveRms = false; const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); const hipStream_t stream = at::hip::getCurrentHIPStream(); // 1. For bf16, we can also choose the vllm-like solution. // const bool can_use_vllm_bf16_bypass = dtype == torch::kBFloat16 && is_dense_row_major_2d(input) && // is_dense_row_major_2d(out) && // is_dense_row_major_2d(residual_in) && // is_dense_row_major_2d(residual_out) && // is_dense_last_dim(weight); // if(can_use_vllm_bf16_bypass) // { // fused_add_rms_norm_out(out, input, residual_in, residual_out, weight, epsilon); // return; // } // 2. CK solution rmsnorm2d_fwd({dtype_str, // input precision dtype_str, // output precision dtype_str, // x-scale, used for [1*N] input smooth quant dtype_str, // y-scale, used for [M*1] output for next layer SaveRms, false, // save_unquant 1, // fused_add 0, // fused_quant }, {input.data_ptr(), // p_x residual_in.data_ptr(), // p_x_residual nullptr, // p_x_scale weight.data_ptr(), // p_gamma out.data_ptr(), // p_y residual_out.data_ptr(), // p_y_residual nullptr, // p_y_scale nullptr, // p_invRms nullptr, // p_y_unquant static_cast(epsilon), m, n, stride, xr_stride, y_stride, yr_stride}, {stream}); } void rmsnorm2d_with_smoothquant( torch::Tensor& out, // [m ,n] torch::Tensor& input, // [m ,n] torch::Tensor& xscale, // [1 ,n] torch::Tensor& yscale, // [m ,1] torch::Tensor& weight, // [1 ,n] double epsilon) // 0: Use default RMSNorm; 1: Use T5-like implementation { auto dtype = input.dtype(); TORCH_CHECK(dtype == torch::kFloat16 || dtype == torch::kBFloat16, "ck rmsnorm2d only support fp16 and bf16 data type"); std::string dtype_str = torchDTypeToStr(input.dtype()); std::string out_dtype_str = torchDTypeToStr(out.dtype()); std::string xscale_dtype_str = torchDTypeToStr(xscale.dtype()); std::string yscale_dtype_str = torchDTypeToStr(yscale.dtype()); int n = input.size(-1); int m = input.numel() / n; int stride = input.stride(0); int xr_stride = -1; int y_stride = out.stride(0); int yr_stride = -1; bool SaveRms = false; const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); const hipStream_t stream = at::hip::getCurrentHIPStream(); rmsnorm2d_fwd({dtype_str, // input precision out_dtype_str, // output precision xscale_dtype_str, // x-scale, used for [1*N] input smooth quant yscale_dtype_str, // y-scale, used for [M*1] output for next layer SaveRms, false, // save_unquant 0, // fused_add 1, // fused_quant }, {input.data_ptr(), // p_x nullptr, // p_x_residual xscale.data_ptr(), // p_x_scale weight.data_ptr(), // p_gamma out.data_ptr(), // p_y nullptr, // p_y_residual yscale.data_ptr(), // p_y_scale nullptr, // p_invRms nullptr, // p_y_unquant static_cast(epsilon), m, n, stride, xr_stride, y_stride, yr_stride}, {stream}); } void rmsnorm2d_with_add_smoothquant( torch::Tensor& out, // [m ,n] torch::Tensor& input, // [m ,n] torch::Tensor& residual_in, // [m ,n] torch::Tensor& residual_out, // [m ,n] torch::Tensor& xscale, // [1 ,n] torch::Tensor& yscale, // [m ,1] torch::Tensor& weight, // [1 ,n] double epsilon, std::optional out_before_quant) // 0: Use default RMSNorm; 1: Use T5-like implementation { auto dtype = input.dtype(); TORCH_CHECK(dtype == torch::kFloat16 || dtype == torch::kBFloat16, "ck rmsnorm2d only support fp16 and bf16 data type"); std::string dtype_str = torchDTypeToStr(input.dtype()); std::string out_dtype_str = torchDTypeToStr(out.dtype()); std::string xscale_dtype_str = torchDTypeToStr(xscale.dtype()); std::string yscale_dtype_str = torchDTypeToStr(yscale.dtype()); int n = input.size(-1); int m = input.numel() / n; int stride = input.stride(0); int xr_stride = residual_in.stride(0); int y_stride = out.stride(0); int yr_stride = residual_out.stride(0); bool SaveRms = false; const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); const hipStream_t stream = at::hip::getCurrentHIPStream(); rmsnorm2d_fwd({dtype_str, // input precision out_dtype_str, // output precision xscale_dtype_str, // x-scale, used for [1*N] input smooth quant yscale_dtype_str, // y-scale, used for [M*1] output for next layer SaveRms, out_before_quant.has_value(), // save_unquant 1, // fused_add 1, // fused_quant }, {input.data_ptr(), // p_x residual_in.data_ptr(), // p_x_residual xscale.data_ptr(), // p_x_scale weight.data_ptr(), // p_gamma out.data_ptr(), // p_y residual_out.data_ptr(), // p_y_residual yscale.data_ptr(), // p_y_scale nullptr, // p_invRms out_before_quant.has_value() ? out_before_quant.value().data_ptr() : nullptr, // p_y_unquant static_cast(epsilon), m, n, stride, xr_stride, y_stride, yr_stride}, {stream}); } void rmsnorm2d_with_dynamicquant( torch::Tensor& out, // [m ,n] torch::Tensor& input, // [m ,n] torch::Tensor& yscale, // [m ,1] torch::Tensor& weight, // [1 ,n] double epsilon) // 0: Use default RMSNorm; 1: Use T5-like implementation { auto dtype = input.dtype(); TORCH_CHECK(dtype == torch::kFloat16 || dtype == torch::kBFloat16, "ck rmsnorm2d only support fp16 and bf16 data type"); std::string dtype_str = torchDTypeToStr(input.dtype()); std::string out_dtype_str = torchDTypeToStr(out.dtype()); std::string yscale_dtype_str = torchDTypeToStr(yscale.dtype()); int n = input.size(-1); int m = input.numel() / n; int stride = input.stride(0); int xr_stride = -1; int y_stride = out.stride(0); int yr_stride = -1; bool SaveRms = false; const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); const hipStream_t stream = at::hip::getCurrentHIPStream(); rmsnorm2d_fwd({dtype_str, // input precision out_dtype_str, // output precision dtype_str, // x-scale, used for [1*N] input smooth quant yscale_dtype_str, // y-scale, used for [M*1] output for next layer SaveRms, false, // save_unquant 0, // fused_add 2, // fused_quant }, {input.data_ptr(), // p_x nullptr, // p_x_residual nullptr, // p_x_scale weight.data_ptr(), // p_gamma out.data_ptr(), // p_y nullptr, // p_y_residual yscale.data_ptr(), // p_y_scale nullptr, // p_invRms nullptr, // p_y_unquant static_cast(epsilon), m, n, stride, xr_stride, y_stride, yr_stride}, {stream}); } void rmsnorm2d_with_add_dynamicquant( torch::Tensor& out, // [m ,n] torch::Tensor& input, // [m ,n] torch::Tensor& residual_in, // [m ,n] torch::Tensor& residual_out, // [m ,n] torch::Tensor& yscale, // [m ,1] torch::Tensor& weight, // [1 ,n] double epsilon) // 0: Use default RMSNorm; 1: Use T5-like implementation { auto dtype = input.dtype(); TORCH_CHECK(dtype == torch::kFloat16 || dtype == torch::kBFloat16, "ck rmsnorm2d only support fp16 and bf16 data type"); std::string dtype_str = torchDTypeToStr(input.dtype()); std::string out_dtype_str = torchDTypeToStr(out.dtype()); std::string yscale_dtype_str = torchDTypeToStr(yscale.dtype()); int n = input.size(-1); int m = input.numel() / n; int stride = input.stride(0); int xr_stride = residual_in.stride(0); int y_stride = out.stride(0); int yr_stride = residual_out.stride(0); bool SaveRms = false; const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); const hipStream_t stream = at::hip::getCurrentHIPStream(); rmsnorm2d_fwd({dtype_str, // input precision out_dtype_str, // output precision dtype_str, // x-scale, used for [1*N] input smooth quant yscale_dtype_str, // y-scale, used for [M*1] output for next layer SaveRms, false, // save_unquant 1, // fused_add 2, // fused_quant }, {input.data_ptr(), // p_x residual_in.data_ptr(), // p_x_residual nullptr, // p_x_scale weight.data_ptr(), // p_gamma out.data_ptr(), // p_y residual_out.data_ptr(), // p_y_residual yscale.data_ptr(), // p_y_scale nullptr, // p_invRms nullptr, // p_y_unquant static_cast(epsilon), m, n, stride, xr_stride, y_stride, yr_stride}, {stream}); }