Commit b8160878 authored by Yongye Zhu's avatar Yongye Zhu Committed by khluu
Browse files

[DSV4] Add silu clamp limit to shared expert (#40950)


Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
(cherry picked from commit 706a04d34ba64ea23d430d5e50038791aacfae96)
parent 84c276d7
......@@ -11,29 +11,74 @@
namespace vllm {
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
bool act_first>
bool act_first, bool HAS_CLAMP>
__device__ __forceinline__ scalar_t compute(const scalar_t& x,
const scalar_t& y) {
return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
const scalar_t& y,
const float limit) {
if constexpr (act_first) {
scalar_t gate = x;
scalar_t up = y;
if constexpr (HAS_CLAMP) {
gate = (scalar_t)fminf((float)gate, limit);
up = (scalar_t)fmaxf(fminf((float)up, limit), -limit);
}
return ACT_FN(gate) * up;
} else {
scalar_t gate = x;
scalar_t up = y;
if constexpr (HAS_CLAMP) {
gate = (scalar_t)fmaxf(fminf((float)gate, limit), -limit);
up = (scalar_t)fminf((float)up, limit);
}
return gate * ACT_FN(up);
}
}
template <typename packed_t, packed_t (*PACKED_ACT_FN)(const packed_t&),
bool act_first>
bool act_first, bool HAS_CLAMP>
__device__ __forceinline__ packed_t packed_compute(const packed_t& x,
const packed_t& y) {
return act_first ? packed_mul(PACKED_ACT_FN(x), y)
: packed_mul(x, PACKED_ACT_FN(y));
const packed_t& y,
const float limit) {
if constexpr (act_first) {
packed_t gate = x;
packed_t up = y;
if constexpr (HAS_CLAMP) {
float2 g = cast_to_float2(gate);
float2 u = cast_to_float2(up);
g.x = fminf(g.x, limit);
g.y = fminf(g.y, limit);
u.x = fmaxf(fminf(u.x, limit), -limit);
u.y = fmaxf(fminf(u.y, limit), -limit);
gate = cast_to_packed<packed_t>(g);
up = cast_to_packed<packed_t>(u);
}
return packed_mul(PACKED_ACT_FN(gate), up);
} else {
packed_t gate = x;
packed_t up = y;
if constexpr (HAS_CLAMP) {
float2 g = cast_to_float2(gate);
float2 u = cast_to_float2(up);
g.x = fmaxf(fminf(g.x, limit), -limit);
g.y = fmaxf(fminf(g.y, limit), -limit);
u.x = fminf(u.x, limit);
u.y = fminf(u.y, limit);
gate = cast_to_packed<packed_t>(g);
up = cast_to_packed<packed_t>(u);
}
return packed_mul(gate, PACKED_ACT_FN(up));
}
}
// Activation and gating kernel template.
template <typename scalar_t, typename packed_t,
scalar_t (*ACT_FN)(const scalar_t&),
packed_t (*PACKED_ACT_FN)(const packed_t&), bool act_first,
bool use_vec, bool use_256b = false>
bool use_vec, bool HAS_CLAMP, bool use_256b = false>
__global__ void act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
const int d, const float limit) {
const scalar_t* x_ptr = input + blockIdx.x * 2 * d;
const scalar_t* y_ptr = x_ptr + d;
scalar_t* out_ptr = out + blockIdx.x * d;
......@@ -58,8 +103,9 @@ __global__ void act_and_mul_kernel(
}
#pragma unroll
for (int j = 0; j < pvec_t::NUM_ELTS; j++) {
x.elts[j] = packed_compute<packed_t, PACKED_ACT_FN, act_first>(
x.elts[j], y.elts[j]);
x.elts[j] =
packed_compute<packed_t, PACKED_ACT_FN, act_first, HAS_CLAMP>(
x.elts[j], y.elts[j], limit);
}
if constexpr (use_256b) {
st256(x, &out_vec[i]);
......@@ -72,7 +118,8 @@ __global__ void act_and_mul_kernel(
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
out_ptr[idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
out_ptr[idx] =
compute<scalar_t, ACT_FN, act_first, HAS_CLAMP>(x, y, limit);
}
}
}
......@@ -151,8 +198,11 @@ packed_gelu_tanh_kernel(const packed_t& val) {
// Launch activation and gating kernel.
// Use ACT_FIRST (bool) indicating whether to apply the activation function
// first.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST) \
// first. HAS_CLAMP (bool) enables pre-activation clamping: gate input is
// clamped (max only) and up input is clamped (both sides) before the
// activation function is applied.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST, \
HAS_CLAMP, LIMIT) \
auto dtype = input.scalar_type(); \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
......@@ -177,8 +227,8 @@ packed_gelu_tanh_kernel(const packed_t& val) {
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
ACT_FIRST, true, true><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
ACT_FIRST, true, HAS_CLAMP, true><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, LIMIT); \
}); \
} else { \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
......@@ -186,8 +236,8 @@ packed_gelu_tanh_kernel(const packed_t& val) {
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
ACT_FIRST, true, false><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
ACT_FIRST, true, HAS_CLAMP, false><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, LIMIT); \
}); \
} \
} else { \
......@@ -197,8 +247,8 @@ packed_gelu_tanh_kernel(const packed_t& val) {
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
ACT_FIRST, false><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
ACT_FIRST, false, HAS_CLAMP><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, LIMIT); \
}); \
}
......@@ -206,7 +256,14 @@ void silu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel,
true);
true, false, 0.0f);
}
void silu_and_mul_clamp(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., 2 * d]
double limit) {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel,
true, true, (float)limit);
}
void mul_and_silu(torch::Tensor& out, // [..., d]
......@@ -215,21 +272,21 @@ void mul_and_silu(torch::Tensor& out, // [..., d]
// The difference between mul_and_silu and silu_and_mul is that mul_and_silu
// applies the silu to the latter half of the input.
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel,
false);
false, false, 0.0f);
}
void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, vllm::packed_gelu_kernel,
true);
true, false, 0.0f);
}
void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel,
vllm::packed_gelu_tanh_kernel, true);
LAUNCH_ACTIVATION_GATE_KERNEL(
vllm::gelu_tanh_kernel, vllm::packed_gelu_tanh_kernel, true, false, 0.0f);
}
namespace vllm {
......
......@@ -163,6 +163,8 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
void silu_and_mul_clamp(torch::Tensor& out, torch::Tensor& input, double limit);
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale);
......
......@@ -106,6 +106,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()");
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
// SwiGLU activation with input clamping.
ops.def(
"silu_and_mul_with_clamp(Tensor! result, Tensor input, float limit) "
"-> ()");
ops.impl("silu_and_mul_with_clamp", torch::kCUDA, &silu_and_mul_clamp);
ops.def(
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
......
......@@ -16,6 +16,7 @@ from vllm.model_executor.layers.activation import (
NewGELU,
QuickGELU,
SiluAndMul,
SiluAndMulWithClamp,
SwigluOAIAndMul,
SwigluStepAndMul,
swiglustep_and_mul_triton,
......@@ -116,6 +117,85 @@ def test_act_and_mul(
opcheck(fn, (out, x))
SWIGLU_LIMITS = [3.0, 7.0, 15.0]
@pytest.mark.parametrize("swiglu_limit", SWIGLU_LIMITS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_silu_and_mul_with_clamp(
default_vllm_config,
swiglu_limit: float,
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
"""SiluAndMulWithClamp: cuda kernel must match native reference."""
set_random_seed(seed)
torch.set_default_device(device)
# Use large values to ensure clamping is exercised.
x = torch.randn(num_tokens, 2 * d, dtype=dtype) * swiglu_limit * 2
layer = SiluAndMulWithClamp(swiglu_limit, compile_native=False)
out = layer(x)
ref_out = layer.forward_native(x)
rtol = {
torch.float16: 2e-3,
torch.bfloat16: 2e-2,
torch.float: 1.3e-6,
}
torch.testing.assert_close(
out, ref_out, atol=get_default_atol(out), rtol=rtol[out.dtype]
)
# Verify clamping is actually being applied: the clamped output should
# differ from the unclamped SiluAndMul output when inputs are large.
unclamped_out = SiluAndMul.forward_native(x)
assert not torch.equal(ref_out.float(), unclamped_out.float()), (
"Input was not large enough to exercise the clamp; increase scale"
)
# Verify gate clamping semantics with a controlled scalar case.
# gate=large_val is clamped to limit first, then silu(limit) * 1.0.
x_gate = torch.tensor(
[[swiglu_limit * 20.0, 1.0]], dtype=torch.float32, device=device
)
out_gate = SiluAndMulWithClamp(swiglu_limit, compile_native=False)(x_gate)
expected_gate = torch.nn.functional.silu(
torch.tensor(swiglu_limit, dtype=torch.float32)
).item()
torch.testing.assert_close(
out_gate,
torch.tensor([[expected_gate]], dtype=torch.float32, device=device),
atol=1e-3,
rtol=1e-3,
)
# Verify up clamping semantics: up >> limit gets clamped to limit.
x_up = torch.tensor(
[[1.0, swiglu_limit * 20.0]], dtype=torch.float32, device=device
)
out_up = SiluAndMulWithClamp(swiglu_limit, compile_native=False)(x_up)
silu_1 = torch.nn.functional.silu(torch.tensor(1.0)).item()
torch.testing.assert_close(
out_up,
torch.tensor([[silu_1 * swiglu_limit]], dtype=torch.float32, device=device),
atol=1e-3,
rtol=1e-3,
)
# opcheck
out_buf = torch.empty(x.shape[:-1] + (d,), dtype=dtype, device=device)
opcheck(torch.ops._C.silu_and_mul_with_clamp, (out_buf, x, swiglu_limit))
@pytest.mark.parametrize(
"activation",
[
......
......@@ -151,6 +151,46 @@ class SiluAndMul(CustomOp):
return self.forward_cuda(x)
@CustomOp.register("silu_and_mul_with_clamp")
class SiluAndMulWithClamp(CustomOp):
"""SwiGLU activation with input clamping (used by some MoE shared experts).
Computes:
gate = clamp(x[..., :d], max=swiglu_limit)
up = clamp(x[..., d:], min=-swiglu_limit, max=swiglu_limit)
out = silu(gate) * up
where d = x.shape[-1] // 2.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def __init__(self, swiglu_limit: float, *, compile_native: bool = True):
super().__init__(compile_native=compile_native)
self.swiglu_limit = float(swiglu_limit)
if current_platform.is_cuda_alike() or current_platform.is_xpu():
self.op = torch.ops._C.silu_and_mul_with_clamp
elif current_platform.is_cpu():
self._forward_method = self.forward_native
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
gate = torch.clamp(x[..., :d], max=self.swiglu_limit)
up = torch.clamp(x[..., d:], min=-self.swiglu_limit, max=self.swiglu_limit)
return F.silu(gate) * up
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
self.op(out, x, self.swiglu_limit)
return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_cuda(x)
# --8<-- [start:mul_and_silu]
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
......
......@@ -45,7 +45,7 @@ def _gelu_and_mul(
# Uses static methods or standalone functions to avoid instantiating CustomOp
# classes, which would call get_current_vllm_config() before config is set.
_CPU_MOE_ACT_FN: dict[MoEActivation, Callable[[torch.Tensor], torch.Tensor]] = {
MoEActivation.SILU: SiluAndMul.forward_native,
MoEActivation.SILU: lambda x: SiluAndMul(compile_native=False).forward_native(x),
MoEActivation.SWIGLUOAI: _swigluoai_forward_native,
MoEActivation.GELU: _gelu_and_mul,
}
......
......@@ -17,6 +17,7 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul, SiluAndMulWithClamp
from vllm.model_executor.layers.deepseek_v4_attention import (
DeepseekV4Indexer,
DeepseekV4MLAModules,
......@@ -34,7 +35,10 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization import (
QuantizationConfig,
QuantizationMethods,
)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4MoEMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import (
......@@ -46,7 +50,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.deepseek_v2 import DeepseekV2MLP
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
......@@ -63,6 +66,57 @@ from .utils import (
)
class DeepseekV4MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
swiglu_limit: float | None = None,
quant_config: QuantizationConfig | None = None,
reduce_results: bool = True,
is_sequence_parallel: bool = False,
prefix: str = "",
) -> None:
super().__init__()
# If is_sequence_parallel, the input and output tensors are sharded
# across the ranks within the tp_group. In this case the weights are
# replicated and no collective ops are needed.
# Otherwise we use standard TP with an allreduce at the end.
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
if swiglu_limit is not None:
self.act_fn = SiluAndMulWithClamp(swiglu_limit)
else:
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class DeepseekV4FP8Config(Fp8Config):
"""FP8 config that routes MoE layers to MXFP4 quantization.
......@@ -672,10 +726,11 @@ class DeepseekV4MoE(nn.Module):
else:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV2MLP(
self.shared_experts = DeepseekV4MLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
swiglu_limit=self.swiglu_limit,
quant_config=quant_config,
reduce_results=self.use_mega_moe,
prefix=f"{prefix}.shared_experts",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment