Unverified Commit 3289da5b authored by Fan Yin's avatar Fan Yin Committed by GitHub
Browse files

[sgl-kernel] support hadamard (#11663)

parent 868403f6
...@@ -62,7 +62,7 @@ fi ...@@ -62,7 +62,7 @@ fi
$PIP_CMD list $PIP_CMD list
# Install additional dependencies # Install additional dependencies
$PIP_CMD install mooncake-transfer-engine==0.3.6.post1 nvidia-cuda-nvrtc-cu12 py-spy huggingface_hub[hf_xet] $PIP_INSTALL_SUFFIX $PIP_CMD install mooncake-transfer-engine==0.3.6.post1 nvidia-cuda-nvrtc-cu12 py-spy scipy huggingface_hub[hf_xet] $PIP_INSTALL_SUFFIX
if [ "$IS_BLACKWELL" != "1" ]; then if [ "$IS_BLACKWELL" != "1" ]; then
# For lmms_evals evaluating MMMU # For lmms_evals evaluating MMMU
......
...@@ -60,6 +60,7 @@ FetchContent_Declare( ...@@ -60,6 +60,7 @@ FetchContent_Declare(
) )
FetchContent_Populate(repo-deepgemm) FetchContent_Populate(repo-deepgemm)
# fmt
FetchContent_Declare( FetchContent_Declare(
repo-fmt repo-fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt GIT_REPOSITORY https://github.com/fmtlib/fmt
...@@ -113,6 +114,15 @@ FetchContent_Declare( ...@@ -113,6 +114,15 @@ FetchContent_Declare(
) )
FetchContent_Populate(repo-mscclpp) FetchContent_Populate(repo-mscclpp)
# fast-hadamard-transform
FetchContent_Declare(
repo-fast-hadamard-transform
GIT_REPOSITORY https://github.com/sgl-project/fast-hadamard-transform.git
GIT_TAG 48f3c13764dc2ec662ade842a4696a90a137f1bc
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-fast-hadamard-transform)
# ccache option # ccache option
option(ENABLE_CCACHE "Whether to use ccache" ON) option(ENABLE_CCACHE "Whether to use ccache" ON)
find_program(CCACHE_FOUND ccache) find_program(CCACHE_FOUND ccache)
...@@ -138,6 +148,7 @@ include_directories( ...@@ -138,6 +148,7 @@ include_directories(
${repo-flashinfer_SOURCE_DIR}/include ${repo-flashinfer_SOURCE_DIR}/include
${repo-flashinfer_SOURCE_DIR}/csrc ${repo-flashinfer_SOURCE_DIR}/csrc
${repo-mscclpp_SOURCE_DIR}/include ${repo-mscclpp_SOURCE_DIR}/include
${repo-fast-hadamard-transform}/csrc
) )
set(SGL_KERNEL_CUDA_FLAGS set(SGL_KERNEL_CUDA_FLAGS
...@@ -329,6 +340,9 @@ set(SOURCES ...@@ -329,6 +340,9 @@ set(SOURCES
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
"${repo-fast-hadamard-transform_SOURCE_DIR}/csrc/fast_hadamard_transform_cuda.cu"
"${repo-fast-hadamard-transform_SOURCE_DIR}/csrc/fast_hadamard_transform.cpp"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu" "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu" "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu" "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu"
......
...@@ -540,6 +540,24 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -540,6 +540,24 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"stride_a, Tensor stride_b, Tensor stride_d, Tensor problem_sizes, Tensor expert_offsets, Tensor workspace) -> " "stride_a, Tensor stride_b, Tensor stride_d, Tensor problem_sizes, Tensor expert_offsets, Tensor workspace) -> "
"()"); "()");
m.impl("es_fp8_blockwise_scaled_grouped_mm", &es_fp8_blockwise_scaled_grouped_mm); m.impl("es_fp8_blockwise_scaled_grouped_mm", &es_fp8_blockwise_scaled_grouped_mm);
/*
* From hadamard-transform
*/
m.def("fast_hadamard_transform(Tensor x, float scale) -> Tensor");
m.impl("fast_hadamard_transform", torch::kCUDA, &fast_hadamard_transform);
m.def("fast_hadamard_transform_12N(Tensor x, float scale) -> Tensor");
m.impl("fast_hadamard_transform_12N", torch::kCUDA, &fast_hadamard_transform_12N);
m.def("fast_hadamard_transform_20N(Tensor x, float scale) -> Tensor");
m.impl("fast_hadamard_transform_20N", torch::kCUDA, &fast_hadamard_transform_20N);
m.def("fast_hadamard_transform_28N(Tensor x, float scale) -> Tensor");
m.impl("fast_hadamard_transform_28N", torch::kCUDA, &fast_hadamard_transform_28N);
m.def("fast_hadamard_transform_40N(Tensor x, float scale) -> Tensor");
m.impl("fast_hadamard_transform_40N", torch::kCUDA, &fast_hadamard_transform_40N);
} }
REGISTER_EXTENSION(common_ops) REGISTER_EXTENSION(common_ops)
...@@ -837,3 +837,11 @@ void es_fp8_blockwise_scaled_grouped_mm( ...@@ -837,3 +837,11 @@ void es_fp8_blockwise_scaled_grouped_mm(
const torch::Tensor& problem_sizes, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& expert_offsets,
const torch::Tensor& workspace); const torch::Tensor& workspace);
/*
* From fast-hadamard-transform
*/
torch::Tensor fast_hadamard_transform(torch::Tensor& x, double scale);
torch::Tensor fast_hadamard_transform_12N(torch::Tensor& x, double scale);
torch::Tensor fast_hadamard_transform_20N(torch::Tensor& x, double scale);
torch::Tensor fast_hadamard_transform_28N(torch::Tensor& x, double scale);
torch::Tensor fast_hadamard_transform_40N(torch::Tensor& x, double scale);
...@@ -270,6 +270,13 @@ from sgl_kernel.gemm import ( ...@@ -270,6 +270,13 @@ from sgl_kernel.gemm import (
silu_and_mul_scaled_fp4_grouped_quant, silu_and_mul_scaled_fp4_grouped_quant,
) )
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
from sgl_kernel.hadamard import (
hadamard_transform,
hadamard_transform_12n,
hadamard_transform_20n,
hadamard_transform_28n,
hadamard_transform_40n,
)
from sgl_kernel.kvcacheio import ( from sgl_kernel.kvcacheio import (
transfer_kv_all_layer, transfer_kv_all_layer,
transfer_kv_all_layer_mla, transfer_kv_all_layer_mla,
......
import torch
def hadamard_transform(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
return torch.ops.sgl_kernel.fast_hadamard_transform.default(x, scale)
def hadamard_transform_12n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
return torch.ops.sgl_kernel.fast_hadamard_transform_12N.default(x, scale)
def hadamard_transform_20n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
return torch.ops.sgl_kernel.fast_hadamard_transform_20N.default(x, scale)
def hadamard_transform_28n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
return torch.ops.sgl_kernel.fast_hadamard_transform_28N.default(x, scale)
def hadamard_transform_40n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
return torch.ops.sgl_kernel.fast_hadamard_transform_40N.default(x, scale)
import math
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from scipy.linalg import hadamard
from sgl_kernel import hadamard_transform
def hadamard_transform_ref(x, scale=1.0):
"""
x: (..., dim)
out: (..., dim)
"""
if hadamard is None:
raise ImportError("Please install scipy")
x_shape = x.shape
dim = x.shape[-1]
x = x.reshape(-1, dim)
log_dim = math.ceil(math.log2(dim))
dim_padded = 2**log_dim
if dim != dim_padded:
x = F.pad(x, (0, dim_padded - dim))
out = F.linear(
x,
torch.tensor(hadamard(dim_padded, dtype=float), dtype=x.dtype, device=x.device),
)
out = out * scale
return out[..., :dim].reshape(*x_shape)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
"dim",
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 137, 1024, 2048, 4096, 8192, 16384, 32768],
)
def test_fast_hadamard_transform(dim, dtype):
device = "cuda"
if dtype == torch.float32:
rtol, atol = 3e-4, 3e-3
elif dtype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
else: # float16
rtol, atol = 3e-3, 5e-3
torch.random.manual_seed(0)
batch_size = 15
x = torch.randn(batch_size, dim, device=device, dtype=dtype)
x_ref = x.detach().clone().to(torch.float32)
x_pt = x.detach().clone()
scale = 1 / math.sqrt(dim)
out = hadamard_transform(x, scale=scale)
out_ref = hadamard_transform_ref(x_ref, scale=scale)
out_pt = hadamard_transform_ref(x_pt, scale=scale)
torch.testing.assert_close(
out_pt.float(),
out_ref,
rtol=rtol,
atol=atol,
msg="Reference implementations mismatch",
)
torch.testing.assert_close(
out.float(),
out_ref,
rtol=rtol,
atol=atol,
msg="fast_hadamard_transform output mismatch",
)
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment