"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "a47e6ffe9366516ea5ca28e27fc87367a869e854"
Unverified Commit a9d18b51 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Bugfix] Fix gpt_oss packed_modules_mapping (#28536)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent edb59a94
...@@ -92,7 +92,7 @@ class OAIAttention(nn.Module): ...@@ -92,7 +92,7 @@ class OAIAttention(nn.Module):
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = config.rope_theta self.rope_theta = config.rope_theta
self.qkv = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
head_size=self.head_dim, head_size=self.head_dim,
total_num_heads=self.num_attention_heads, total_num_heads=self.num_attention_heads,
...@@ -129,7 +129,7 @@ class OAIAttention(nn.Module): ...@@ -129,7 +129,7 @@ class OAIAttention(nn.Module):
def forward( def forward(
self, hidden_states: torch.Tensor, positions: torch.Tensor self, hidden_states: torch.Tensor, positions: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
v = v.contiguous() v = v.contiguous()
...@@ -606,9 +606,9 @@ class GptOssModel(nn.Module): ...@@ -606,9 +606,9 @@ class GptOssModel(nn.Module):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
(".qkv", ".k_proj", "k"), (".qkv_proj", ".k_proj", "k"),
(".qkv", ".v_proj", "v"), (".qkv_proj", ".v_proj", "v"),
] ]
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment