Commit 6d2051cc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.3.post1' into v0.6.3.post1-dev

parents 2c7f740a a2c71c54
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Deepseek model.""" """Inference-only Deepseek model."""
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -29,7 +29,7 @@ from transformers import PretrainedConfig ...@@ -29,7 +29,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -40,8 +40,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -40,8 +40,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -50,6 +49,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -50,6 +49,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class DeepseekMLP(nn.Module): class DeepseekMLP(nn.Module):
...@@ -329,6 +332,7 @@ class DeepseekModel(nn.Module): ...@@ -329,6 +332,7 @@ class DeepseekModel(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -338,14 +342,17 @@ class DeepseekModel(nn.Module): ...@@ -338,14 +342,17 @@ class DeepseekModel(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
DeepseekDecoderLayer(config, config.num_hidden_layers,
layer_idx, lambda prefix: DeepseekDecoderLayer(config,
cache_config, int(prefix.split(".")[-1]),
quant_config=quant_config) cache_config,
for layer_idx in range(config.num_hidden_layers) quant_config=quant_config),
]) prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward( def forward(
self, self,
...@@ -353,19 +360,29 @@ class DeepseekModel(nn.Module): ...@@ -353,19 +360,29 @@ class DeepseekModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
hidden_states = self.embed_tokens(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
residual = None if get_pp_group().is_first_rank:
for i in range(len(self.layers)): hidden_states = self.embed_tokens(input_ids)
residual = None
else:
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], attn_metadata, kv_caches[i - self.start_layer],
residual) attn_metadata, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
class DeepseekForCausalLM(nn.Module): class DeepseekForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
...@@ -384,6 +401,8 @@ class DeepseekForCausalLM(nn.Module): ...@@ -384,6 +401,8 @@ class DeepseekForCausalLM(nn.Module):
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -392,9 +411,9 @@ class DeepseekForCausalLM(nn.Module): ...@@ -392,9 +411,9 @@ class DeepseekForCausalLM(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -439,6 +458,8 @@ class DeepseekForCausalLM(nn.Module): ...@@ -439,6 +458,8 @@ class DeepseekForCausalLM(nn.Module):
if (("mlp.experts." in name or "mlp.shared_experts." in name) if (("mlp.experts." in name or "mlp.shared_experts." in name)
and name not in params_dict): and name not in params_dict):
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -451,6 +472,8 @@ class DeepseekForCausalLM(nn.Module): ...@@ -451,6 +472,8 @@ class DeepseekForCausalLM(nn.Module):
if (("mlp.experts." in name or "mlp.shared_experts." in name) if (("mlp.experts." in name or "mlp.shared_experts." in name)
and name not in params_dict): and name not in params_dict):
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only DeepseekV2 model.""" """Inference-only DeepseekV2 model."""
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -40,8 +40,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -40,8 +40,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -50,7 +49,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -50,7 +49,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers from .interfaces import SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
...@@ -241,7 +242,7 @@ class DeepseekV2Attention(nn.Module): ...@@ -241,7 +242,7 @@ class DeepseekV2Attention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj") prefix=f"{prefix}.o_proj")
rope_scaling['type'] = 'deepseek_yarn' rope_scaling["rope_type"] = 'deepseek_yarn'
self.rotary_emb = get_rope(qk_rope_head_dim, self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim, rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
...@@ -439,6 +440,9 @@ class DeepseekV2Model(nn.Module): ...@@ -439,6 +440,9 @@ class DeepseekV2Model(nn.Module):
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward( def forward(
self, self,
...@@ -447,7 +451,7 @@ class DeepseekV2Model(nn.Module): ...@@ -447,7 +451,7 @@ class DeepseekV2Model(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
...@@ -472,7 +476,7 @@ class DeepseekV2Model(nn.Module): ...@@ -472,7 +476,7 @@ class DeepseekV2Model(nn.Module):
return hidden_states return hidden_states
class DeepseekV2ForCausalLM(nn.Module): class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
...@@ -492,6 +496,8 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -492,6 +496,8 @@ class DeepseekV2ForCausalLM(nn.Module):
quant_config=quant_config) quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -500,7 +506,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -500,7 +506,7 @@ class DeepseekV2ForCausalLM(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
......
...@@ -38,8 +38,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -38,8 +38,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale) get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
...@@ -53,8 +52,9 @@ from vllm.sequence import IntermediateTensors ...@@ -53,8 +52,9 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.exaone import ExaoneConfig from vllm.transformers_utils.configs.exaone import ExaoneConfig
from vllm.utils import is_hip from vllm.utils import is_hip
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsPP
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class ExaoneGatedMLP(nn.Module): class ExaoneGatedMLP(nn.Module):
...@@ -354,6 +354,10 @@ class ExaoneModel(nn.Module): ...@@ -354,6 +354,10 @@ class ExaoneModel(nn.Module):
else: else:
self.ln_f = PPMissingLayer() self.ln_f = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids) return self.wte(input_ids)
...@@ -397,7 +401,7 @@ class ExaoneModel(nn.Module): ...@@ -397,7 +401,7 @@ class ExaoneModel(nn.Module):
return hidden_states return hidden_states
class ExaoneForCausalLM(nn.Module, SupportsLoRA): class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -477,6 +481,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA): ...@@ -477,6 +481,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA):
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -506,24 +513,6 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA): ...@@ -506,24 +513,6 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros(
(batch_size, self.config.hidden_size),
dtype=dtype,
device=device,
),
"residual":
torch.zeros(
(batch_size, self.config.hidden_size),
dtype=dtype,
device=device,
),
})
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
......
...@@ -28,7 +28,7 @@ from transformers import FalconConfig as HF_FalconConfig ...@@ -28,7 +28,7 @@ from transformers import FalconConfig as HF_FalconConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
...@@ -36,8 +36,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -36,8 +36,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -47,6 +46,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -47,6 +46,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import RWConfig from vllm.transformers_utils.configs import RWConfig
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
FalconConfig = Union[HF_FalconConfig, RWConfig] FalconConfig = Union[HF_FalconConfig, RWConfig]
...@@ -333,6 +336,7 @@ class FalconModel(nn.Module): ...@@ -333,6 +336,7 @@ class FalconModel(nn.Module):
config: FalconConfig, config: FalconConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -347,35 +351,56 @@ class FalconModel(nn.Module): ...@@ -347,35 +351,56 @@ class FalconModel(nn.Module):
) )
# Transformer blocks # Transformer blocks
self.h = nn.ModuleList([ self.start_layer, self.end_layer, self.h = make_layers(
FalconDecoderLayer(config, cache_config, quant_config) config.num_hidden_layers,
for _ in range(config.num_hidden_layers) lambda prefix: FalconDecoderLayer(config, cache_config,
]) quant_config),
prefix=f"{prefix}.h")
# Final Layer Norm # Final Layer Norm
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward( def forward(
self, self,
input_ids: torch.LongTensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
hidden_states = self.word_embeddings(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
for i in range(len(self.h)): if get_pp_group().is_first_rank:
hidden_states = self.word_embeddings(input_ids)
else:
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.h[i] layer = self.h[i]
hidden_states = layer( hidden_states = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
class FalconForCausalLM(nn.Module): class FalconForCausalLM(nn.Module, SupportsPP):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {}
default_bitsandbytes_target_modules = [
".query_key_value.",
".dense.",
".dense_h_to_4h.",
".dense_4h_to_h.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".dense_4h_to_h.", ".dense."]
def __init__( def __init__(
self, self,
...@@ -403,6 +428,8 @@ class FalconForCausalLM(nn.Module): ...@@ -403,6 +428,8 @@ class FalconForCausalLM(nn.Module):
) )
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -412,12 +439,8 @@ class FalconForCausalLM(nn.Module): ...@@ -412,12 +439,8 @@ class FalconForCausalLM(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer( hidden_states = self.transformer(input_ids, positions, kv_caches,
input_ids, attn_metadata, intermediate_tensors)
positions,
kv_caches,
attn_metadata,
)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -454,6 +477,8 @@ class FalconForCausalLM(nn.Module): ...@@ -454,6 +477,8 @@ class FalconForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
if "query_key_value" in name: if "query_key_value" in name:
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
......
...@@ -27,11 +27,11 @@ from transformers import FuyuConfig, FuyuImageProcessor ...@@ -27,11 +27,11 @@ from transformers import FuyuConfig, FuyuImageProcessor
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
...@@ -41,8 +41,8 @@ from vllm.multimodal.utils import cached_get_tokenizer ...@@ -41,8 +41,8 @@ from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData) SequenceData)
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal, SupportsPP
from .utils import merge_multimodal_embeddings from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
# Cannot find the following 2 numbers from hf config. # Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011 _IMAGE_TOKEN_ID = 71011
...@@ -150,10 +150,10 @@ def _fuyu_image_preprocess(image_processor: FuyuImageProcessor, ...@@ -150,10 +150,10 @@ def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
return model_image_input return model_image_input
def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs): def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs return inputs
model_config = ctx.model_config model_config = ctx.model_config
image_data = multi_modal_data["image"] image_data = multi_modal_data["image"]
...@@ -165,7 +165,7 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -165,7 +165,7 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
model_config.model) model_config.model)
model_image_input = _fuyu_image_preprocess(image_processor, image_data) model_image_input = _fuyu_image_preprocess(image_processor, image_data)
image_patches = torch.stack([ image_patches = torch.cat([
image_patch[0] image_patch[0]
for image_patch in model_image_input["image_patches"] for image_patch in model_image_input["image_patches"]
]) ])
...@@ -177,8 +177,8 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -177,8 +177,8 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")
# process prompts # process prompts
prompt = llm_inputs.get("prompt") prompt = inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"] prompt_token_ids = inputs["prompt_token_ids"]
tokenizer = cached_get_tokenizer(model_config.model) tokenizer = cached_get_tokenizer(model_config.model)
# dim0 is batch_size, dim1 is subseq_size which will always be 1 # dim0 is batch_size, dim1 is subseq_size which will always be 1
image_input_ids: List[List[ image_input_ids: List[List[
...@@ -191,9 +191,9 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -191,9 +191,9 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[ new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[
1:] + boa_token 1:] + boa_token
return LLMInputs(prompt=new_prompt, return token_inputs(prompt=new_prompt,
prompt_token_ids=new_prompt_token_ids, prompt_token_ids=new_prompt_token_ids,
multi_modal_data=new_multi_modal_data) multi_modal_data=new_multi_modal_data)
def input_mapper_for_fuyu(ctx: InputContext, data: object): def input_mapper_for_fuyu(ctx: InputContext, data: object):
...@@ -210,14 +210,14 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object): ...@@ -210,14 +210,14 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object):
]) ])
# image has been processed with prompt in input processor # image has been processed with prompt in input processor
return MultiModalInputs({"image_patches": data}) return MultiModalInputs({"pixel_values": data})
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu) @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu) @INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
class FuyuForCausalLM(nn.Module, SupportsMultiModal): class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, def __init__(self,
config: FuyuConfig, config: FuyuConfig,
...@@ -237,28 +237,54 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal): ...@@ -237,28 +237,54 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
self.image_feature_size, self.image_feature_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
gather_output=True,
) )
self.language_model = PersimmonForCausalLM(config, self.language_model = PersimmonForCausalLM(config.text_config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@property
def sampler(self):
return self.language_model.sampler
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.patch_size
num_channels = self.config.num_channels
expected_dims = num_channels * h * w
def _validate_shape(d: torch.Tensor):
actual_dims = d.size(-1)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f" per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data.to(self.vision_embed_tokens.weight.dtype)
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[FuyuImagePixelInputs]: self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
image_patches = kwargs.pop("image_patches", None) pixel_values = kwargs.pop("pixel_values", None)
if isinstance(image_patches, torch.Tensor): if pixel_values is not None:
# Remove the N dimension until multiple images are supported. if not isinstance(pixel_values, (torch.Tensor, list)):
image_patches = image_patches.squeeze(1) raise ValueError("Incorrect type of image patches. "
f"Got type: {type(pixel_values)}")
return FuyuImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
)
expected_feature_size = self.image_feature_size
if image_patches.size(-1) != expected_feature_size:
raise ValueError(
f"Expected image patches to have the last dimension of "
f"{expected_feature_size}, got {image_patches.size(-1)}")
image_patches = image_patches.to(
self.vision_embed_tokens.weight.dtype)
return FuyuImagePixelInputs(type="pixel_values",
data=image_patches)
return None return None
def _process_image_input( def _process_image_input(
...@@ -277,23 +303,29 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal): ...@@ -277,23 +303,29 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object, **kwargs: object,
): ):
image_input = self._parse_and_validate_image_input(**kwargs) if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None: if image_input is not None:
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.embed_tokens(input_ids) inputs_embeds = self.language_model.model.embed_tokens(
inputs_embeds = merge_multimodal_embeddings( input_ids)
input_ids, inputs_embeds, vision_embeddings, inputs_embeds = merge_multimodal_embeddings(
self.image_token_id) input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
else: else:
inputs_embeds = None inputs_embeds = None
hidden_states = self.language_model( hidden_states = self.language_model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
return hidden_states return hidden_states
...@@ -316,34 +348,5 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal): ...@@ -316,34 +348,5 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False)) loader = AutoWeightsLoader(self)
for name, loaded_weight in weights: loader.load_weights(weights)
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
param = params_dict[name]
if "query_key_value" in name:
# copy from vllm/model_executor/models/bloom.py
# NOTE: Fuyu's fused QKV's output_dim has the shape of
# (num_heads * 3 * head_size), while the
# required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion.
output_dim = getattr(param, "output_dim", None)
num_heads = self.config.num_attention_heads
if output_dim is not None:
loaded_weight_shape = loaded_weight.shape
loaded_weight = loaded_weight.view(
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
loaded_weight_shape[output_dim + 1:])
loaded_weight = loaded_weight.transpose(
output_dim, output_dim + 1)
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights.""" """Inference-only Gemma model compatible with HuggingFace weights."""
from functools import lru_cache from functools import lru_cache
from typing import Iterable, List, Optional, Set, Tuple from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -23,7 +23,7 @@ from transformers import GemmaConfig ...@@ -23,7 +23,7 @@ from transformers import GemmaConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.layers.layernorm import GemmaRMSNorm
...@@ -31,8 +31,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -31,8 +31,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -41,7 +40,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -41,7 +40,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -245,6 +246,7 @@ class GemmaModel(nn.Module): ...@@ -245,6 +246,7 @@ class GemmaModel(nn.Module):
config: GemmaConfig, config: GemmaConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -253,10 +255,11 @@ class GemmaModel(nn.Module): ...@@ -253,10 +255,11 @@ class GemmaModel(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
GemmaDecoderLayer(config, cache_config, quant_config) config.num_hidden_layers,
for _ in range(config.num_hidden_layers) lambda prefix: GemmaDecoderLayer(config, cache_config, quant_config
]) ),
prefix=f"{prefix}.layers")
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Normalize the embedding by sqrt(hidden_size) # Normalize the embedding by sqrt(hidden_size)
...@@ -265,6 +268,9 @@ class GemmaModel(nn.Module): ...@@ -265,6 +268,9 @@ class GemmaModel(nn.Module):
# See https://github.com/huggingface/transformers/pull/29402 # See https://github.com/huggingface/transformers/pull/29402
normalizer = self.config.hidden_size**0.5 normalizer = self.config.hidden_size**0.5
self.register_buffer("normalizer", torch.tensor(normalizer)) self.register_buffer("normalizer", torch.tensor(normalizer))
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
...@@ -275,29 +281,38 @@ class GemmaModel(nn.Module): ...@@ -275,29 +281,38 @@ class GemmaModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
if inputs_embeds is not None: if get_pp_group().is_first_rank:
hidden_states = inputs_embeds if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states *= self.normalizer
residual = None
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = intermediate_tensors["hidden_states"]
hidden_states *= self.normalizer residual = intermediate_tensors["residual"]
residual = None for i in range(self.start_layer, self.end_layer):
for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
class GemmaForCausalLM(nn.Module, SupportsLoRA): class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -317,6 +332,28 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA): ...@@ -317,6 +332,28 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
"gate_up_proj", "gate_up_proj",
"down_proj", "down_proj",
] ]
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
# Gemma does not apply LoRA to the embedding layer. # Gemma does not apply LoRA to the embedding layer.
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []
...@@ -339,6 +376,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA): ...@@ -339,6 +376,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
self.model = GemmaModel(config, cache_config, quant_config) self.model = GemmaModel(config, cache_config, quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -347,9 +386,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA): ...@@ -347,9 +386,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -388,6 +427,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA): ...@@ -388,6 +427,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -400,6 +441,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA): ...@@ -400,6 +441,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -14,15 +14,16 @@ ...@@ -14,15 +14,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Iterable, List, Optional, Set, Tuple from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import Gemma2Config from transformers import Gemma2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.layers.layernorm import GemmaRMSNorm
...@@ -30,17 +31,20 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -30,17 +31,20 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.pooler import Pooler, PoolingType
QuantizationConfig) from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -237,6 +241,13 @@ class Gemma2DecoderLayer(nn.Module): ...@@ -237,6 +241,13 @@ class Gemma2DecoderLayer(nn.Module):
return hidden_states, residual return hidden_states, residual
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
"positions": 0,
"inputs_embeds": 0,
"intermediate_tensors": 0,
})
class Gemma2Model(nn.Module): class Gemma2Model(nn.Module):
def __init__( def __init__(
...@@ -244,6 +255,7 @@ class Gemma2Model(nn.Module): ...@@ -244,6 +255,7 @@ class Gemma2Model(nn.Module):
config: Gemma2Config, config: Gemma2Config,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -252,10 +264,11 @@ class Gemma2Model(nn.Module): ...@@ -252,10 +264,11 @@ class Gemma2Model(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config) config.num_hidden_layers,
for layer_idx in range(config.num_hidden_layers) lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[
]) -1]), config, cache_config, quant_config),
prefix=f"{prefix}.layers")
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Normalize the embedding by sqrt(hidden_size) # Normalize the embedding by sqrt(hidden_size)
...@@ -264,32 +277,92 @@ class Gemma2Model(nn.Module): ...@@ -264,32 +277,92 @@ class Gemma2Model(nn.Module):
# See https://github.com/huggingface/transformers/pull/29402 # See https://github.com/huggingface/transformers/pull/29402
normalizer = self.config.hidden_size**0.5 normalizer = self.config.hidden_size**0.5
self.register_buffer("normalizer", torch.tensor(normalizer)) self.register_buffer("normalizer", torch.tensor(normalizer))
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
hidden_states = self.embed_tokens(input_ids) inputs_embeds: Optional[torch.Tensor] = None,
hidden_states *= self.normalizer ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
residual = None if inputs_embeds is not None:
for i in range(len(self.layers)): hidden_states = inputs_embeds
else:
hidden_states = self.embed_tokens(input_ids)
hidden_states *= self.normalizer
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
class Gemma2ForCausalLM(nn.Module, SupportsLoRA): unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
logger.warning(
"Some weights are not initialized from checkpoints: %s",
unloaded_params)
class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -312,6 +385,19 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -312,6 +385,19 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
# Gemma does not apply LoRA to the embedding layer. # Gemma does not apply LoRA to the embedding layer.
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
bitsandbytes_stacked_params_mapping = { bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index # shard_name, weight_name, index
"q_proj": ("qkv_proj", 0), "q_proj": ("qkv_proj", 0),
...@@ -338,6 +424,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -338,6 +424,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(
config.vocab_size, soft_cap=config.final_logit_softcapping) config.vocab_size, soft_cap=config.final_logit_softcapping)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -346,9 +434,9 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -346,9 +434,9 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -369,44 +457,56 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -369,44 +457,56 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ loader = AutoWeightsLoader(
# (param_name, shard_name, shard_id) self,
("qkv_proj", "q_proj", "q"), skip_prefixes=(["lm_head."]
("qkv_proj", "k_proj", "k"), if self.config.tie_word_embeddings else None),
("qkv_proj", "v_proj", "v"), )
("gate_up_proj", "gate_proj", 0), loader.load_weights(weights)
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params: class Gemma2EmbeddingModel(nn.Module, SupportsPP):
logger.warning( """
"Some weights are not initialized from checkpoints: %s", A model that uses Gemma2 with additional embedding functionalities.
unloaded_params)
This class encapsulates the Gemma2Model and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of Gemma2Model used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(
self,
**kwargs,
) -> None:
super().__init__()
self.model = Gemma2Model(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return self.model(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors, inputs_embeds)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)
# coding=utf-8
# Adapted from
# https://github.com/THUDM/GLM-4
"""Inference-only GLM-4v model visual encoder compatible with THUDM weights."""
from argparse import Namespace
from typing import Optional
import torch
from torch import nn
from torch.nn import LayerNorm
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
class PatchEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
self.proj = nn.Conv2d(config.in_channels,
config.hidden_size,
kernel_size=config.patch_size,
stride=config.patch_size)
self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
self.position_embedding = nn.Embedding(config.num_positions,
config.hidden_size)
def forward(self, images: torch.Tensor) -> torch.Tensor:
"""
Parameters:
images : torch.Tensor
Input image tensor with shape (B, C, H, W)
Returns:
torch.Tensor
Transformed tensor with shape (B, L, D)
"""
images = images.to(self.proj.weight.device)
x = self.proj(images)
x = x.flatten(2).transpose(1, 2)
cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
x += self.position_embedding.weight.unsqueeze(0)
return x
class Attention(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_rank = config.num_heads // self.tp_size
self.head_dim = config.hidden_size // config.num_heads
self.scale = self.head_dim**-0.5
self.query_key_value = QKVParallelLinear(
config.hidden_size,
self.head_dim,
config.num_heads,
quant_config=quant_config,
)
self.dense = RowParallelLinear(
config.hidden_size,
config.hidden_size,
quant_config=quant_config,
)
self.output_dropout = torch.nn.Dropout(config.dropout_prob)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, L, _ = x.shape
qkv, _ = self.query_key_value(x) # B, L, 3 * H * D
q, k, v = qkv.chunk(3, dim=-1)
q = q.reshape(B, L, self.num_heads_per_rank,
self.head_dim).permute(0, 2, 1, 3) # B, H, L, D
k = k.reshape(B, L, self.num_heads_per_rank,
self.head_dim).permute(0, 2, 1, 3) # B, H, L, D
v = v.reshape(B, L, self.num_heads_per_rank,
self.head_dim).permute(0, 2, 1, 3) # B, H, L, D
out = torch.nn.functional.scaled_dot_product_attention(q,
k,
v,
attn_mask=None,
dropout_p=0.,
is_causal=False)
# output, _ = self.dense(out.transpose(1, 2).view(B, L, -1))
output, _ = self.dense(out.transpose(1, 2).reshape(B, L, -1))
output = self.output_dropout(output)
return output
class MLP(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.fc1(x)
x = self.activation_fn(x)
x, _ = self.fc2(x)
return x
class TransformerLayer(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.input_layernorm = LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.attention = Attention(config, quant_config=quant_config)
self.mlp = MLP(config, quant_config=quant_config)
self.post_attention_layernorm = LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(self, hidden_states):
attention_input = hidden_states
attention_output = self.input_layernorm(
self.attention(attention_input))
hidden_states = attention_input + attention_output
mlp_input = hidden_states
mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
output = mlp_input + mlp_output
return output
class Transformer(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.layers = nn.ModuleList([
TransformerLayer(config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
def forward(self, hidden_states):
for layer_module in self.layers:
hidden_states = layer_module(hidden_states)
return hidden_states
class GLU(nn.Module):
def __init__(
self,
config,
in_features,
quant_config: Optional[QuantizationConfig] = None,
):
"""
The original implementation is the same as:
```python
self.dense_h_to_4h = ColumnParallelLinear(
config.hidden_size,
config.ffn_hidden_size,
bias=False,
quant_config=quant_config
)
self.gate_proj = ColumnParallelLinear(
config.hidden_size,
config.ffn_hidden_size,
bias=False,
quant_config=quant_config
)
```
```
gate_proj_output, _ = self.gate_proj(x)
dense_h_to_4h_output, _ = self.dense_h_to_4h(x)
x = torch.cat([gate_proj_output, dense_h_to_4h_output], dim=-1)
```
We merge two ColumnParallelLinear into one MergedColumnParallelLinear:
```
self.merged_proj = MergedColumnParallelLinear(
config.hidden_size,
[config.ffn_hidden_size] * 2,
bias=False,
quant_config=quant_config
)
```
```
x, _ = self.merged_proj(x)
```
"""
super().__init__()
self.linear_proj = ReplicatedLinear(in_features,
config.hidden_size,
bias=False,
quant_config=quant_config)
self.norm1 = nn.LayerNorm(config.hidden_size)
self.act1 = nn.GELU()
self.act2 = SiluAndMul()
self.merged_proj = MergedColumnParallelLinear(
config.hidden_size, [config.ffn_hidden_size] * 2,
bias=False,
quant_config=quant_config)
self.dense_4h_to_h = RowParallelLinear(config.ffn_hidden_size,
config.hidden_size,
bias=False,
quant_config=quant_config)
def forward(self, x):
x, _ = self.linear_proj(x)
x = self.act1(self.norm1(x))
x, _ = self.merged_proj(x)
x = self.act2(x)
x, _ = self.dense_4h_to_h(x)
return x
class EVA2CLIPModel(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
vision_config = Namespace(**config.vision_config)
self.patch_embedding = PatchEmbedding(vision_config)
self.transformer = Transformer(vision_config,
quant_config=quant_config)
self.linear_proj = GLU(config,
in_features=config.hidden_size,
quant_config=quant_config)
self.conv = nn.Conv2d(in_channels=vision_config.hidden_size,
out_channels=config.hidden_size,
kernel_size=2,
stride=2)
self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.scaling_factor = vision_config.scaling_factor
def forward(self, images: torch.Tensor) -> torch.Tensor:
"""
Parameters:
images : torch.Tensor
Input image tensor with shape (B, C, H, W)
Returns:
torch.Tensor
Transformed tensor with shape (B, L, D)
"""
x = self.patch_embedding(images)
x = self.transformer(x)
x = x[:, 1:]
b, s, h = x.shape
grid_size = int(s**0.5)
x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
x = self.conv(x)
x = x.flatten(2).transpose(1, 2)
x = self.linear_proj(x)
boi = self.boi.expand(x.shape[0], -1, -1)
eoi = self.eoi.expand(x.shape[0], -1, -1)
x = torch.cat((boi, x, eoi), dim=1)
x = x / self.scaling_factor
return x
...@@ -32,8 +32,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -32,8 +32,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
...@@ -41,7 +40,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -41,7 +40,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .utils import is_pp_missing_parameter, make_layers from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class GPT2Attention(nn.Module): class GPT2Attention(nn.Module):
...@@ -204,6 +205,9 @@ class GPT2Model(nn.Module): ...@@ -204,6 +205,9 @@ class GPT2Model(nn.Module):
config, cache_config, quant_config, prefix=prefix), config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h") prefix=f"{prefix}.h")
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd))
def forward( def forward(
self, self,
...@@ -234,7 +238,7 @@ class GPT2Model(nn.Module): ...@@ -234,7 +238,7 @@ class GPT2Model(nn.Module):
return hidden_states return hidden_states
class GPT2LMHeadModel(nn.Module): class GPT2LMHeadModel(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
...@@ -256,6 +260,8 @@ class GPT2LMHeadModel(nn.Module): ...@@ -256,6 +260,8 @@ class GPT2LMHeadModel(nn.Module):
self.config.hidden_size) self.config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -264,7 +270,7 @@ class GPT2LMHeadModel(nn.Module): ...@@ -264,7 +270,7 @@ class GPT2LMHeadModel(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
...@@ -286,16 +292,6 @@ class GPT2LMHeadModel(nn.Module): ...@@ -286,16 +292,6 @@ class GPT2LMHeadModel(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights: for name, loaded_weight in weights:
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPTBigCode model compatible with HuggingFace weights.""" """Inference-only GPTBigCode model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -26,14 +26,13 @@ from transformers import GPTBigCodeConfig ...@@ -26,14 +26,13 @@ from transformers import GPTBigCodeConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
...@@ -41,7 +40,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -41,7 +40,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class GPTBigCodeAttention(nn.Module): class GPTBigCodeAttention(nn.Module):
...@@ -194,6 +195,7 @@ class GPTBigCodeModel(nn.Module): ...@@ -194,6 +195,7 @@ class GPTBigCodeModel(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -207,11 +209,15 @@ class GPTBigCodeModel(nn.Module): ...@@ -207,11 +209,15 @@ class GPTBigCodeModel(nn.Module):
self.embed_dim, self.embed_dim,
org_num_embeddings=config.vocab_size) org_num_embeddings=config.vocab_size)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList([ self.start_layer, self.end_layer, self.h = make_layers(
GPTBigCodeBlock(config, cache_config, quant_config) config.num_hidden_layers,
for _ in range(config.num_hidden_layers) lambda prefix: GPTBigCodeBlock(config, cache_config, quant_config),
]) prefix=f"{prefix}.h",
)
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd))
def forward( def forward(
self, self,
...@@ -219,20 +225,28 @@ class GPTBigCodeModel(nn.Module): ...@@ -219,20 +225,28 @@ class GPTBigCodeModel(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds = self.wte(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
position_embeds = self.wpe(position_ids) if get_pp_group().is_first_rank:
hidden_states = inputs_embeds + position_embeds inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
else:
hidden_states = intermediate_tensors["hidden_states"]
for i in range(len(self.h)): for i in range(self.start_layer, self.end_layer):
layer = self.h[i] layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA): class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {"c_attn": ["c_attn"]} packed_modules_mapping = {"c_attn": ["c_attn"]}
supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"] supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"]
...@@ -272,6 +286,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA): ...@@ -272,6 +286,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -280,9 +296,9 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA): ...@@ -280,9 +296,9 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -311,6 +327,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA): ...@@ -311,6 +327,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
# Skip attention mask. # Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped. # NOTE: "c_attn.bias" should not be skipped.
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPT-J model compatible with HuggingFace weights.""" """Inference-only GPT-J model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -24,14 +24,13 @@ from transformers import GPTJConfig ...@@ -24,14 +24,13 @@ from transformers import GPTJConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -40,6 +39,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -40,6 +39,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class GPTJAttention(nn.Module): class GPTJAttention(nn.Module):
...@@ -178,6 +181,7 @@ class GPTJModel(nn.Module): ...@@ -178,6 +181,7 @@ class GPTJModel(nn.Module):
config: GPTJConfig, config: GPTJConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -186,11 +190,15 @@ class GPTJModel(nn.Module): ...@@ -186,11 +190,15 @@ class GPTJModel(nn.Module):
config.vocab_size, config.vocab_size,
self.embed_dim, self.embed_dim,
) )
self.h = nn.ModuleList([ self.start_layer, self.end_layer, self.h = make_layers(
GPTJBlock(config, cache_config, quant_config) config.n_layer,
for _ in range(config.n_layer) lambda prefix: GPTJBlock(config, cache_config, quant_config),
]) prefix=f"{prefix}.h",
)
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd))
def forward( def forward(
self, self,
...@@ -198,21 +206,27 @@ class GPTJModel(nn.Module): ...@@ -198,21 +206,27 @@ class GPTJModel(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
hidden_states = self.wte(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
for i in range(len(self.h)): if get_pp_group().is_first_rank:
hidden_states = self.wte(input_ids)
else:
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.h[i] layer = self.h[i]
hidden_states = layer( hidden_states = layer(
position_ids, position_ids,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
class GPTJForCausalLM(nn.Module): class GPTJForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
...@@ -233,6 +247,8 @@ class GPTJForCausalLM(nn.Module): ...@@ -233,6 +247,8 @@ class GPTJForCausalLM(nn.Module):
) )
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -241,9 +257,9 @@ class GPTJForCausalLM(nn.Module): ...@@ -241,9 +257,9 @@ class GPTJForCausalLM(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -283,6 +299,8 @@ class GPTJForCausalLM(nn.Module): ...@@ -283,6 +299,8 @@ class GPTJForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -291,6 +309,8 @@ class GPTJForCausalLM(nn.Module): ...@@ -291,6 +309,8 @@ class GPTJForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPT-NeoX model compatible with HuggingFace weights.""" """Inference-only GPT-NeoX model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -24,14 +24,13 @@ from transformers import GPTNeoXConfig ...@@ -24,14 +24,13 @@ from transformers import GPTNeoXConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -40,6 +39,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -40,6 +39,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class GPTNeoXAttention(nn.Module): class GPTNeoXAttention(nn.Module):
...@@ -191,6 +194,7 @@ class GPTNeoXModel(nn.Module): ...@@ -191,6 +194,7 @@ class GPTNeoXModel(nn.Module):
config: GPTNeoXConfig, config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -199,12 +203,16 @@ class GPTNeoXModel(nn.Module): ...@@ -199,12 +203,16 @@ class GPTNeoXModel(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
GPTNeoXLayer(config, cache_config, quant_config) config.num_hidden_layers,
for _ in range(config.num_hidden_layers) lambda prefix: GPTNeoXLayer(config, cache_config, quant_config),
]) prefix=f"{prefix}.layers",
)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, self.final_layer_norm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward( def forward(
self, self,
...@@ -212,21 +220,27 @@ class GPTNeoXModel(nn.Module): ...@@ -212,21 +220,27 @@ class GPTNeoXModel(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
hidden_states = self.embed_in(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
for i in range(len(self.layers)): if get_pp_group().is_first_rank:
hidden_states = self.embed_in(input_ids)
else:
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer( hidden_states = layer(
position_ids, position_ids,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
return hidden_states return hidden_states
class GPTNeoXForCausalLM(nn.Module): class GPTNeoXForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
...@@ -247,6 +261,8 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -247,6 +261,8 @@ class GPTNeoXForCausalLM(nn.Module):
self.embed_out.weight = self.gpt_neox.embed_in.weight self.embed_out.weight = self.gpt_neox.embed_in.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.gpt_neox.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -255,9 +271,9 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -255,9 +271,9 @@ class GPTNeoXForCausalLM(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.gpt_neox(input_ids, positions, kv_caches, hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -288,6 +304,8 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -288,6 +304,8 @@ class GPTNeoXForCausalLM(nn.Module):
# Models trained using OpenRLHF may include # Models trained using OpenRLHF may include
# these tensors in the checkpoint. Skip them. # these tensors in the checkpoint. Skip them.
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
if "query_key_value" in name: if "query_key_value" in name:
......
...@@ -51,7 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -51,7 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_hip from vllm.utils import is_hip
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsPP
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
...@@ -311,13 +311,13 @@ class GraniteModel(nn.Module): ...@@ -311,13 +311,13 @@ class GraniteModel(nn.Module):
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.get_input_embeddings(input_ids)
residual = None residual = None
hidden_states *= self.config.embedding_multiplier
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
hidden_states *= self.config.embedding_multiplier
for i in range(self.start_layer, self.end_layer): for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer( hidden_states = layer(
...@@ -337,7 +337,7 @@ class GraniteModel(nn.Module): ...@@ -337,7 +337,7 @@ class GraniteModel(nn.Module):
return hidden_states return hidden_states
class GraniteForCausalLM(nn.Module, SupportsLoRA): class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -404,9 +404,12 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA): ...@@ -404,9 +404,12 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA):
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
if hasattr(config, "logits_scaling"):
logit_scale /= config.logits_scaling
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, config.vocab_size,
logit_scale) scale=logit_scale)
self.sampler = Sampler() self.sampler = Sampler()
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
...@@ -428,8 +431,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA): ...@@ -428,8 +431,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA):
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
if logits is not None:
logits /= self.config.logits_scaling
return logits return logits
def sample( def sample(
......
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GraniteMoe model."""
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers.models.granitemoe import GraniteMoeConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from . import mixtral
from .interfaces import SupportsLoRA, SupportsPP
from .utils import make_layers
class GraniteMoeMoE(nn.Module):
"""A tensor-parallel MoE implementation for GraniteMoe that shards each
expert across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def __init__(self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = ""):
super().__init__()
self.hidden_size = hidden_size
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(hidden_size,
num_experts,
bias=False,
params_dtype=params_dtype,
quant_config=None,
prefix=f"{prefix}.gate")
self.experts = FusedMoE(num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=True,
quant_config=quant_config,
tp_size=tp_size,
prefix=f"{prefix}.experts")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits)
return final_hidden_states.view(orig_shape)
class GraniteMoeAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
attention_multiplier: Optional[float] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = (attention_multiplier if attention_multiplier
is not None else self.head_dim**-1)
self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class GraniteMoeDecoderLayer(nn.Module):
def __init__(
self,
config: GraniteMoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = GraniteMoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
attention_multiplier=config.attention_multiplier)
self.block_sparse_moe = GraniteMoeMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.block_sparse_moe")
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.residual_multiplier = config.residual_multiplier
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states * self.residual_multiplier
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.block_sparse_moe(hidden_states)
hidden_states = residual + hidden_states * self.residual_multiplier
return hidden_states
class GraniteMoeModel(nn.Module):
def __init__(
self,
config: GraniteMoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.embedding_multiplier = config.embedding_multiplier
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: GraniteMoeDecoderLayer(
config, cache_config, quant_config=quant_config, prefix=prefix
),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> torch.Tensor:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
hidden_states *= self.embedding_multiplier
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states = self.norm(hidden_states)
return hidden_states
class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
fall_back_to_pt_during_load = False
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,
config: GraniteMoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.lora_config = lora_config
self.model = GraniteMoeModel(config,
cache_config,
quant_config,
lora_config=lora_config,
prefix="model")
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
scale=1 /
self.config.logits_scaling)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
new_weights = {}
for n, p in weights:
if n.endswith('.block_sparse_moe.input_linear.weight'):
for e in range(p.size(0)):
w1_name = n.replace(
'.block_sparse_moe.input_linear.weight',
".block_sparse_moe.experts.%d.w1.weight" % e)
w3_name = n.replace(
'.block_sparse_moe.input_linear.weight',
".block_sparse_moe.experts.%d.w3.weight" % e)
w1_param, w3_param = p[e].chunk(2, dim=0)
assert w1_name not in new_weights
assert w3_name not in new_weights
new_weights[w1_name] = w1_param
new_weights[w3_name] = w3_param
elif n.endswith('.block_sparse_moe.output_linear.weight'):
for e in range(p.size(0)):
w2_name = n.replace(
'.block_sparse_moe.output_linear.weight',
".block_sparse_moe.experts.%d.w2.weight" % e)
w2_param = p[e]
assert w2_name not in new_weights
new_weights[w2_name] = w2_param
elif n.endswith('.block_sparse_moe.router.layer.weight'):
gate_name = n.replace('.block_sparse_moe.router.layer.weight',
".block_sparse_moe.gate.weight")
assert gate_name not in new_weights
new_weights[gate_name] = p
elif n == 'lm_head.weight' and self.config.tie_word_embeddings:
pass
else:
new_weights[n] = p
mixtral.MixtralForCausalLM.load_weights(self, new_weights.items())
...@@ -65,11 +65,10 @@ class Idefics2VisionEmbeddings(nn.Module): ...@@ -65,11 +65,10 @@ class Idefics2VisionEmbeddings(nn.Module):
self.position_embedding = nn.Embedding(self.num_positions, self.position_embedding = nn.Embedding(self.num_positions,
self.embed_dim) self.embed_dim)
def forward( def forward(self,
self, pixel_values: torch.FloatTensor,
pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor,
patch_attention_mask: torch.BoolTensor, tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor:
) -> torch.Tensor:
batch_size, _, max_im_h, max_im_w = pixel_values.shape batch_size, _, max_im_h, max_im_w = pixel_values.shape
patch_embeds = self.patch_embedding(pixel_values) patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2) embeddings = patch_embeds.flatten(2).transpose(1, 2)
...@@ -84,8 +83,13 @@ class Idefics2VisionEmbeddings(nn.Module): ...@@ -84,8 +83,13 @@ class Idefics2VisionEmbeddings(nn.Module):
fill_value=0) fill_value=0)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask): for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum() if tgt_sizes is not None:
nb_patches_h = tgt_sizes[batch_idx][0]
nb_patches_w = tgt_sizes[batch_idx][1]
else:
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h, bucket_coords_h = torch.bucketize(fractional_coords_h,
...@@ -287,10 +291,12 @@ class Idefics2VisionTransformer(nn.Module): ...@@ -287,10 +291,12 @@ class Idefics2VisionTransformer(nn.Module):
self, self,
pixel_values, pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None, patch_attention_mask: Optional[torch.BoolTensor] = None,
) -> torch.tensor: tgt_sizes: Optional[torch.IntTensor] = None,
) -> torch.Tensor:
hidden_states = self.embeddings( hidden_states = self.embeddings(
pixel_values=pixel_values, pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask) patch_attention_mask=patch_attention_mask,
tgt_sizes=tgt_sizes)
encoder_outputs = self.encoder(hidden_states) encoder_outputs = self.encoder(hidden_states)
last_hidden_state = self.post_layernorm(encoder_outputs) last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state return last_hidden_state
from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type, from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
Union, overload, runtime_checkable) Protocol, Type, Union, overload, runtime_checkable)
import torch
from typing_extensions import TypeIs from typing_extensions import TypeIs
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import supports_kw
if TYPE_CHECKING:
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
from vllm.sequence import IntermediateTensors
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -22,7 +27,7 @@ class SupportsMultiModal(Protocol): ...@@ -22,7 +27,7 @@ class SupportsMultiModal(Protocol):
MRO of your model class. MRO of your model class.
""" """
def __init__(self, *, multimodal_config: MultiModalConfig) -> None: def __init__(self, *, multimodal_config: "MultiModalConfig") -> None:
... ...
...@@ -32,7 +37,7 @@ class SupportsMultiModal(Protocol): ...@@ -32,7 +37,7 @@ class SupportsMultiModal(Protocol):
class _SupportsMultiModalType(Protocol): class _SupportsMultiModalType(Protocol):
supports_multimodal: Literal[True] supports_multimodal: Literal[True]
def __call__(self, *, multimodal_config: MultiModalConfig) -> None: def __call__(self, *, multimodal_config: "MultiModalConfig") -> None:
... ...
...@@ -75,7 +80,7 @@ class SupportsLoRA(Protocol): ...@@ -75,7 +80,7 @@ class SupportsLoRA(Protocol):
embedding_padding_modules: ClassVar[List[str]] embedding_padding_modules: ClassVar[List[str]]
# lora_config is None when LoRA is not enabled # lora_config is None when LoRA is not enabled
def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None: def __init__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None:
... ...
...@@ -90,7 +95,7 @@ class _SupportsLoRAType(Protocol): ...@@ -90,7 +95,7 @@ class _SupportsLoRAType(Protocol):
embedding_modules: Dict[str, str] embedding_modules: Dict[str, str]
embedding_padding_modules: List[str] embedding_padding_modules: List[str]
def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None: def __call__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None:
... ...
...@@ -136,15 +141,128 @@ def supports_lora( ...@@ -136,15 +141,128 @@ def supports_lora(
return result return result
def _supports_lora( def _supports_lora(model: Union[Type[object], object]) -> bool:
model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
if isinstance(model, type): if isinstance(model, type):
return isinstance(model, _SupportsLoRAType) return isinstance(model, _SupportsLoRAType)
return isinstance(model, SupportsLoRA) return isinstance(model, SupportsLoRA)
@runtime_checkable
class SupportsPP(Protocol):
"""The interface required for all models that support pipeline parallel."""
supports_pp: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports pipeline parallel.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
def make_empty_intermediate_tensors(
self,
batch_size: int,
dtype: torch.dtype,
device: torch.device,
) -> "IntermediateTensors":
"""Called when PP rank > 0 for profiling purposes."""
...
def forward(
self,
*,
intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]:
"""
Accept :class:`IntermediateTensors` when PP rank > 0.
Return :class:`IntermediateTensors` only for the last PP rank.
"""
...
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsPPType(Protocol):
supports_pp: Literal[True]
def make_empty_intermediate_tensors(
self,
batch_size: int,
dtype: torch.dtype,
device: torch.device,
) -> "IntermediateTensors":
...
def forward(
self,
*,
intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]:
...
@overload
def supports_pp(model: Type[object]) -> TypeIs[Type[SupportsPP]]:
...
@overload
def supports_pp(model: object) -> TypeIs[SupportsPP]:
...
def supports_pp(
model: Union[Type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
supports_attributes = _supports_pp_attributes(model)
supports_inspect = _supports_pp_inspect(model)
if supports_attributes and not supports_inspect:
logger.warning(
"The model (%s) sets `supports_pp=True`, but does not accept "
"`intermediate_tensors` in its `forward` method", model)
if not supports_attributes:
pp_attrs = ("make_empty_intermediate_tensors", )
missing_attrs = tuple(attr for attr in pp_attrs
if not hasattr(model, attr))
if getattr(model, "supports_pp", False):
if missing_attrs:
logger.warning(
"The model (%s) sets `supports_pp=True`, "
"but is missing PP-specific attributes: %s",
model,
missing_attrs,
)
else:
if not missing_attrs:
logger.warning(
"The model (%s) contains all PP-specific attributes, "
"but does not set `supports_pp=True`.", model)
return supports_attributes and supports_inspect
def _supports_pp_attributes(model: Union[Type[object], object]) -> bool:
if isinstance(model, type):
return isinstance(model, _SupportsPPType)
return isinstance(model, SupportsPP)
def _supports_pp_inspect(model: Union[Type[object], object]) -> bool:
model_forward = getattr(model, "forward", None)
if not callable(model_forward):
return False
return supports_kw(model_forward, "intermediate_tensors")
@runtime_checkable @runtime_checkable
class HasInnerState(Protocol): class HasInnerState(Protocol):
"""The interface required for all models that has inner state.""" """The interface required for all models that has inner state."""
...@@ -153,12 +271,12 @@ class HasInnerState(Protocol): ...@@ -153,12 +271,12 @@ class HasInnerState(Protocol):
""" """
A flag that indicates this model has inner state. A flag that indicates this model has inner state.
Models that has inner state usually need access to the scheduler_config Models that has inner state usually need access to the scheduler_config
for max_num_seqs ,etc... (Currently only used by Jamba) for max_num_seqs, etc. True for e.g. both Mamba and Jamba.
""" """
def __init__(self, def __init__(self,
*, *,
scheduler_config: Optional[SchedulerConfig] = None) -> None: scheduler_config: Optional["SchedulerConfig"] = None) -> None:
... ...
...@@ -168,7 +286,7 @@ class _HasInnerStateType(Protocol): ...@@ -168,7 +286,7 @@ class _HasInnerStateType(Protocol):
def __init__(self, def __init__(self,
*, *,
scheduler_config: Optional[SchedulerConfig] = None) -> None: scheduler_config: Optional["SchedulerConfig"] = None) -> None:
... ...
...@@ -189,3 +307,46 @@ def has_inner_state( ...@@ -189,3 +307,46 @@ def has_inner_state(
return isinstance(model, _HasInnerStateType) return isinstance(model, _HasInnerStateType)
return isinstance(model, HasInnerState) return isinstance(model, HasInnerState)
@runtime_checkable
class IsAttentionFree(Protocol):
"""The interface required for all models like Mamba that lack attention,
but do have state whose size is constant wrt the number of tokens."""
is_attention_free: ClassVar[Literal[True]] = True
"""
A flag that indicates this model has no attention.
Used for block manager and attention backend selection.
True for Mamba but not Jamba.
"""
def __init__(self) -> None:
...
@runtime_checkable
class _IsAttentionFreeType(Protocol):
is_attention_free: ClassVar[Literal[True]]
def __init__(self) -> None:
...
@overload
def is_attention_free(model: object) -> TypeIs[IsAttentionFree]:
...
@overload
def is_attention_free(model: Type[object]) -> TypeIs[Type[IsAttentionFree]]:
...
def is_attention_free(
model: Union[Type[object], object]
) -> Union[TypeIs[Type[IsAttentionFree]], TypeIs[IsAttentionFree]]:
if isinstance(model, type):
return isinstance(model, _IsAttentionFreeType)
return isinstance(model, IsAttentionFree)
from typing import (TYPE_CHECKING, List, Optional, Protocol, Type, Union,
overload, runtime_checkable)
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from typing_extensions import TypeIs, TypeVar
from vllm.logger import init_logger
from vllm.utils import supports_kw
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.model_executor.layers.pooler import PoolerOutput
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
logger = init_logger(__name__)
# The type of HF config
C_co = TypeVar("C_co", bound=PretrainedConfig, covariant=True)
# The type of hidden states
# Currently, T = torch.Tensor for all models except for Medusa
# which has T = List[torch.Tensor]
T = TypeVar("T", default=torch.Tensor)
T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)
# NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags
# for the base interfaces to avoid breaking OOT registration for existing models
# that don't inherit from the base interface classes
@runtime_checkable
class VllmModel(Protocol[C_co, T_co]):
def __init__(
self,
config: C_co,
*,
cache_config: Optional["CacheConfig"],
quant_config: Optional["QuantizationConfig"],
) -> None:
...
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: "AttentionMetadata",
) -> T_co:
...
def _check_vllm_model_init(model: Union[Type[object], object]) -> bool:
model_init = model.__init__
vllm_kws = ("cache_config", "quant_config")
missing_kws = tuple(kw for kw in vllm_kws
if not supports_kw(model_init, kw))
if missing_kws and (isinstance(model, type)
and issubclass(model, nn.Module)):
logger.warning(
"The model (%s) is missing "
"vLLM-specific keywords from its initializer: %s",
model,
missing_kws,
)
return len(missing_kws) == 0
def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool:
model_forward = getattr(model, "forward", None)
if not callable(model_forward):
return False
vllm_kws = ("input_ids", "positions", "kv_caches", "attn_metadata")
missing_kws = tuple(kw for kw in vllm_kws
if not supports_kw(model_forward, kw))
if missing_kws and (isinstance(model, type)
and issubclass(model, nn.Module)):
logger.warning(
"The model (%s) is missing "
"vLLM-specific keywords from its initializer: %s",
model,
missing_kws,
)
return len(missing_kws) == 0
@overload
def is_vllm_model(model: Type[object]) -> TypeIs[Type[VllmModel]]:
...
@overload
def is_vllm_model(model: object) -> TypeIs[VllmModel]:
...
def is_vllm_model(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[VllmModel]], TypeIs[VllmModel]]:
return _check_vllm_model_init(model) and _check_vllm_model_forward(model)
@runtime_checkable
class VllmModelForTextGeneration(VllmModel[C_co, T], Protocol[C_co, T]):
def compute_logits(
self,
hidden_states: T,
sampling_metadata: "SamplingMetadata",
) -> Optional[T]:
"""Return `None` if TP rank > 0."""
...
def sample(
self,
logits: T,
sampling_metadata: "SamplingMetadata",
) -> "SamplerOutput":
"""Only called on TP rank 0."""
...
@overload
def is_text_generation_model(
model: Type[object]) -> TypeIs[Type[VllmModelForTextGeneration]]:
...
@overload
def is_text_generation_model(
model: object) -> TypeIs[VllmModelForTextGeneration]:
...
def is_text_generation_model(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[VllmModelForTextGeneration]],
TypeIs[VllmModelForTextGeneration]]:
if not is_vllm_model(model):
return False
if isinstance(model, type):
return isinstance(model, VllmModelForTextGeneration)
return isinstance(model, VllmModelForTextGeneration)
@runtime_checkable
class VllmModelForEmbedding(VllmModel[C_co, T], Protocol[C_co, T]):
def pooler(
self,
hidden_states: T,
pooling_metadata: "PoolingMetadata",
) -> "PoolerOutput":
"""Only called on TP rank 0."""
...
@overload
def is_embedding_model(
model: Type[object]) -> TypeIs[Type[VllmModelForEmbedding]]:
...
@overload
def is_embedding_model(model: object) -> TypeIs[VllmModelForEmbedding]:
...
def is_embedding_model(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[VllmModelForEmbedding]], TypeIs[VllmModelForEmbedding]]:
if not is_vllm_model(model):
return False
if isinstance(model, type):
return isinstance(model, VllmModelForEmbedding)
return isinstance(model, VllmModelForEmbedding)
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Copyright (c) 2023 OpenGVLab # Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
from functools import partial
from typing import Iterable, Optional, Tuple from typing import Iterable, Optional, Tuple
import torch import torch
...@@ -11,7 +12,10 @@ import torch.nn as nn ...@@ -11,7 +12,10 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -54,7 +58,7 @@ class InternVisionEmbeddings(nn.Module): ...@@ -54,7 +58,7 @@ class InternVisionEmbeddings(nn.Module):
self.position_embedding = nn.Parameter( self.position_embedding = nn.Parameter(
torch.randn(1, self.num_positions, self.embed_dim)) torch.randn(1, self.num_positions, self.embed_dim))
def _get_pos_embed(self, pos_embed, H, W): def _get_pos_embed(self, pos_embed: torch.Tensor, H: int, W: int):
target_dtype = pos_embed.dtype target_dtype = pos_embed.dtype
pos_embed = pos_embed.float().reshape( pos_embed = pos_embed.float().reshape(
1, self.image_size // self.patch_size, 1, self.image_size // self.patch_size,
...@@ -63,9 +67,21 @@ class InternVisionEmbeddings(nn.Module): ...@@ -63,9 +67,21 @@ class InternVisionEmbeddings(nn.Module):
size=(H, W), size=(H, W),
mode='bicubic', mode='bicubic',
align_corners=False) align_corners=False)
pos_embed = pos_embed.reshape(1, -1, H * W).permute(0, 2, return pos_embed.reshape(1, -1, H * W).permute(0, 2,
1).to(target_dtype) 1).to(target_dtype)
return pos_embed
def _get_position_embedding(self, H: int, W: int) -> torch.Tensor:
position_embedding = self.position_embedding
if self.num_patches == H * W:
return position_embedding
return torch.cat(
[
position_embedding[:, :1, :],
self._get_pos_embed(position_embedding[:, 1:, :], H, W),
],
dim=1,
)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
target_dtype = self.patch_embedding.weight.dtype target_dtype = self.patch_embedding.weight.dtype
...@@ -76,12 +92,7 @@ class InternVisionEmbeddings(nn.Module): ...@@ -76,12 +92,7 @@ class InternVisionEmbeddings(nn.Module):
class_embeds = self.class_embedding.expand(batch_size, 1, class_embeds = self.class_embedding.expand(batch_size, 1,
-1).to(target_dtype) -1).to(target_dtype)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
position_embedding = torch.cat([ position_embedding = self._get_position_embedding(height, width)
self.position_embedding[:, :1, :],
self._get_pos_embed(self.position_embedding[:, 1:, :], height,
width)
],
dim=1)
embeddings = embeddings + position_embedding.to(target_dtype) embeddings = embeddings + position_embedding.to(target_dtype)
return embeddings return embeddings
...@@ -93,8 +104,11 @@ class InternParallelAttention(nn.Module): ...@@ -93,8 +104,11 @@ class InternParallelAttention(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): *,
num_dummy_heads: int = 0,
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
...@@ -105,11 +119,19 @@ class InternParallelAttention(nn.Module): ...@@ -105,11 +119,19 @@ class InternParallelAttention(nn.Module):
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:' f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
f' {self.num_heads}).') f' {self.num_heads}).')
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
# Additional dummy heads are used to enable TP for common GPU counts.
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
self.num_heads_per_partition = divide(num_dummy_heads + self.num_heads,
self.tp_size)
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.qkv = QKVParallelLinear( self.qkv = QKVParallelLinear(
self.embed_dim, self.embed_dim,
self.head_dim, self.head_dim,
self.num_heads, num_dummy_heads + self.num_heads,
bias=config.qkv_bias, bias=config.qkv_bias,
quant_config=quant_config, quant_config=quant_config,
) )
...@@ -117,34 +139,44 @@ class InternParallelAttention(nn.Module): ...@@ -117,34 +139,44 @@ class InternParallelAttention(nn.Module):
self.qk_normalization = config.qk_normalization self.qk_normalization = config.qk_normalization
if self.qk_normalization: if self.qk_normalization:
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps) self.q_norm = RMSNorm(self.dummy_dim,
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps) eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.k_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.proj = RowParallelLinear( self.proj = RowParallelLinear(
self.embed_dim, self.dummy_dim,
self.embed_dim, self.embed_dim,
quant_config=quant_config, quant_config=quant_config,
) )
self.tp_size = get_tensor_model_parallel_world_size() def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
self.num_heads_per_partition = divide(self.num_heads, self.tp_size) if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
def forward(self, x): k = tensor_model_parallel_all_gather(k.contiguous())
B, N, C = x.shape q = self.q_norm.forward_native(q)
k = self.k_norm.forward_native(k)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
return q, k
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, _ = x.shape
qkv, _ = self.qkv(x) qkv, _ = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1) q, k, v = qkv.chunk(3, dim=-1)
if self.qk_normalization:
q, k = self._apply_qk_norm(q, k)
q = q.view(B, N, self.num_heads_per_partition, self.head_dim) q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
k = k.view(B, N, self.num_heads_per_partition, self.head_dim) k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
v = v.view(B, N, self.num_heads_per_partition, self.head_dim) v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
if self.qk_normalization:
B_, N_, H_, D_ = q.shape
q = self.q_norm.forward_native(q.flatten(-2,
-1)).view(B_, N_, H_, D_)
k = self.k_norm.forward_native(k.flatten(-2,
-1)).view(B_, N_, H_, D_)
x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale) x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
x = x.view(B, N, -1) x = x.view(B, N, -1)
...@@ -155,8 +187,14 @@ class InternParallelAttention(nn.Module): ...@@ -155,8 +187,14 @@ class InternParallelAttention(nn.Module):
class InternSdpaAttention(nn.Module): class InternSdpaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: PretrainedConfig): def __init__(
self,
config: PretrainedConfig,
*,
num_dummy_heads: int = 0,
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
...@@ -167,20 +205,27 @@ class InternSdpaAttention(nn.Module): ...@@ -167,20 +205,27 @@ class InternSdpaAttention(nn.Module):
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:' f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
f' {self.num_heads}).') f' {self.num_heads}).')
# Additional dummy heads are used to enable TP for common GPU counts.
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.qkv = nn.Linear(self.embed_dim, self.qkv = nn.Linear(self.embed_dim,
3 * self.embed_dim, 3 * self.dummy_dim,
bias=config.qkv_bias) bias=config.qkv_bias)
self.qk_normalization = config.qk_normalization self.qk_normalization = config.qk_normalization
if self.qk_normalization: if self.qk_normalization:
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps) self.q_norm = RMSNorm(self.dummy_dim,
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps) eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.k_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.proj = nn.Linear(self.embed_dim, self.embed_dim) self.proj = nn.Linear(self.dummy_dim, self.embed_dim)
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape B, N, C = x.shape
qkv = self.qkv(x) qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1) q, k, v = qkv.chunk(3, dim=-1)
...@@ -233,22 +278,23 @@ class InternMLP(nn.Module): ...@@ -233,22 +278,23 @@ class InternMLP(nn.Module):
class InternVisionEncoderLayer(nn.Module): class InternVisionEncoderLayer(nn.Module):
def __init__(self, def __init__(
config: PretrainedConfig, self,
quant_config: Optional[QuantizationConfig] = None): config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_dummy_heads: int = 0,
) -> None:
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.intermediate_size = config.intermediate_size self.intermediate_size = config.intermediate_size
self.norm_type = config.norm_type self.norm_type = config.norm_type
# fallback to sdpa attention if tp unavailable self.attn = self._init_attn(config,
tp_size = get_tensor_model_parallel_world_size() quant_config,
num_heads = config.num_attention_heads num_dummy_heads=num_dummy_heads)
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.attn = InternParallelAttention(config,
quant_config=quant_config)
else:
self.attn = InternSdpaAttention(config)
self.mlp = InternMLP(config, quant_config=quant_config) self.mlp = InternMLP(config, quant_config=quant_config)
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
...@@ -260,6 +306,24 @@ class InternVisionEncoderLayer(nn.Module): ...@@ -260,6 +306,24 @@ class InternVisionEncoderLayer(nn.Module):
self.ls2 = nn.Parameter(config.initializer_factor * self.ls2 = nn.Parameter(config.initializer_factor *
torch.ones(self.embed_dim)) torch.ones(self.embed_dim))
def _init_attn(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
*,
num_dummy_heads: int,
):
# fallback to sdpa attention if tp unavailable
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
if USE_XFORMERS_OPS and (num_heads + num_dummy_heads) % tp_size == 0:
return InternParallelAttention(config,
quant_config=quant_config,
num_dummy_heads=num_dummy_heads)
return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -275,19 +339,27 @@ class InternVisionEncoderLayer(nn.Module): ...@@ -275,19 +339,27 @@ class InternVisionEncoderLayer(nn.Module):
class InternVisionEncoder(nn.Module): class InternVisionEncoder(nn.Module):
def __init__(self, def __init__(
config: PretrainedConfig, self,
quant_config: Optional[QuantizationConfig] = None, config: PretrainedConfig,
num_hidden_layers_override: Optional[int] = None): quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0,
):
super().__init__() super().__init__()
self.config = config self.config = config
if num_hidden_layers_override is None: if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers num_hidden_layers = config.num_hidden_layers
else: else:
num_hidden_layers = num_hidden_layers_override num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
InternVisionEncoderLayer(config=config, quant_config=quant_config) InternVisionEncoderLayer(config,
quant_config,
num_dummy_heads=num_dummy_heads)
for _ in range(num_hidden_layers) for _ in range(num_hidden_layers)
]) ])
...@@ -302,35 +374,25 @@ class InternVisionEncoder(nn.Module): ...@@ -302,35 +374,25 @@ class InternVisionEncoder(nn.Module):
class InternVisionModel(nn.Module): class InternVisionModel(nn.Module):
def __init__(self, def __init__(
config: PretrainedConfig, self,
quant_config: Optional[QuantizationConfig] = None, config: PretrainedConfig,
num_hidden_layers_override: Optional[int] = None): quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.embeddings = InternVisionEmbeddings(config) self.embeddings = InternVisionEmbeddings(config)
self.encoder = InternVisionEncoder( self.encoder = InternVisionEncoder(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override) num_hidden_layers_override=num_hidden_layers_override,
num_dummy_heads=num_dummy_heads,
def resize_pos_embeddings(self, old_size, new_size, patch_size): )
pos_emb = self.embeddings.position_embedding
_, num_positions, embed_dim = pos_emb.shape
cls_emb = pos_emb[:, :1, :]
pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size,
old_size // patch_size,
-1).permute(0, 3, 1, 2)
pos_emb = F.interpolate(pos_emb.float(),
size=new_size // patch_size,
mode='bicubic',
align_corners=False)
pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim,
-1).permute(0, 2, 1)
pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
self.embeddings.position_embedding = nn.Parameter(pos_emb)
self.embeddings.image_size = new_size
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings return self.embeddings
......
...@@ -18,8 +18,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -18,8 +18,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -28,6 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -28,6 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers)
...@@ -266,7 +266,7 @@ class InternLM2Model(nn.Module): ...@@ -266,7 +266,7 @@ class InternLM2Model(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: IntermediateTensors = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
...@@ -297,7 +297,7 @@ class InternLM2Model(nn.Module): ...@@ -297,7 +297,7 @@ class InternLM2Model(nn.Module):
return hidden_states return hidden_states
class InternLM2ForCausalLM(nn.Module): class InternLM2ForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
...@@ -325,7 +325,7 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -325,7 +325,7 @@ class InternLM2ForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: IntermediateTensors, intermediate_tensors: Optional[IntermediateTensors],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors)
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
import re import re
from functools import cached_property, partial
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -16,11 +17,10 @@ from transformers import PretrainedConfig ...@@ -16,11 +17,10 @@ from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_pp_group from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs token_inputs)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.intern_vit import InternVisionModel from vllm.model_executor.models.intern_vit import InternVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
...@@ -31,9 +31,9 @@ from vllm.utils import is_list_of ...@@ -31,9 +31,9 @@ from vllm.utils import is_list_of
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches) get_clip_num_patches)
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, group_weights_with_prefix, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
init_vllm_registered_model, merge_multimodal_embeddings) merge_multimodal_embeddings)
IMG_START = '<img>' IMG_START = '<img>'
IMG_END = '</img>' IMG_END = '</img>'
...@@ -122,6 +122,20 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int, ...@@ -122,6 +122,20 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
return blocks, target_width, target_height return blocks, target_width, target_height
def calculate_num_blocks_wrapper(hf_config: PretrainedConfig,
max_dynamic_patch: Optional[int] = None):
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
min_num = hf_config.min_dynamic_patch
image_size = hf_config.vision_config.image_size
use_thumbnail = hf_config.use_thumbnail
return partial(calculate_num_blocks,
min_num=min_num,
max_num=max_dynamic_patch,
image_size=image_size,
use_thumbnail=use_thumbnail)
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int, def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
image_size: int, image_size: int,
...@@ -168,172 +182,231 @@ def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int, ...@@ -168,172 +182,231 @@ def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
return pixel_values return pixel_values
def get_internvl_num_patches(image_size: int, patch_size: int, def image_to_pixel_values_wrapper(hf_config: PretrainedConfig,
downsample_ratio: float): max_dynamic_patch: Optional[int] = None):
image_size = hf_config.vision_config.image_size
min_num = hf_config.min_dynamic_patch
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
use_thumbnail = hf_config.use_thumbnail
return partial(image_to_pixel_values,
input_size=image_size,
min_num=min_num,
max_num=max_dynamic_patch,
use_thumbnail=use_thumbnail)
def get_internvl_num_patches(hf_config: PretrainedConfig):
vision_config = hf_config.vision_config
downsample_ratio = hf_config.downsample_ratio
image_size = vision_config.image_size
patch_size = vision_config.patch_size
return int( return int(
get_clip_num_patches(image_size=image_size, patch_size=patch_size) * get_clip_num_patches(image_size=image_size, patch_size=patch_size) *
(downsample_ratio**2)) (downsample_ratio**2))
def get_max_internvl_image_tokens(ctx: InputContext): def get_max_internvl_image_tokens(ctx: InputContext,
*,
max_dynamic_patch: Optional[int] = None):
hf_config = ctx.get_hf_config() hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
use_thumbnail = hf_config.use_thumbnail use_thumbnail = hf_config.use_thumbnail
max_dynamic_patch = hf_config.max_dynamic_patch if use_thumbnail and max_dynamic_patch > 1:
if use_thumbnail:
max_dynamic_patch += 1 max_dynamic_patch += 1
downsample_ratio = hf_config.downsample_ratio
image_size = vision_config.image_size num_patches = get_internvl_num_patches(hf_config)
patch_size = vision_config.patch_size
num_patches = get_internvl_num_patches(image_size, patch_size,
downsample_ratio)
return num_patches * max_dynamic_patch return num_patches * max_dynamic_patch
def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): def get_max_internvl_image_size(ctx: InputContext,
multi_modal_data = llm_inputs.get("multi_modal_data") *,
if multi_modal_data is None or "image" not in multi_modal_data: max_dynamic_patch: Optional[int] = None):
return llm_inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config() hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config image_size = hf_config.vision_config.image_size
image_size = vision_config.image_size
patch_size = vision_config.patch_size
downsample_ratio = hf_config.downsample_ratio
num_patches = get_internvl_num_patches(image_size, patch_size,
downsample_ratio)
image_data = multi_modal_data["image"] if max_dynamic_patch is None:
min_num = hf_config.min_dynamic_patch max_dynamic_patch = hf_config.max_dynamic_patch
max_num = hf_config.max_dynamic_patch
use_thumbnail = hf_config.use_thumbnail use_thumbnail = hf_config.use_thumbnail
if isinstance(image_data, Image.Image): if use_thumbnail and max_dynamic_patch > 1:
width, height = image_data.size max_dynamic_patch += 1
num_blocks, _, _ = calculate_num_blocks(width, height, min_num, width = image_size * max_dynamic_patch
max_num, image_size, height = image_size
use_thumbnail) return width, height
image_feature_size = [num_blocks * num_patches]
elif is_list_of(image_data, Image.Image):
image_feature_size = []
for image in image_data:
width, height = image.size
num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
max_num, image_size,
use_thumbnail)
image_feature_size.append(num_blocks * num_patches)
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
prompt = llm_inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"]
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
new_prompt = prompt
image_idx = sorted(map(int, re.findall(r"Image-(\d+): <image>\n", prompt)))
for idx, feature_size in enumerate(image_feature_size, start=1):
image_prompt = IMG_START + IMG_CONTEXT * feature_size + IMG_END
if not image_idx:
image_prompt = f"Image-{idx}: {image_prompt}"
new_prompt = new_prompt.replace('<image>', image_prompt, 1)
new_prompt_token_ids = tokenizer.encode(new_prompt)
return LLMInputs(prompt=prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data)
def input_mapper_for_internvl(ctx: InputContext, data: object):
hf_config = ctx.get_hf_config()
use_thumbnail = hf_config.use_thumbnail
min_num = hf_config.min_dynamic_patch
max_num = hf_config.max_dynamic_patch
image_size = hf_config.vision_config.image_size
if isinstance(data, Image.Image): class InternVLInputPipeline:
data = image_to_pixel_values(data,
image_size,
min_num,
max_num,
use_thumbnail=use_thumbnail)
# Add an N dimension for number of images per prompt (currently 1).
data = data.unsqueeze(0)
elif is_list_of(data, Image.Image):
# we can't stack here because the images may have different num_patches
data = [
image_to_pixel_values(img,
image_size,
min_num,
max_num,
use_thumbnail=use_thumbnail) for img in data
]
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
image_token_id = tokenizer.encode(IMG_CONTEXT,
add_special_tokens=False,
return_tensors="pt")[0]
return MultiModalInputs({
"pixel_values": data,
"image_token_id": image_token_id
})
def dummy_data_for_internvl(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
image_feature_size = get_max_internvl_image_tokens(ctx)
model_config = ctx.model_config
hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
seq_data = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_images,
image_token_id=tokenizer.encode(IMG_CONTEXT,
add_special_tokens=False)[0],
image_feature_size_override=image_feature_size,
)
image_size = vision_config.image_size def __init__(
min_num = hf_config.min_dynamic_patch self,
max_num = hf_config.max_dynamic_patch img_start_token: str,
max_image_width = max_num * image_size img_end_token: str,
max_image_height = min_num * image_size img_context_token: str,
) -> None:
super().__init__()
self.img_start_token = img_start_token
self.img_end_token = img_end_token
self.img_context_token = img_context_token
def _create_image_prompt(self, feature_size: int, num_patches: int) -> str:
return (self.img_start_token + self.img_context_token * feature_size +
self.img_end_token)
mm_data = dummy_image_for_clip( def _expand_image_prompt(
vision_config, self,
num_images, prompt: str,
image_width_override=max_image_width, feature_sizes: List[int],
image_height_override=max_image_height, num_patches: int,
) ) -> str:
image_idx = sorted(
map(int, re.findall(r"Image-(\d+): <image>\n", prompt)))
new_prompt = prompt
for idx, feature_size in enumerate(feature_sizes, start=1):
image_prompt = self._create_image_prompt(feature_size, num_patches)
if not image_idx:
image_prompt = f"Image-{idx}: {image_prompt}"
new_prompt = new_prompt.replace('<image>', image_prompt, 1)
return new_prompt
def input_processor(
self,
ctx: InputContext,
inputs: DecoderOnlyInputs,
*,
max_dynamic_patch: Optional[int] = None,
) -> DecoderOnlyInputs:
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config()
image_data = multi_modal_data["image"]
num_patches = get_internvl_num_patches(hf_config)
num_blocks_calculator = calculate_num_blocks_wrapper(
hf_config, max_dynamic_patch)
if isinstance(image_data, Image.Image):
width, height = image_data.size
num_blocks, _, _ = num_blocks_calculator(width, height)
image_feature_sizes = [num_blocks * num_patches]
elif is_list_of(image_data, Image.Image):
image_feature_sizes = []
for image in image_data:
width, height = image.size
num_blocks, _, _ = num_blocks_calculator(width, height)
image_feature_sizes.append(num_blocks * num_patches)
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
image_feature_sizes = [image_feature_size]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
prompt = inputs.get("prompt")
prompt_token_ids = inputs["prompt_token_ids"]
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
new_prompt = self._expand_image_prompt(prompt, image_feature_sizes,
num_patches)
new_prompt_token_ids = tokenizer.encode(new_prompt)
return token_inputs(prompt=prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data)
def input_mapper(
self,
ctx: InputContext,
data: object,
*,
max_dynamic_patch: Optional[int] = None,
):
hf_config = ctx.get_hf_config()
image_pixel_values_mapper = image_to_pixel_values_wrapper(
hf_config, max_dynamic_patch)
if isinstance(data, Image.Image):
data = image_pixel_values_mapper(data)
# Add an N dimension for number of images per prompt (currently 1).
data = data.unsqueeze(0)
elif is_list_of(data, Image.Image):
# we can't stack here because images may have different num_patches
data = [image_pixel_values_mapper(img) for img in data]
else:
return MultiModalInputs({"image_embeds": data})
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
image_token_id = tokenizer.encode(self.img_context_token,
add_special_tokens=False,
return_tensors="pt")[0]
return MultiModalInputs({
"pixel_values": data,
"image_token_id": image_token_id
})
def dummy_data(
self,
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
*,
max_dynamic_patch: Optional[int] = None,
):
num_images = mm_counts["image"]
hf_config = ctx.get_hf_config()
image_feature_size = get_max_internvl_image_tokens(
ctx, max_dynamic_patch=max_dynamic_patch)
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
seq_data = dummy_seq_data_for_clip(
hf_config.vision_config,
seq_len,
num_images,
image_token_id=tokenizer.encode(self.img_context_token,
add_special_tokens=False)[0],
image_feature_size_override=image_feature_size,
)
max_image_width, max_image_height = get_max_internvl_image_size(
ctx, max_dynamic_patch=max_dynamic_patch)
mm_data = dummy_image_for_clip(
hf_config.vision_config,
num_images,
image_width_override=max_image_width,
image_height_override=max_image_height,
)
return seq_data, mm_data return seq_data, mm_data
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_internvl) input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl) @INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data)
@INPUT_REGISTRY.register_input_processor(input_processor_for_internvl) @INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
class InternVLChatModel(nn.Module, SupportsMultiModal): class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
...@@ -360,29 +433,40 @@ class InternVLChatModel(nn.Module, SupportsMultiModal): ...@@ -360,29 +433,40 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
+ vision_feature_layer + 1 + vision_feature_layer + 1
else: else:
num_hidden_layers = vision_feature_layer + 1 num_hidden_layers = vision_feature_layer + 1
self.vision_model = InternVisionModel( self.vision_model = self._init_vision_model(config, num_hidden_layers)
config.vision_config, num_hidden_layers_override=num_hidden_layers)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config) config.text_config, cache_config, quant_config)
vit_hidden_size = config.vision_config.hidden_size self.mlp1 = self._init_mlp1(config)
llm_hidden_size = config.text_config.hidden_size
self.mlp1 = nn.Sequential(
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
llm_hidden_size), nn.GELU(),
nn.Linear(llm_hidden_size, llm_hidden_size))
self.img_context_token_id = None self.img_context_token_id = None
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"): if hasattr(self.language_model, "sampler"):
self.sampler = self.language_model.sampler return self.language_model.sampler
else:
self.sampler = Sampler() return Sampler()
def _init_vision_model(self, config: PretrainedConfig,
num_hidden_layers: int):
return InternVisionModel(config.vision_config,
num_hidden_layers_override=num_hidden_layers)
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size
return nn.Sequential(
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
llm_hidden_size),
nn.GELU(),
nn.Linear(llm_hidden_size, llm_hidden_size),
)
def pixel_shuffle(self, x, scale_factor=0.5): def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size() n, w, h, c = x.size()
...@@ -470,7 +554,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal): ...@@ -470,7 +554,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
self, self,
image_input: InternVLImageInputs, image_input: InternVLImageInputs,
) -> torch.Tensor: ) -> torch.Tensor:
if image_input["type"] == "image_embeds": if image_input["type"] == "image_embeds":
return image_input["data"] return image_input["data"]
...@@ -487,18 +570,22 @@ class InternVLChatModel(nn.Module, SupportsMultiModal): ...@@ -487,18 +570,22 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object, **kwargs: object,
) -> SamplerOutput: ) -> Union[SamplerOutput, IntermediateTensors]:
image_input = self._parse_and_validate_image_input(**kwargs) if intermediate_tensors is not None:
if image_input is not None and get_pp_group().is_first_rank:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.img_context_token_id)
input_ids = None input_ids = None
else:
inputs_embeds = None inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.img_context_token_id)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
...@@ -524,19 +611,5 @@ class InternVLChatModel(nn.Module, SupportsMultiModal): ...@@ -524,19 +611,5 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components loader = AutoWeightsLoader(self)
weights_group = group_weights_with_prefix(weights) loader.load_weights(weights)
# load vision encoder
self.vision_model.load_weights(weights_group["vision_model"])
# load mlp projector
mlp_params_dict = dict(self.mlp1.named_parameters())
for name, loaded_weight in weights_group["mlp1"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment