Unverified Commit e694e985 authored by xkszltl's avatar xkszltl Committed by GitHub
Browse files

Fix typo of `Block`. (#28727)

parent 9e8f35fa
......@@ -787,7 +787,7 @@ MIXTRAL_ATTENTION_CLASSES = {
}
class MixtralBLockSparseTop2MLP(nn.Module):
class MixtralBlockSparseTop2MLP(nn.Module):
def __init__(self, config: MixtralConfig):
super().__init__()
self.ffn_dim = config.intermediate_size
......@@ -805,6 +805,14 @@ class MixtralBLockSparseTop2MLP(nn.Module):
return current_hidden_states
class MixtralBLockSparseTop2MLP(MixtralBlockSparseTop2MLP):
def __init__(self, *args, **kwargs):
logger.warning_once(
"MixtralBLockSparseTop2MLP is deprecated by MixtralBlockSparseTop2MLP and will be removed in v4.40."
)
super().__init__(*args, **kwargs)
class MixtralSparseMoeBlock(nn.Module):
"""
This implementation is
......@@ -827,7 +835,7 @@ class MixtralSparseMoeBlock(nn.Module):
# gating
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
self.experts = nn.ModuleList([MixtralBLockSparseTop2MLP(config) for _ in range(self.num_experts)])
self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment