Commit 251235c2 authored by maxiao1's avatar maxiao1
Browse files

适配v0.5.4

parent 1053e1be
...@@ -130,21 +130,45 @@ class RMSNorm(CustomOp): ...@@ -130,21 +130,45 @@ class RMSNorm(CustomOp):
return output, residual_out return output, residual_out
return rms_norm(x, self.weight.data, self.variance_epsilon) return rms_norm(x, self.weight.data, self.variance_epsilon)
# def forward_hip(
# self,
# x: torch.Tensor,
# residual: Optional[torch.Tensor] = None,
# ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# if not x.is_contiguous():
# # NOTE: Remove this if aiter kernel supports discontinuous input
# x = x.contiguous()
# if residual is not None:
# if _vllm_version < Version("0.9"):
# fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
# return x, residual
# else:
# residual_out = torch.empty_like(x)
# output = torch.empty_like(x)
# fused_add_rms_norm(
# output,
# x,
# residual_out,
# residual,
# self.weight.data,
# self.variance_epsilon,
# )
# return output, residual_out
# out = torch.empty_like(x)
# rms_norm(out, x, self.weight.data, self.variance_epsilon)
# return out
def forward_hip( def forward_hip(
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ):
if not x.is_contiguous(): if not x.is_contiguous():
# NOTE: Remove this if aiter kernel supports discontinuous input
x = x.contiguous() x = x.contiguous()
if residual is not None: if residual is not None:
if _vllm_version < Version("0.9"): try:
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
return x, residual
else:
residual_out = torch.empty_like(x)
output = torch.empty_like(x) output = torch.empty_like(x)
residual_out = torch.empty_like(x)
fused_add_rms_norm( fused_add_rms_norm(
output, output,
x, x,
...@@ -154,10 +178,20 @@ class RMSNorm(CustomOp): ...@@ -154,10 +178,20 @@ class RMSNorm(CustomOp):
self.variance_epsilon, self.variance_epsilon,
) )
return output, residual_out return output, residual_out
except TypeError:
fused_add_rms_norm(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
return x, residual
out = torch.empty_like(x) out = torch.empty_like(x)
rms_norm(out, x, self.weight.data, self.variance_epsilon) rms_norm(out, x, self.weight.data, self.variance_epsilon)
return out return out
def forward_native( def forward_native(
self, self,
x: torch.Tensor, x: torch.Tensor,
......
...@@ -61,7 +61,7 @@ def inplace_fused_experts( ...@@ -61,7 +61,7 @@ def inplace_fused_experts(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None, b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None, b2: Optional[torch.Tensor] = None,
activation: str = "silu", activation: int = 0,#0 silu 1 gelu
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
...@@ -79,6 +79,8 @@ def inplace_fused_experts( ...@@ -79,6 +79,8 @@ def inplace_fused_experts(
gemm1_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
) -> None: ) -> None:
if isinstance(activation, int):
activation = "silu" if activation == 0 else "gelu"
fused_experts_impl( fused_experts_impl(
hidden_states, hidden_states,
w1, w1,
...@@ -117,7 +119,7 @@ def inplace_fused_experts_fake( ...@@ -117,7 +119,7 @@ def inplace_fused_experts_fake(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None, b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None, b2: Optional[torch.Tensor] = None,
activation: str = "silu", activation: int = 0,#0 silu 1 gelu
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
...@@ -154,7 +156,7 @@ def outplace_fused_experts( ...@@ -154,7 +156,7 @@ def outplace_fused_experts(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None, b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None, b2: Optional[torch.Tensor] = None,
activation: str = "silu", activation: int = 0,#0 silu 1 gelu
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
...@@ -173,6 +175,8 @@ def outplace_fused_experts( ...@@ -173,6 +175,8 @@ def outplace_fused_experts(
gemm1_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if isinstance(activation, int):
activation = "silu" if activation == 0 else "gelu"
return fused_experts_impl( return fused_experts_impl(
hidden_states, hidden_states,
w1, w1,
...@@ -211,7 +215,7 @@ def outplace_fused_experts_fake( ...@@ -211,7 +215,7 @@ def outplace_fused_experts_fake(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None, b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None, b2: Optional[torch.Tensor] = None,
activation: str = "silu", activation: int = 0,#0 silu 1 gelu
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
...@@ -263,6 +267,13 @@ def fused_experts( ...@@ -263,6 +267,13 @@ def fused_experts(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
): ):
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
act_id = (
0 if (
moe_runner_config.activation == 0
or (isinstance(moe_runner_config.activation, str)
and moe_runner_config.activation.lower() == "silu")
) else 1
)
if moe_runner_config.inplace: if moe_runner_config.inplace:
assert not moe_runner_config.no_combine, "no combine + inplace makes no sense" assert not moe_runner_config.no_combine, "no combine + inplace makes no sense"
torch.ops.sglang.inplace_fused_experts( torch.ops.sglang.inplace_fused_experts(
...@@ -273,7 +284,7 @@ def fused_experts( ...@@ -273,7 +284,7 @@ def fused_experts(
topk_ids, topk_ids,
b1, b1,
b2, b2,
moe_runner_config.activation, act_id,
moe_runner_config.apply_router_weight_on_input, moe_runner_config.apply_router_weight_on_input,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a8,
...@@ -301,7 +312,7 @@ def fused_experts( ...@@ -301,7 +312,7 @@ def fused_experts(
topk_ids, topk_ids,
b1, b1,
b2, b2,
moe_runner_config.activation, act_id,
moe_runner_config.apply_router_weight_on_input, moe_runner_config.apply_router_weight_on_input,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a8,
...@@ -345,7 +356,7 @@ def fused_experts_impl( ...@@ -345,7 +356,7 @@ def fused_experts_impl(
b1: Optional[torch.Tensor] = None, b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None, b2: Optional[torch.Tensor] = None,
inplace: bool = False, inplace: bool = False,
activation: str = "silu", activation: int = 0,#0 silu 1 gelu
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
...@@ -364,6 +375,8 @@ def fused_experts_impl( ...@@ -364,6 +375,8 @@ def fused_experts_impl(
gemm1_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
): ):
if isinstance(activation, int):
activation = "silu" if activation == 0 else "gelu"
padded_size = padding_size padded_size = padding_size
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter: if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
padded_size = 0 padded_size = 0
......
...@@ -564,7 +564,7 @@ class ModelRunner: ...@@ -564,7 +564,7 @@ class ModelRunner:
): ):
server_args.attention_backend = "fa3" server_args.attention_backend = "fa3"
elif _is_hip: elif _is_hip:
server_args.attention_backend = "aiter" server_args.attention_backend = "triton"
elif _is_npu: elif _is_npu:
server_args.attention_backend = "ascend" server_args.attention_backend = "ascend"
else: else:
...@@ -581,7 +581,7 @@ class ModelRunner: ...@@ -581,7 +581,7 @@ class ModelRunner:
head_num = self.model_config.get_num_kv_heads(self.tp_size) head_num = self.model_config.get_num_kv_heads(self.tp_size)
# TODO current aiter only support head number 16 or 128 head number # TODO current aiter only support head number 16 or 128 head number
if head_num == 128 or head_num == 16: if head_num == 128 or head_num == 16:
server_args.attention_backend = "aiter" server_args.attention_backend = "triton"
else: else:
server_args.attention_backend = "triton" server_args.attention_backend = "triton"
elif _is_npu: elif _is_npu:
......
...@@ -165,10 +165,10 @@ DINLINE void start_sync( ...@@ -165,10 +165,10 @@ DINLINE void start_sync(
if (threadIdx.x < ngpus) { if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks. // simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write // Latency = 1 p2p write
__scoped_atomic_store_n( __hip_atomic_store(
&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM); &sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM);
// wait until we got true from all ranks // wait until we got true from all ranks
while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) < while (__hip_atomic_load(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT) <
flag) flag)
; ;
} }
...@@ -211,16 +211,16 @@ DINLINE void end_sync( ...@@ -211,16 +211,16 @@ DINLINE void end_sync(
if (threadIdx.x < ngpus) { if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks. // simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write // Latency = 1 p2p write
__scoped_atomic_store_n( __hip_atomic_store(
&sg.signals[threadIdx.x]->end[blockIdx.x][rank], &sg.signals[threadIdx.x]->end[blockIdx.x][rank],
flag, flag,
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE, final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
__MEMORY_SCOPE_SYSTEM); __HIP_MEMORY_SCOPE_SYSTEM);
// wait until we got true from all ranks // wait until we got true from all ranks
while (__scoped_atomic_load_n( while (__hip_atomic_load(
&self_sg->end[blockIdx.x][threadIdx.x], &self_sg->end[blockIdx.x][threadIdx.x],
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE, final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
__MEMORY_SCOPE_DEVICE) < flag) __HIP_MEMORY_SCOPE_AGENT) < flag)
; ;
} }
__syncthreads(); __syncthreads();
......
...@@ -21,6 +21,7 @@ limitations under the License. ...@@ -21,6 +21,7 @@ limitations under the License.
#include "utils.h" #include "utils.h"
#define WARP_SIZE 64
#define VEC_SIZE 4 #define VEC_SIZE 4
using Vec = int4; using Vec = int4;
...@@ -45,7 +46,7 @@ __device__ __forceinline__ int warp_exclusive_scan(int v, unsigned mask = 0xffff ...@@ -45,7 +46,7 @@ __device__ __forceinline__ int warp_exclusive_scan(int v, unsigned mask = 0xffff
int original = v; int original = v;
#pragma unroll #pragma unroll
for (int offset = 1; offset < WARP_SIZE; offset <<= 1) { for (int offset = 1; offset < WARP_SIZE; offset <<= 1) {
int n = __shfl_up_sync(mask, v, offset); int n = __shfl_up(v, offset);
if ((threadIdx.x & (WARP_SIZE - 1)) >= offset) v += n; if ((threadIdx.x & (WARP_SIZE - 1)) >= offset) v += n;
} }
return v - original; return v - original;
......
...@@ -60,7 +60,7 @@ template <typename T> ...@@ -60,7 +60,7 @@ template <typename T>
__device__ float convert_to_float(T x) { __device__ float convert_to_float(T x) {
if constexpr (std::is_same_v<T, __half>) { if constexpr (std::is_same_v<T, __half>) {
return __half2float(x); return __half2float(x);
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) { } else if constexpr (std::is_same_v<T, __hip_bfloat16>) {
return __bfloat162float(x); return __bfloat162float(x);
} else if constexpr (std::is_same_v<T, float>) { } else if constexpr (std::is_same_v<T, float>) {
return x; return x;
...@@ -575,8 +575,8 @@ void topk_softmax( ...@@ -575,8 +575,8 @@ void topk_softmax(
renormalize, renormalize,
stream); stream);
} else if (dtype == at::ScalarType::BFloat16) { } else if (dtype == at::ScalarType::BFloat16) {
topkGatingSoftmaxKernelLauncher<__nv_bfloat16>( topkGatingSoftmaxKernelLauncher<__hip_bfloat16>(
reinterpret_cast<const __nv_bfloat16*>(gating_output.data_ptr<at::BFloat16>()), reinterpret_cast<const __hip_bfloat16*>(gating_output.data_ptr<at::BFloat16>()),
topk_weights.data_ptr<float>(), topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(), topk_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(), softmax_workspace.data_ptr<float>(),
......
...@@ -369,25 +369,25 @@ __device__ __forceinline__ dstDtype castFromFloat(float val) { ...@@ -369,25 +369,25 @@ __device__ __forceinline__ dstDtype castFromFloat(float val) {
#endif #endif
// add FP8 support // add FP8 support
#ifndef USE_ROCM // #ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h> // #include <c10/util/Float8_e4m3fn.h>
using FP8_TYPE = c10::Float8_e4m3fn; // using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max(); // C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
#else // USE_ROCM // #else // USE_ROCM
#if HIP_FP8_TYPE_FNUZ // #if HIP_FP8_TYPE_FNUZ
#include <c10/util/Float8_e4m3fnuz.h> // #include <c10/util/Float8_e4m3fnuz.h>
using FP8_TYPE = c10::Float8_e4m3fnuz; // using FP8_TYPE = c10::Float8_e4m3fnuz;
constexpr auto FP8_E4M3_MAX = 224.0f; // constexpr auto FP8_E4M3_MAX = 224.0f;
#else // #else
#if HIP_FP8_TYPE_E4M3 // #if HIP_FP8_TYPE_E4M3
#include <c10/util/Float8_e4m3fn.h> // #include <c10/util/Float8_e4m3fn.h>
using FP8_TYPE = c10::Float8_e4m3fn; // using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max(); // C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
#else // #else
#error "fp8 is not supported in this processor (arch < gfx942)." // #error "fp8 is not supported in this processor (arch < gfx942)."
#endif // HIP_FP8_TYPE_E4M3 // #endif // HIP_FP8_TYPE_E4M3
#endif // HIP_FP8_TYPE_FNUZ // #endif // HIP_FP8_TYPE_FNUZ
#endif // USE_ROCM // #endif // USE_ROCM
#define FULL_MASK 0xffffffff #define FULL_MASK 0xffffffff
......
# Copyright 2025 SGLang Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import platform
import sys
from pathlib import Path
import torch
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
root = Path(__file__).parent.resolve()
arch = platform.machine().lower()
def _get_version():
with open(root / "pyproject.toml") as f:
for line in f:
if line.startswith("version"):
return line.split("=")[1].strip().strip('"')
operator_namespace = "sgl_kernel"
include_dirs = [
root / "include",
root / "include" / "impl",
root / "csrc",
]
sources = [
"csrc/allreduce/custom_all_reduce.hip",
"csrc/allreduce/quick_all_reduce.cu",
"csrc/common_extension_rocm.cc",
"csrc/elementwise/activation.cu",
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu",
"csrc/moe/moe_align_kernel.cu",
"csrc/moe/moe_topk_softmax_kernels.cu",
"csrc/speculative/eagle_utils.cu",
"csrc/kvcacheio/transfer.cu",
]
cxx_flags = ["-O3", "-w"]
libraries = ["hiprtc", "amdhip64", "c10", "torch", "torch_python"]
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", f"-L/usr/lib/{arch}-linux-gnu"]
default_target = "gfx942"
amdgpu_target = os.environ.get("AMDGPU_TARGET", default_target)
if torch.cuda.is_available():
try:
amdgpu_target = torch.cuda.get_device_properties(0).gcnArchName.split(":")[0]
except Exception as e:
print(f"Warning: Failed to detect GPU properties: {e}")
else:
print(f"Warning: torch.cuda not available. Using default target: {amdgpu_target}")
if amdgpu_target not in ["gfx942", "gfx950", "gfx936"]:
print(
f"Warning: Unsupported GPU architecture detected '{amdgpu_target}'. Expected 'gfx942' or 'gfx950'."
)
sys.exit(1)
fp8_macro = (
"-DHIP_FP8_TYPE_FNUZ" if amdgpu_target == "gfx942" else "-DHIP_FP8_TYPE_E4M3"
)
hipcc_flags = [
"-DNDEBUG",
f"-DOPERATOR_NAMESPACE={operator_namespace}",
"-O3",
"-Xcompiler",
"-fPIC",
"-std=c++17",
f"--amdgpu-target={amdgpu_target}",
"-DENABLE_BF16",
"-DENABLE_FP8",
fp8_macro,
]
ext_modules = [
CUDAExtension(
name="sgl_kernel.common_ops",
sources=sources,
include_dirs=include_dirs,
extra_compile_args={
"nvcc": hipcc_flags,
"cxx": cxx_flags,
},
libraries=libraries,
extra_link_args=extra_link_args,
py_limited_api=False,
),
]
setup(
name="sgl-kernel",
version=_get_version(),
packages=find_packages(where="python"),
package_dir={"": "python"},
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension.with_options(use_ninja=True)},
options={"bdist_wheel": {"py_limited_api": "cp39"}},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment