Commit 6d2051cc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.3.post1' into v0.6.3.post1-dev

parents 2c7f740a a2c71c54
......@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Deepseek model."""
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
......@@ -29,7 +29,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
......@@ -40,8 +40,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -50,6 +49,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class DeepseekMLP(nn.Module):
......@@ -329,6 +332,7 @@ class DeepseekModel(nn.Module):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
......@@ -338,14 +342,17 @@ class DeepseekModel(nn.Module):
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
DeepseekDecoderLayer(config,
layer_idx,
cache_config,
quant_config=quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekDecoderLayer(config,
int(prefix.split(".")[-1]),
cache_config,
quant_config=quant_config),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward(
self,
......@@ -353,19 +360,29 @@ class DeepseekModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
residual = None
else:
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], attn_metadata,
residual)
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class DeepseekForCausalLM(nn.Module):
class DeepseekForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
......@@ -384,6 +401,8 @@ class DeepseekForCausalLM(nn.Module):
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
......@@ -392,9 +411,9 @@ class DeepseekForCausalLM(nn.Module):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
......@@ -439,6 +458,8 @@ class DeepseekForCausalLM(nn.Module):
if (("mlp.experts." in name or "mlp.shared_experts." in name)
and name not in params_dict):
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
......@@ -451,6 +472,8 @@ class DeepseekForCausalLM(nn.Module):
if (("mlp.experts." in name or "mlp.shared_experts." in name)
and name not in params_dict):
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......
......@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only DeepseekV2 model."""
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
......@@ -40,8 +40,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -50,7 +49,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
from .interfaces import SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class DeepseekV2MLP(nn.Module):
......@@ -241,7 +242,7 @@ class DeepseekV2Attention(nn.Module):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
rope_scaling['type'] = 'deepseek_yarn'
rope_scaling["rope_type"] = 'deepseek_yarn'
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
......@@ -439,6 +440,9 @@ class DeepseekV2Model(nn.Module):
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward(
self,
......@@ -447,7 +451,7 @@ class DeepseekV2Model(nn.Module):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
residual = None
......@@ -472,7 +476,7 @@ class DeepseekV2Model(nn.Module):
return hidden_states
class DeepseekV2ForCausalLM(nn.Module):
class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
......@@ -492,6 +496,8 @@ class DeepseekV2ForCausalLM(nn.Module):
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
......@@ -500,7 +506,7 @@ class DeepseekV2ForCausalLM(nn.Module):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
......
......@@ -38,8 +38,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope
......@@ -53,8 +52,9 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.exaone import ExaoneConfig
from vllm.utils import is_hip
from .interfaces import SupportsLoRA
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class ExaoneGatedMLP(nn.Module):
......@@ -354,6 +354,10 @@ class ExaoneModel(nn.Module):
else:
self.ln_f = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids)
......@@ -397,7 +401,7 @@ class ExaoneModel(nn.Module):
return hidden_states
class ExaoneForCausalLM(nn.Module, SupportsLoRA):
class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
......@@ -477,6 +481,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA):
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
......@@ -506,24 +513,6 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros(
(batch_size, self.config.hidden_size),
dtype=dtype,
device=device,
),
"residual":
torch.zeros(
(batch_size, self.config.hidden_size),
dtype=dtype,
device=device,
),
})
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
......
......@@ -28,7 +28,7 @@ from transformers import FalconConfig as HF_FalconConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import get_act_fn
......@@ -36,8 +36,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -47,6 +46,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import RWConfig
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
FalconConfig = Union[HF_FalconConfig, RWConfig]
......@@ -333,6 +336,7 @@ class FalconModel(nn.Module):
config: FalconConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
......@@ -347,35 +351,56 @@ class FalconModel(nn.Module):
)
# Transformer blocks
self.h = nn.ModuleList([
FalconDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda prefix: FalconDecoderLayer(config, cache_config,
quant_config),
prefix=f"{prefix}.h")
# Final Layer Norm
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward(
self,
input_ids: torch.LongTensor,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
for i in range(len(self.h)):
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.word_embeddings(input_ids)
else:
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.h[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states)
return hidden_states
class FalconForCausalLM(nn.Module):
class FalconForCausalLM(nn.Module, SupportsPP):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {}
default_bitsandbytes_target_modules = [
".query_key_value.",
".dense.",
".dense_h_to_4h.",
".dense_4h_to_h.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".dense_4h_to_h.", ".dense."]
def __init__(
self,
......@@ -403,6 +428,8 @@ class FalconForCausalLM(nn.Module):
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward(
self,
......@@ -412,12 +439,8 @@ class FalconForCausalLM(nn.Module):
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.transformer(
input_ids,
positions,
kv_caches,
attn_metadata,
)
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
......@@ -454,6 +477,8 @@ class FalconForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
if "query_key_value" in name:
output_dim = getattr(param, "output_dim", None)
......
......@@ -27,11 +27,11 @@ from transformers import FuyuConfig, FuyuImageProcessor
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
......@@ -41,8 +41,8 @@ from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
......@@ -150,10 +150,10 @@ def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
return model_image_input
def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
return inputs
model_config = ctx.model_config
image_data = multi_modal_data["image"]
......@@ -165,7 +165,7 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
model_config.model)
model_image_input = _fuyu_image_preprocess(image_processor, image_data)
image_patches = torch.stack([
image_patches = torch.cat([
image_patch[0]
for image_patch in model_image_input["image_patches"]
])
......@@ -177,8 +177,8 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
raise TypeError(f"Invalid image type: {type(image_data)}")
# process prompts
prompt = llm_inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"]
prompt = inputs.get("prompt")
prompt_token_ids = inputs["prompt_token_ids"]
tokenizer = cached_get_tokenizer(model_config.model)
# dim0 is batch_size, dim1 is subseq_size which will always be 1
image_input_ids: List[List[
......@@ -191,9 +191,9 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[
1:] + boa_token
return LLMInputs(prompt=new_prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=new_multi_modal_data)
return token_inputs(prompt=new_prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=new_multi_modal_data)
def input_mapper_for_fuyu(ctx: InputContext, data: object):
......@@ -210,14 +210,14 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object):
])
# image has been processed with prompt in input processor
return MultiModalInputs({"image_patches": data})
return MultiModalInputs({"pixel_values": data})
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
class FuyuForCausalLM(nn.Module, SupportsMultiModal):
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self,
config: FuyuConfig,
......@@ -237,28 +237,54 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
self.image_feature_size,
config.hidden_size,
quant_config=quant_config,
gather_output=True,
)
self.language_model = PersimmonForCausalLM(config,
self.language_model = PersimmonForCausalLM(config.text_config,
cache_config=cache_config,
quant_config=quant_config)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@property
def sampler(self):
return self.language_model.sampler
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.patch_size
num_channels = self.config.num_channels
expected_dims = num_channels * h * w
def _validate_shape(d: torch.Tensor):
actual_dims = d.size(-1)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f" per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data.to(self.vision_embed_tokens.weight.dtype)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
image_patches = kwargs.pop("image_patches", None)
pixel_values = kwargs.pop("pixel_values", None)
if isinstance(image_patches, torch.Tensor):
# Remove the N dimension until multiple images are supported.
image_patches = image_patches.squeeze(1)
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image patches. "
f"Got type: {type(pixel_values)}")
return FuyuImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
)
expected_feature_size = self.image_feature_size
if image_patches.size(-1) != expected_feature_size:
raise ValueError(
f"Expected image patches to have the last dimension of "
f"{expected_feature_size}, got {image_patches.size(-1)}")
image_patches = image_patches.to(
self.vision_embed_tokens.weight.dtype)
return FuyuImagePixelInputs(type="pixel_values",
data=image_patches)
return None
def _process_image_input(
......@@ -277,23 +303,29 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
):
image_input = self._parse_and_validate_image_input(**kwargs)
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.embed_tokens(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
else:
inputs_embeds = None
else:
inputs_embeds = None
hidden_states = self.language_model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
......@@ -316,34 +348,5 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
param = params_dict[name]
if "query_key_value" in name:
# copy from vllm/model_executor/models/bloom.py
# NOTE: Fuyu's fused QKV's output_dim has the shape of
# (num_heads * 3 * head_size), while the
# required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion.
output_dim = getattr(param, "output_dim", None)
num_heads = self.config.num_attention_heads
if output_dim is not None:
loaded_weight_shape = loaded_weight.shape
loaded_weight = loaded_weight.view(
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
loaded_weight_shape[output_dim + 1:])
loaded_weight = loaded_weight.transpose(
output_dim, output_dim + 1)
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loader = AutoWeightsLoader(self)
loader.load_weights(weights)
......@@ -15,7 +15,7 @@
# limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights."""
from functools import lru_cache
from typing import Iterable, List, Optional, Set, Tuple
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
......@@ -23,7 +23,7 @@ from transformers import GemmaConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
......@@ -31,8 +31,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -41,7 +40,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
logger = init_logger(__name__)
......@@ -245,6 +246,7 @@ class GemmaModel(nn.Module):
config: GemmaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
......@@ -253,10 +255,11 @@ class GemmaModel(nn.Module):
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
GemmaDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: GemmaDecoderLayer(config, cache_config, quant_config
),
prefix=f"{prefix}.layers")
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Normalize the embedding by sqrt(hidden_size)
......@@ -265,6 +268,9 @@ class GemmaModel(nn.Module):
# See https://github.com/huggingface/transformers/pull/29402
normalizer = self.config.hidden_size**0.5
self.register_buffer("normalizer", torch.tensor(normalizer))
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
......@@ -275,29 +281,38 @@ class GemmaModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states *= self.normalizer
residual = None
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states *= self.normalizer
residual = None
for i in range(len(self.layers)):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class GemmaForCausalLM(nn.Module, SupportsLoRA):
class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
......@@ -317,6 +332,28 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
"gate_up_proj",
"down_proj",
]
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
# Gemma does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []
......@@ -339,6 +376,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
self.model = GemmaModel(config, cache_config, quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
......@@ -347,9 +386,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
......@@ -388,6 +427,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
......@@ -400,6 +441,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......
......@@ -14,15 +14,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, List, Optional, Set, Tuple
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
from transformers import Gemma2Config
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
......@@ -30,17 +31,20 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsLoRA
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
logger = init_logger(__name__)
......@@ -237,6 +241,13 @@ class Gemma2DecoderLayer(nn.Module):
return hidden_states, residual
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
"positions": 0,
"inputs_embeds": 0,
"intermediate_tensors": 0,
})
class Gemma2Model(nn.Module):
def __init__(
......@@ -244,6 +255,7 @@ class Gemma2Model(nn.Module):
config: Gemma2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
......@@ -252,10 +264,11 @@ class Gemma2Model(nn.Module):
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[
-1]), config, cache_config, quant_config),
prefix=f"{prefix}.layers")
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Normalize the embedding by sqrt(hidden_size)
......@@ -264,32 +277,92 @@ class Gemma2Model(nn.Module):
# See https://github.com/huggingface/transformers/pull/29402
normalizer = self.config.hidden_size**0.5
self.register_buffer("normalizer", torch.tensor(normalizer))
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
hidden_states *= self.normalizer
residual = None
for i in range(len(self.layers)):
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_tokens(input_ids)
hidden_states *= self.normalizer
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
logger.warning(
"Some weights are not initialized from checkpoints: %s",
unloaded_params)
class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
......@@ -312,6 +385,19 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
# Gemma does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
......@@ -338,6 +424,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(
config.vocab_size, soft_cap=config.final_logit_softcapping)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
......@@ -346,9 +434,9 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
......@@ -369,44 +457,56 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
loader.load_weights(weights)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
logger.warning(
"Some weights are not initialized from checkpoints: %s",
unloaded_params)
class Gemma2EmbeddingModel(nn.Module, SupportsPP):
"""
A model that uses Gemma2 with additional embedding functionalities.
This class encapsulates the Gemma2Model and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of Gemma2Model used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(
self,
**kwargs,
) -> None:
super().__init__()
self.model = Gemma2Model(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return self.model(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors, inputs_embeds)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)
# coding=utf-8
# Adapted from
# https://github.com/THUDM/GLM-4
"""Inference-only GLM-4v model visual encoder compatible with THUDM weights."""
from argparse import Namespace
from typing import Optional
import torch
from torch import nn
from torch.nn import LayerNorm
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
class PatchEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
self.proj = nn.Conv2d(config.in_channels,
config.hidden_size,
kernel_size=config.patch_size,
stride=config.patch_size)
self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
self.position_embedding = nn.Embedding(config.num_positions,
config.hidden_size)
def forward(self, images: torch.Tensor) -> torch.Tensor:
"""
Parameters:
images : torch.Tensor
Input image tensor with shape (B, C, H, W)
Returns:
torch.Tensor
Transformed tensor with shape (B, L, D)
"""
images = images.to(self.proj.weight.device)
x = self.proj(images)
x = x.flatten(2).transpose(1, 2)
cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
x += self.position_embedding.weight.unsqueeze(0)
return x
class Attention(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_rank = config.num_heads // self.tp_size
self.head_dim = config.hidden_size // config.num_heads
self.scale = self.head_dim**-0.5
self.query_key_value = QKVParallelLinear(
config.hidden_size,
self.head_dim,
config.num_heads,
quant_config=quant_config,
)
self.dense = RowParallelLinear(
config.hidden_size,
config.hidden_size,
quant_config=quant_config,
)
self.output_dropout = torch.nn.Dropout(config.dropout_prob)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, L, _ = x.shape
qkv, _ = self.query_key_value(x) # B, L, 3 * H * D
q, k, v = qkv.chunk(3, dim=-1)
q = q.reshape(B, L, self.num_heads_per_rank,
self.head_dim).permute(0, 2, 1, 3) # B, H, L, D
k = k.reshape(B, L, self.num_heads_per_rank,
self.head_dim).permute(0, 2, 1, 3) # B, H, L, D
v = v.reshape(B, L, self.num_heads_per_rank,
self.head_dim).permute(0, 2, 1, 3) # B, H, L, D
out = torch.nn.functional.scaled_dot_product_attention(q,
k,
v,
attn_mask=None,
dropout_p=0.,
is_causal=False)
# output, _ = self.dense(out.transpose(1, 2).view(B, L, -1))
output, _ = self.dense(out.transpose(1, 2).reshape(B, L, -1))
output = self.output_dropout(output)
return output
class MLP(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.fc1(x)
x = self.activation_fn(x)
x, _ = self.fc2(x)
return x
class TransformerLayer(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.input_layernorm = LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.attention = Attention(config, quant_config=quant_config)
self.mlp = MLP(config, quant_config=quant_config)
self.post_attention_layernorm = LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(self, hidden_states):
attention_input = hidden_states
attention_output = self.input_layernorm(
self.attention(attention_input))
hidden_states = attention_input + attention_output
mlp_input = hidden_states
mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
output = mlp_input + mlp_output
return output
class Transformer(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.layers = nn.ModuleList([
TransformerLayer(config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
def forward(self, hidden_states):
for layer_module in self.layers:
hidden_states = layer_module(hidden_states)
return hidden_states
class GLU(nn.Module):
def __init__(
self,
config,
in_features,
quant_config: Optional[QuantizationConfig] = None,
):
"""
The original implementation is the same as:
```python
self.dense_h_to_4h = ColumnParallelLinear(
config.hidden_size,
config.ffn_hidden_size,
bias=False,
quant_config=quant_config
)
self.gate_proj = ColumnParallelLinear(
config.hidden_size,
config.ffn_hidden_size,
bias=False,
quant_config=quant_config
)
```
```
gate_proj_output, _ = self.gate_proj(x)
dense_h_to_4h_output, _ = self.dense_h_to_4h(x)
x = torch.cat([gate_proj_output, dense_h_to_4h_output], dim=-1)
```
We merge two ColumnParallelLinear into one MergedColumnParallelLinear:
```
self.merged_proj = MergedColumnParallelLinear(
config.hidden_size,
[config.ffn_hidden_size] * 2,
bias=False,
quant_config=quant_config
)
```
```
x, _ = self.merged_proj(x)
```
"""
super().__init__()
self.linear_proj = ReplicatedLinear(in_features,
config.hidden_size,
bias=False,
quant_config=quant_config)
self.norm1 = nn.LayerNorm(config.hidden_size)
self.act1 = nn.GELU()
self.act2 = SiluAndMul()
self.merged_proj = MergedColumnParallelLinear(
config.hidden_size, [config.ffn_hidden_size] * 2,
bias=False,
quant_config=quant_config)
self.dense_4h_to_h = RowParallelLinear(config.ffn_hidden_size,
config.hidden_size,
bias=False,
quant_config=quant_config)
def forward(self, x):
x, _ = self.linear_proj(x)
x = self.act1(self.norm1(x))
x, _ = self.merged_proj(x)
x = self.act2(x)
x, _ = self.dense_4h_to_h(x)
return x
class EVA2CLIPModel(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
vision_config = Namespace(**config.vision_config)
self.patch_embedding = PatchEmbedding(vision_config)
self.transformer = Transformer(vision_config,
quant_config=quant_config)
self.linear_proj = GLU(config,
in_features=config.hidden_size,
quant_config=quant_config)
self.conv = nn.Conv2d(in_channels=vision_config.hidden_size,
out_channels=config.hidden_size,
kernel_size=2,
stride=2)
self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.scaling_factor = vision_config.scaling_factor
def forward(self, images: torch.Tensor) -> torch.Tensor:
"""
Parameters:
images : torch.Tensor
Input image tensor with shape (B, C, H, W)
Returns:
torch.Tensor
Transformed tensor with shape (B, L, D)
"""
x = self.patch_embedding(images)
x = self.transformer(x)
x = x[:, 1:]
b, s, h = x.shape
grid_size = int(s**0.5)
x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
x = self.conv(x)
x = x.flatten(2).transpose(1, 2)
x = self.linear_proj(x)
boi = self.boi.expand(x.shape[0], -1, -1)
eoi = self.eoi.expand(x.shape[0], -1, -1)
x = torch.cat((boi, x, eoi), dim=1)
x = x / self.scaling_factor
return x
......@@ -32,8 +32,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
......@@ -41,7 +40,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .utils import is_pp_missing_parameter, make_layers
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class GPT2Attention(nn.Module):
......@@ -204,6 +205,9 @@ class GPT2Model(nn.Module):
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h")
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd))
def forward(
self,
......@@ -234,7 +238,7 @@ class GPT2Model(nn.Module):
return hidden_states
class GPT2LMHeadModel(nn.Module):
class GPT2LMHeadModel(nn.Module, SupportsPP):
def __init__(
self,
......@@ -256,6 +260,8 @@ class GPT2LMHeadModel(nn.Module):
self.config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward(
self,
......@@ -264,7 +270,7 @@ class GPT2LMHeadModel(nn.Module):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
......@@ -286,16 +292,6 @@ class GPT2LMHeadModel(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
......
......@@ -18,7 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
......@@ -26,14 +26,13 @@ from transformers import GPTBigCodeConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
......@@ -41,7 +40,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class GPTBigCodeAttention(nn.Module):
......@@ -194,6 +195,7 @@ class GPTBigCodeModel(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
......@@ -207,11 +209,15 @@ class GPTBigCodeModel(nn.Module):
self.embed_dim,
org_num_embeddings=config.vocab_size)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList([
GPTBigCodeBlock(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda prefix: GPTBigCodeBlock(config, cache_config, quant_config),
prefix=f"{prefix}.h",
)
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd))
def forward(
self,
......@@ -219,20 +225,28 @@ class GPTBigCodeModel(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
else:
hidden_states = intermediate_tensors["hidden_states"]
for i in range(len(self.h)):
for i in range(self.start_layer, self.end_layer):
layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states)
return hidden_states
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {"c_attn": ["c_attn"]}
supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"]
......@@ -272,6 +286,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward(
self,
......@@ -280,9 +296,9 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
......@@ -311,6 +327,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......
......@@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-J model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
......@@ -24,14 +24,13 @@ from transformers import GPTJConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -40,6 +39,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class GPTJAttention(nn.Module):
......@@ -178,6 +181,7 @@ class GPTJModel(nn.Module):
config: GPTJConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
......@@ -186,11 +190,15 @@ class GPTJModel(nn.Module):
config.vocab_size,
self.embed_dim,
)
self.h = nn.ModuleList([
GPTJBlock(config, cache_config, quant_config)
for _ in range(config.n_layer)
])
self.start_layer, self.end_layer, self.h = make_layers(
config.n_layer,
lambda prefix: GPTJBlock(config, cache_config, quant_config),
prefix=f"{prefix}.h",
)
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd))
def forward(
self,
......@@ -198,21 +206,27 @@ class GPTJModel(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
for i in range(len(self.h)):
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.wte(input_ids)
else:
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.h[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i],
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states)
return hidden_states
class GPTJForCausalLM(nn.Module):
class GPTJForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
......@@ -233,6 +247,8 @@ class GPTJForCausalLM(nn.Module):
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward(
self,
......@@ -241,9 +257,9 @@ class GPTJForCausalLM(nn.Module):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
......@@ -283,6 +299,8 @@ class GPTJForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
......@@ -291,6 +309,8 @@ class GPTJForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......
......@@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
......@@ -24,14 +24,13 @@ from transformers import GPTNeoXConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -40,6 +39,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class GPTNeoXAttention(nn.Module):
......@@ -191,6 +194,7 @@ class GPTNeoXModel(nn.Module):
config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
......@@ -199,12 +203,16 @@ class GPTNeoXModel(nn.Module):
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
GPTNeoXLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: GPTNeoXLayer(config, cache_config, quant_config),
prefix=f"{prefix}.layers",
)
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward(
self,
......@@ -212,21 +220,27 @@ class GPTNeoXModel(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_in(input_ids)
for i in range(len(self.layers)):
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_in(input_ids)
else:
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i],
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
class GPTNeoXForCausalLM(nn.Module):
class GPTNeoXForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
......@@ -247,6 +261,8 @@ class GPTNeoXForCausalLM(nn.Module):
self.embed_out.weight = self.gpt_neox.embed_in.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.gpt_neox.make_empty_intermediate_tensors)
def forward(
self,
......@@ -255,9 +271,9 @@ class GPTNeoXForCausalLM(nn.Module):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
......@@ -288,6 +304,8 @@ class GPTNeoXForCausalLM(nn.Module):
# Models trained using OpenRLHF may include
# these tensors in the checkpoint. Skip them.
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
if "query_key_value" in name:
......
......@@ -51,7 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import is_hip
from .interfaces import SupportsLoRA
from .interfaces import SupportsLoRA, SupportsPP
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
......@@ -311,13 +311,13 @@ class GraniteModel(nn.Module):
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
hidden_states *= self.config.embedding_multiplier
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
hidden_states *= self.config.embedding_multiplier
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(
......@@ -337,7 +337,7 @@ class GraniteModel(nn.Module):
return hidden_states
class GraniteForCausalLM(nn.Module, SupportsLoRA):
class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
......@@ -404,9 +404,12 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA):
self.lm_head.weight = self.model.embed_tokens.weight
logit_scale = getattr(config, "logit_scale", 1.0)
if hasattr(config, "logits_scaling"):
logit_scale /= config.logits_scaling
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
scale=logit_scale)
self.sampler = Sampler()
else:
self.lm_head = PPMissingLayer()
......@@ -428,8 +431,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA):
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
if logits is not None:
logits /= self.config.logits_scaling
return logits
def sample(
......
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GraniteMoe model."""
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers.models.granitemoe import GraniteMoeConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from . import mixtral
from .interfaces import SupportsLoRA, SupportsPP
from .utils import make_layers
class GraniteMoeMoE(nn.Module):
"""A tensor-parallel MoE implementation for GraniteMoe that shards each
expert across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def __init__(self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = ""):
super().__init__()
self.hidden_size = hidden_size
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(hidden_size,
num_experts,
bias=False,
params_dtype=params_dtype,
quant_config=None,
prefix=f"{prefix}.gate")
self.experts = FusedMoE(num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=True,
quant_config=quant_config,
tp_size=tp_size,
prefix=f"{prefix}.experts")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits)
return final_hidden_states.view(orig_shape)
class GraniteMoeAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
attention_multiplier: Optional[float] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = (attention_multiplier if attention_multiplier
is not None else self.head_dim**-1)
self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class GraniteMoeDecoderLayer(nn.Module):
def __init__(
self,
config: GraniteMoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = GraniteMoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
attention_multiplier=config.attention_multiplier)
self.block_sparse_moe = GraniteMoeMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.block_sparse_moe")
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.residual_multiplier = config.residual_multiplier
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states * self.residual_multiplier
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.block_sparse_moe(hidden_states)
hidden_states = residual + hidden_states * self.residual_multiplier
return hidden_states
class GraniteMoeModel(nn.Module):
def __init__(
self,
config: GraniteMoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.embedding_multiplier = config.embedding_multiplier
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: GraniteMoeDecoderLayer(
config, cache_config, quant_config=quant_config, prefix=prefix
),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> torch.Tensor:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
hidden_states *= self.embedding_multiplier
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states = self.norm(hidden_states)
return hidden_states
class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
fall_back_to_pt_during_load = False
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,
config: GraniteMoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.lora_config = lora_config
self.model = GraniteMoeModel(config,
cache_config,
quant_config,
lora_config=lora_config,
prefix="model")
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
scale=1 /
self.config.logits_scaling)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
new_weights = {}
for n, p in weights:
if n.endswith('.block_sparse_moe.input_linear.weight'):
for e in range(p.size(0)):
w1_name = n.replace(
'.block_sparse_moe.input_linear.weight',
".block_sparse_moe.experts.%d.w1.weight" % e)
w3_name = n.replace(
'.block_sparse_moe.input_linear.weight',
".block_sparse_moe.experts.%d.w3.weight" % e)
w1_param, w3_param = p[e].chunk(2, dim=0)
assert w1_name not in new_weights
assert w3_name not in new_weights
new_weights[w1_name] = w1_param
new_weights[w3_name] = w3_param
elif n.endswith('.block_sparse_moe.output_linear.weight'):
for e in range(p.size(0)):
w2_name = n.replace(
'.block_sparse_moe.output_linear.weight',
".block_sparse_moe.experts.%d.w2.weight" % e)
w2_param = p[e]
assert w2_name not in new_weights
new_weights[w2_name] = w2_param
elif n.endswith('.block_sparse_moe.router.layer.weight'):
gate_name = n.replace('.block_sparse_moe.router.layer.weight',
".block_sparse_moe.gate.weight")
assert gate_name not in new_weights
new_weights[gate_name] = p
elif n == 'lm_head.weight' and self.config.tie_word_embeddings:
pass
else:
new_weights[n] = p
mixtral.MixtralForCausalLM.load_weights(self, new_weights.items())
......@@ -65,11 +65,10 @@ class Idefics2VisionEmbeddings(nn.Module):
self.position_embedding = nn.Embedding(self.num_positions,
self.embed_dim)
def forward(
self,
pixel_values: torch.FloatTensor,
patch_attention_mask: torch.BoolTensor,
) -> torch.Tensor:
def forward(self,
pixel_values: torch.FloatTensor,
patch_attention_mask: torch.BoolTensor,
tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor:
batch_size, _, max_im_h, max_im_w = pixel_values.shape
patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)
......@@ -84,8 +83,13 @@ class Idefics2VisionEmbeddings(nn.Module):
fill_value=0)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
if tgt_sizes is not None:
nb_patches_h = tgt_sizes[batch_idx][0]
nb_patches_w = tgt_sizes[batch_idx][1]
else:
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h,
......@@ -287,10 +291,12 @@ class Idefics2VisionTransformer(nn.Module):
self,
pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None,
) -> torch.tensor:
tgt_sizes: Optional[torch.IntTensor] = None,
) -> torch.Tensor:
hidden_states = self.embeddings(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask)
patch_attention_mask=patch_attention_mask,
tgt_sizes=tgt_sizes)
encoder_outputs = self.encoder(hidden_states)
last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state
from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
Union, overload, runtime_checkable)
from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
Protocol, Type, Union, overload, runtime_checkable)
import torch
from typing_extensions import TypeIs
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.utils import supports_kw
if TYPE_CHECKING:
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
from vllm.sequence import IntermediateTensors
logger = init_logger(__name__)
......@@ -22,7 +27,7 @@ class SupportsMultiModal(Protocol):
MRO of your model class.
"""
def __init__(self, *, multimodal_config: MultiModalConfig) -> None:
def __init__(self, *, multimodal_config: "MultiModalConfig") -> None:
...
......@@ -32,7 +37,7 @@ class SupportsMultiModal(Protocol):
class _SupportsMultiModalType(Protocol):
supports_multimodal: Literal[True]
def __call__(self, *, multimodal_config: MultiModalConfig) -> None:
def __call__(self, *, multimodal_config: "MultiModalConfig") -> None:
...
......@@ -75,7 +80,7 @@ class SupportsLoRA(Protocol):
embedding_padding_modules: ClassVar[List[str]]
# lora_config is None when LoRA is not enabled
def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
def __init__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None:
...
......@@ -90,7 +95,7 @@ class _SupportsLoRAType(Protocol):
embedding_modules: Dict[str, str]
embedding_padding_modules: List[str]
def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
def __call__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None:
...
......@@ -136,15 +141,128 @@ def supports_lora(
return result
def _supports_lora(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
def _supports_lora(model: Union[Type[object], object]) -> bool:
if isinstance(model, type):
return isinstance(model, _SupportsLoRAType)
return isinstance(model, SupportsLoRA)
@runtime_checkable
class SupportsPP(Protocol):
"""The interface required for all models that support pipeline parallel."""
supports_pp: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports pipeline parallel.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
def make_empty_intermediate_tensors(
self,
batch_size: int,
dtype: torch.dtype,
device: torch.device,
) -> "IntermediateTensors":
"""Called when PP rank > 0 for profiling purposes."""
...
def forward(
self,
*,
intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]:
"""
Accept :class:`IntermediateTensors` when PP rank > 0.
Return :class:`IntermediateTensors` only for the last PP rank.
"""
...
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsPPType(Protocol):
supports_pp: Literal[True]
def make_empty_intermediate_tensors(
self,
batch_size: int,
dtype: torch.dtype,
device: torch.device,
) -> "IntermediateTensors":
...
def forward(
self,
*,
intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]:
...
@overload
def supports_pp(model: Type[object]) -> TypeIs[Type[SupportsPP]]:
...
@overload
def supports_pp(model: object) -> TypeIs[SupportsPP]:
...
def supports_pp(
model: Union[Type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
supports_attributes = _supports_pp_attributes(model)
supports_inspect = _supports_pp_inspect(model)
if supports_attributes and not supports_inspect:
logger.warning(
"The model (%s) sets `supports_pp=True`, but does not accept "
"`intermediate_tensors` in its `forward` method", model)
if not supports_attributes:
pp_attrs = ("make_empty_intermediate_tensors", )
missing_attrs = tuple(attr for attr in pp_attrs
if not hasattr(model, attr))
if getattr(model, "supports_pp", False):
if missing_attrs:
logger.warning(
"The model (%s) sets `supports_pp=True`, "
"but is missing PP-specific attributes: %s",
model,
missing_attrs,
)
else:
if not missing_attrs:
logger.warning(
"The model (%s) contains all PP-specific attributes, "
"but does not set `supports_pp=True`.", model)
return supports_attributes and supports_inspect
def _supports_pp_attributes(model: Union[Type[object], object]) -> bool:
if isinstance(model, type):
return isinstance(model, _SupportsPPType)
return isinstance(model, SupportsPP)
def _supports_pp_inspect(model: Union[Type[object], object]) -> bool:
model_forward = getattr(model, "forward", None)
if not callable(model_forward):
return False
return supports_kw(model_forward, "intermediate_tensors")
@runtime_checkable
class HasInnerState(Protocol):
"""The interface required for all models that has inner state."""
......@@ -153,12 +271,12 @@ class HasInnerState(Protocol):
"""
A flag that indicates this model has inner state.
Models that has inner state usually need access to the scheduler_config
for max_num_seqs ,etc... (Currently only used by Jamba)
for max_num_seqs, etc. True for e.g. both Mamba and Jamba.
"""
def __init__(self,
*,
scheduler_config: Optional[SchedulerConfig] = None) -> None:
scheduler_config: Optional["SchedulerConfig"] = None) -> None:
...
......@@ -168,7 +286,7 @@ class _HasInnerStateType(Protocol):
def __init__(self,
*,
scheduler_config: Optional[SchedulerConfig] = None) -> None:
scheduler_config: Optional["SchedulerConfig"] = None) -> None:
...
......@@ -189,3 +307,46 @@ def has_inner_state(
return isinstance(model, _HasInnerStateType)
return isinstance(model, HasInnerState)
@runtime_checkable
class IsAttentionFree(Protocol):
"""The interface required for all models like Mamba that lack attention,
but do have state whose size is constant wrt the number of tokens."""
is_attention_free: ClassVar[Literal[True]] = True
"""
A flag that indicates this model has no attention.
Used for block manager and attention backend selection.
True for Mamba but not Jamba.
"""
def __init__(self) -> None:
...
@runtime_checkable
class _IsAttentionFreeType(Protocol):
is_attention_free: ClassVar[Literal[True]]
def __init__(self) -> None:
...
@overload
def is_attention_free(model: object) -> TypeIs[IsAttentionFree]:
...
@overload
def is_attention_free(model: Type[object]) -> TypeIs[Type[IsAttentionFree]]:
...
def is_attention_free(
model: Union[Type[object], object]
) -> Union[TypeIs[Type[IsAttentionFree]], TypeIs[IsAttentionFree]]:
if isinstance(model, type):
return isinstance(model, _IsAttentionFreeType)
return isinstance(model, IsAttentionFree)
from typing import (TYPE_CHECKING, List, Optional, Protocol, Type, Union,
overload, runtime_checkable)
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from typing_extensions import TypeIs, TypeVar
from vllm.logger import init_logger
from vllm.utils import supports_kw
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.model_executor.layers.pooler import PoolerOutput
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
logger = init_logger(__name__)
# The type of HF config
C_co = TypeVar("C_co", bound=PretrainedConfig, covariant=True)
# The type of hidden states
# Currently, T = torch.Tensor for all models except for Medusa
# which has T = List[torch.Tensor]
T = TypeVar("T", default=torch.Tensor)
T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)
# NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags
# for the base interfaces to avoid breaking OOT registration for existing models
# that don't inherit from the base interface classes
@runtime_checkable
class VllmModel(Protocol[C_co, T_co]):
def __init__(
self,
config: C_co,
*,
cache_config: Optional["CacheConfig"],
quant_config: Optional["QuantizationConfig"],
) -> None:
...
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: "AttentionMetadata",
) -> T_co:
...
def _check_vllm_model_init(model: Union[Type[object], object]) -> bool:
model_init = model.__init__
vllm_kws = ("cache_config", "quant_config")
missing_kws = tuple(kw for kw in vllm_kws
if not supports_kw(model_init, kw))
if missing_kws and (isinstance(model, type)
and issubclass(model, nn.Module)):
logger.warning(
"The model (%s) is missing "
"vLLM-specific keywords from its initializer: %s",
model,
missing_kws,
)
return len(missing_kws) == 0
def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool:
model_forward = getattr(model, "forward", None)
if not callable(model_forward):
return False
vllm_kws = ("input_ids", "positions", "kv_caches", "attn_metadata")
missing_kws = tuple(kw for kw in vllm_kws
if not supports_kw(model_forward, kw))
if missing_kws and (isinstance(model, type)
and issubclass(model, nn.Module)):
logger.warning(
"The model (%s) is missing "
"vLLM-specific keywords from its initializer: %s",
model,
missing_kws,
)
return len(missing_kws) == 0
@overload
def is_vllm_model(model: Type[object]) -> TypeIs[Type[VllmModel]]:
...
@overload
def is_vllm_model(model: object) -> TypeIs[VllmModel]:
...
def is_vllm_model(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[VllmModel]], TypeIs[VllmModel]]:
return _check_vllm_model_init(model) and _check_vllm_model_forward(model)
@runtime_checkable
class VllmModelForTextGeneration(VllmModel[C_co, T], Protocol[C_co, T]):
def compute_logits(
self,
hidden_states: T,
sampling_metadata: "SamplingMetadata",
) -> Optional[T]:
"""Return `None` if TP rank > 0."""
...
def sample(
self,
logits: T,
sampling_metadata: "SamplingMetadata",
) -> "SamplerOutput":
"""Only called on TP rank 0."""
...
@overload
def is_text_generation_model(
model: Type[object]) -> TypeIs[Type[VllmModelForTextGeneration]]:
...
@overload
def is_text_generation_model(
model: object) -> TypeIs[VllmModelForTextGeneration]:
...
def is_text_generation_model(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[VllmModelForTextGeneration]],
TypeIs[VllmModelForTextGeneration]]:
if not is_vllm_model(model):
return False
if isinstance(model, type):
return isinstance(model, VllmModelForTextGeneration)
return isinstance(model, VllmModelForTextGeneration)
@runtime_checkable
class VllmModelForEmbedding(VllmModel[C_co, T], Protocol[C_co, T]):
def pooler(
self,
hidden_states: T,
pooling_metadata: "PoolingMetadata",
) -> "PoolerOutput":
"""Only called on TP rank 0."""
...
@overload
def is_embedding_model(
model: Type[object]) -> TypeIs[Type[VllmModelForEmbedding]]:
...
@overload
def is_embedding_model(model: object) -> TypeIs[VllmModelForEmbedding]:
...
def is_embedding_model(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[VllmModelForEmbedding]], TypeIs[VllmModelForEmbedding]]:
if not is_vllm_model(model):
return False
if isinstance(model, type):
return isinstance(model, VllmModelForEmbedding)
return isinstance(model, VllmModelForEmbedding)
......@@ -4,6 +4,7 @@
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from functools import partial
from typing import Iterable, Optional, Tuple
import torch
......@@ -11,7 +12,10 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
......@@ -54,7 +58,7 @@ class InternVisionEmbeddings(nn.Module):
self.position_embedding = nn.Parameter(
torch.randn(1, self.num_positions, self.embed_dim))
def _get_pos_embed(self, pos_embed, H, W):
def _get_pos_embed(self, pos_embed: torch.Tensor, H: int, W: int):
target_dtype = pos_embed.dtype
pos_embed = pos_embed.float().reshape(
1, self.image_size // self.patch_size,
......@@ -63,9 +67,21 @@ class InternVisionEmbeddings(nn.Module):
size=(H, W),
mode='bicubic',
align_corners=False)
pos_embed = pos_embed.reshape(1, -1, H * W).permute(0, 2,
1).to(target_dtype)
return pos_embed
return pos_embed.reshape(1, -1, H * W).permute(0, 2,
1).to(target_dtype)
def _get_position_embedding(self, H: int, W: int) -> torch.Tensor:
position_embedding = self.position_embedding
if self.num_patches == H * W:
return position_embedding
return torch.cat(
[
position_embedding[:, :1, :],
self._get_pos_embed(position_embedding[:, 1:, :], H, W),
],
dim=1,
)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
target_dtype = self.patch_embedding.weight.dtype
......@@ -76,12 +92,7 @@ class InternVisionEmbeddings(nn.Module):
class_embeds = self.class_embedding.expand(batch_size, 1,
-1).to(target_dtype)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
position_embedding = torch.cat([
self.position_embedding[:, :1, :],
self._get_pos_embed(self.position_embedding[:, 1:, :], height,
width)
],
dim=1)
position_embedding = self._get_position_embedding(height, width)
embeddings = embeddings + position_embedding.to(target_dtype)
return embeddings
......@@ -93,8 +104,11 @@ class InternParallelAttention(nn.Module):
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
):
*,
num_dummy_heads: int = 0,
) -> None:
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
......@@ -105,11 +119,19 @@ class InternParallelAttention(nn.Module):
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
f' {self.num_heads}).')
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
# Additional dummy heads are used to enable TP for common GPU counts.
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
self.num_heads_per_partition = divide(num_dummy_heads + self.num_heads,
self.tp_size)
self.scale = self.head_dim**-0.5
self.qkv = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
num_dummy_heads + self.num_heads,
bias=config.qkv_bias,
quant_config=quant_config,
)
......@@ -117,34 +139,44 @@ class InternParallelAttention(nn.Module):
self.qk_normalization = config.qk_normalization
if self.qk_normalization:
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.q_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.k_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.proj = RowParallelLinear(
self.embed_dim,
self.dummy_dim,
self.embed_dim,
quant_config=quant_config,
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
def forward(self, x):
B, N, C = x.shape
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous())
q = self.q_norm.forward_native(q)
k = self.k_norm.forward_native(k)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
return q, k
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, _ = x.shape
qkv, _ = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
if self.qk_normalization:
q, k = self._apply_qk_norm(q, k)
q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
if self.qk_normalization:
B_, N_, H_, D_ = q.shape
q = self.q_norm.forward_native(q.flatten(-2,
-1)).view(B_, N_, H_, D_)
k = self.k_norm.forward_native(k.flatten(-2,
-1)).view(B_, N_, H_, D_)
x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
x = x.view(B, N, -1)
......@@ -155,8 +187,14 @@ class InternParallelAttention(nn.Module):
class InternSdpaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: PretrainedConfig):
def __init__(
self,
config: PretrainedConfig,
*,
num_dummy_heads: int = 0,
) -> None:
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
......@@ -167,20 +205,27 @@ class InternSdpaAttention(nn.Module):
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
f' {self.num_heads}).')
# Additional dummy heads are used to enable TP for common GPU counts.
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
self.scale = self.head_dim**-0.5
self.qkv = nn.Linear(self.embed_dim,
3 * self.embed_dim,
3 * self.dummy_dim,
bias=config.qkv_bias)
self.qk_normalization = config.qk_normalization
if self.qk_normalization:
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.q_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.k_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
self.proj = nn.Linear(self.dummy_dim, self.embed_dim)
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
......@@ -233,22 +278,23 @@ class InternMLP(nn.Module):
class InternVisionEncoderLayer(nn.Module):
def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_dummy_heads: int = 0,
) -> None:
super().__init__()
self.embed_dim = config.hidden_size
self.intermediate_size = config.intermediate_size
self.norm_type = config.norm_type
# fallback to sdpa attention if tp unavailable
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.attn = InternParallelAttention(config,
quant_config=quant_config)
else:
self.attn = InternSdpaAttention(config)
self.attn = self._init_attn(config,
quant_config,
num_dummy_heads=num_dummy_heads)
self.mlp = InternMLP(config, quant_config=quant_config)
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
eps=config.layer_norm_eps)
......@@ -260,6 +306,24 @@ class InternVisionEncoderLayer(nn.Module):
self.ls2 = nn.Parameter(config.initializer_factor *
torch.ones(self.embed_dim))
def _init_attn(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
*,
num_dummy_heads: int,
):
# fallback to sdpa attention if tp unavailable
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
if USE_XFORMERS_OPS and (num_heads + num_dummy_heads) % tp_size == 0:
return InternParallelAttention(config,
quant_config=quant_config,
num_dummy_heads=num_dummy_heads)
return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads)
def forward(
self,
hidden_states: torch.Tensor,
......@@ -275,19 +339,27 @@ class InternVisionEncoderLayer(nn.Module):
class InternVisionEncoder(nn.Module):
def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0,
):
super().__init__()
self.config = config
if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([
InternVisionEncoderLayer(config=config, quant_config=quant_config)
InternVisionEncoderLayer(config,
quant_config,
num_dummy_heads=num_dummy_heads)
for _ in range(num_hidden_layers)
])
......@@ -302,35 +374,25 @@ class InternVisionEncoder(nn.Module):
class InternVisionModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0,
):
super().__init__()
self.config = config
self.embeddings = InternVisionEmbeddings(config)
self.encoder = InternVisionEncoder(
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override)
def resize_pos_embeddings(self, old_size, new_size, patch_size):
pos_emb = self.embeddings.position_embedding
_, num_positions, embed_dim = pos_emb.shape
cls_emb = pos_emb[:, :1, :]
pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size,
old_size // patch_size,
-1).permute(0, 3, 1, 2)
pos_emb = F.interpolate(pos_emb.float(),
size=new_size // patch_size,
mode='bicubic',
align_corners=False)
pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim,
-1).permute(0, 2, 1)
pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
self.embeddings.position_embedding = nn.Parameter(pos_emb)
self.embeddings.image_size = new_size
num_hidden_layers_override=num_hidden_layers_override,
num_dummy_heads=num_dummy_heads,
)
def get_input_embeddings(self):
return self.embeddings
......
......@@ -18,8 +18,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -28,6 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
......@@ -266,7 +266,7 @@ class InternLM2Model(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: IntermediateTensors = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
......@@ -297,7 +297,7 @@ class InternLM2Model(nn.Module):
return hidden_states
class InternLM2ForCausalLM(nn.Module):
class InternLM2ForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
......@@ -325,7 +325,7 @@ class InternLM2ForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: IntermediateTensors,
intermediate_tensors: Optional[IntermediateTensors],
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
......
......@@ -5,6 +5,7 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import re
from functools import cached_property, partial
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
......@@ -16,11 +17,10 @@ from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_pp_group
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.intern_vit import InternVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
......@@ -31,9 +31,9 @@ from vllm.utils import is_list_of
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches)
from .interfaces import SupportsMultiModal
from .utils import (flatten_bn, group_weights_with_prefix,
init_vllm_registered_model, merge_multimodal_embeddings)
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings)
IMG_START = '<img>'
IMG_END = '</img>'
......@@ -122,6 +122,20 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
return blocks, target_width, target_height
def calculate_num_blocks_wrapper(hf_config: PretrainedConfig,
max_dynamic_patch: Optional[int] = None):
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
min_num = hf_config.min_dynamic_patch
image_size = hf_config.vision_config.image_size
use_thumbnail = hf_config.use_thumbnail
return partial(calculate_num_blocks,
min_num=min_num,
max_num=max_dynamic_patch,
image_size=image_size,
use_thumbnail=use_thumbnail)
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
image_size: int,
......@@ -168,172 +182,231 @@ def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
return pixel_values
def get_internvl_num_patches(image_size: int, patch_size: int,
downsample_ratio: float):
def image_to_pixel_values_wrapper(hf_config: PretrainedConfig,
max_dynamic_patch: Optional[int] = None):
image_size = hf_config.vision_config.image_size
min_num = hf_config.min_dynamic_patch
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
use_thumbnail = hf_config.use_thumbnail
return partial(image_to_pixel_values,
input_size=image_size,
min_num=min_num,
max_num=max_dynamic_patch,
use_thumbnail=use_thumbnail)
def get_internvl_num_patches(hf_config: PretrainedConfig):
vision_config = hf_config.vision_config
downsample_ratio = hf_config.downsample_ratio
image_size = vision_config.image_size
patch_size = vision_config.patch_size
return int(
get_clip_num_patches(image_size=image_size, patch_size=patch_size) *
(downsample_ratio**2))
def get_max_internvl_image_tokens(ctx: InputContext):
def get_max_internvl_image_tokens(ctx: InputContext,
*,
max_dynamic_patch: Optional[int] = None):
hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
use_thumbnail = hf_config.use_thumbnail
max_dynamic_patch = hf_config.max_dynamic_patch
if use_thumbnail:
if use_thumbnail and max_dynamic_patch > 1:
max_dynamic_patch += 1
downsample_ratio = hf_config.downsample_ratio
image_size = vision_config.image_size
patch_size = vision_config.patch_size
num_patches = get_internvl_num_patches(image_size, patch_size,
downsample_ratio)
num_patches = get_internvl_num_patches(hf_config)
return num_patches * max_dynamic_patch
def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
model_config = ctx.model_config
def get_max_internvl_image_size(ctx: InputContext,
*,
max_dynamic_patch: Optional[int] = None):
hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config
image_size = vision_config.image_size
patch_size = vision_config.patch_size
downsample_ratio = hf_config.downsample_ratio
num_patches = get_internvl_num_patches(image_size, patch_size,
downsample_ratio)
image_size = hf_config.vision_config.image_size
image_data = multi_modal_data["image"]
min_num = hf_config.min_dynamic_patch
max_num = hf_config.max_dynamic_patch
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
use_thumbnail = hf_config.use_thumbnail
if isinstance(image_data, Image.Image):
width, height = image_data.size
num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
max_num, image_size,
use_thumbnail)
image_feature_size = [num_blocks * num_patches]
elif is_list_of(image_data, Image.Image):
image_feature_size = []
for image in image_data:
width, height = image.size
num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
max_num, image_size,
use_thumbnail)
image_feature_size.append(num_blocks * num_patches)
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
prompt = llm_inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"]
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
new_prompt = prompt
image_idx = sorted(map(int, re.findall(r"Image-(\d+): <image>\n", prompt)))
for idx, feature_size in enumerate(image_feature_size, start=1):
image_prompt = IMG_START + IMG_CONTEXT * feature_size + IMG_END
if not image_idx:
image_prompt = f"Image-{idx}: {image_prompt}"
new_prompt = new_prompt.replace('<image>', image_prompt, 1)
new_prompt_token_ids = tokenizer.encode(new_prompt)
return LLMInputs(prompt=prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data)
def input_mapper_for_internvl(ctx: InputContext, data: object):
hf_config = ctx.get_hf_config()
if use_thumbnail and max_dynamic_patch > 1:
max_dynamic_patch += 1
width = image_size * max_dynamic_patch
height = image_size
return width, height
use_thumbnail = hf_config.use_thumbnail
min_num = hf_config.min_dynamic_patch
max_num = hf_config.max_dynamic_patch
image_size = hf_config.vision_config.image_size
if isinstance(data, Image.Image):
data = image_to_pixel_values(data,
image_size,
min_num,
max_num,
use_thumbnail=use_thumbnail)
# Add an N dimension for number of images per prompt (currently 1).
data = data.unsqueeze(0)
elif is_list_of(data, Image.Image):
# we can't stack here because the images may have different num_patches
data = [
image_to_pixel_values(img,
image_size,
min_num,
max_num,
use_thumbnail=use_thumbnail) for img in data
]
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
image_token_id = tokenizer.encode(IMG_CONTEXT,
add_special_tokens=False,
return_tensors="pt")[0]
return MultiModalInputs({
"pixel_values": data,
"image_token_id": image_token_id
})
def dummy_data_for_internvl(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
image_feature_size = get_max_internvl_image_tokens(ctx)
model_config = ctx.model_config
hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
seq_data = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_images,
image_token_id=tokenizer.encode(IMG_CONTEXT,
add_special_tokens=False)[0],
image_feature_size_override=image_feature_size,
)
class InternVLInputPipeline:
image_size = vision_config.image_size
min_num = hf_config.min_dynamic_patch
max_num = hf_config.max_dynamic_patch
max_image_width = max_num * image_size
max_image_height = min_num * image_size
def __init__(
self,
img_start_token: str,
img_end_token: str,
img_context_token: str,
) -> None:
super().__init__()
self.img_start_token = img_start_token
self.img_end_token = img_end_token
self.img_context_token = img_context_token
def _create_image_prompt(self, feature_size: int, num_patches: int) -> str:
return (self.img_start_token + self.img_context_token * feature_size +
self.img_end_token)
mm_data = dummy_image_for_clip(
vision_config,
num_images,
image_width_override=max_image_width,
image_height_override=max_image_height,
)
def _expand_image_prompt(
self,
prompt: str,
feature_sizes: List[int],
num_patches: int,
) -> str:
image_idx = sorted(
map(int, re.findall(r"Image-(\d+): <image>\n", prompt)))
new_prompt = prompt
for idx, feature_size in enumerate(feature_sizes, start=1):
image_prompt = self._create_image_prompt(feature_size, num_patches)
if not image_idx:
image_prompt = f"Image-{idx}: {image_prompt}"
new_prompt = new_prompt.replace('<image>', image_prompt, 1)
return new_prompt
def input_processor(
self,
ctx: InputContext,
inputs: DecoderOnlyInputs,
*,
max_dynamic_patch: Optional[int] = None,
) -> DecoderOnlyInputs:
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config()
image_data = multi_modal_data["image"]
num_patches = get_internvl_num_patches(hf_config)
num_blocks_calculator = calculate_num_blocks_wrapper(
hf_config, max_dynamic_patch)
if isinstance(image_data, Image.Image):
width, height = image_data.size
num_blocks, _, _ = num_blocks_calculator(width, height)
image_feature_sizes = [num_blocks * num_patches]
elif is_list_of(image_data, Image.Image):
image_feature_sizes = []
for image in image_data:
width, height = image.size
num_blocks, _, _ = num_blocks_calculator(width, height)
image_feature_sizes.append(num_blocks * num_patches)
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
image_feature_sizes = [image_feature_size]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
prompt = inputs.get("prompt")
prompt_token_ids = inputs["prompt_token_ids"]
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
new_prompt = self._expand_image_prompt(prompt, image_feature_sizes,
num_patches)
new_prompt_token_ids = tokenizer.encode(new_prompt)
return token_inputs(prompt=prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data)
def input_mapper(
self,
ctx: InputContext,
data: object,
*,
max_dynamic_patch: Optional[int] = None,
):
hf_config = ctx.get_hf_config()
image_pixel_values_mapper = image_to_pixel_values_wrapper(
hf_config, max_dynamic_patch)
if isinstance(data, Image.Image):
data = image_pixel_values_mapper(data)
# Add an N dimension for number of images per prompt (currently 1).
data = data.unsqueeze(0)
elif is_list_of(data, Image.Image):
# we can't stack here because images may have different num_patches
data = [image_pixel_values_mapper(img) for img in data]
else:
return MultiModalInputs({"image_embeds": data})
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
image_token_id = tokenizer.encode(self.img_context_token,
add_special_tokens=False,
return_tensors="pt")[0]
return MultiModalInputs({
"pixel_values": data,
"image_token_id": image_token_id
})
def dummy_data(
self,
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
*,
max_dynamic_patch: Optional[int] = None,
):
num_images = mm_counts["image"]
hf_config = ctx.get_hf_config()
image_feature_size = get_max_internvl_image_tokens(
ctx, max_dynamic_patch=max_dynamic_patch)
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
seq_data = dummy_seq_data_for_clip(
hf_config.vision_config,
seq_len,
num_images,
image_token_id=tokenizer.encode(self.img_context_token,
add_special_tokens=False)[0],
image_feature_size_override=image_feature_size,
)
max_image_width, max_image_height = get_max_internvl_image_size(
ctx, max_dynamic_patch=max_dynamic_patch)
mm_data = dummy_image_for_clip(
hf_config.vision_config,
num_images,
image_width_override=max_image_width,
image_height_override=max_image_height,
)
return seq_data, mm_data
return seq_data, mm_data
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_internvl)
input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl)
@INPUT_REGISTRY.register_input_processor(input_processor_for_internvl)
class InternVLChatModel(nn.Module, SupportsMultiModal):
@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data)
@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self,
config: PretrainedConfig,
......@@ -360,29 +433,40 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
self.vision_model = InternVisionModel(
config.vision_config, num_hidden_layers_override=num_hidden_layers)
self.vision_model = self._init_vision_model(config, num_hidden_layers)
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size
self.mlp1 = nn.Sequential(
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
llm_hidden_size), nn.GELU(),
nn.Linear(llm_hidden_size, llm_hidden_size))
self.mlp1 = self._init_mlp1(config)
self.img_context_token_id = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
self.sampler = self.language_model.sampler
else:
self.sampler = Sampler()
return self.language_model.sampler
return Sampler()
def _init_vision_model(self, config: PretrainedConfig,
num_hidden_layers: int):
return InternVisionModel(config.vision_config,
num_hidden_layers_override=num_hidden_layers)
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size
return nn.Sequential(
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
llm_hidden_size),
nn.GELU(),
nn.Linear(llm_hidden_size, llm_hidden_size),
)
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
......@@ -470,7 +554,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
self,
image_input: InternVLImageInputs,
) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
......@@ -487,18 +570,22 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> SamplerOutput:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None and get_pp_group().is_first_rank:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.img_context_token_id)
) -> Union[SamplerOutput, IntermediateTensors]:
if intermediate_tensors is not None:
input_ids = None
else:
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.img_context_token_id)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids,
positions,
......@@ -524,19 +611,5 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
weights_group = group_weights_with_prefix(weights)
# load vision encoder
self.vision_model.load_weights(weights_group["vision_model"])
# load mlp projector
mlp_params_dict = dict(self.mlp1.named_parameters())
for name, loaded_weight in weights_group["mlp1"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
loader = AutoWeightsLoader(self)
loader.load_weights(weights)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment