Unverified Commit 445b7093 authored by Fadi Arafeh's avatar Fadi Arafeh Committed by GitHub
Browse files

[perf][cpu] Accelerate BF16 GELU with LUT impl on Arm CPUs (#37469)


Signed-off-by: default avatarFadi Arafeh <fadi.arafeh@arm.com>
Co-authored-by: default avatarmergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
parent 18013df6
......@@ -51,6 +51,7 @@ function cpu_tests() {
set -e
pytest -x -v -s tests/kernels/test_onednn.py
pytest -x -v -s tests/kernels/attention/test_cpu_attn.py
pytest -x -v -s tests/kernels/core/test_cpu_activation.py
pytest -x -v -s tests/kernels/moe/test_moe.py -k test_cpu_fused_moe_basic"
# basic online serving
......
......@@ -360,6 +360,7 @@ set(VLLM_EXT_SRC
if (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND)
set(VLLM_EXT_SRC
"csrc/cpu/shm.cpp"
"csrc/cpu/activation_lut_bf16.cpp"
${VLLM_EXT_SRC})
endif()
......
#include "cpu_types.hpp"
#include <array>
#include <cstdint>
#include <mutex>
#include <string>
#include <ATen/ops/empty.h>
#include <ATen/ops/gelu.h>
#include <c10/util/BFloat16.h>
constexpr uint32_t ActivationLutSize = 1u << 16;
at::Tensor gelu_reference(const at::Tensor& x) { return at::gelu(x, "none"); }
void maybe_init_activation_lut_bf16(
uint16_t* lut, std::once_flag& once,
at::Tensor (*activation)(const at::Tensor&)) {
std::call_once(once, [&]() {
auto lut_input =
at::empty({static_cast<int64_t>(ActivationLutSize)},
at::TensorOptions().device(at::kCPU).dtype(at::kFloat));
auto* lut_input_ptr = lut_input.data_ptr<float>();
#pragma omp parallel for
for (uint32_t i = 0; i < ActivationLutSize; ++i) {
lut_input_ptr[i] = c10::detail::f32_from_bits(static_cast<uint16_t>(i));
}
auto lut_output = activation(lut_input);
const auto* lut_output_ptr = lut_output.data_ptr<float>();
#pragma omp parallel for
for (uint32_t i = 0; i < ActivationLutSize; ++i) {
lut[i] = c10::detail::round_to_nearest_even(lut_output_ptr[i]);
}
});
}
void activation_lut_bf16(torch::Tensor& out, torch::Tensor& input,
const uint16_t* lut, const char* op_name) {
TORCH_CHECK(input.scalar_type() == at::kBFloat16, op_name,
": input must be bfloat16");
TORCH_CHECK(out.scalar_type() == at::kBFloat16, op_name,
": out must be bfloat16");
TORCH_CHECK(input.is_contiguous(), op_name, ": input must be contiguous");
TORCH_CHECK(out.is_contiguous(), op_name, ": out must be contiguous");
const auto* src =
reinterpret_cast<const uint16_t*>(input.data_ptr<at::BFloat16>());
auto* dst = reinterpret_cast<uint16_t*>(out.data_ptr<at::BFloat16>());
const int64_t n = input.numel();
CPU_KERNEL_GUARD_IN(activation_lut_bf16_impl)
#pragma omp parallel for
for (int64_t i = 0; i < n; ++i) {
dst[i] = lut[src[i]];
}
CPU_KERNEL_GUARD_OUT(activation_lut_bf16_impl)
}
void activation_lut_bf16(torch::Tensor& out, torch::Tensor& input,
const std::string& activation) {
if (activation == "gelu") {
static std::array<uint16_t, ActivationLutSize> lut{};
static std::once_flag once;
maybe_init_activation_lut_bf16(lut.data(), once, gelu_reference);
activation_lut_bf16(out, input, lut.data(), "gelu_lut");
return;
}
TORCH_CHECK(false, "Unsupported activation: ", activation);
}
......@@ -85,6 +85,9 @@ at::Tensor int4_scaled_mm_cpu(at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros,
at::Tensor& w_scales,
std::optional<at::Tensor> bias);
void activation_lut_bf16(torch::Tensor& out, torch::Tensor& input,
const std::string& activation);
torch::Tensor get_scheduler_metadata(
const int64_t num_req, const int64_t num_heads_q,
const int64_t num_heads_kv, const int64_t head_dim,
......@@ -231,6 +234,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_quick", torch::kCPU, &gelu_quick);
#if (defined(__aarch64__) && !defined(__APPLE__))
ops.def(
"activation_lut_bf16(Tensor! out, Tensor input, str activation)"
" -> ()");
ops.impl("activation_lut_bf16", torch::kCPU, &activation_lut_bf16);
#endif // (defined(__aarch64__) && !defined(__APPLE__))
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
......
#pragma once
#include <optional>
#include <string>
#include <torch/library.h>
#include <tuple>
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from tests.kernels.utils import opcheck
from vllm.platforms import CpuArchEnum, current_platform
from vllm.utils.torch_utils import set_random_seed
if not current_platform.is_cpu():
pytest.skip("skipping CPU-only tests", allow_module_level=True)
from vllm.model_executor.layers.activation import (
GELU,
FastGELU,
GeluAndMul,
NewGELU,
QuickGELU,
SiluAndMul,
)
DTYPES = [torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 83]
D = [512, 2048]
SEEDS = [0]
@pytest.mark.parametrize(
("activation_cls", "fn"),
[
(SiluAndMul, torch.ops._C.silu_and_mul),
(GeluAndMul, torch.ops._C.gelu_and_mul),
(GeluAndMul, torch.ops._C.gelu_tanh_and_mul),
],
)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_cpu_act_and_mul(
default_vllm_config,
activation_cls: type[torch.nn.Module],
fn: object,
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
) -> None:
set_random_seed(seed)
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
layer = activation_cls()
out = layer(x)
ref_out = layer.forward_native(x)
torch.testing.assert_close(
out, ref_out, atol=get_default_atol(out), rtol=get_default_rtol(out)
)
output_shape = x.shape[:-1] + (x.shape[-1] // 2,)
raw_out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
opcheck(fn, (raw_out, x))
@pytest.mark.parametrize(
("activation_cls", "fn", "op_args"),
[
(NewGELU, torch.ops._C.gelu_new, ()),
(FastGELU, torch.ops._C.gelu_fast, ()),
(QuickGELU, torch.ops._C.gelu_quick, ()),
pytest.param(
GELU,
getattr(torch.ops._C, "activation_lut_bf16", None),
("gelu",),
marks=pytest.mark.skipif(
current_platform.get_cpu_architecture() != CpuArchEnum.ARM,
reason="activation_lut_bf16 is only built on Arm CPU",
),
),
],
)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_cpu_unary_activation(
default_vllm_config,
activation_cls: type[torch.nn.Module],
fn: object,
op_args: tuple[str, ...],
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
) -> None:
set_random_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype)
layer = activation_cls()
out = layer(x)
ref_out = layer.forward_native(x)
torch.testing.assert_close(
out, ref_out, atol=get_default_atol(out), rtol=get_default_rtol(out)
)
# gelu with activation_lut_bf16 only makes sense for BF16
if not (activation_cls is GELU and dtype != torch.bfloat16):
raw_out = torch.empty_like(x)
opcheck(fn, (raw_out, x, *op_args))
......@@ -3300,6 +3300,12 @@ def cpu_gemm_wna16(
return output
def cpu_activation_lut_bf16(input: torch.Tensor, activation: str) -> torch.Tensor:
out = torch.empty_like(input)
torch.ops._C.activation_lut_bf16(out, input, activation)
return out
def cpu_prepack_moe_weight(
weight: torch.Tensor,
isa: str,
......
......@@ -16,7 +16,7 @@ from vllm.distributed import (
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms import CpuArchEnum, current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.collection_utils import LazyDict
......@@ -247,6 +247,34 @@ class GeluAndMulSparse(CustomOp):
return self.forward_native(x)
# --8<-- [start:gelu]
@CustomOp.register("gelu")
class GELU(CustomOp):
# --8<-- [end:gelu]
def __init__(self):
super().__init__()
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM and hasattr(
torch.ops._C, "activation_lut_bf16"
):
self.op = torch.ops._C.activation_lut_bf16
else:
self.op = None
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
return F.gelu(x, approximate="none")
def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
if self.op and x.dtype == torch.bfloat16 and x.is_contiguous():
out = torch.empty_like(x)
self.op(out, x, "gelu")
return out
return self.forward_native(x)
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_native(x)
# --8<-- [start:gelu_and_mul]
@CustomOp.register("gelu_and_mul")
class GeluAndMul(CustomOp):
......@@ -635,7 +663,7 @@ class ScaledActivation(nn.Module):
_ACTIVATION_REGISTRY = LazyDict(
{
"gelu": lambda: nn.GELU(),
"gelu": lambda: GELU(),
"gelu_fast": lambda: FastGELU(),
"gelu_new": lambda: NewGELU(),
"gelu_pytorch_tanh": lambda: (
......
......@@ -246,6 +246,13 @@ class CpuPlatform(Platform):
if vllm_config.lora_config is not None:
compilation_config.mode = CompilationMode.NONE
if (
cls.get_cpu_architecture() == CpuArchEnum.ARM
and "+gelu" not in compilation_config.custom_ops
and "-gelu" not in compilation_config.custom_ops
):
compilation_config.custom_ops.append("+gelu")
vllm_config.profiler_config.torch_profiler_dump_cuda_time_total = False
assert vllm_config.device_config.device_type == "cpu"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment