Unverified Commit 93f2c0aa authored by Lukas Geiger's avatar Lukas Geiger Committed by GitHub
Browse files

[Models] Improve iteration over layers (#26425)


Signed-off-by: default avatarLukas Geiger <lukas.geiger94@gmail.com>
parent 4ebc9108
......@@ -26,6 +26,7 @@
"""Inference-only Apertus model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union
import torch
......@@ -412,7 +413,9 @@ class ApertusModel(nn.Module):
residual = intermediate_tensors["residual"]
aux_hidden_states = []
for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]):
for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual)
......
......@@ -3,6 +3,7 @@
"""Inference-only FalconH1 model."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional
import torch
......@@ -480,8 +481,7 @@ class FalconH1Model(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(
positions=positions,
hidden_states=hidden_states,
......
......@@ -26,6 +26,7 @@
import typing
from collections.abc import Callable, Iterable
from itertools import islice
from typing import Any, Optional, Union
import regex as re
......@@ -672,8 +673,9 @@ class HunYuanModel(nn.Module):
cla_factor = _get_cla_factor(self.config)
prev_kv_states = None
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
for i, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
hidden_states, residual, kv_states = layer(
positions,
hidden_states,
......@@ -681,10 +683,7 @@ class HunYuanModel(nn.Module):
prev_kv_states,
)
if (
getattr(self.config, "use_cla", False)
and (i - self.start_layer) % cla_factor == 0
):
if getattr(self.config, "use_cla", False) and i % cla_factor == 0:
prev_kv_states = kv_states
else:
prev_kv_states = None
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional
import torch
......@@ -492,7 +493,7 @@ class Lfm2MoeModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer : self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
......
......@@ -35,6 +35,7 @@
import typing
from collections.abc import Callable, Iterable
from itertools import islice
from typing import Optional, Union
import torch
......@@ -519,8 +520,7 @@ class FlashModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions,
hidden_states,
......
......@@ -3,6 +3,7 @@
"""PyTorch MAMBA model."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional
import torch
......@@ -162,8 +163,7 @@ class MambaModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions=positions, hidden_states=hidden_states, residual=residual
)
......
......@@ -26,6 +26,7 @@
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from itertools import islice
from typing import Any, Callable, Optional, Union
import numpy as np
......@@ -1106,11 +1107,9 @@ class Qwen3LLMModel(Qwen3Model):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer_idx, layer in enumerate(
self.layers[self.start_layer : self.end_layer]
for layer_idx, layer in islice(
enumerate(self.layers), self.start_layer, self.end_layer
):
layer_idx = layer_idx + self.start_layer
hidden_states, residual = layer(
positions,
hidden_states,
......
......@@ -26,6 +26,7 @@
import typing
from collections.abc import Iterable
from itertools import islice
from typing import Callable, Optional, Union
import torch
......@@ -103,11 +104,9 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer_idx, layer in enumerate(
self.layers[self.start_layer : self.end_layer]
for layer_idx, layer in islice(
enumerate(self.layers), self.start_layer, self.end_layer
):
layer_idx = layer_idx + self.start_layer
hidden_states, residual = layer(
positions,
hidden_states,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment