Unverified Commit 853d4eb9 authored by Nicolas Patry's avatar Nicolas Patry
Browse files

Hotfixing after refactor.

parent fb2f74e2
...@@ -355,7 +355,7 @@ class Block(nn.Module): ...@@ -355,7 +355,7 @@ class Block(nn.Module):
self.ln_2 = FastLayerNorm.load( self.ln_2 = FastLayerNorm.load(
prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon
) )
self.attn = FlashMQAttention( self.self_attn = FlashMQAttention(
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
config=config, config=config,
weights=weights, weights=weights,
...@@ -378,7 +378,7 @@ class Block(nn.Module): ...@@ -378,7 +378,7 @@ class Block(nn.Module):
max_s, max_s,
): ):
hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn( hidden_states = self.self_attn(
hidden_states, hidden_states,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
...@@ -412,7 +412,7 @@ class FlashSantacoderModel(nn.Module): ...@@ -412,7 +412,7 @@ class FlashSantacoderModel(nn.Module):
reduce=False, reduce=False,
) )
self.h = nn.ModuleList( self.layers = nn.ModuleList(
[ [
Block( Block(
layer_id, layer_id,
...@@ -426,8 +426,8 @@ class FlashSantacoderModel(nn.Module): ...@@ -426,8 +426,8 @@ class FlashSantacoderModel(nn.Module):
prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon
) )
self.head_size = self.h[0].attn.head_size self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.h[0].attn.num_heads self.num_heads = self.layers[0].self_attn.num_heads
def forward( def forward(
self, self,
...@@ -446,7 +446,7 @@ class FlashSantacoderModel(nn.Module): ...@@ -446,7 +446,7 @@ class FlashSantacoderModel(nn.Module):
torch.distributed.all_reduce(hidden_states, group=self.process_group) torch.distributed.all_reduce(hidden_states, group=self.process_group)
residual = None residual = None
for i, layer in enumerate(self.h): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
residual, residual,
...@@ -467,7 +467,7 @@ class FlashSantacoderForCausalLM(nn.Module): ...@@ -467,7 +467,7 @@ class FlashSantacoderForCausalLM(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
config.transpose = config.architectures[0].startswith("GPT2") config.transpose = config.architectures[0].startswith("GPT2")
self.transformer = FlashSantacoderModel(config, weights) self.model = FlashSantacoderModel(config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="transformer.wte", weights=weights config, prefix="transformer.wte", weights=weights
) )
...@@ -486,7 +486,7 @@ class FlashSantacoderForCausalLM(nn.Module): ...@@ -486,7 +486,7 @@ class FlashSantacoderForCausalLM(nn.Module):
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer( hidden_states = self.model(
input_ids, input_ids,
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
......
...@@ -60,7 +60,7 @@ class Model(ABC): ...@@ -60,7 +60,7 @@ class Model(ABC):
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict( self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
LayerAdapterWeights LayerAdapterWeights
) )
self.target_to_layer = self.adapter_target_to_layer() self.target_to_layer = None
self.loaded_adapters = set() self.loaded_adapters = set()
self.static_adapter_id = adapter_id self.static_adapter_id = adapter_id
...@@ -187,6 +187,8 @@ class Model(ABC): ...@@ -187,6 +187,8 @@ class Model(ABC):
into model. Otherwise, the adapter weights are applied during the forward into model. Otherwise, the adapter weights are applied during the forward
pass and stored separately from the base model parameters. pass and stored separately from the base model parameters.
""" """
if self.target_to_layer is None:
self.target_to_layer = self.adapter_target_to_layer()
if adapter_index in self.loaded_adapters: if adapter_index in self.loaded_adapters:
# Adapter already loaded # Adapter already loaded
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment