Unverified Commit 4d542402 authored by Shawn Tan's avatar Shawn Tan Committed by GitHub
Browse files

[Feature]:Allow for Granite MoE Hybrid models with _only_ shared experts. (#19652)


Signed-off-by: default avatarShawn Tan <shawntan@ibm.com>
parent 3e750697
......@@ -67,6 +67,8 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
activation=config.hidden_act,
quant_config=quant_config)
self.block_sparse_moe = None
if getattr(config, "num_local_experts", 0) > 0:
self.block_sparse_moe = GraniteMoeMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
......@@ -105,13 +107,19 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
if self.shared_mlp is None:
if self.block_sparse_moe is not None:
hidden_states = self.block_sparse_moe(hidden_states)
# else: skip
else:
# create a copy since block_sparse_moe modifies in-place
if self.block_sparse_moe is not None:
moe_hidden_states = hidden_states.clone()
moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
hidden_states = moe_hidden_states + self.shared_mlp(
hidden_states)
del moe_hidden_states
else:
hidden_states = self.shared_mlp(hidden_states)
hidden_states = residual + hidden_states * self.residual_multiplier
return hidden_states, residual
......@@ -137,6 +145,8 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.block_sparse_moe = None
if getattr(config, "num_local_experts", 0) > 0:
self.block_sparse_moe = GraniteMoeMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
......@@ -178,13 +188,19 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module):
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
if self.shared_mlp is None:
if self.block_sparse_moe is not None:
hidden_states = self.block_sparse_moe(hidden_states)
# else: skip
else:
# create a copy since block_sparse_moe modifies in-place
if self.block_sparse_moe is not None:
moe_hidden_states = hidden_states.clone()
moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
hidden_states = moe_hidden_states + self.shared_mlp(
hidden_states)
del moe_hidden_states
else:
hidden_states = self.shared_mlp(hidden_states)
hidden_states = residual + hidden_states * self.residual_multiplier
return hidden_states, residual
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment