Unverified Commit 93c6fb12 authored by michael-amd's avatar michael-amd Committed by GitHub
Browse files

Fix: deepseek forward absorb (#5723)


Co-authored-by: default avatarispobock <ispobaoke@163.com>
parent 11e27d09
......@@ -20,9 +20,10 @@ import torch
import torch.nn as nn
from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda
from sglang.srt.utils import is_cuda, is_hip
_is_cuda = is_cuda()
_is_hip = is_hip()
if _is_cuda:
from sgl_kernel import (
......@@ -32,6 +33,8 @@ if _is_cuda:
rmsnorm,
)
if _is_hip:
from vllm._custom_ops import fused_add_rms_norm, rms_norm
logger = logging.getLogger(__name__)
......@@ -46,23 +49,49 @@ class RMSNorm(CustomOp):
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, *args, **kwargs):
if torch.compiler.is_compiling():
return self.forward_native(*args, **kwargs)
if _is_cuda:
return self.forward_cuda(*args, **kwargs)
elif _is_hip:
return self.forward_hip(*args, **kwargs)
else:
return self.forward_native(*args, **kwargs)
def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
return x, residual
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
return out
def forward_hip(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not x.is_contiguous():
# NOTE: Romove this if aiter kernel supports discontinuous input
x = x.contiguous()
if residual is not None:
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
return x, residual
out = torch.empty_like(x)
rms_norm(out, x, self.weight.data, self.variance_epsilon)
return out
def forward_native(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not x.is_contiguous():
x = x.contiguous()
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
......@@ -88,6 +117,14 @@ class GemmaRMSNorm(CustomOp):
self.weight = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, *args, **kwargs):
if torch.compiler.is_compiling():
return self.forward_native(*args, **kwargs)
if _is_cuda:
return self.forward_cuda(*args, **kwargs)
else:
return self.forward_native(*args, **kwargs)
def forward_native(
self,
x: torch.Tensor,
......@@ -139,8 +176,8 @@ class Gemma3RMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
if not _is_cuda:
if not (_is_cuda or _is_hip):
logger.info(
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
)
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment