Unverified Commit 4d542402 authored by Shawn Tan's avatar Shawn Tan Committed by GitHub
Browse files

[Feature]:Allow for Granite MoE Hybrid models with _only_ shared experts. (#19652)


Signed-off-by: default avatarShawn Tan <shawntan@ibm.com>
parent 3e750697
...@@ -67,6 +67,8 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module): ...@@ -67,6 +67,8 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
activation=config.hidden_act, activation=config.hidden_act,
quant_config=quant_config) quant_config=quant_config)
self.block_sparse_moe = None
if getattr(config, "num_local_experts", 0) > 0:
self.block_sparse_moe = GraniteMoeMoE( self.block_sparse_moe = GraniteMoeMoE(
num_experts=config.num_local_experts, num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
...@@ -105,13 +107,19 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module): ...@@ -105,13 +107,19 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.post_attention_layernorm(hidden_states)
if self.shared_mlp is None: if self.shared_mlp is None:
if self.block_sparse_moe is not None:
hidden_states = self.block_sparse_moe(hidden_states) hidden_states = self.block_sparse_moe(hidden_states)
# else: skip
else: else:
# create a copy since block_sparse_moe modifies in-place # create a copy since block_sparse_moe modifies in-place
if self.block_sparse_moe is not None:
moe_hidden_states = hidden_states.clone() moe_hidden_states = hidden_states.clone()
moe_hidden_states = self.block_sparse_moe(moe_hidden_states) moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) hidden_states = moe_hidden_states + self.shared_mlp(
hidden_states)
del moe_hidden_states del moe_hidden_states
else:
hidden_states = self.shared_mlp(hidden_states)
hidden_states = residual + hidden_states * self.residual_multiplier hidden_states = residual + hidden_states * self.residual_multiplier
return hidden_states, residual return hidden_states, residual
...@@ -137,6 +145,8 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module): ...@@ -137,6 +145,8 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn") prefix=f"{prefix}.self_attn")
self.block_sparse_moe = None
if getattr(config, "num_local_experts", 0) > 0:
self.block_sparse_moe = GraniteMoeMoE( self.block_sparse_moe = GraniteMoeMoE(
num_experts=config.num_local_experts, num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
...@@ -178,13 +188,19 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module): ...@@ -178,13 +188,19 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.post_attention_layernorm(hidden_states)
if self.shared_mlp is None: if self.shared_mlp is None:
if self.block_sparse_moe is not None:
hidden_states = self.block_sparse_moe(hidden_states) hidden_states = self.block_sparse_moe(hidden_states)
# else: skip
else: else:
# create a copy since block_sparse_moe modifies in-place # create a copy since block_sparse_moe modifies in-place
if self.block_sparse_moe is not None:
moe_hidden_states = hidden_states.clone() moe_hidden_states = hidden_states.clone()
moe_hidden_states = self.block_sparse_moe(moe_hidden_states) moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) hidden_states = moe_hidden_states + self.shared_mlp(
hidden_states)
del moe_hidden_states del moe_hidden_states
else:
hidden_states = self.shared_mlp(hidden_states)
hidden_states = residual + hidden_states * self.residual_multiplier hidden_states = residual + hidden_states * self.residual_multiplier
return hidden_states, residual return hidden_states, residual
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment