Unverified Commit 5dd8c644 authored by ssshinigami's avatar ssshinigami Committed by GitHub
Browse files

[Bug fix] Fix Gemma 2 and fix Gemma 3 multimodal with bs > 1 on NPU (#9871)


Co-authored-by: default avatarMaksim <makcum888e@mail.ru>
parent ee21817c
...@@ -288,16 +288,11 @@ class GemmaRMSNorm(CustomOp): ...@@ -288,16 +288,11 @@ class GemmaRMSNorm(CustomOp):
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype
if residual is not None: if residual is not None:
x = x + residual x = x + residual
residual = x residual = x
x = x.float() x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
variance = torch_npu.mean(torch_npu.pow(x, 2), dim=-1, keepdim=True)
x = x * torch_npu.rsqrt(variance + self.variance_epsilon)
x = x * (1.0 + self.weight.float())
x = x.to(orig_dtype)
return x if residual is None else (x, residual) return x if residual is None else (x, residual)
......
...@@ -46,10 +46,12 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -46,10 +46,12 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch, ForwardBatch,
ForwardMode, ForwardMode,
) )
from sglang.srt.utils import dump_to_file, use_intel_amx_backend from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_is_npu = is_npu()
@dataclasses.dataclass @dataclasses.dataclass
class LogitsProcessorOutput: class LogitsProcessorOutput:
...@@ -517,7 +519,12 @@ class LogitsProcessor(nn.Module): ...@@ -517,7 +519,12 @@ class LogitsProcessor(nn.Module):
logits = logits[:, : self.config.vocab_size].float() logits = logits[:, : self.config.vocab_size].float()
if self.final_logit_softcapping: if self.final_logit_softcapping:
if not _is_npu:
fused_softcap(logits, self.final_logit_softcapping) fused_softcap(logits, self.final_logit_softcapping)
else:
logits = self.final_logit_softcapping * torch.tanh(
logits / self.final_logit_softcapping
)
return logits return logits
......
...@@ -550,7 +550,7 @@ class PrefillAdder: ...@@ -550,7 +550,7 @@ class PrefillAdder:
) )
else: else:
# Make sure at least one page is available # Make sure at least one page is available
trunc_len = self.rem_chunk_tokens - self.page_size + 1 trunc_len = self.rem_chunk_tokens // self.page_size * self.page_size
if trunc_len <= 0: if trunc_len <= 0:
return AddReqResult.OTHER return AddReqResult.OTHER
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment