Unverified Commit de533ab2 authored by Lukas Geiger's avatar Lukas Geiger Committed by GitHub
Browse files

[Models] Improve iteration over layers (#19497)


Signed-off-by: default avatarLukas Geiger <lukas.geiger94@gmail.com>
parent 235c9db8
......@@ -3,6 +3,7 @@
"""Inference-only MiniMaxText01 model."""
import math
from collections.abc import Iterable
from itertools import islice
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING:
......@@ -1019,8 +1020,7 @@ class MiniMaxText01Model(nn.Module):
minimax_cache_index = 0
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
for layer in islice(self.layers, self.start_layer, self.end_layer):
_caches = None
if not envs.VLLM_USE_V1 and isinstance(
layer.self_attn, MiniMaxText01LinearAttention):
......
......@@ -24,6 +24,7 @@
# limitations under the License.
"""Inference-only Mixtral model."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
import torch
......@@ -307,7 +308,7 @@ class MixtralModel(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
......
......@@ -24,6 +24,7 @@
# limitations under the License.
"""Inference-only Mixtral model."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
import numpy as np
......@@ -346,7 +347,7 @@ class MixtralModel(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
......
......@@ -5,6 +5,7 @@ import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from functools import cached_property, partial
from itertools import islice
from typing import Annotated, Optional, Union
import numpy as np
......@@ -842,7 +843,7 @@ class MolmoModel(nn.Module, SupportsQuant):
residual = intermediate_tensors["residual"]
# Apply blocks one-by-one.
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions,
hidden_states,
......
......@@ -4,6 +4,7 @@
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
import torch
......@@ -260,7 +261,7 @@ class MPTModel(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for block in self.blocks[self.start_layer:self.end_layer]:
for block in islice(self.blocks, self.start_layer, self.end_layer):
hidden_states = block(position_ids, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
......
......@@ -24,6 +24,7 @@
# limitations under the License.
"""Inference-only Nemotron model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union
import torch
......@@ -353,7 +354,7 @@ class NemotronModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
......
......@@ -399,8 +399,7 @@ class NemotronHModel(nn.Module):
residual = None
num_non_mamba_layers = 0
for i in range(len(self.layers)):
layer = self.layers[i]
for i, layer in enumerate(self.layers):
layer_mamba_cache_params = None
if isinstance(layer,
NemotronHMambaDecoderLayer) and mamba_cache_params:
......
......@@ -24,6 +24,7 @@
# limitations under the License.
"""Inference-only deci model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union
import torch
......@@ -287,8 +288,7 @@ class DeciModel(nn.Module):
residual = intermediate_tensors["residual"]
kv_cache_index = 0
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
for layer in islice(self.layers, self.start_layer, self.end_layer):
if not layer._is_no_op_attention:
hidden_states, residual = layer(positions, hidden_states,
residual)
......
......@@ -24,6 +24,7 @@
# limitations under the License.
"""Inference-only OLMo model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
import torch
......@@ -280,7 +281,7 @@ class OlmoModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
# Apply blocks one-by-one.
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
# shape: (batch_size, seq_len, d_model)
hidden_states = layer(positions, hidden_states)
......
......@@ -26,6 +26,7 @@
from collections.abc import Iterable
from functools import partial
from itertools import islice
from typing import Optional, Union
import torch
......@@ -305,7 +306,7 @@ class Olmo2Model(nn.Module):
assert isinstance(hidden_states, torch.Tensor)
# Apply blocks one-by-one.
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
# shape: (batch_size, seq_len, d_model)
hidden_states = layer(positions, hidden_states)
......
......@@ -15,6 +15,7 @@
"""Inference-only OLMoE model compatible with HuggingFace weights."""
from collections.abc import Iterable
from functools import partial
from itertools import islice
from typing import Any, Optional, Union
import torch
......@@ -314,7 +315,7 @@ class OlmoeModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions,
hidden_states,
......
......@@ -20,6 +20,7 @@
# limitations under the License.
"""Inference-only OPT model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
import torch
......@@ -269,7 +270,7 @@ class OPTDecoder(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(hidden_states)
if not get_pp_group().is_last_rank:
......
......@@ -7,6 +7,7 @@
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
"""Inference-only Orion-14B model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union
import torch
......@@ -252,7 +253,7 @@ class OrionModel(nn.Module):
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
......
......@@ -23,6 +23,7 @@
# limitations under the License.
"""Inference-only persimmon model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
import torch
......@@ -255,7 +256,7 @@ class PersimmonModel(nn.Module):
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
......
......@@ -38,6 +38,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
import torch
......@@ -240,7 +241,7 @@ class PhiModel(nn.Module):
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank:
......
......@@ -24,6 +24,7 @@
# limitations under the License.
"""Inference-only PhiMoE model."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
import torch
......@@ -500,7 +501,7 @@ class PhiMoEModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions,
hidden_states,
......
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only PLaMo2 model."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional
import torch
......@@ -614,7 +615,7 @@ class Plamo2Decoder(torch.nn.Module):
mamba2_metadata: Mamba2Metadata,
) -> torch.Tensor:
mamba_cache_index = 0
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
layer_mamba_cache_params = None
if layer.is_mamba:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
......
......@@ -8,6 +8,7 @@
"""Inference-only QWen model compatible with HuggingFace weights."""
import json
from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union
import torch
......@@ -234,7 +235,7 @@ class QWenModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.h[self.start_layer:self.end_layer]:
for layer in islice(self.h, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions,
hidden_states,
......
......@@ -25,6 +25,7 @@
# limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union
import torch
......@@ -358,7 +359,7 @@ class Qwen2Model(nn.Module):
aux_hidden_states = []
for idx, layer in enumerate(
self.layers[self.start_layer:self.end_layer]):
islice(self.layers, self.start_layer, self.end_layer)):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual)
......
......@@ -25,6 +25,7 @@
# limitations under the License.
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union
import torch
......@@ -381,7 +382,7 @@ class Qwen2MoeModel(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment