Commit 470dc415 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.11.0-dev-wm-0119' into 'v0.11.0-dev'

[fix]解决gpt oss nn moe权重加载出错

See merge request dcutoolkit/deeplearing/vllm!372
parents 6216b12d 4d70732e
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from collections.abc import Iterable from collections.abc import Iterable
from typing import Optional from typing import Optional
...@@ -253,6 +254,7 @@ class GptOssModel(nn.Module): ...@@ -253,6 +254,7 @@ class GptOssModel(nn.Module):
make_empty_intermediate_tensors_factory( make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], self.config.hidden_size)) ["hidden_states", "residual"], self.config.hidden_size))
self.aux_hidden_state_layers = tuple[int, ...]() self.aux_hidden_state_layers = tuple[int, ...]()
self.use_nn_moe = int(os.environ.get('MOE_NN', 1)) == 1
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embedding(input_ids) return self.embedding(input_ids)
...@@ -524,7 +526,9 @@ class GptOssModel(nn.Module): ...@@ -524,7 +526,9 @@ class GptOssModel(nn.Module):
narrow_weight = weight[:, :, narrow_weight = weight[:, :,
2 * tp_rank_start:2 * tp_rank_end] 2 * tp_rank_start:2 * tp_rank_end]
if not self.use_nn_moe:
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
param = params_dict[name] param = params_dict[name]
param.copy_(narrow_weight) param.copy_(narrow_weight)
...@@ -536,7 +540,10 @@ class GptOssModel(nn.Module): ...@@ -536,7 +540,10 @@ class GptOssModel(nn.Module):
narrow_weight = weight[ep_rank_start:ep_rank_end, ...] narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else: else:
narrow_weight = weight[:, tp_rank_start:tp_rank_end, :] narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
if not self.use_nn_moe:
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
param = params_dict[name] param = params_dict[name]
param.copy_(narrow_weight) param.copy_(narrow_weight)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment