Commit 30620a9a authored by OlivierDehaene's avatar OlivierDehaene
Browse files

hotfix: mixtral

parent ad9d6288
...@@ -464,9 +464,9 @@ class DenseMoE(nn.Module): ...@@ -464,9 +464,9 @@ class DenseMoE(nn.Module):
class MixtralLayer(nn.Module): class MixtralLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}" prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = MixtralAttention( self.self_attn = MixtralAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
...@@ -525,16 +525,20 @@ class MixtralLayer(nn.Module): ...@@ -525,16 +525,20 @@ class MixtralLayer(nn.Module):
class MixtralModel(torch.nn.Module): class MixtralModel(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights prefix=(
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
),
weights=weights,
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
MixtralLayer( MixtralLayer(
"model" if not prefix else f"{prefix}.model",
layer_id, layer_id,
config, config,
weights, weights,
...@@ -543,7 +547,9 @@ class MixtralModel(torch.nn.Module): ...@@ -543,7 +547,9 @@ class MixtralModel(torch.nn.Module):
] ]
) )
self.norm = FastRMSNorm.load( self.norm = FastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps prefix="model.norm" if not prefix else f"{prefix}.model.norm",
weights=weights,
eps=config.rms_norm_eps,
) )
self.head_size = self.layers[0].self_attn.head_size self.head_size = self.layers[0].self_attn.head_size
...@@ -593,13 +599,13 @@ class MixtralModel(torch.nn.Module): ...@@ -593,13 +599,13 @@ class MixtralModel(torch.nn.Module):
class FlashMixtralForCausalLM(torch.nn.Module): class FlashMixtralForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.model = MixtralModel(config, weights) self.model = MixtralModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head" if not prefix else f"{prefix}.lm_head",
weights=weights, weights=weights,
) )
self.max_past = config.sliding_window self.max_past = config.sliding_window
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment