Unverified Commit e22f3a5e authored by ronnie_zheng's avatar ronnie_zheng Committed by GitHub
Browse files

[Ascend]optimize Qwen3 on Ascend (#10574)


Co-authored-by: default avatarc30031083 <chenxu140@huawei.com>
parent 095093ee
...@@ -50,6 +50,7 @@ from sglang.srt.utils import ( ...@@ -50,6 +50,7 @@ from sglang.srt.utils import (
is_hip, is_hip,
is_sm90_supported, is_sm90_supported,
is_sm100_supported, is_sm100_supported,
prepare_weight_cache,
) )
_is_flashinfer_available = is_flashinfer_available() _is_flashinfer_available = is_flashinfer_available()
...@@ -275,7 +276,11 @@ class LayerCommunicator: ...@@ -275,7 +276,11 @@ class LayerCommunicator:
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: torch.Tensor, residual: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
cache=None,
): ):
if cache is not None:
self._context.cache = cache
return self._communicate_with_all_reduce_and_layer_norm_fn( return self._communicate_with_all_reduce_and_layer_norm_fn(
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual,
...@@ -349,6 +354,7 @@ class CommunicateContext: ...@@ -349,6 +354,7 @@ class CommunicateContext:
attn_tp_size: int attn_tp_size: int
attn_dp_size: int attn_dp_size: int
tp_size: int tp_size: int
cache = None
def is_same_group_size(self, a: ScatterMode, b: ScatterMode): def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
return self.process_group_sizes[a] == self.process_group_sizes[b] return self.process_group_sizes[a] == self.process_group_sizes[b]
...@@ -533,6 +539,8 @@ class CommunicateWithAllReduceAndLayerNormFn: ...@@ -533,6 +539,8 @@ class CommunicateWithAllReduceAndLayerNormFn:
) )
else: else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states) hidden_states = tensor_model_parallel_all_reduce(hidden_states)
if context.cache is not None:
_ = prepare_weight_cache(hidden_states, context.cache)
hidden_states, residual = layernorm(hidden_states, residual) hidden_states, residual = layernorm(hidden_states, residual)
return hidden_states, residual return hidden_states, residual
......
...@@ -638,6 +638,7 @@ class NPU_W8A8LinearMethodImpl: ...@@ -638,6 +638,7 @@ class NPU_W8A8LinearMethodImpl:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight_scale.data = torch.flatten(layer.weight_scale.data) layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
layer.weight_offset.data = torch.flatten(layer.weight_offset.data) layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
class NPU_W8A8LinearMethodMTImpl: class NPU_W8A8LinearMethodMTImpl:
...@@ -830,6 +831,7 @@ class NPU_W8A8DynamicLinearMethodImpl: ...@@ -830,6 +831,7 @@ class NPU_W8A8DynamicLinearMethodImpl:
layer.weight_scale.data = layer.weight_scale.data.flatten() layer.weight_scale.data = layer.weight_scale.data.flatten()
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
layer.weight_offset.data = layer.weight_offset.data.flatten() layer.weight_offset.data = layer.weight_offset.data.flatten()
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
class NPU_W8A8DynamicLinearMethod(LinearMethodBase): class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
......
...@@ -179,6 +179,13 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300 ...@@ -179,6 +179,13 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if _is_npu:
import torch_npu
torch.npu.config.allow_internal_format = True
torch_npu.npu.set_compile_mode(jit_compile=False)
class RankZeroFilter(logging.Filter): class RankZeroFilter(logging.Filter):
"""Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank.""" """Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
......
...@@ -19,8 +19,10 @@ import logging ...@@ -19,8 +19,10 @@ import logging
import threading import threading
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
import numpy as np
import torch import torch
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -30,12 +30,19 @@ from sglang.srt.model_loader.weight_utils import ( ...@@ -30,12 +30,19 @@ from sglang.srt.model_loader.weight_utils import (
) )
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.utils import add_prefix, is_cuda from sglang.srt.utils import (
add_prefix,
get_cmo_stream,
is_cuda,
is_npu,
wait_cmo_stream,
)
Qwen3Config = None Qwen3Config = None
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_npu = is_npu()
class Qwen3Attention(nn.Module): class Qwen3Attention(nn.Module):
...@@ -235,9 +242,18 @@ class Qwen3DecoderLayer(nn.Module): ...@@ -235,9 +242,18 @@ class Qwen3DecoderLayer(nn.Module):
# Fully Connected # Fully Connected
hidden_states, residual = self.layer_communicator.prepare_mlp( hidden_states, residual = self.layer_communicator.prepare_mlp(
hidden_states, residual, forward_batch hidden_states,
residual,
forward_batch,
cache=(
[self.mlp.gate_up_proj.weight, self.mlp.down_proj.weight]
if _is_npu
else None
),
) )
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
if _is_npu and get_cmo_stream():
wait_cmo_stream()
hidden_states, residual = self.layer_communicator.postprocess_layer( hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch hidden_states, residual, forward_batch
) )
......
...@@ -517,6 +517,50 @@ def make_layers( ...@@ -517,6 +517,50 @@ def make_layers(
return modules, start_layer, end_layer return modules, start_layer, end_layer
cmo_stream = None
def get_cmo_stream():
"""
Cache Management Operation(CMO).
Launch a new stream to prefetch the weight of matmul when running other
AIV or communication kernels, aiming to overlap the memory access time.
"""
global cmo_stream
if cmo_stream is None:
cmo_stream = torch.get_device_module().Stream()
return cmo_stream
def prepare_weight_cache(handle, cache):
import torch_npu
NPU_PREFETCH_MAX_SIZE_BYTES = (
1000000000 # 1GB, a large value to prefetch entire weight
)
stream = get_cmo_stream()
stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(stream):
if isinstance(cache, list):
for weight in cache:
torch_npu.npu_prefetch(
weight,
handle,
NPU_PREFETCH_MAX_SIZE_BYTES,
)
else:
torch_npu.npu_prefetch(
cache,
handle,
NPU_PREFETCH_MAX_SIZE_BYTES,
)
def wait_cmo_stream():
cur_stream = torch.get_device_module().current_stream()
cur_stream.wait_stream(get_cmo_stream())
def set_random_seed(seed: int) -> None: def set_random_seed(seed: int) -> None:
"""Set the random seed for all libraries.""" """Set the random seed for all libraries."""
random.seed(seed) random.seed(seed)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment