Unverified Commit de533ab2 authored by Lukas Geiger's avatar Lukas Geiger Committed by GitHub
Browse files

[Models] Improve iteration over layers (#19497)


Signed-off-by: default avatarLukas Geiger <lukas.geiger94@gmail.com>
parent 235c9db8
......@@ -16,6 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
import torch
......@@ -398,7 +399,7 @@ class Gemma3Model(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions,
hidden_states,
......
......@@ -24,6 +24,7 @@
"""Inference-only GLM-4.5 model compatible with HuggingFace weights."""
import typing
from collections.abc import Callable, Iterable
from itertools import islice
from typing import Any, Optional, Union
import torch
......@@ -440,8 +441,7 @@ class Glm4MoeModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
......
......@@ -20,6 +20,7 @@
# limitations under the License.
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
import torch
......@@ -228,7 +229,7 @@ class GPT2Model(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for layer in self.h[self.start_layer:self.end_layer]:
for layer in islice(self.h, self.start_layer, self.end_layer):
hidden_states = layer(hidden_states)
if not get_pp_group().is_last_rank:
......
......@@ -21,6 +21,7 @@
# limitations under the License.
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
import torch
......@@ -246,7 +247,7 @@ class GPTBigCodeModel(nn.Module):
else:
hidden_states = intermediate_tensors["hidden_states"]
for layer in self.h[self.start_layer:self.end_layer]:
for layer in islice(self.h, self.start_layer, self.end_layer):
hidden_states = layer(hidden_states)
if not get_pp_group().is_last_rank:
......
......@@ -19,6 +19,7 @@
# limitations under the License.
"""Inference-only GPT-J model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
import torch
......@@ -223,7 +224,7 @@ class GPTJModel(nn.Module):
hidden_states = self.get_input_embeddings(input_ids)
else:
hidden_states = intermediate_tensors["hidden_states"]
for layer in self.h[self.start_layer:self.end_layer]:
for layer in islice(self.h, self.start_layer, self.end_layer):
hidden_states = layer(position_ids, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
......@@ -336,4 +337,4 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
\ No newline at end of file
return loader.load_weights(weights)
......@@ -19,6 +19,7 @@
# limitations under the License.
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
import torch
......@@ -235,7 +236,7 @@ class GPTNeoXModel(nn.Module):
hidden_states = self.get_input_embeddings(input_ids)
else:
hidden_states = intermediate_tensors["hidden_states"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(position_ids, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
......
......@@ -24,6 +24,7 @@
# limitations under the License.
"""Inference-only IBM Granite model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union
import torch
......@@ -316,7 +317,7 @@ class GraniteModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank:
......
......@@ -24,6 +24,7 @@
# limitations under the License.
"""Inference-only GraniteMoe model."""
from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional
import torch
......@@ -303,7 +304,7 @@ class GraniteMoeModel(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
......
......@@ -397,8 +397,7 @@ class GraniteMoeHybridModel(nn.Module):
residual = intermediate_tensors["residual"]
num_attn = 0
for i in range(len(self.layers)):
layer = self.layers[i]
for i, layer in enumerate(self.layers):
if isinstance(layer, GraniteMoeHybridAttentionDecoderLayer):
num_attn += 1
......
......@@ -6,6 +6,7 @@ The architecture is the same as granitemoe but with the addition of shared
experts.
"""
from collections.abc import Iterable
from itertools import islice
from typing import Optional
import torch
......@@ -200,8 +201,7 @@ class GraniteMoeSharedModel(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
......
......@@ -23,6 +23,7 @@
# limitations under the License.
"""Inference-only Grok1 model."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
import torch
......@@ -347,8 +348,7 @@ class Grok1Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
......
......@@ -3,6 +3,7 @@
from collections.abc import Iterable
from functools import partial
from itertools import islice
from typing import Any, Optional, Union
import torch
......@@ -297,7 +298,7 @@ class InternLM2Model(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from itertools import islice
from typing import Optional, Union
import torch
......@@ -123,7 +124,7 @@ class InternLM2VEModel(InternLM2Model):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions,
hidden_states,
......
......@@ -23,6 +23,7 @@
import math
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
import torch
......@@ -276,7 +277,7 @@ class JAISModel(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for layer in self.h[self.start_layer:self.end_layer]:
for layer in islice(self.h, self.start_layer, self.end_layer):
hidden_states = layer(hidden_states)
if not get_pp_group().is_last_rank:
......
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Jamba model."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional
import torch
......@@ -350,7 +351,7 @@ class JambaModel(nn.Module):
kv_cache_index = 0
mamba_cache_index = 0
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
layer_mamba_cache_params = None
if isinstance(layer, JambaAttentionDecoderLayer):
kv_cache_index += 1
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional
import torch
......@@ -374,7 +375,7 @@ class Lfm2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
......@@ -554,4 +555,4 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
\ No newline at end of file
return loader.load_weights(weights)
......@@ -24,6 +24,7 @@
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union
import torch
......@@ -383,7 +384,7 @@ class LlamaModel(nn.Module):
aux_hidden_states = []
for idx, layer in enumerate(
self.layers[self.start_layer:self.end_layer]):
islice(self.layers, self.start_layer, self.end_layer)):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual)
......
......@@ -164,9 +164,7 @@ class Mamba2Model(nn.Module):
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
for i in range(len(self.layers)):
layer = self.layers[i]
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
......
......@@ -26,6 +26,7 @@
# limitations under the License.
"""Inference-only MiMo model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
import torch
......@@ -74,7 +75,7 @@ class MiMoModel(Qwen2Model):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions,
hidden_states,
......
......@@ -25,6 +25,7 @@
"""Inference-only MiniCPM model compatible with HuggingFace weights."""
import math
from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union
import torch
......@@ -414,7 +415,7 @@ class MiniCPMModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions,
hidden_states,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment