"examples/vscode:/vscode.git/clone" did not exist on "6c4dbe23eb85e5d1da00ccaf4923a275d8769a7f"
Unverified Commit 93f2c0aa authored by Lukas Geiger's avatar Lukas Geiger Committed by GitHub
Browse files

[Models] Improve iteration over layers (#26425)


Signed-off-by: default avatarLukas Geiger <lukas.geiger94@gmail.com>
parent 4ebc9108
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
"""Inference-only Apertus model compatible with HuggingFace weights.""" """Inference-only Apertus model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
...@@ -412,7 +413,9 @@ class ApertusModel(nn.Module): ...@@ -412,7 +413,9 @@ class ApertusModel(nn.Module):
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
aux_hidden_states = [] aux_hidden_states = []
for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]): for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
if idx in self.aux_hidden_state_layers: if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual) aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
"""Inference-only FalconH1 model.""" """Inference-only FalconH1 model."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional from typing import Optional
import torch import torch
...@@ -480,8 +481,7 @@ class FalconH1Model(nn.Module): ...@@ -480,8 +481,7 @@ class FalconH1Model(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer( hidden_states = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
import typing import typing
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import regex as re import regex as re
...@@ -672,8 +673,9 @@ class HunYuanModel(nn.Module): ...@@ -672,8 +673,9 @@ class HunYuanModel(nn.Module):
cla_factor = _get_cla_factor(self.config) cla_factor = _get_cla_factor(self.config)
prev_kv_states = None prev_kv_states = None
for i in range(self.start_layer, self.end_layer): for i, layer in enumerate(
layer = self.layers[i] islice(self.layers, self.start_layer, self.end_layer)
):
hidden_states, residual, kv_states = layer( hidden_states, residual, kv_states = layer(
positions, positions,
hidden_states, hidden_states,
...@@ -681,10 +683,7 @@ class HunYuanModel(nn.Module): ...@@ -681,10 +683,7 @@ class HunYuanModel(nn.Module):
prev_kv_states, prev_kv_states,
) )
if ( if getattr(self.config, "use_cla", False) and i % cla_factor == 0:
getattr(self.config, "use_cla", False)
and (i - self.start_layer) % cla_factor == 0
):
prev_kv_states = kv_states prev_kv_states = kv_states
else: else:
prev_kv_states = None prev_kv_states = None
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional from typing import Any, Optional
import torch import torch
...@@ -492,7 +493,7 @@ class Lfm2MoeModel(nn.Module): ...@@ -492,7 +493,7 @@ class Lfm2MoeModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer : self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
import typing import typing
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -519,8 +520,7 @@ class FlashModel(nn.Module): ...@@ -519,8 +520,7 @@ class FlashModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
"""PyTorch MAMBA model.""" """PyTorch MAMBA model."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional from typing import Optional
import torch import torch
...@@ -162,8 +163,7 @@ class MambaModel(nn.Module): ...@@ -162,8 +163,7 @@ class MambaModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, hidden_states=hidden_states, residual=residual positions=positions, hidden_states=hidden_states, residual=residual
) )
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import partial from functools import partial
from itertools import islice
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
import numpy as np import numpy as np
...@@ -1106,11 +1107,9 @@ class Qwen3LLMModel(Qwen3Model): ...@@ -1106,11 +1107,9 @@ class Qwen3LLMModel(Qwen3Model):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer_idx, layer in enumerate( for layer_idx, layer in islice(
self.layers[self.start_layer : self.end_layer] enumerate(self.layers), self.start_layer, self.end_layer
): ):
layer_idx = layer_idx + self.start_layer
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
import typing import typing
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Callable, Optional, Union from typing import Callable, Optional, Union
import torch import torch
...@@ -103,11 +104,9 @@ class Qwen3MoeLLMModel(Qwen3MoeModel): ...@@ -103,11 +104,9 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer_idx, layer in enumerate( for layer_idx, layer in islice(
self.layers[self.start_layer : self.end_layer] enumerate(self.layers), self.start_layer, self.end_layer
): ):
layer_idx = layer_idx + self.start_layer
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment