Unverified Commit de533ab2 authored by Lukas Geiger's avatar Lukas Geiger Committed by GitHub
Browse files

[Models] Improve iteration over layers (#19497)


Signed-off-by: default avatarLukas Geiger <lukas.geiger94@gmail.com>
parent 235c9db8
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
# activation. # activation.
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
...@@ -243,7 +244,7 @@ class ArceeModel(nn.Module): ...@@ -243,7 +244,7 @@ class ArceeModel(nn.Module):
aux_hidden_states: list[torch.Tensor] = [] aux_hidden_states: list[torch.Tensor] = []
for idx, layer in enumerate( for idx, layer in enumerate(
self.layers[self.start_layer:self.end_layer]): islice(self.layers, self.start_layer, self.end_layer)):
if idx in self.aux_hidden_state_layers: if idx in self.aux_hidden_state_layers:
aux_hidden_states.append( aux_hidden_states.append(
hidden_states + hidden_states +
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Snowflake Arctic model.""" """Inference-only Snowflake Arctic model."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -403,7 +404,7 @@ class ArcticModel(nn.Module): ...@@ -403,7 +404,7 @@ class ArcticModel(nn.Module):
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(positions, hidden_states) hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
"""Inference-only BaiChuan model compatible with HuggingFace weights.""" """Inference-only BaiChuan model compatible with HuggingFace weights."""
import math import math
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -309,7 +310,7 @@ class BaiChuanModel(nn.Module): ...@@ -309,7 +310,7 @@ class BaiChuanModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only BailingMoE model compatible with HuggingFace weights.""" """Inference-only BailingMoE model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -359,8 +360,7 @@ class BailingMoeModel(nn.Module): ...@@ -359,8 +360,7 @@ class BailingMoeModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
position_ids, position_ids,
......
...@@ -345,8 +345,7 @@ class BambaModel(nn.Module): ...@@ -345,8 +345,7 @@ class BambaModel(nn.Module):
residual = None residual = None
num_attn = 0 num_attn = 0
for i in range(len(self.layers)): for i, layer in enumerate(self.layers):
layer = self.layers[i]
if isinstance(layer, BambaAttentionDecoderLayer): if isinstance(layer, BambaAttentionDecoderLayer):
num_attn += 1 num_attn += 1
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
"""Inference-only BLOOM model compatible with HuggingFace weights.""" """Inference-only BLOOM model compatible with HuggingFace weights."""
import math import math
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -273,7 +274,7 @@ class BloomModel(nn.Module): ...@@ -273,7 +274,7 @@ class BloomModel(nn.Module):
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for layer in self.h[self.start_layer:self.end_layer]: for layer in islice(self.h, self.start_layer, self.end_layer):
hidden_states = layer(position_ids, hidden_states) hidden_states = layer(position_ids, hidden_states)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property from functools import cached_property
from itertools import islice
from typing import Annotated, Any, Literal, Optional, Union from typing import Annotated, Any, Literal, Optional, Union
import torch import torch
...@@ -914,7 +915,7 @@ class ChameleonModel(nn.Module): ...@@ -914,7 +915,7 @@ class ChameleonModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
"""Inference-only ChatGLM model compatible with THUDM weights.""" """Inference-only ChatGLM model compatible with THUDM weights."""
import json import json
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -281,7 +282,7 @@ class GLMTransformer(nn.Module): ...@@ -281,7 +282,7 @@ class GLMTransformer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(hidden_states=hidden_states, hidden_states = layer(hidden_states=hidden_states,
position_ids=position_ids) position_ids=position_ids)
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
# This file is based on the LLama model definition file in transformers # This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model.""" """PyTorch Cohere model."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -322,7 +323,7 @@ class CohereModel(nn.Module): ...@@ -322,7 +323,7 @@ class CohereModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -359,7 +360,7 @@ class DbrxModel(nn.Module): ...@@ -359,7 +360,7 @@ class DbrxModel(nn.Module):
else: else:
assert intermediate_tensors assert intermediate_tensors
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for block in self.blocks[self.start_layer:self.end_layer]: for block in islice(self.blocks, self.start_layer, self.end_layer):
hidden_states = block(position_ids, hidden_states) hidden_states = block(position_ids, hidden_states)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Deepseek model.""" """Inference-only Deepseek model."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
...@@ -377,7 +378,7 @@ class DeepseekModel(nn.Module): ...@@ -377,7 +378,7 @@ class DeepseekModel(nn.Module):
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
...@@ -483,4 +484,4 @@ class DeepseekForCausalLM(nn.Module, SupportsPP): ...@@ -483,4 +484,4 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
"""Inference-only DeepseekV2/DeepseekV3 model.""" """Inference-only DeepseekV2/DeepseekV3 model."""
import typing import typing
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
...@@ -712,7 +713,7 @@ class DeepseekV2Model(nn.Module): ...@@ -712,7 +713,7 @@ class DeepseekV2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only dots1 model.""" """Inference-only dots1 model."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
...@@ -391,7 +392,7 @@ class Dots1Model(nn.Module): ...@@ -391,7 +392,7 @@ class Dots1Model(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only ErineMoE model compatible with HuggingFace weights.""" """Inference-only ErineMoE model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
...@@ -419,8 +420,7 @@ class Ernie4_5_MoeModel(nn.Module): ...@@ -419,8 +420,7 @@ class Ernie4_5_MoeModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Erine VL model compatible with HuggingFace weights.""" """Inference-only Erine VL model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
...@@ -508,8 +509,7 @@ class Ernie4_5_VLMoeModel(nn.Module): ...@@ -508,8 +509,7 @@ class Ernie4_5_VLMoeModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, residual, hidden_states, residual = layer(positions, hidden_states, residual,
visual_token_mask, **kwargs) visual_token_mask, **kwargs)
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
"""Inference-only Exaone model compatible with HuggingFace weights.""" """Inference-only Exaone model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
...@@ -371,7 +372,7 @@ class ExaoneModel(nn.Module): ...@@ -371,7 +372,7 @@ class ExaoneModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.h[self.start_layer:self.end_layer]: for layer in islice(self.h, self.start_layer, self.end_layer):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
"""Inference-only Exaone model compatible with HuggingFace weights.""" """Inference-only Exaone model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
...@@ -354,7 +355,7 @@ class Exaone4Model(nn.Module): ...@@ -354,7 +355,7 @@ class Exaone4Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
import math import math
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -389,7 +390,7 @@ class FalconModel(nn.Module): ...@@ -389,7 +390,7 @@ class FalconModel(nn.Module):
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.get_input_embeddings(input_ids)
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for layer in self.h[self.start_layer:self.end_layer]: for layer in islice(self.h, self.start_layer, self.end_layer):
hidden_states = layer(positions, hidden_states) hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
"""Inference-only Gemma model compatible with HuggingFace weights.""" """Inference-only Gemma model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from functools import cache from functools import cache
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -308,7 +309,7 @@ class GemmaModel(nn.Module): ...@@ -308,7 +309,7 @@ class GemmaModel(nn.Module):
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -292,7 +293,7 @@ class Gemma2Model(nn.Module): ...@@ -292,7 +293,7 @@ class Gemma2Model(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment