Unverified Commit 9f8f2c7f authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

update norm cu (#3048)

parent 6fc37bd8
...@@ -91,7 +91,7 @@ ext_modules = [ ...@@ -91,7 +91,7 @@ ext_modules = [
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu", "src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
"src/sgl-kernel/csrc/sgl_kernel_ops.cu", "src/sgl-kernel/csrc/sgl_kernel_ops.cu",
"src/sgl-kernel/csrc/rotary_embedding.cu", "src/sgl-kernel/csrc/rotary_embedding.cu",
"src/sgl-kernel/csrc/norm.cu", "3rdparty/flashinfer/csrc/norm.cu",
], ],
include_dirs=include_dirs, include_dirs=include_dirs,
extra_compile_args={ extra_compile_args={
......
#include <cstdint>
#include <flashinfer/norm.cuh>
#include "pytorch_extension_utils.h"
using namespace flashinfer;
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream) {
CHECK_INPUT(input);
CHECK_INPUT(weight);
auto device = input.device();
CHECK_EQ(weight.device(), device);
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
CHECK_DIM(1, weight); // weight: (hidden_size)
CHECK_EQ(input.size(1), weight.size(0));
unsigned int batch_size = input.size(0);
unsigned int hidden_size = input.size(1);
CHECK_EQ(output.size(0), batch_size);
CHECK_EQ(output.size(1), hidden_size);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
cudaError_t status = norm::RMSNorm(static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(weight.data_ptr()),
static_cast<c_type*>(output.data_ptr()), batch_size, hidden_size, eps, stream);
TORCH_CHECK(status == cudaSuccess, "RMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
return true;
});
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment