Unverified Commit 4d542402 authored by Shawn Tan's avatar Shawn Tan Committed by GitHub
Browse files

[Feature]:Allow for Granite MoE Hybrid models with _only_ shared experts. (#19652)


Signed-off-by: default avatarShawn Tan <shawntan@ibm.com>
parent 3e750697
...@@ -67,13 +67,15 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module): ...@@ -67,13 +67,15 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
activation=config.hidden_act, activation=config.hidden_act,
quant_config=quant_config) quant_config=quant_config)
self.block_sparse_moe = GraniteMoeMoE( self.block_sparse_moe = None
num_experts=config.num_local_experts, if getattr(config, "num_local_experts", 0) > 0:
top_k=config.num_experts_per_tok, self.block_sparse_moe = GraniteMoeMoE(
hidden_size=config.hidden_size, num_experts=config.num_local_experts,
intermediate_size=config.intermediate_size, top_k=config.num_experts_per_tok,
quant_config=quant_config, hidden_size=config.hidden_size,
prefix=f"{prefix}.block_sparse_moe") intermediate_size=config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.block_sparse_moe")
self.shared_mlp = None if \ self.shared_mlp = None if \
getattr(config, 'shared_intermediate_size', 0) == 0 \ getattr(config, 'shared_intermediate_size', 0) == 0 \
...@@ -105,13 +107,19 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module): ...@@ -105,13 +107,19 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.post_attention_layernorm(hidden_states)
if self.shared_mlp is None: if self.shared_mlp is None:
hidden_states = self.block_sparse_moe(hidden_states) if self.block_sparse_moe is not None:
hidden_states = self.block_sparse_moe(hidden_states)
# else: skip
else: else:
# create a copy since block_sparse_moe modifies in-place # create a copy since block_sparse_moe modifies in-place
moe_hidden_states = hidden_states.clone() if self.block_sparse_moe is not None:
moe_hidden_states = self.block_sparse_moe(moe_hidden_states) moe_hidden_states = hidden_states.clone()
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
del moe_hidden_states hidden_states = moe_hidden_states + self.shared_mlp(
hidden_states)
del moe_hidden_states
else:
hidden_states = self.shared_mlp(hidden_states)
hidden_states = residual + hidden_states * self.residual_multiplier hidden_states = residual + hidden_states * self.residual_multiplier
return hidden_states, residual return hidden_states, residual
...@@ -137,13 +145,15 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module): ...@@ -137,13 +145,15 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn") prefix=f"{prefix}.self_attn")
self.block_sparse_moe = GraniteMoeMoE( self.block_sparse_moe = None
num_experts=config.num_local_experts, if getattr(config, "num_local_experts", 0) > 0:
top_k=config.num_experts_per_tok, self.block_sparse_moe = GraniteMoeMoE(
hidden_size=config.hidden_size, num_experts=config.num_local_experts,
intermediate_size=config.intermediate_size, top_k=config.num_experts_per_tok,
quant_config=quant_config, hidden_size=config.hidden_size,
prefix=f"{prefix}.block_sparse_moe") intermediate_size=config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.block_sparse_moe")
self.shared_mlp = None if \ self.shared_mlp = None if \
getattr(config, 'shared_intermediate_size', 0) == 0 \ getattr(config, 'shared_intermediate_size', 0) == 0 \
...@@ -178,13 +188,19 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module): ...@@ -178,13 +188,19 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.post_attention_layernorm(hidden_states)
if self.shared_mlp is None: if self.shared_mlp is None:
hidden_states = self.block_sparse_moe(hidden_states) if self.block_sparse_moe is not None:
hidden_states = self.block_sparse_moe(hidden_states)
# else: skip
else: else:
# create a copy since block_sparse_moe modifies in-place # create a copy since block_sparse_moe modifies in-place
moe_hidden_states = hidden_states.clone() if self.block_sparse_moe is not None:
moe_hidden_states = self.block_sparse_moe(moe_hidden_states) moe_hidden_states = hidden_states.clone()
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
del moe_hidden_states hidden_states = moe_hidden_states + self.shared_mlp(
hidden_states)
del moe_hidden_states
else:
hidden_states = self.shared_mlp(hidden_states)
hidden_states = residual + hidden_states * self.residual_multiplier hidden_states = residual + hidden_states * self.residual_multiplier
return hidden_states, residual return hidden_states, residual
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment