Unverified Commit 3289da5b authored by Fan Yin's avatar Fan Yin Committed by GitHub
Browse files

[sgl-kernel] support hadamard (#11663)

parent 868403f6
......@@ -62,7 +62,7 @@ fi
$PIP_CMD list
# Install additional dependencies
$PIP_CMD install mooncake-transfer-engine==0.3.6.post1 nvidia-cuda-nvrtc-cu12 py-spy huggingface_hub[hf_xet] $PIP_INSTALL_SUFFIX
$PIP_CMD install mooncake-transfer-engine==0.3.6.post1 nvidia-cuda-nvrtc-cu12 py-spy scipy huggingface_hub[hf_xet] $PIP_INSTALL_SUFFIX
if [ "$IS_BLACKWELL" != "1" ]; then
# For lmms_evals evaluating MMMU
......
......@@ -60,6 +60,7 @@ FetchContent_Declare(
)
FetchContent_Populate(repo-deepgemm)
# fmt
FetchContent_Declare(
repo-fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt
......@@ -113,6 +114,15 @@ FetchContent_Declare(
)
FetchContent_Populate(repo-mscclpp)
# fast-hadamard-transform
FetchContent_Declare(
repo-fast-hadamard-transform
GIT_REPOSITORY https://github.com/sgl-project/fast-hadamard-transform.git
GIT_TAG 48f3c13764dc2ec662ade842a4696a90a137f1bc
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-fast-hadamard-transform)
# ccache option
option(ENABLE_CCACHE "Whether to use ccache" ON)
find_program(CCACHE_FOUND ccache)
......@@ -138,6 +148,7 @@ include_directories(
${repo-flashinfer_SOURCE_DIR}/include
${repo-flashinfer_SOURCE_DIR}/csrc
${repo-mscclpp_SOURCE_DIR}/include
${repo-fast-hadamard-transform}/csrc
)
set(SGL_KERNEL_CUDA_FLAGS
......@@ -329,6 +340,9 @@ set(SOURCES
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
"${repo-fast-hadamard-transform_SOURCE_DIR}/csrc/fast_hadamard_transform_cuda.cu"
"${repo-fast-hadamard-transform_SOURCE_DIR}/csrc/fast_hadamard_transform.cpp"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu"
......
......@@ -540,6 +540,24 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"stride_a, Tensor stride_b, Tensor stride_d, Tensor problem_sizes, Tensor expert_offsets, Tensor workspace) -> "
"()");
m.impl("es_fp8_blockwise_scaled_grouped_mm", &es_fp8_blockwise_scaled_grouped_mm);
/*
* From hadamard-transform
*/
m.def("fast_hadamard_transform(Tensor x, float scale) -> Tensor");
m.impl("fast_hadamard_transform", torch::kCUDA, &fast_hadamard_transform);
m.def("fast_hadamard_transform_12N(Tensor x, float scale) -> Tensor");
m.impl("fast_hadamard_transform_12N", torch::kCUDA, &fast_hadamard_transform_12N);
m.def("fast_hadamard_transform_20N(Tensor x, float scale) -> Tensor");
m.impl("fast_hadamard_transform_20N", torch::kCUDA, &fast_hadamard_transform_20N);
m.def("fast_hadamard_transform_28N(Tensor x, float scale) -> Tensor");
m.impl("fast_hadamard_transform_28N", torch::kCUDA, &fast_hadamard_transform_28N);
m.def("fast_hadamard_transform_40N(Tensor x, float scale) -> Tensor");
m.impl("fast_hadamard_transform_40N", torch::kCUDA, &fast_hadamard_transform_40N);
}
REGISTER_EXTENSION(common_ops)
......@@ -837,3 +837,11 @@ void es_fp8_blockwise_scaled_grouped_mm(
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& workspace);
/*
* From fast-hadamard-transform
*/
torch::Tensor fast_hadamard_transform(torch::Tensor& x, double scale);
torch::Tensor fast_hadamard_transform_12N(torch::Tensor& x, double scale);
torch::Tensor fast_hadamard_transform_20N(torch::Tensor& x, double scale);
torch::Tensor fast_hadamard_transform_28N(torch::Tensor& x, double scale);
torch::Tensor fast_hadamard_transform_40N(torch::Tensor& x, double scale);
......@@ -270,6 +270,13 @@ from sgl_kernel.gemm import (
silu_and_mul_scaled_fp4_grouped_quant,
)
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
from sgl_kernel.hadamard import (
hadamard_transform,
hadamard_transform_12n,
hadamard_transform_20n,
hadamard_transform_28n,
hadamard_transform_40n,
)
from sgl_kernel.kvcacheio import (
transfer_kv_all_layer,
transfer_kv_all_layer_mla,
......
import torch
def hadamard_transform(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
return torch.ops.sgl_kernel.fast_hadamard_transform.default(x, scale)
def hadamard_transform_12n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
return torch.ops.sgl_kernel.fast_hadamard_transform_12N.default(x, scale)
def hadamard_transform_20n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
return torch.ops.sgl_kernel.fast_hadamard_transform_20N.default(x, scale)
def hadamard_transform_28n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
return torch.ops.sgl_kernel.fast_hadamard_transform_28N.default(x, scale)
def hadamard_transform_40n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
return torch.ops.sgl_kernel.fast_hadamard_transform_40N.default(x, scale)
import math
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from scipy.linalg import hadamard
from sgl_kernel import hadamard_transform
def hadamard_transform_ref(x, scale=1.0):
"""
x: (..., dim)
out: (..., dim)
"""
if hadamard is None:
raise ImportError("Please install scipy")
x_shape = x.shape
dim = x.shape[-1]
x = x.reshape(-1, dim)
log_dim = math.ceil(math.log2(dim))
dim_padded = 2**log_dim
if dim != dim_padded:
x = F.pad(x, (0, dim_padded - dim))
out = F.linear(
x,
torch.tensor(hadamard(dim_padded, dtype=float), dtype=x.dtype, device=x.device),
)
out = out * scale
return out[..., :dim].reshape(*x_shape)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
"dim",
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 137, 1024, 2048, 4096, 8192, 16384, 32768],
)
def test_fast_hadamard_transform(dim, dtype):
device = "cuda"
if dtype == torch.float32:
rtol, atol = 3e-4, 3e-3
elif dtype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
else: # float16
rtol, atol = 3e-3, 5e-3
torch.random.manual_seed(0)
batch_size = 15
x = torch.randn(batch_size, dim, device=device, dtype=dtype)
x_ref = x.detach().clone().to(torch.float32)
x_pt = x.detach().clone()
scale = 1 / math.sqrt(dim)
out = hadamard_transform(x, scale=scale)
out_ref = hadamard_transform_ref(x_ref, scale=scale)
out_pt = hadamard_transform_ref(x_pt, scale=scale)
torch.testing.assert_close(
out_pt.float(),
out_ref,
rtol=rtol,
atol=atol,
msg="Reference implementations mismatch",
)
torch.testing.assert_close(
out.float(),
out_ref,
rtol=rtol,
atol=atol,
msg="fast_hadamard_transform output mismatch",
)
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment