Unverified Commit de533ab2 authored by Lukas Geiger's avatar Lukas Geiger Committed by GitHub
Browse files

[Models] Improve iteration over layers (#19497)


Signed-off-by: default avatarLukas Geiger <lukas.geiger94@gmail.com>
parent 235c9db8
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -398,7 +399,7 @@ class Gemma3Model(nn.Module): ...@@ -398,7 +399,7 @@ class Gemma3Model(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
"""Inference-only GLM-4.5 model compatible with HuggingFace weights.""" """Inference-only GLM-4.5 model compatible with HuggingFace weights."""
import typing import typing
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
...@@ -440,8 +441,7 @@ class Glm4MoeModel(nn.Module): ...@@ -440,8 +441,7 @@ class Glm4MoeModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only GPT-2 model compatible with HuggingFace weights.""" """Inference-only GPT-2 model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -228,7 +229,7 @@ class GPT2Model(nn.Module): ...@@ -228,7 +229,7 @@ class GPT2Model(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for layer in self.h[self.start_layer:self.end_layer]: for layer in islice(self.h, self.start_layer, self.end_layer):
hidden_states = layer(hidden_states) hidden_states = layer(hidden_states)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only GPTBigCode model compatible with HuggingFace weights.""" """Inference-only GPTBigCode model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -246,7 +247,7 @@ class GPTBigCodeModel(nn.Module): ...@@ -246,7 +247,7 @@ class GPTBigCodeModel(nn.Module):
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for layer in self.h[self.start_layer:self.end_layer]: for layer in islice(self.h, self.start_layer, self.end_layer):
hidden_states = layer(hidden_states) hidden_states = layer(hidden_states)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only GPT-J model compatible with HuggingFace weights.""" """Inference-only GPT-J model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -223,7 +224,7 @@ class GPTJModel(nn.Module): ...@@ -223,7 +224,7 @@ class GPTJModel(nn.Module):
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.get_input_embeddings(input_ids)
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for layer in self.h[self.start_layer:self.end_layer]: for layer in islice(self.h, self.start_layer, self.end_layer):
hidden_states = layer(position_ids, hidden_states) hidden_states = layer(position_ids, hidden_states)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only GPT-NeoX model compatible with HuggingFace weights.""" """Inference-only GPT-NeoX model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -235,7 +236,7 @@ class GPTNeoXModel(nn.Module): ...@@ -235,7 +236,7 @@ class GPTNeoXModel(nn.Module):
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.get_input_embeddings(input_ids)
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(position_ids, hidden_states) hidden_states = layer(position_ids, hidden_states)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only IBM Granite model compatible with HuggingFace weights.""" """Inference-only IBM Granite model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
...@@ -316,7 +317,7 @@ class GraniteModel(nn.Module): ...@@ -316,7 +317,7 @@ class GraniteModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(positions, hidden_states) hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only GraniteMoe model.""" """Inference-only GraniteMoe model."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional from typing import Any, Optional
import torch import torch
...@@ -303,7 +304,7 @@ class GraniteMoeModel(nn.Module): ...@@ -303,7 +304,7 @@ class GraniteMoeModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(positions, hidden_states) hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
......
...@@ -397,8 +397,7 @@ class GraniteMoeHybridModel(nn.Module): ...@@ -397,8 +397,7 @@ class GraniteMoeHybridModel(nn.Module):
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
num_attn = 0 num_attn = 0
for i in range(len(self.layers)): for i, layer in enumerate(self.layers):
layer = self.layers[i]
if isinstance(layer, GraniteMoeHybridAttentionDecoderLayer): if isinstance(layer, GraniteMoeHybridAttentionDecoderLayer):
num_attn += 1 num_attn += 1
......
...@@ -6,6 +6,7 @@ The architecture is the same as granitemoe but with the addition of shared ...@@ -6,6 +6,7 @@ The architecture is the same as granitemoe but with the addition of shared
experts. experts.
""" """
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional from typing import Optional
import torch import torch
...@@ -200,8 +201,7 @@ class GraniteMoeSharedModel(nn.Module): ...@@ -200,8 +201,7 @@ class GraniteMoeSharedModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(positions, hidden_states) hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Grok1 model.""" """Inference-only Grok1 model."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -347,8 +348,7 @@ class Grok1Model(nn.Module): ...@@ -347,8 +348,7 @@ class Grok1Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from collections.abc import Iterable from collections.abc import Iterable
from functools import partial from functools import partial
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
...@@ -297,7 +298,7 @@ class InternLM2Model(nn.Module): ...@@ -297,7 +298,7 @@ class InternLM2Model(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -123,7 +124,7 @@ class InternLM2VEModel(InternLM2Model): ...@@ -123,7 +124,7 @@ class InternLM2VEModel(InternLM2Model):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
import math import math
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -276,7 +277,7 @@ class JAISModel(nn.Module): ...@@ -276,7 +277,7 @@ class JAISModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for layer in self.h[self.start_layer:self.end_layer]: for layer in islice(self.h, self.start_layer, self.end_layer):
hidden_states = layer(hidden_states) hidden_states = layer(hidden_states)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Jamba model.""" """Inference-only Jamba model."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional from typing import Optional
import torch import torch
...@@ -350,7 +351,7 @@ class JambaModel(nn.Module): ...@@ -350,7 +351,7 @@ class JambaModel(nn.Module):
kv_cache_index = 0 kv_cache_index = 0
mamba_cache_index = 0 mamba_cache_index = 0
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
layer_mamba_cache_params = None layer_mamba_cache_params = None
if isinstance(layer, JambaAttentionDecoderLayer): if isinstance(layer, JambaAttentionDecoderLayer):
kv_cache_index += 1 kv_cache_index += 1
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional from typing import Any, Optional
import torch import torch
...@@ -374,7 +375,7 @@ class Lfm2Model(nn.Module): ...@@ -374,7 +375,7 @@ class Lfm2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights.""" """Inference-only LLaMA model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
...@@ -383,7 +384,7 @@ class LlamaModel(nn.Module): ...@@ -383,7 +384,7 @@ class LlamaModel(nn.Module):
aux_hidden_states = [] aux_hidden_states = []
for idx, layer in enumerate( for idx, layer in enumerate(
self.layers[self.start_layer:self.end_layer]): islice(self.layers, self.start_layer, self.end_layer)):
if idx in self.aux_hidden_state_layers: if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual) aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
......
...@@ -164,9 +164,7 @@ class Mamba2Model(nn.Module): ...@@ -164,9 +164,7 @@ class Mamba2Model(nn.Module):
# v1 get mamba2_metadata from forward_context # v1 get mamba2_metadata from forward_context
mamba2_metadata = None mamba2_metadata = None
for i in range(len(self.layers)): for i, layer in enumerate(self.layers):
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only MiMo model compatible with HuggingFace weights.""" """Inference-only MiMo model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -74,7 +75,7 @@ class MiMoModel(Qwen2Model): ...@@ -74,7 +75,7 @@ class MiMoModel(Qwen2Model):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
"""Inference-only MiniCPM model compatible with HuggingFace weights.""" """Inference-only MiniCPM model compatible with HuggingFace weights."""
import math import math
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
...@@ -414,7 +415,7 @@ class MiniCPMModel(nn.Module): ...@@ -414,7 +415,7 @@ class MiniCPMModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment