Unverified Commit e22f3a5e authored by ronnie_zheng's avatar ronnie_zheng Committed by GitHub
Browse files

[Ascend]optimize Qwen3 on Ascend (#10574)


Co-authored-by: default avatarc30031083 <chenxu140@huawei.com>
parent 095093ee
......@@ -50,6 +50,7 @@ from sglang.srt.utils import (
is_hip,
is_sm90_supported,
is_sm100_supported,
prepare_weight_cache,
)
_is_flashinfer_available = is_flashinfer_available()
......@@ -275,7 +276,11 @@ class LayerCommunicator:
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
cache=None,
):
if cache is not None:
self._context.cache = cache
return self._communicate_with_all_reduce_and_layer_norm_fn(
hidden_states=hidden_states,
residual=residual,
......@@ -349,6 +354,7 @@ class CommunicateContext:
attn_tp_size: int
attn_dp_size: int
tp_size: int
cache = None
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
return self.process_group_sizes[a] == self.process_group_sizes[b]
......@@ -533,6 +539,8 @@ class CommunicateWithAllReduceAndLayerNormFn:
)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
if context.cache is not None:
_ = prepare_weight_cache(hidden_states, context.cache)
hidden_states, residual = layernorm(hidden_states, residual)
return hidden_states, residual
......
......@@ -638,6 +638,7 @@ class NPU_W8A8LinearMethodImpl:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
class NPU_W8A8LinearMethodMTImpl:
......@@ -830,6 +831,7 @@ class NPU_W8A8DynamicLinearMethodImpl:
layer.weight_scale.data = layer.weight_scale.data.flatten()
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
layer.weight_offset.data = layer.weight_offset.data.flatten()
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
......
......@@ -179,6 +179,13 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
logger = logging.getLogger(__name__)
if _is_npu:
import torch_npu
torch.npu.config.allow_internal_format = True
torch_npu.npu.set_compile_mode(jit_compile=False)
class RankZeroFilter(logging.Filter):
"""Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
......
......@@ -19,8 +19,10 @@ import logging
import threading
from typing import TYPE_CHECKING, Optional, Union
import numpy as np
import torch
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
logger = logging.getLogger(__name__)
......
......@@ -30,12 +30,19 @@ from sglang.srt.model_loader.weight_utils import (
)
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.utils import add_prefix, is_cuda
from sglang.srt.utils import (
add_prefix,
get_cmo_stream,
is_cuda,
is_npu,
wait_cmo_stream,
)
Qwen3Config = None
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_npu = is_npu()
class Qwen3Attention(nn.Module):
......@@ -235,9 +242,18 @@ class Qwen3DecoderLayer(nn.Module):
# Fully Connected
hidden_states, residual = self.layer_communicator.prepare_mlp(
hidden_states, residual, forward_batch
hidden_states,
residual,
forward_batch,
cache=(
[self.mlp.gate_up_proj.weight, self.mlp.down_proj.weight]
if _is_npu
else None
),
)
hidden_states = self.mlp(hidden_states)
if _is_npu and get_cmo_stream():
wait_cmo_stream()
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
)
......
......@@ -517,6 +517,50 @@ def make_layers(
return modules, start_layer, end_layer
cmo_stream = None
def get_cmo_stream():
"""
Cache Management Operation(CMO).
Launch a new stream to prefetch the weight of matmul when running other
AIV or communication kernels, aiming to overlap the memory access time.
"""
global cmo_stream
if cmo_stream is None:
cmo_stream = torch.get_device_module().Stream()
return cmo_stream
def prepare_weight_cache(handle, cache):
import torch_npu
NPU_PREFETCH_MAX_SIZE_BYTES = (
1000000000 # 1GB, a large value to prefetch entire weight
)
stream = get_cmo_stream()
stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(stream):
if isinstance(cache, list):
for weight in cache:
torch_npu.npu_prefetch(
weight,
handle,
NPU_PREFETCH_MAX_SIZE_BYTES,
)
else:
torch_npu.npu_prefetch(
cache,
handle,
NPU_PREFETCH_MAX_SIZE_BYTES,
)
def wait_cmo_stream():
cur_stream = torch.get_device_module().current_stream()
cur_stream.wait_stream(get_cmo_stream())
def set_random_seed(seed: int) -> None:
"""Set the random seed for all libraries."""
random.seed(seed)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment