Unverified Commit d1d4074c authored by blzheng's avatar blzheng Committed by GitHub
Browse files

[CPU] Add gelu_and_mul kernel in sgl-kernel and add ut (#9300)

parent 718f25ae
......@@ -110,6 +110,14 @@ class GeluAndMul(CustomOp):
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
if _is_cpu_amx_available and self.approximate == "tanh":
return torch.ops.sgl_kernel.gelu_tanh_and_mul_cpu(x)
elif _is_cpu_amx_available and self.approximate == "none":
return torch.ops.sgl_kernel.gelu_and_mul_cpu(x)
else:
return self.forward_native(x)
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
return self._forward_impl(x)
......
......@@ -77,3 +77,59 @@ at::Tensor silu_and_mul_cpu(at::Tensor& input) {
});
return out;
}
at::Tensor gelu_tanh_and_mul_cpu(const at::Tensor& input) {
RECORD_FUNCTION("sgl-kernel::gelu_tanh_and_mul_cpu", std::vector<c10::IValue>({input}));
auto sizes = input.sizes().vec();
int64_t last_dim = input.ndimension() - 1;
int64_t d = sizes[last_dim] / 2;
sizes[last_dim] = d;
int64_t num_tokens = input.numel() / input.size(-1);
at::Tensor out = at::empty(sizes, input.options());
const float sqrt_2_div_pi = std::sqrt(2.f / M_PI);
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "gelu_tanh_and_mul", [&] {
using Vec = at::vec::Vectorized<float>;
act_and_mul_kernel_impl(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
num_tokens,
d,
[sqrt_2_div_pi](float x) {
float x3 = x * x * x;
float tanh_arg = sqrt_2_div_pi * (x + 0.044715f * x3);
return 0.5f * x * (1.f + std::tanh(tanh_arg));
},
[sqrt_2_div_pi](Vec x) {
Vec x3 = x * x * x;
Vec tanh_arg = Vec(sqrt_2_div_pi) * (x + Vec(0.044715f) * x3);
return Vec(0.5f) * x * (Vec(1.f) + tanh_arg.tanh());
});
});
return out;
}
at::Tensor gelu_and_mul_cpu(const at::Tensor& input) {
RECORD_FUNCTION("sgl-kernel::gelu_and_mul_cpu", std::vector<c10::IValue>({input}));
auto sizes = input.sizes().vec();
int64_t last_dim = input.ndimension() - 1;
int64_t d = sizes[last_dim] / 2;
sizes[last_dim] = d;
int64_t num_tokens = input.numel() / input.size(-1);
at::Tensor out = at::empty(sizes, input.options());
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul", [&] {
using Vec = at::vec::Vectorized<float>;
const float inv_sqrt2 = 1.0f / std::sqrt(2.0f);
act_and_mul_kernel_impl(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
num_tokens,
d,
[inv_sqrt2](float x) { return 0.5f * x * (1.f + std::erf(x * inv_sqrt2)); },
[inv_sqrt2](Vec x) { return Vec(0.5f) * x * (Vec(1.f) + (x * Vec(inv_sqrt2)).erf()); });
});
return out;
}
......@@ -23,6 +23,10 @@ limitations under the License.
// silu_and_mul
at::Tensor silu_and_mul_cpu(at::Tensor& input);
// gelu_and_mul
at::Tensor gelu_tanh_and_mul_cpu(const at::Tensor& input);
at::Tensor gelu_and_mul_cpu(const at::Tensor& input);
// l2norm
at::Tensor l2norm_cpu(at::Tensor& input, double eps);
......@@ -233,6 +237,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// activation
m.def("silu_and_mul_cpu(Tensor input) -> Tensor");
m.impl("silu_and_mul_cpu", torch::kCPU, &silu_and_mul_cpu);
m.def("gelu_tanh_and_mul_cpu(Tensor input) -> Tensor");
m.impl("gelu_tanh_and_mul_cpu", torch::kCPU, &gelu_tanh_and_mul_cpu);
m.def("gelu_and_mul_cpu(Tensor input) -> Tensor");
m.impl("gelu_and_mul_cpu", torch::kCPU, &gelu_and_mul_cpu);
// norm
m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor");
......
......@@ -4,7 +4,7 @@ import unittest
import sgl_kernel
import torch
import torch.nn.functional as F
from utils import SiluAndMul, precision
from utils import GeluAndMul, SiluAndMul, precision
from sglang.test.test_utils import CustomTestCase
......@@ -16,7 +16,7 @@ class TestActivation(CustomTestCase):
N = [22016, 22018]
dtype = [torch.float16, torch.bfloat16]
def _activation_test(self, m, n, dtype):
def _silu_and_mul_test(self, m, n, dtype):
x = torch.randn([m, n], dtype=dtype)
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
......@@ -25,10 +25,30 @@ class TestActivation(CustomTestCase):
atol = rtol = precision[ref_out.dtype]
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def _gelu_and_mul_test(self, m, n, dtype):
x = torch.randn([m, n], dtype=dtype)
out = torch.ops.sgl_kernel.gelu_and_mul_cpu(x)
ref_out = GeluAndMul(x, approximate="none")
atol = rtol = precision[ref_out.dtype]
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def _gelu_tanh_and_mul_test(self, m, n, dtype):
x = torch.randn([m, n], dtype=dtype)
out = torch.ops.sgl_kernel.gelu_tanh_and_mul_cpu(x)
ref_out = GeluAndMul(x, approximate="tanh")
atol = rtol = precision[ref_out.dtype]
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def test_activation(self):
for params in itertools.product(self.M, self.N, self.dtype):
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
self._activation_test(*params)
self._silu_and_mul_test(*params)
self._gelu_and_mul_test(*params)
self._gelu_tanh_and_mul_test(*params)
if __name__ == "__main__":
......
......@@ -20,6 +20,11 @@ def SiluAndMul(x: torch.Tensor) -> torch.Tensor:
return F.silu(x[..., :d]) * x[..., d:]
def GeluAndMul(x: torch.Tensor, approximate="tanh") -> torch.Tensor:
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
def per_token_quant_int8(x):
x = x.float()
absmax = x.abs().max(dim=-1).values
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment