Unverified Commit 8685ba1a authored by manikandan.tm@zucisystems.com's avatar manikandan.tm@zucisystems.com Committed by GitHub
Browse files

Inclusion of InternVLChatModel In PP_SUPPORTED_MODELS(Pipeline Parallelism) (#7860)

parent 288a9388
...@@ -18,23 +18,26 @@ logger = init_logger("test_pipeline_parallel") ...@@ -18,23 +18,26 @@ logger = init_logger("test_pipeline_parallel")
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
@pytest.mark.parametrize(("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, " @pytest.mark.parametrize(
"MODEL_NAME, DIST_BACKEND"), ("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, "
[ "MODEL_NAME, DIST_BACKEND"),
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), [
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), (1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), (1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), (1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
]) (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 1, 1, 1, "internlm/internlm2_5-7b-chat", "ray"),
],
)
@fork_new_process_for_each_test @fork_new_process_for_each_test
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
DIST_BACKEND): TRUST_REMOTE_CODE, MODEL_NAME, DIST_BACKEND):
if VLLM_MULTI_NODE and DIST_BACKEND == "mp": if VLLM_MULTI_NODE and DIST_BACKEND == "mp":
pytest.skip("Skipping multi-node pipeline parallel test for " pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend") "multiprocessing distributed backend")
...@@ -71,6 +74,9 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, ...@@ -71,6 +74,9 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
if EAGER_MODE: if EAGER_MODE:
pp_args.append("--enforce-eager") pp_args.append("--enforce-eager")
tp_args.append("--enforce-eager") tp_args.append("--enforce-eager")
if TRUST_REMOTE_CODE:
pp_args.append("--trust-remote-code")
tp_args.append("--trust-remote-code")
pp_env = None pp_env = None
if (DIST_BACKEND == "ray" and TP_SIZE == 2 and PP_SIZE == 2 if (DIST_BACKEND == "ray" and TP_SIZE == 2 and PP_SIZE == 2
and CHUNKED_PREFILL): and CHUNKED_PREFILL):
......
...@@ -178,7 +178,12 @@ def compare_two_settings(model: str, ...@@ -178,7 +178,12 @@ def compare_two_settings(model: str,
env2: The second set of environment variables to pass to the API server. env2: The second set of environment variables to pass to the API server.
""" """
tokenizer = AutoTokenizer.from_pretrained(model) trust_remote_code = "--trust-remote-code"
if trust_remote_code in arg1 or trust_remote_code in arg2:
tokenizer = AutoTokenizer.from_pretrained(model,
trust_remote_code=True)
else:
tokenizer = AutoTokenizer.from_pretrained(model)
prompt = "Hello, my name is" prompt = "Hello, my name is"
token_ids = tokenizer(prompt)["input_ids"] token_ids = tokenizer(prompt)["input_ids"]
......
...@@ -35,18 +35,20 @@ _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 ...@@ -35,18 +35,20 @@ _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096 _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096
_PP_SUPPORTED_MODELS = [ _PP_SUPPORTED_MODELS = [
"AquilaModel",
"AquilaForCausalLM", "AquilaForCausalLM",
"AquilaModel",
"DeepseekV2ForCausalLM", "DeepseekV2ForCausalLM",
"GPT2LMHeadModel",
"InternLM2ForCausalLM",
"InternLMForCausalLM", "InternLMForCausalLM",
"InternVLChatModel",
"JAISLMHeadModel", "JAISLMHeadModel",
"LlamaForCausalLM", "LlamaForCausalLM",
"LLaMAForCausalLM", "LLaMAForCausalLM",
"MistralForCausalLM", "MistralForCausalLM",
"Phi3ForCausalLM",
"GPT2LMHeadModel",
"MixtralForCausalLM", "MixtralForCausalLM",
"NemotronForCausalLM", "NemotronForCausalLM",
"Phi3ForCausalLM",
"Qwen2ForCausalLM", "Qwen2ForCausalLM",
"Qwen2MoeForCausalLM", "Qwen2MoeForCausalLM",
"QWenLMHeadModel", "QWenLMHeadModel",
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from functools import partial from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -8,7 +8,7 @@ from transformers import PretrainedConfig ...@@ -8,7 +8,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
split_tensor_along_last_dim, split_tensor_along_last_dim,
tensor_model_parallel_all_gather) tensor_model_parallel_all_gather)
...@@ -28,6 +28,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -28,6 +28,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class InternLM2MLP(nn.Module): class InternLM2MLP(nn.Module):
...@@ -234,6 +237,7 @@ class InternLM2Model(nn.Module): ...@@ -234,6 +237,7 @@ class InternLM2Model(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -243,11 +247,15 @@ class InternLM2Model(nn.Module): ...@@ -243,11 +247,15 @@ class InternLM2Model(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
InternLMDecoderLayer(config, cache_config, quant_config) config.num_hidden_layers,
for _ in range(config.num_hidden_layers) lambda prefix: InternLMDecoderLayer(config, cache_config,
]) quant_config),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.tok_embeddings(input_ids) return self.tok_embeddings(input_ids)
...@@ -260,21 +268,31 @@ class InternLM2Model(nn.Module): ...@@ -260,21 +268,31 @@ class InternLM2Model(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: IntermediateTensors = None, intermediate_tensors: IntermediateTensors = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
if inputs_embeds is not None: if get_pp_group().is_first_rank:
hidden_states = inputs_embeds if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.tok_embeddings(input_ids)
residual = None
else: else:
hidden_states = self.tok_embeddings(input_ids) assert intermediate_tensors is not None
residual = None hidden_states = intermediate_tensors["hidden_states"]
for i in range(len(self.layers)): residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
...@@ -298,6 +316,8 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -298,6 +316,8 @@ class InternLM2ForCausalLM(nn.Module):
self.output.weight = self.model.tok_embeddings.weight self.output.weight = self.model.tok_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -308,7 +328,7 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -308,7 +328,7 @@ class InternLM2ForCausalLM(nn.Module):
intermediate_tensors: IntermediateTensors, intermediate_tensors: IntermediateTensors,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -345,6 +365,8 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -345,6 +365,8 @@ class InternLM2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -353,6 +375,8 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -353,6 +375,8 @@ class InternLM2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -341,6 +341,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal): ...@@ -341,6 +341,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
nn.Linear(llm_hidden_size, llm_hidden_size)) nn.Linear(llm_hidden_size, llm_hidden_size))
self.img_context_token_id = None self.img_context_token_id = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def pixel_shuffle(self, x, scale_factor=0.5): def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size() n, w, h, c = x.size()
...@@ -461,7 +463,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal): ...@@ -461,7 +463,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
None, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
......
...@@ -12,6 +12,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -12,6 +12,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.multimodal.base import NestedTensors from vllm.multimodal.base import NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
...@@ -279,3 +280,18 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: ...@@ -279,3 +280,18 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
if name.startswith(missing_layer_name): if name.startswith(missing_layer_name):
return True return True
return False return False
def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
def make_empty_intermediate_tensors(
batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
key: torch.zeros((batch_size, hidden_size),
dtype=dtype,
device=device)
for key in keys
})
return make_empty_intermediate_tensors
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment