Unverified Commit 66318ffe authored by Jani Monoses's avatar Jani Monoses Committed by GitHub
Browse files

Rename layer_idx to layer_id for consistency (#2078)

parent 76619261
...@@ -97,7 +97,7 @@ class Gemma2MLP(nn.Module): ...@@ -97,7 +97,7 @@ class Gemma2MLP(nn.Module):
class Gemma2Attention(nn.Module): class Gemma2Attention(nn.Module):
def __init__( def __init__(
self, self,
layer_idx: int, layer_id: int,
config: PretrainedConfig, config: PretrainedConfig,
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
...@@ -109,7 +109,7 @@ class Gemma2Attention(nn.Module): ...@@ -109,7 +109,7 @@ class Gemma2Attention(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_idx = layer_idx self.layer_id = layer_id
self.config = config self.config = config
self.hidden_size = hidden_size self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
...@@ -156,13 +156,13 @@ class Gemma2Attention(nn.Module): ...@@ -156,13 +156,13 @@ class Gemma2Attention(nn.Module):
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
) )
use_sliding_window = layer_idx % 2 == 0 and hasattr(config, "sliding_window") use_sliding_window = layer_id % 2 == 0 and hasattr(config, "sliding_window")
self.attn = RadixAttention( self.attn = RadixAttention(
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_idx, layer_id=layer_id,
logit_cap=self.config.attn_logit_softcapping, logit_cap=self.config.attn_logit_softcapping,
sliding_window_size=( sliding_window_size=(
get_attention_sliding_window_size(config) get_attention_sliding_window_size(config)
...@@ -188,7 +188,7 @@ class Gemma2Attention(nn.Module): ...@@ -188,7 +188,7 @@ class Gemma2Attention(nn.Module):
class Gemma2DecoderLayer(nn.Module): class Gemma2DecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
layer_idx: int, layer_id: int,
config: PretrainedConfig, config: PretrainedConfig,
cache_config=None, cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
...@@ -196,7 +196,7 @@ class Gemma2DecoderLayer(nn.Module): ...@@ -196,7 +196,7 @@ class Gemma2DecoderLayer(nn.Module):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = Gemma2Attention( self.self_attn = Gemma2Attention(
layer_idx=layer_idx, layer_id=layer_id,
config=config, config=config,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
...@@ -269,8 +269,8 @@ class Gemma2Model(nn.Module): ...@@ -269,8 +269,8 @@ class Gemma2Model(nn.Module):
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config) Gemma2DecoderLayer(layer_id, config, cache_config, quant_config)
for layer_idx in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......
...@@ -223,8 +223,8 @@ class OlmoModel(nn.Module): ...@@ -223,8 +223,8 @@ class OlmoModel(nn.Module):
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
OlmoDecoderLayer(config, layer_idx, quant_config) OlmoDecoderLayer(config, layer_id, quant_config)
for layer_idx in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.norm = nn.LayerNorm( self.norm = nn.LayerNorm(
...@@ -250,7 +250,7 @@ class OlmoModel(nn.Module): ...@@ -250,7 +250,7 @@ class OlmoModel(nn.Module):
hidden_states = input_embeds hidden_states = input_embeds
# Apply blocks one-by-one. # Apply blocks one-by-one.
for layer_idx, decoder_layer in enumerate(self.layers): for layer_id, decoder_layer in enumerate(self.layers):
# shape: (batch_size, seq_len, d_model) # shape: (batch_size, seq_len, d_model)
hidden_states = decoder_layer( hidden_states = decoder_layer(
positions, positions,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment