Unverified Commit a9d18b51 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Bugfix] Fix gpt_oss packed_modules_mapping (#28536)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent edb59a94
......@@ -92,7 +92,7 @@ class OAIAttention(nn.Module):
self.scaling = self.head_dim**-0.5
self.rope_theta = config.rope_theta
self.qkv = QKVParallelLinear(
self.qkv_proj = QKVParallelLinear(
hidden_size=self.hidden_size,
head_size=self.head_dim,
total_num_heads=self.num_attention_heads,
......@@ -129,7 +129,7 @@ class OAIAttention(nn.Module):
def forward(
self, hidden_states: torch.Tensor, positions: torch.Tensor
) -> torch.Tensor:
qkv, _ = self.qkv(hidden_states)
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
v = v.contiguous()
......@@ -606,9 +606,9 @@ class GptOssModel(nn.Module):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv", ".q_proj", "q"),
(".qkv", ".k_proj", "k"),
(".qkv", ".v_proj", "v"),
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
]
tp_rank = get_tensor_model_parallel_rank()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment