Unverified Commit de533ab2 authored by Lukas Geiger's avatar Lukas Geiger Committed by GitHub
Browse files

[Models] Improve iteration over layers (#19497)


Signed-off-by: default avatarLukas Geiger <lukas.geiger94@gmail.com>
parent 235c9db8
......@@ -9,6 +9,7 @@
# activation.
from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union
import torch
......@@ -243,7 +244,7 @@ class ArceeModel(nn.Module):
aux_hidden_states: list[torch.Tensor] = []
for idx, layer in enumerate(
self.layers[self.start_layer:self.end_layer]):
islice(self.layers, self.start_layer, self.end_layer)):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(
hidden_states +
......
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Snowflake Arctic model."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
import torch
......@@ -403,7 +404,7 @@ class ArcticModel(nn.Module):
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
......
......@@ -22,6 +22,7 @@
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
import math
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
import torch
......@@ -309,7 +310,7 @@ class BaiChuanModel(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions,
hidden_states,
......
......@@ -24,6 +24,7 @@
# limitations under the License.
"""Inference-only BailingMoE model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
import torch
......@@ -359,8 +360,7 @@ class BailingMoeModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
hidden_states,
position_ids,
......
......@@ -345,8 +345,7 @@ class BambaModel(nn.Module):
residual = None
num_attn = 0
for i in range(len(self.layers)):
layer = self.layers[i]
for i, layer in enumerate(self.layers):
if isinstance(layer, BambaAttentionDecoderLayer):
num_attn += 1
......
......@@ -20,6 +20,7 @@
"""Inference-only BLOOM model compatible with HuggingFace weights."""
import math
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
import torch
......@@ -273,7 +274,7 @@ class BloomModel(nn.Module):
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for layer in self.h[self.start_layer:self.end_layer]:
for layer in islice(self.h, self.start_layer, self.end_layer):
hidden_states = layer(position_ids, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
......
......@@ -3,6 +3,7 @@
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from itertools import islice
from typing import Annotated, Any, Literal, Optional, Union
import torch
......@@ -914,7 +915,7 @@ class ChameleonModel(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions,
hidden_states,
......
......@@ -5,6 +5,7 @@
"""Inference-only ChatGLM model compatible with THUDM weights."""
import json
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
import torch
......@@ -281,7 +282,7 @@ class GLMTransformer(nn.Module):
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
) -> Union[torch.Tensor, IntermediateTensors]:
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(hidden_states=hidden_states,
position_ids=position_ids)
......
......@@ -23,6 +23,7 @@
# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
import torch
......@@ -322,7 +323,7 @@ class CohereModel(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions,
hidden_states,
......
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
import torch
......@@ -359,7 +360,7 @@ class DbrxModel(nn.Module):
else:
assert intermediate_tensors
hidden_states = intermediate_tensors["hidden_states"]
for block in self.blocks[self.start_layer:self.end_layer]:
for block in islice(self.blocks, self.start_layer, self.end_layer):
hidden_states = block(position_ids, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
......
......@@ -24,6 +24,7 @@
# limitations under the License.
"""Inference-only Deepseek model."""
from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union
import torch
......@@ -377,7 +378,7 @@ class DeepseekModel(nn.Module):
else:
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
......@@ -483,4 +484,4 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
\ No newline at end of file
return loader.load_weights(weights)
......@@ -25,6 +25,7 @@
"""Inference-only DeepseekV2/DeepseekV3 model."""
import typing
from collections.abc import Callable, Iterable
from itertools import islice
from typing import Any, Optional, Union
import torch
......@@ -712,7 +713,7 @@ class DeepseekV2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
......
......@@ -25,6 +25,7 @@
# limitations under the License.
"""Inference-only dots1 model."""
from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union
import torch
......@@ -391,7 +392,7 @@ class Dots1Model(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions,
hidden_states,
......
......@@ -23,6 +23,7 @@
# limitations under the License.
"""Inference-only ErineMoE model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union
import torch
......@@ -419,8 +420,7 @@ class Ernie4_5_MoeModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
......
......@@ -23,6 +23,7 @@
# limitations under the License.
"""Inference-only Erine VL model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union
import torch
......@@ -508,8 +509,7 @@ class Ernie4_5_VLMoeModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual,
visual_token_mask, **kwargs)
......
......@@ -26,6 +26,7 @@
"""Inference-only Exaone model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union
import torch
......@@ -371,7 +372,7 @@ class ExaoneModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.h[self.start_layer:self.end_layer]:
for layer in islice(self.h, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions,
hidden_states,
......
......@@ -22,6 +22,7 @@
"""Inference-only Exaone model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union
import torch
......@@ -354,7 +355,7 @@ class Exaone4Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions,
hidden_states,
......
......@@ -22,6 +22,7 @@
import math
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
import torch
......@@ -389,7 +390,7 @@ class FalconModel(nn.Module):
hidden_states = self.get_input_embeddings(input_ids)
else:
hidden_states = intermediate_tensors["hidden_states"]
for layer in self.h[self.start_layer:self.end_layer]:
for layer in islice(self.h, self.start_layer, self.end_layer):
hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
......
......@@ -18,6 +18,7 @@
"""Inference-only Gemma model compatible with HuggingFace weights."""
from collections.abc import Iterable
from functools import cache
from itertools import islice
from typing import Optional, Union
import torch
......@@ -308,7 +309,7 @@ class GemmaModel(nn.Module):
else:
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions,
hidden_states,
......
......@@ -17,6 +17,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
import torch
......@@ -292,7 +293,7 @@ class Gemma2Model(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions,
hidden_states,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment