Unverified Commit 264dc6e7 authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

[optimize] add two stream norm for qwen3 (#7740)


Co-authored-by: default avatarispobock <ispobaoke@gmail.com>
parent 646cef2e
...@@ -190,6 +190,7 @@ class Qwen2DecoderLayer(nn.Module): ...@@ -190,6 +190,7 @@ class Qwen2DecoderLayer(nn.Module):
layer_id: int = 0, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -253,6 +254,7 @@ class Qwen2Model(nn.Module): ...@@ -253,6 +254,7 @@ class Qwen2Model(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer, decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer,
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -280,6 +282,7 @@ class Qwen2Model(nn.Module): ...@@ -280,6 +282,7 @@ class Qwen2Model(nn.Module):
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
alt_stream=alt_stream,
), ),
pp_rank=self.pp_group.rank_in_group, pp_rank=self.pp_group.rank_in_group,
pp_size=self.pp_group.world_size, pp_size=self.pp_group.world_size,
......
...@@ -291,6 +291,7 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -291,6 +291,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
layer_id: int, layer_id: int,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -393,6 +394,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -393,6 +394,7 @@ class Qwen2MoeModel(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
decoder_layer_type: type[nn.Module] = Qwen2MoeDecoderLayer, decoder_layer_type: type[nn.Module] = Qwen2MoeDecoderLayer,
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -418,6 +420,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -418,6 +420,7 @@ class Qwen2MoeModel(nn.Module):
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
alt_stream=alt_stream,
), ),
pp_rank=self.pp_group.rank_in_group, pp_rank=self.pp_group.rank_in_group,
pp_size=self.pp_group.world_size, pp_size=self.pp_group.world_size,
......
...@@ -25,15 +25,17 @@ from sglang.srt.layers.radix_attention import RadixAttention ...@@ -25,15 +25,17 @@ from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix, is_cuda
Qwen3Config = None Qwen3Config = None
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
class Qwen3Attention(nn.Module): class Qwen3Attention(nn.Module):
...@@ -51,6 +53,7 @@ class Qwen3Attention(nn.Module): ...@@ -51,6 +53,7 @@ class Qwen3Attention(nn.Module):
rms_norm_eps: float = None, rms_norm_eps: float = None,
attention_bias: bool = False, attention_bias: bool = False,
prefix: str = "", prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -119,15 +122,27 @@ class Qwen3Attention(nn.Module): ...@@ -119,15 +122,27 @@ class Qwen3Attention(nn.Module):
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix), prefix=add_prefix("attn", prefix),
) )
self.alt_stream = alt_stream
def _apply_qk_norm( def _apply_qk_norm(
self, q: torch.Tensor, k: torch.Tensor self, q: torch.Tensor, k: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
q_by_head = q.reshape(-1, self.head_dim) # overlap qk norm
q_by_head = self.q_norm(q_by_head) if self.alt_stream is not None and get_is_capture_mode():
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
with torch.cuda.stream(self.alt_stream):
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
current_stream.wait_stream(self.alt_stream)
else:
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
q = q_by_head.view(q.shape) q = q_by_head.view(q.shape)
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape) k = k_by_head.view(k.shape)
return q, k return q, k
...@@ -153,6 +168,7 @@ class Qwen3DecoderLayer(nn.Module): ...@@ -153,6 +168,7 @@ class Qwen3DecoderLayer(nn.Module):
layer_id: int = 0, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -173,6 +189,7 @@ class Qwen3DecoderLayer(nn.Module): ...@@ -173,6 +189,7 @@ class Qwen3DecoderLayer(nn.Module):
rms_norm_eps=config.rms_norm_eps, rms_norm_eps=config.rms_norm_eps,
attention_bias=config.attention_bias, attention_bias=config.attention_bias,
prefix=add_prefix("self_attn", prefix), prefix=add_prefix("self_attn", prefix),
alt_stream=alt_stream,
) )
self.mlp = Qwen3MLP( self.mlp = Qwen3MLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
...@@ -234,11 +251,13 @@ class Qwen3Model(Qwen2Model): ...@@ -234,11 +251,13 @@ class Qwen3Model(Qwen2Model):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
alt_stream = torch.cuda.Stream() if _is_cuda else None
super().__init__( super().__init__(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
decoder_layer_type=Qwen3DecoderLayer, decoder_layer_type=Qwen3DecoderLayer,
alt_stream=alt_stream,
) )
......
...@@ -67,6 +67,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -67,6 +67,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch, ForwardBatch,
ForwardMode, ForwardMode,
...@@ -76,11 +77,12 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader ...@@ -76,11 +77,12 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
from sglang.srt.models.qwen2_moe import Qwen2MoeModel from sglang.srt.models.qwen2_moe import Qwen2MoeModel
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
from sglang.srt.utils import DeepEPMode, add_prefix, is_non_idle_and_non_empty from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_non_idle_and_non_empty
Qwen3MoeConfig = None Qwen3MoeConfig = None
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
class Qwen3MoeSparseMoeBlock(nn.Module): class Qwen3MoeSparseMoeBlock(nn.Module):
...@@ -352,6 +354,7 @@ class Qwen3MoeAttention(nn.Module): ...@@ -352,6 +354,7 @@ class Qwen3MoeAttention(nn.Module):
attention_bias: bool = False, attention_bias: bool = False,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -421,15 +424,27 @@ class Qwen3MoeAttention(nn.Module): ...@@ -421,15 +424,27 @@ class Qwen3MoeAttention(nn.Module):
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.alt_stream = alt_stream
def _apply_qk_norm( def _apply_qk_norm(
self, q: torch.Tensor, k: torch.Tensor self, q: torch.Tensor, k: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
q_by_head = q.reshape(-1, self.head_dim) # overlap qk norm
q_by_head = self.q_norm(q_by_head) if self.alt_stream is not None and get_is_capture_mode():
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
with torch.cuda.stream(self.alt_stream):
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
current_stream.wait_stream(self.alt_stream)
else:
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
q = q_by_head.view(q.shape) q = q_by_head.view(q.shape)
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape) k = k_by_head.view(k.shape)
return q, k return q, k
...@@ -489,6 +504,7 @@ class Qwen3MoeDecoderLayer(nn.Module): ...@@ -489,6 +504,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
layer_id: int, layer_id: int,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -514,6 +530,7 @@ class Qwen3MoeDecoderLayer(nn.Module): ...@@ -514,6 +530,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
attention_bias=attention_bias, attention_bias=attention_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("self_attn", prefix), prefix=add_prefix("self_attn", prefix),
alt_stream=alt_stream,
) )
self.layer_id = layer_id self.layer_id = layer_id
...@@ -657,11 +674,13 @@ class Qwen3MoeModel(Qwen2MoeModel): ...@@ -657,11 +674,13 @@ class Qwen3MoeModel(Qwen2MoeModel):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
alt_stream = torch.cuda.Stream() if _is_cuda else None
super().__init__( super().__init__(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
decoder_layer_type=Qwen3MoeDecoderLayer, decoder_layer_type=Qwen3MoeDecoderLayer,
alt_stream=alt_stream,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment