Unverified Commit 0b6f535f authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

[Reland] perf: optimize qwen-vl with symm mem allreduce (#11457)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
parent c5fe3c0b
...@@ -3,13 +3,13 @@ MiB = 1024 * 1024 ...@@ -3,13 +3,13 @@ MiB = 1024 * 1024
SYMM_MEM_ALL_REDUCE_MAX_SIZES = { SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
9: { 9: {
2: 64 * MiB, # 64 MB 2: 64 * MiB, # 64 MB
4: 32 * MiB, # 32 MB 4: 64 * MiB, # 64 MB
6: 64 * MiB, # 64 MB 6: 128 * MiB, # 128 MB
8: 64 * MiB, # 64 MB 8: 128 * MiB, # 128 MB
}, },
10: { 10: {
2: 64 * MiB, # 64 MB 2: 64 * MiB, # 64 MB
4: 32 * MiB, # 32 MB 4: 64 * MiB, # 64 MB
6: 128 * MiB, # 128 MB 6: 128 * MiB, # 128 MB
8: 128 * MiB, # 128 MB 8: 128 * MiB, # 128 MB
}, },
......
...@@ -615,8 +615,11 @@ class GroupCoordinator: ...@@ -615,8 +615,11 @@ class GroupCoordinator:
def _all_reduce_in_place(self, input_: torch.Tensor) -> None: def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
pynccl_comm = self.pynccl_comm pynccl_comm = self.pynccl_comm
symm_mem_comm = self.symm_mem_comm
if pynccl_comm is not None and not pynccl_comm.disabled: if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.all_reduce(input_) pynccl_comm.all_reduce(input_)
elif symm_mem_comm is not None and not symm_mem_comm.disabled:
symm_mem_comm.all_reduce(input_)
else: else:
torch.distributed.all_reduce(input_, group=self.device_group) torch.distributed.all_reduce(input_, group=self.device_group)
......
...@@ -1008,6 +1008,17 @@ class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding): ...@@ -1008,6 +1008,17 @@ class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
return cache return cache
def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.Tensor:
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THTHWHTHW...TT], preserving frequency continuity.
"""
x_t = x[0].clone()
x_t[..., 1 : mrope_section[1] * 3 : 3] = x[1, ..., 1 : mrope_section[1] * 3 : 3]
x_t[..., 2 : mrope_section[2] * 3 : 3] = x[2, ..., 2 : mrope_section[2] * 3 : 3]
return x_t
class MRotaryEmbedding(RotaryEmbedding): class MRotaryEmbedding(RotaryEmbedding):
"""Rotary Embedding with Multimodal Sections.""" """Rotary Embedding with Multimodal Sections."""
...@@ -1020,12 +1031,14 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1020,12 +1031,14 @@ class MRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool, is_neox_style: bool,
dtype: torch.dtype, dtype: torch.dtype,
mrope_section: Optional[List[int]] = None, mrope_section: Optional[List[int]] = None,
mrope_interleaved: bool = False,
) -> None: ) -> None:
super().__init__( super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
) )
self.mrope_section = mrope_section self.mrope_section = mrope_section
self.mrope_interleaved = mrope_interleaved
if self.mrope_section: if self.mrope_section:
expected_sum = rotary_dim // 2 expected_sum = rotary_dim // 2
actual_sum = sum(self.mrope_section) actual_sum = sum(self.mrope_section)
...@@ -1086,15 +1099,18 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1086,15 +1099,18 @@ class MRotaryEmbedding(RotaryEmbedding):
cos, sin = cos_sin.chunk(2, dim=-1) cos, sin = cos_sin.chunk(2, dim=-1)
if positions.ndim == 2: if positions.ndim == 2:
assert self.mrope_section assert self.mrope_section
if self.mrope_interleaved:
cos = torch.cat( cos = apply_interleaved_rope(cos, self.mrope_section)
[m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))], sin = apply_interleaved_rope(sin, self.mrope_section)
dim=-1, else:
) cos = torch.cat(
sin = torch.cat( [m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
[m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))], dim=-1,
dim=-1, )
) sin = torch.cat(
[m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
dim=-1,
)
query_shape = query.shape query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size) query = query.view(num_tokens, -1, self.head_size)
...@@ -1768,6 +1784,7 @@ def get_rope( ...@@ -1768,6 +1784,7 @@ def get_rope(
is_neox_style, is_neox_style,
dtype, dtype,
mrope_section=rope_scaling["mrope_section"], mrope_section=rope_scaling["mrope_section"],
mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
) )
else: else:
rotary_emb = RotaryEmbedding( rotary_emb = RotaryEmbedding(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment