Unverified Commit de533ab2 authored by Lukas Geiger's avatar Lukas Geiger Committed by GitHub
Browse files

[Models] Improve iteration over layers (#19497)


Signed-off-by: default avatarLukas Geiger <lukas.geiger94@gmail.com>
parent 235c9db8
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
"""Inference-only MiniMaxText01 model.""" """Inference-only MiniMaxText01 model."""
import math import math
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -1019,8 +1020,7 @@ class MiniMaxText01Model(nn.Module): ...@@ -1019,8 +1020,7 @@ class MiniMaxText01Model(nn.Module):
minimax_cache_index = 0 minimax_cache_index = 0
for i in range(self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
layer = self.layers[i]
_caches = None _caches = None
if not envs.VLLM_USE_V1 and isinstance( if not envs.VLLM_USE_V1 and isinstance(
layer.self_attn, MiniMaxText01LinearAttention): layer.self_attn, MiniMaxText01LinearAttention):
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Mixtral model.""" """Inference-only Mixtral model."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -307,7 +308,7 @@ class MixtralModel(nn.Module): ...@@ -307,7 +308,7 @@ class MixtralModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Mixtral model.""" """Inference-only Mixtral model."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import numpy as np import numpy as np
...@@ -346,7 +347,7 @@ class MixtralModel(nn.Module): ...@@ -346,7 +347,7 @@ class MixtralModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
......
...@@ -5,6 +5,7 @@ import math ...@@ -5,6 +5,7 @@ import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property, partial from functools import cached_property, partial
from itertools import islice
from typing import Annotated, Optional, Union from typing import Annotated, Optional, Union
import numpy as np import numpy as np
...@@ -842,7 +843,7 @@ class MolmoModel(nn.Module, SupportsQuant): ...@@ -842,7 +843,7 @@ class MolmoModel(nn.Module, SupportsQuant):
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
# Apply blocks one-by-one. # Apply blocks one-by-one.
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math import math
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -260,7 +261,7 @@ class MPTModel(nn.Module): ...@@ -260,7 +261,7 @@ class MPTModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for block in self.blocks[self.start_layer:self.end_layer]: for block in islice(self.blocks, self.start_layer, self.end_layer):
hidden_states = block(position_ids, hidden_states) hidden_states = block(position_ids, hidden_states)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Nemotron model compatible with HuggingFace weights.""" """Inference-only Nemotron model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
...@@ -353,7 +354,7 @@ class NemotronModel(nn.Module): ...@@ -353,7 +354,7 @@ class NemotronModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
......
...@@ -399,8 +399,7 @@ class NemotronHModel(nn.Module): ...@@ -399,8 +399,7 @@ class NemotronHModel(nn.Module):
residual = None residual = None
num_non_mamba_layers = 0 num_non_mamba_layers = 0
for i in range(len(self.layers)): for i, layer in enumerate(self.layers):
layer = self.layers[i]
layer_mamba_cache_params = None layer_mamba_cache_params = None
if isinstance(layer, if isinstance(layer,
NemotronHMambaDecoderLayer) and mamba_cache_params: NemotronHMambaDecoderLayer) and mamba_cache_params:
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only deci model compatible with HuggingFace weights.""" """Inference-only deci model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
...@@ -287,8 +288,7 @@ class DeciModel(nn.Module): ...@@ -287,8 +288,7 @@ class DeciModel(nn.Module):
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
kv_cache_index = 0 kv_cache_index = 0
for i in range(self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
layer = self.layers[i]
if not layer._is_no_op_attention: if not layer._is_no_op_attention:
hidden_states, residual = layer(positions, hidden_states, hidden_states, residual = layer(positions, hidden_states,
residual) residual)
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only OLMo model compatible with HuggingFace weights.""" """Inference-only OLMo model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -280,7 +281,7 @@ class OlmoModel(nn.Module): ...@@ -280,7 +281,7 @@ class OlmoModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
# Apply blocks one-by-one. # Apply blocks one-by-one.
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
# shape: (batch_size, seq_len, d_model) # shape: (batch_size, seq_len, d_model)
hidden_states = layer(positions, hidden_states) hidden_states = layer(positions, hidden_states)
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
from collections.abc import Iterable from collections.abc import Iterable
from functools import partial from functools import partial
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -305,7 +306,7 @@ class Olmo2Model(nn.Module): ...@@ -305,7 +306,7 @@ class Olmo2Model(nn.Module):
assert isinstance(hidden_states, torch.Tensor) assert isinstance(hidden_states, torch.Tensor)
# Apply blocks one-by-one. # Apply blocks one-by-one.
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
# shape: (batch_size, seq_len, d_model) # shape: (batch_size, seq_len, d_model)
hidden_states = layer(positions, hidden_states) hidden_states = layer(positions, hidden_states)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""Inference-only OLMoE model compatible with HuggingFace weights.""" """Inference-only OLMoE model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from functools import partial from functools import partial
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
...@@ -314,7 +315,7 @@ class OlmoeModel(nn.Module): ...@@ -314,7 +315,7 @@ class OlmoeModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only OPT model compatible with HuggingFace weights.""" """Inference-only OPT model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -269,7 +270,7 @@ class OPTDecoder(nn.Module): ...@@ -269,7 +270,7 @@ class OPTDecoder(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(hidden_states) hidden_states = layer(hidden_states)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
"""Inference-only Orion-14B model compatible with HuggingFace weights.""" """Inference-only Orion-14B model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
...@@ -252,7 +253,7 @@ class OrionModel(nn.Module): ...@@ -252,7 +253,7 @@ class OrionModel(nn.Module):
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(positions, hidden_states) hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only persimmon model compatible with HuggingFace weights.""" """Inference-only persimmon model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -255,7 +256,7 @@ class PersimmonModel(nn.Module): ...@@ -255,7 +256,7 @@ class PersimmonModel(nn.Module):
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(positions, hidden_states) hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
......
...@@ -38,6 +38,7 @@ ...@@ -38,6 +38,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only Phi-1.5 model compatible with HuggingFace weights.""" """Inference-only Phi-1.5 model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -240,7 +241,7 @@ class PhiModel(nn.Module): ...@@ -240,7 +241,7 @@ class PhiModel(nn.Module):
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(positions, hidden_states) hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only PhiMoE model.""" """Inference-only PhiMoE model."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -500,7 +501,7 @@ class PhiMoEModel(nn.Module): ...@@ -500,7 +501,7 @@ class PhiMoEModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only PLaMo2 model.""" """Inference-only PLaMo2 model."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional from typing import Optional
import torch import torch
...@@ -614,7 +615,7 @@ class Plamo2Decoder(torch.nn.Module): ...@@ -614,7 +615,7 @@ class Plamo2Decoder(torch.nn.Module):
mamba2_metadata: Mamba2Metadata, mamba2_metadata: Mamba2Metadata,
) -> torch.Tensor: ) -> torch.Tensor:
mamba_cache_index = 0 mamba_cache_index = 0
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
layer_mamba_cache_params = None layer_mamba_cache_params = None
if layer.is_mamba: if layer.is_mamba:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx( layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
"""Inference-only QWen model compatible with HuggingFace weights.""" """Inference-only QWen model compatible with HuggingFace weights."""
import json import json
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
...@@ -234,7 +235,7 @@ class QWenModel(nn.Module): ...@@ -234,7 +235,7 @@ class QWenModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.h[self.start_layer:self.end_layer]: for layer in islice(self.h, self.start_layer, self.end_layer):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights.""" """Inference-only Qwen2 model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
...@@ -358,7 +359,7 @@ class Qwen2Model(nn.Module): ...@@ -358,7 +359,7 @@ class Qwen2Model(nn.Module):
aux_hidden_states = [] aux_hidden_states = []
for idx, layer in enumerate( for idx, layer in enumerate(
self.layers[self.start_layer:self.end_layer]): islice(self.layers, self.start_layer, self.end_layer)):
if idx in self.aux_hidden_state_layers: if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual) aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2MoE model compatible with HuggingFace weights.""" """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
...@@ -381,7 +382,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -381,7 +382,7 @@ class Qwen2MoeModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment