Unverified Commit c68b5c63 authored by rongfu.leng's avatar rongfu.leng Committed by GitHub
Browse files

[Misc] fix olmoe model layer can't laod in tp gt 1 (#18828)


Signed-off-by: default avatarrongfu.leng <rongfu.leng@daocloud.io>
parent fced7569
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only OLMoE model compatible with HuggingFace weights.""" """Inference-only OLMoE model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from functools import partial
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
...@@ -22,7 +23,10 @@ from transformers import PretrainedConfig ...@@ -22,7 +23,10 @@ from transformers import PretrainedConfig
from vllm.attention import Attention from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.distributed.utils import split_tensor_along_last_dim
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -140,8 +144,11 @@ class OlmoeAttention(nn.Module): ...@@ -140,8 +144,11 @@ class OlmoeAttention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
) )
self.q_norm = RMSNorm(hidden_size, eps=1e-5) self.tp_size = tp_size
self.k_norm = RMSNorm(hidden_size, eps=1e-5) self.tp_rank = get_tensor_model_parallel_rank()
self.q_norm = RMSNorm(self.total_num_heads * self.head_dim, eps=1e-5)
self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim,
eps=1e-5)
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
...@@ -165,6 +172,20 @@ class OlmoeAttention(nn.Module): ...@@ -165,6 +172,20 @@ class OlmoeAttention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn") prefix=f"{prefix}.attn")
def _apply_qk_norm(self, q: torch.Tensor,
k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous())
q = self.q_norm(q)
k = self.k_norm(k)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
return q, k
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -172,7 +193,7 @@ class OlmoeAttention(nn.Module): ...@@ -172,7 +193,7 @@ class OlmoeAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous()) q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment