Unverified Commit d1d4074c authored by blzheng's avatar blzheng Committed by GitHub
Browse files

[CPU] Add gelu_and_mul kernel in sgl-kernel and add ut (#9300)

parent 718f25ae
...@@ -110,6 +110,14 @@ class GeluAndMul(CustomOp): ...@@ -110,6 +110,14 @@ class GeluAndMul(CustomOp):
d = x.shape[-1] // 2 d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:] return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
if _is_cpu_amx_available and self.approximate == "tanh":
return torch.ops.sgl_kernel.gelu_tanh_and_mul_cpu(x)
elif _is_cpu_amx_available and self.approximate == "none":
return torch.ops.sgl_kernel.gelu_and_mul_cpu(x)
else:
return self.forward_native(x)
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
return self._forward_impl(x) return self._forward_impl(x)
......
...@@ -77,3 +77,59 @@ at::Tensor silu_and_mul_cpu(at::Tensor& input) { ...@@ -77,3 +77,59 @@ at::Tensor silu_and_mul_cpu(at::Tensor& input) {
}); });
return out; return out;
} }
at::Tensor gelu_tanh_and_mul_cpu(const at::Tensor& input) {
RECORD_FUNCTION("sgl-kernel::gelu_tanh_and_mul_cpu", std::vector<c10::IValue>({input}));
auto sizes = input.sizes().vec();
int64_t last_dim = input.ndimension() - 1;
int64_t d = sizes[last_dim] / 2;
sizes[last_dim] = d;
int64_t num_tokens = input.numel() / input.size(-1);
at::Tensor out = at::empty(sizes, input.options());
const float sqrt_2_div_pi = std::sqrt(2.f / M_PI);
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "gelu_tanh_and_mul", [&] {
using Vec = at::vec::Vectorized<float>;
act_and_mul_kernel_impl(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
num_tokens,
d,
[sqrt_2_div_pi](float x) {
float x3 = x * x * x;
float tanh_arg = sqrt_2_div_pi * (x + 0.044715f * x3);
return 0.5f * x * (1.f + std::tanh(tanh_arg));
},
[sqrt_2_div_pi](Vec x) {
Vec x3 = x * x * x;
Vec tanh_arg = Vec(sqrt_2_div_pi) * (x + Vec(0.044715f) * x3);
return Vec(0.5f) * x * (Vec(1.f) + tanh_arg.tanh());
});
});
return out;
}
at::Tensor gelu_and_mul_cpu(const at::Tensor& input) {
RECORD_FUNCTION("sgl-kernel::gelu_and_mul_cpu", std::vector<c10::IValue>({input}));
auto sizes = input.sizes().vec();
int64_t last_dim = input.ndimension() - 1;
int64_t d = sizes[last_dim] / 2;
sizes[last_dim] = d;
int64_t num_tokens = input.numel() / input.size(-1);
at::Tensor out = at::empty(sizes, input.options());
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul", [&] {
using Vec = at::vec::Vectorized<float>;
const float inv_sqrt2 = 1.0f / std::sqrt(2.0f);
act_and_mul_kernel_impl(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
num_tokens,
d,
[inv_sqrt2](float x) { return 0.5f * x * (1.f + std::erf(x * inv_sqrt2)); },
[inv_sqrt2](Vec x) { return Vec(0.5f) * x * (Vec(1.f) + (x * Vec(inv_sqrt2)).erf()); });
});
return out;
}
...@@ -23,6 +23,10 @@ limitations under the License. ...@@ -23,6 +23,10 @@ limitations under the License.
// silu_and_mul // silu_and_mul
at::Tensor silu_and_mul_cpu(at::Tensor& input); at::Tensor silu_and_mul_cpu(at::Tensor& input);
// gelu_and_mul
at::Tensor gelu_tanh_and_mul_cpu(const at::Tensor& input);
at::Tensor gelu_and_mul_cpu(const at::Tensor& input);
// l2norm // l2norm
at::Tensor l2norm_cpu(at::Tensor& input, double eps); at::Tensor l2norm_cpu(at::Tensor& input, double eps);
...@@ -233,6 +237,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -233,6 +237,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// activation // activation
m.def("silu_and_mul_cpu(Tensor input) -> Tensor"); m.def("silu_and_mul_cpu(Tensor input) -> Tensor");
m.impl("silu_and_mul_cpu", torch::kCPU, &silu_and_mul_cpu); m.impl("silu_and_mul_cpu", torch::kCPU, &silu_and_mul_cpu);
m.def("gelu_tanh_and_mul_cpu(Tensor input) -> Tensor");
m.impl("gelu_tanh_and_mul_cpu", torch::kCPU, &gelu_tanh_and_mul_cpu);
m.def("gelu_and_mul_cpu(Tensor input) -> Tensor");
m.impl("gelu_and_mul_cpu", torch::kCPU, &gelu_and_mul_cpu);
// norm // norm
m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor"); m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor");
......
...@@ -4,7 +4,7 @@ import unittest ...@@ -4,7 +4,7 @@ import unittest
import sgl_kernel import sgl_kernel
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from utils import SiluAndMul, precision from utils import GeluAndMul, SiluAndMul, precision
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
...@@ -16,7 +16,7 @@ class TestActivation(CustomTestCase): ...@@ -16,7 +16,7 @@ class TestActivation(CustomTestCase):
N = [22016, 22018] N = [22016, 22018]
dtype = [torch.float16, torch.bfloat16] dtype = [torch.float16, torch.bfloat16]
def _activation_test(self, m, n, dtype): def _silu_and_mul_test(self, m, n, dtype):
x = torch.randn([m, n], dtype=dtype) x = torch.randn([m, n], dtype=dtype)
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x) out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
...@@ -25,10 +25,30 @@ class TestActivation(CustomTestCase): ...@@ -25,10 +25,30 @@ class TestActivation(CustomTestCase):
atol = rtol = precision[ref_out.dtype] atol = rtol = precision[ref_out.dtype]
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def _gelu_and_mul_test(self, m, n, dtype):
x = torch.randn([m, n], dtype=dtype)
out = torch.ops.sgl_kernel.gelu_and_mul_cpu(x)
ref_out = GeluAndMul(x, approximate="none")
atol = rtol = precision[ref_out.dtype]
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def _gelu_tanh_and_mul_test(self, m, n, dtype):
x = torch.randn([m, n], dtype=dtype)
out = torch.ops.sgl_kernel.gelu_tanh_and_mul_cpu(x)
ref_out = GeluAndMul(x, approximate="tanh")
atol = rtol = precision[ref_out.dtype]
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def test_activation(self): def test_activation(self):
for params in itertools.product(self.M, self.N, self.dtype): for params in itertools.product(self.M, self.N, self.dtype):
with self.subTest(m=params[0], n=params[1], dtype=params[2]): with self.subTest(m=params[0], n=params[1], dtype=params[2]):
self._activation_test(*params) self._silu_and_mul_test(*params)
self._gelu_and_mul_test(*params)
self._gelu_tanh_and_mul_test(*params)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -20,6 +20,11 @@ def SiluAndMul(x: torch.Tensor) -> torch.Tensor: ...@@ -20,6 +20,11 @@ def SiluAndMul(x: torch.Tensor) -> torch.Tensor:
return F.silu(x[..., :d]) * x[..., d:] return F.silu(x[..., :d]) * x[..., d:]
def GeluAndMul(x: torch.Tensor, approximate="tanh") -> torch.Tensor:
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
def per_token_quant_int8(x): def per_token_quant_int8(x):
x = x.float() x = x.float()
absmax = x.abs().max(dim=-1).values absmax = x.abs().max(dim=-1).values
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment