Unverified Commit ba861293 authored by VDV1985's avatar VDV1985 Committed by GitHub
Browse files

[feat]Ascend NPU Gemma-3-12b and Gemma-3-27b support (#8909)

parent c112bcc4
...@@ -103,6 +103,15 @@ class GeluAndMul(CustomOp): ...@@ -103,6 +103,15 @@ class GeluAndMul(CustomOp):
raise RuntimeError("GeluAndMul only support tanh or none") raise RuntimeError("GeluAndMul only support tanh or none")
return out return out
def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
y_npu, gelu_npu = torch_npu.npu_geglu(
x,
dim=-1,
approximate=1 if self.approximate == "tanh" else 0,
activate_left=True,
)
return y_npu
class NewGELU(CustomOp): class NewGELU(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor: def forward_native(self, x: torch.Tensor) -> torch.Tensor:
...@@ -137,6 +146,9 @@ class QuickGELU(CustomOp): ...@@ -137,6 +146,9 @@ class QuickGELU(CustomOp):
gelu_quick(x, out) gelu_quick(x, out)
return out return out
def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
return torch_npu.npu_fast_gelu(x)
class ScaledActivation(nn.Module): class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters. """An activation function with post-scale parameters.
......
...@@ -64,7 +64,7 @@ class AscendAttnBackend(AttentionBackend): ...@@ -64,7 +64,7 @@ class AscendAttnBackend(AttentionBackend):
if self.use_mla: if self.use_mla:
self.kv_lora_rank = model_runner.model_config.kv_lora_rank self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.native_attn = TorchNativeAttnBackend(model_runner) self.native_attn = TorchNativeAttnBackend(model_runner)
self.graph_metadata = {} self.graph_metadata = {}
self.max_context_len = model_runner.model_config.context_len self.max_context_len = model_runner.model_config.context_len
self.req_to_token = model_runner.req_to_token_pool.req_to_token self.req_to_token = model_runner.req_to_token_pool.req_to_token
...@@ -180,7 +180,7 @@ class AscendAttnBackend(AttentionBackend): ...@@ -180,7 +180,7 @@ class AscendAttnBackend(AttentionBackend):
if self.use_fia: if self.use_fia:
"""FIA will support multi-bs in the later version of CANN""" """FIA will support multi-bs in the later version of CANN"""
q = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) q = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim)
attn_output = torch.empty( attn_output = torch.empty(
(q.size(0), layer.tp_q_head_num, layer.v_head_dim), (q.size(0), layer.tp_q_head_num, layer.v_head_dim),
device=q.device, device=q.device,
...@@ -208,26 +208,61 @@ class AscendAttnBackend(AttentionBackend): ...@@ -208,26 +208,61 @@ class AscendAttnBackend(AttentionBackend):
) )
else: else:
query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim) if layer.qk_head_dim <= 128:
attn_output = torch.empty( query = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
(query.shape[0], layer.tp_q_head_num * layer.v_head_dim), attn_output = torch.empty(
dtype=query.dtype, (query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
device=query.device, dtype=query.dtype,
) device=query.device,
)
torch_npu._npu_flash_attention_qlens( torch_npu._npu_flash_attention_qlens(
query=query, query=query,
key_cache=k_cache, key_cache=k_cache,
value_cache=v_cache, value_cache=v_cache,
mask=self.mask, mask=self.mask,
block_table=self.forward_metadata.block_tables, block_table=self.forward_metadata.block_tables,
seq_len=self.forward_metadata.extend_seq_lens_cpu_int, seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
context_lens=self.forward_metadata.seq_lens_cpu_int, context_lens=self.forward_metadata.seq_lens_cpu_int,
scale_value=layer.scaling, scale_value=layer.scaling,
num_heads=layer.tp_q_head_num, num_heads=layer.tp_q_head_num,
num_kv_heads=layer.tp_k_head_num, num_kv_heads=layer.tp_k_head_num,
out=attn_output, out=attn_output,
) )
else:
if layer.qk_head_dim != layer.v_head_dim:
attn_output = q.new_empty(
(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
)
else:
attn_output = torch.empty_like(q)
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
o_ = attn_output.view(-1, layer.tp_q_head_num, layer.v_head_dim)
causal = True
if (
layer.is_cross_attention
or layer.attn_type == AttentionType.ENCODER_ONLY
):
causal = False
self.native_attn._run_sdpa_forward_extend(
q_,
o_,
k_cache.view(-1, layer.tp_k_head_num, layer.qk_head_dim),
v_cache.view(-1, layer.tp_v_head_num, layer.v_head_dim),
forward_batch.req_to_token_pool.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.extend_prefix_lens,
forward_batch.extend_seq_lens,
scaling=layer.scaling,
enable_gqa=use_gqa,
causal=causal,
)
else: else:
assert ( assert (
layer.qk_head_dim != layer.v_head_dim layer.qk_head_dim != layer.v_head_dim
...@@ -283,7 +318,7 @@ class AscendAttnBackend(AttentionBackend): ...@@ -283,7 +318,7 @@ class AscendAttnBackend(AttentionBackend):
v_cache = forward_batch.token_to_kv_pool.get_value_buffer( v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
layer.layer_id layer.layer_id
).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim) ).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim) query = q.reshape(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
if self.forward_metadata.seq_lens_cpu_int is None: if self.forward_metadata.seq_lens_cpu_int is None:
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
else: else:
...@@ -439,7 +474,8 @@ class AscendAttnBackend(AttentionBackend): ...@@ -439,7 +474,8 @@ class AscendAttnBackend(AttentionBackend):
scale=layer.scaling, scale=layer.scaling,
) )
else: else:
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) query = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim)
num_tokens = query.shape[0]
attn_output = torch.empty( attn_output = torch.empty(
(num_tokens, layer.tp_q_head_num, layer.v_head_dim), (num_tokens, layer.tp_q_head_num, layer.v_head_dim),
dtype=query.dtype, dtype=query.dtype,
......
...@@ -53,7 +53,7 @@ elif _is_hip: ...@@ -53,7 +53,7 @@ elif _is_hip:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if is_npu(): if _is_npu:
import torch_npu import torch_npu
...@@ -266,23 +266,48 @@ class GemmaRMSNorm(CustomOp): ...@@ -266,23 +266,48 @@ class GemmaRMSNorm(CustomOp):
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon) out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
return out return out
def forward_npu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype
if residual is not None:
x = x + residual
residual = x
class Gemma3RMSNorm(nn.Module): x = x.float()
variance = torch_npu.mean(torch_npu.pow(x, 2), dim=-1, keepdim=True)
x = x * torch_npu.rsqrt(variance + self.variance_epsilon)
x = x * (1.0 + self.weight.float())
x = x.to(orig_dtype)
return x if residual is None else (x, residual)
class Gemma3RMSNorm(CustomOp):
def __init__(self, dim: int, eps: float = 1e-6): def __init__(self, dim: int, eps: float = 1e-6):
super().__init__() super().__init__()
self.eps = eps self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim)) self.weight = nn.Parameter(torch.zeros(dim))
# Re-dispatch
def _norm(self, x): def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x): def forward_native(self, x):
output = self._norm(x.float()) output = self._norm(x.float())
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16) # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402 # See https://github.com/huggingface/transformers/pull/29402
output = output * (1.0 + self.weight.float()) output = output * (1.0 + self.weight.float())
return output.type_as(x) return output.type_as(x)
def forward_cuda(self, x):
return self.forward_native(x)
def forward_npu(self, x):
output, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.eps)
return output
def extra_repr(self): def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}" return f"{tuple(self.weight.shape)}, eps={self.eps}"
......
...@@ -1876,7 +1876,7 @@ def rotate_half(x): ...@@ -1876,7 +1876,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb( def apply_rotary_pos_emb_native(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
cos: torch.Tensor, cos: torch.Tensor,
...@@ -1899,6 +1899,33 @@ def apply_rotary_pos_emb( ...@@ -1899,6 +1899,33 @@ def apply_rotary_pos_emb(
return q_embed, k_embed return q_embed, k_embed
def apply_rotary_pos_emb_npu(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
unsqueeze_dim=1,
) -> Tuple[torch.Tensor, torch.Tensor]:
if q.shape[1] != 128:
return apply_rotary_pos_emb_native(q, k, cos, sin, unsqueeze_dim)
cos = cos.unsqueeze(unsqueeze_dim)
cos = torch.transpose(cos, 1, 2)
sin = sin.unsqueeze(unsqueeze_dim)
sin = torch.transpose(sin, 1, 2)
q = torch.transpose(q, 1, 2)
k = torch.transpose(k, 1, 2)
q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(q, k, cos, sin)
q_embed = torch.transpose(q_embed, 1, 2)
k_embed = torch.transpose(k_embed, 1, 2)
return q_embed, k_embed
if _is_npu:
apply_rotary_pos_emb = apply_rotary_pos_emb_npu
else:
apply_rotary_pos_emb = apply_rotary_pos_emb_native
def get_rope_cpu( def get_rope_cpu(
head_size: int, head_size: int,
rotary_dim: int, rotary_dim: int,
......
...@@ -20,9 +20,11 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -20,9 +20,11 @@ from sglang.srt.managers.schedule_batch import (
) )
from sglang.srt.mem_cache.multimodal_cache import MultiModalCache from sglang.srt.mem_cache.multimodal_cache import MultiModalCache
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import flatten_nested_list, print_warning_once from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once
from sglang.utils import logger from sglang.utils import logger
_is_npu = is_npu()
# NOTE: Using the shared logger from sglang.utils instead of creating a module-specific logger # NOTE: Using the shared logger from sglang.utils instead of creating a module-specific logger
# to ensure consistent logging behavior across the codebase. This prevents issues with log # to ensure consistent logging behavior across the codebase. This prevents issues with log
# propagation that can cause some log messages (like 'server is fired up') to not appear # propagation that can cause some log messages (like 'server is fired up') to not appear
...@@ -486,6 +488,8 @@ def get_embedding_and_mask( ...@@ -486,6 +488,8 @@ def get_embedding_and_mask(
if embedding is None: if embedding is None:
return None, None return None, None
# 2. Get mask # 2. Get mask
if _is_npu:
torch.npu.current_stream().synchronize()
special_multimodal_mask = _get_multimodal_mask(input_ids, placeholder_tensor) special_multimodal_mask = _get_multimodal_mask(input_ids, placeholder_tensor)
# 3. Adjust embedding length if needed # 3. Adjust embedding length if needed
embedding = _adjust_embedding_length(embedding, special_multimodal_mask, logger) embedding = _adjust_embedding_length(embedding, special_multimodal_mask, logger)
......
...@@ -13,7 +13,9 @@ from PIL import Image ...@@ -13,7 +13,9 @@ from PIL import Image
from transformers import BaseImageProcessorFast from transformers import BaseImageProcessorFast
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.utils import load_audio, load_image, load_video, logger from sglang.srt.utils import is_npu, load_audio, load_image, load_video, logger
_is_npu = is_npu()
@dataclasses.dataclass @dataclasses.dataclass
...@@ -232,7 +234,7 @@ class BaseMultimodalProcessor(ABC): ...@@ -232,7 +234,7 @@ class BaseMultimodalProcessor(ABC):
and isinstance(processor.image_processor, BaseImageProcessorFast) and isinstance(processor.image_processor, BaseImageProcessorFast)
and not self.server_args.disable_fast_image_processor and not self.server_args.disable_fast_image_processor
): ):
kwargs["device"] = "cuda" kwargs["device"] = "cuda" if not _is_npu else "npu"
result = processor.__call__( result = processor.__call__(
text=[input_text], text=[input_text],
padding=True, padding=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment