"tests/vscode:/vscode.git/clone" did not exist on "b8ff05361a2ab91e6be33601d4f564408e10eb24"
Unverified Commit de533ab2 authored by Lukas Geiger's avatar Lukas Geiger Committed by GitHub
Browse files

[Models] Improve iteration over layers (#19497)


Signed-off-by: default avatarLukas Geiger <lukas.geiger94@gmail.com>
parent 235c9db8
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
"""Inference-only Qwen3MoE model compatible with HuggingFace weights.""" """Inference-only Qwen3MoE model compatible with HuggingFace weights."""
import typing import typing
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
...@@ -420,8 +421,7 @@ class Qwen3MoeModel(nn.Module): ...@@ -420,8 +421,7 @@ class Qwen3MoeModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only SeedOss model compatible with HuggingFace weights.""" """Inference-only SeedOss model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -340,7 +341,7 @@ class SeedOssModel(nn.Module): ...@@ -340,7 +341,7 @@ class SeedOssModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM) """Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
model compatible with HuggingFace weights.""" model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -247,7 +248,7 @@ class StableLMEpochModel(nn.Module): ...@@ -247,7 +248,7 @@ class StableLMEpochModel(nn.Module):
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states) hidden_states, residual = layer(positions, hidden_states)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
# limitations under the License. # limitations under the License.
""" PyTorch Starcoder2 model.""" """ PyTorch Starcoder2 model."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -250,7 +251,7 @@ class Starcoder2Model(nn.Module): ...@@ -250,7 +251,7 @@ class Starcoder2Model(nn.Module):
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(positions, hidden_states) hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Jurassic model.""" """Inference-only Jurassic model."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional from typing import Any, Optional
import torch import torch
...@@ -346,8 +347,7 @@ class Step3TextModel(nn.Module): ...@@ -346,8 +347,7 @@ class Step3TextModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment