Unverified Commit 52fcbbb8 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Revert "perf: optimize qwen-vl with symm mem allreduce" (#11436)

parent af96ca11
......@@ -3,13 +3,13 @@ MiB = 1024 * 1024
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
9: {
2: 64 * MiB, # 64 MB
4: 64 * MiB, # 64 MB
6: 128 * MiB, # 128 MB
8: 128 * MiB, # 128 MB
4: 32 * MiB, # 32 MB
6: 64 * MiB, # 64 MB
8: 64 * MiB, # 64 MB
},
10: {
2: 64 * MiB, # 64 MB
4: 64 * MiB, # 64 MB
4: 32 * MiB, # 32 MB
6: 128 * MiB, # 128 MB
8: 128 * MiB, # 128 MB
},
......
......@@ -603,11 +603,8 @@ class GroupCoordinator:
def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
pynccl_comm = self.pynccl_comm
symm_mem_comm = self.symm_mem_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.all_reduce(input_)
elif symm_mem_comm is not None and not symm_mem_comm.disabled:
symm_mem_comm.all_reduce(input_)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
......
......@@ -1008,17 +1008,6 @@ class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
return cache
def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.Tensor:
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THTHWHTHW...TT], preserving frequency continuity.
"""
x_t = x[0].clone()
x_t[..., 1 : mrope_section[1] * 3 : 3] = x[1, ..., 1 : mrope_section[1] * 3 : 3]
x_t[..., 2 : mrope_section[2] * 3 : 3] = x[2, ..., 2 : mrope_section[2] * 3 : 3]
return x_t
class MRotaryEmbedding(RotaryEmbedding):
"""Rotary Embedding with Multimodal Sections."""
......@@ -1031,14 +1020,12 @@ class MRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool,
dtype: torch.dtype,
mrope_section: Optional[List[int]] = None,
mrope_interleaved: bool = False,
) -> None:
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
self.mrope_section = mrope_section
self.mrope_interleaved = mrope_interleaved
if self.mrope_section:
expected_sum = rotary_dim // 2
actual_sum = sum(self.mrope_section)
......@@ -1099,18 +1086,15 @@ class MRotaryEmbedding(RotaryEmbedding):
cos, sin = cos_sin.chunk(2, dim=-1)
if positions.ndim == 2:
assert self.mrope_section
if self.mrope_interleaved:
cos = apply_interleaved_rope(cos, self.mrope_section)
sin = apply_interleaved_rope(sin, self.mrope_section)
else:
cos = torch.cat(
[m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
dim=-1,
)
sin = torch.cat(
[m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
dim=-1,
)
cos = torch.cat(
[m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
dim=-1,
)
sin = torch.cat(
[m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
dim=-1,
)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
......@@ -1789,7 +1773,6 @@ def get_rope(
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
)
else:
rotary_emb = RotaryEmbedding(
......
......@@ -1766,11 +1766,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
self.out_cache_loc = None
if isinstance(self.seq_lens_cpu, torch.Tensor):
# CPU tensor
self.seq_lens_sum = int(self.seq_lens_cpu.sum().item())
else:
self.seq_lens_sum = int(np.asarray(self.seq_lens_cpu).sum())
self.seq_lens_sum = self.seq_lens.sum().item()
self.output_ids = self.output_ids[keep_indices_device]
self.return_logprob = any(req.return_logprob for req in self.reqs)
if self.return_logprob:
......
......@@ -27,7 +27,6 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
......@@ -89,17 +88,10 @@ class Qwen2MLP(nn.Module):
)
self.act_fn = SiluAndMul()
def forward(
self,
x,
should_allreduce_fusion: bool = False,
):
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(
x,
skip_all_reduce=should_allreduce_fusion,
)
x, _ = self.down_proj(x)
return x
......@@ -117,11 +109,9 @@ class Qwen2Attention(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
......@@ -153,8 +143,6 @@ class Qwen2Attention(nn.Module):
self.total_num_kv_heads,
bias=True,
quant_config=quant_config,
tp_rank=tp_rank,
tp_size=tp_size,
prefix=add_prefix("qkv_proj", prefix),
)
self.o_proj = RowParallelLinear(
......@@ -162,8 +150,6 @@ class Qwen2Attention(nn.Module):
hidden_size,
bias=False,
quant_config=quant_config,
tp_rank=tp_rank,
tp_size=tp_size,
prefix=add_prefix("o_proj", prefix),
)
......@@ -209,7 +195,6 @@ class Qwen2DecoderLayer(nn.Module):
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
......@@ -231,18 +216,6 @@ class Qwen2DecoderLayer(nn.Module):
dual_chunk_attention_config=dual_chunk_attention_config,
prefix=add_prefix("self_attn", prefix),
)
self.layer_id = layer_id
self.is_layer_sparse = False
is_previous_layer_sparse = False
self.layer_scatter_modes = LayerScatterModes.init_new(
layer_id=layer_id,
num_layers=config.num_hidden_layers,
is_layer_sparse=self.is_layer_sparse,
is_previous_layer_sparse=is_previous_layer_sparse,
)
self.mlp = Qwen2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
......@@ -255,14 +228,6 @@ class Qwen2DecoderLayer(nn.Module):
config.hidden_size, eps=config.rms_norm_eps
)
self.layer_communicator = LayerCommunicator(
layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
allow_reduce_scatter=True,
is_last_layer=(self.layer_id == self.config.num_hidden_layers - 1),
)
def forward(
self,
positions: torch.Tensor,
......@@ -284,13 +249,7 @@ class Qwen2DecoderLayer(nn.Module):
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
should_allreduce_fusion = (
self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
forward_batch
)
)
hidden_states = self.mlp(hidden_states, should_allreduce_fusion)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment