Unverified Commit 7025b11d authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Fix weight loading for Chameleon when TP>1 (#7410)

parent 5469146b
...@@ -4,7 +4,8 @@ import os ...@@ -4,7 +4,8 @@ import os
import sys import sys
from collections import UserList from collections import UserList
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict,
TypeVar, Union)
import pytest import pytest
import torch import torch
...@@ -27,7 +28,7 @@ from vllm.logger import init_logger ...@@ -27,7 +28,7 @@ from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
is_cpu) identity, is_cpu)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -197,6 +198,8 @@ class HfRunner: ...@@ -197,6 +198,8 @@ class HfRunner:
is_embedding_model: bool = False, is_embedding_model: bool = False,
is_vision_model: bool = False, is_vision_model: bool = False,
is_encoder_decoder_model: bool = False, is_encoder_decoder_model: bool = False,
postprocess_inputs: Callable[[BatchEncoding],
BatchEncoding] = identity,
) -> None: ) -> None:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
...@@ -242,12 +245,14 @@ class HfRunner: ...@@ -242,12 +245,14 @@ class HfRunner:
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
trust_remote_code=True, trust_remote_code=True,
) )
except Exception: except Exception as exc:
logger.warning( logger.warning(
"Unable to auto-load processor from HuggingFace for " "Unable to auto-load HuggingFace processor for model (%s). "
"model %s. Using tokenizer instead.", model_name) "Using tokenizer instead. Reason: %s", model_name, exc)
self.processor = self.tokenizer self.processor = self.tokenizer
self.postprocess_inputs = postprocess_inputs
def generate( def generate(
self, self,
prompts: List[str], prompts: List[str],
...@@ -267,6 +272,7 @@ class HfRunner: ...@@ -267,6 +272,7 @@ class HfRunner:
processor_kwargs["images"] = images[i] processor_kwargs["images"] = images[i]
inputs = self.processor(**processor_kwargs) inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)
output_ids = self.model.generate( output_ids = self.model.generate(
**self.wrap_device(inputs), **self.wrap_device(inputs),
...@@ -336,6 +342,7 @@ class HfRunner: ...@@ -336,6 +342,7 @@ class HfRunner:
processor_kwargs["images"] = images[i] processor_kwargs["images"] = images[i]
inputs = self.processor(**processor_kwargs) inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)
output = self.model.generate( output = self.model.generate(
**self.wrap_device(inputs), **self.wrap_device(inputs),
...@@ -420,6 +427,7 @@ class HfRunner: ...@@ -420,6 +427,7 @@ class HfRunner:
processor_kwargs["images"] = images[i] processor_kwargs["images"] = images[i]
inputs = self.processor(**processor_kwargs) inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)
output = self.model.generate( output = self.model.generate(
**self.wrap_device(inputs), **self.wrap_device(inputs),
...@@ -552,7 +560,8 @@ class VllmRunner: ...@@ -552,7 +560,8 @@ class VllmRunner:
self, self,
prompts: List[str], prompts: List[str],
sampling_params: SamplingParams, sampling_params: SamplingParams,
images: Optional[List[Image.Image]] = None, images: Optional[Union[List[Image.Image],
List[List[Image.Image]]]] = None,
) -> List[Tuple[List[List[int]], List[str]]]: ) -> List[Tuple[List[List[int]], List[str]]]:
if images is not None: if images is not None:
assert len(prompts) == len(images) assert len(prompts) == len(images)
...@@ -587,7 +596,7 @@ class VllmRunner: ...@@ -587,7 +596,7 @@ class VllmRunner:
for req_output in req_outputs: for req_output in req_outputs:
for sample in req_output.outputs: for sample in req_output.outputs:
output_str = sample.text output_str = sample.text
output_ids = sample.token_ids output_ids = list(sample.token_ids)
output_logprobs = sample.logprobs output_logprobs = sample.logprobs
outputs.append((output_ids, output_str, output_logprobs)) outputs.append((output_ids, output_str, output_logprobs))
return outputs return outputs
...@@ -596,7 +605,8 @@ class VllmRunner: ...@@ -596,7 +605,8 @@ class VllmRunner:
self, self,
prompts: List[str], prompts: List[str],
sampling_params: SamplingParams, sampling_params: SamplingParams,
images: Optional[List[Image.Image]] = None, images: Optional[Union[List[Image.Image],
List[List[Image.Image]]]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
assert sampling_params.logprobs is not None assert sampling_params.logprobs is not None
......
...@@ -18,8 +18,10 @@ from ..utils import fork_new_process_for_each_test ...@@ -18,8 +18,10 @@ from ..utils import fork_new_process_for_each_test
@pytest.mark.parametrize("model, distributed_executor_backend", [ @pytest.mark.parametrize("model, distributed_executor_backend", [
("llava-hf/llava-1.5-7b-hf", "ray"), ("llava-hf/llava-1.5-7b-hf", "ray"),
("llava-hf/llava-v1.6-mistral-7b-hf", "ray"), ("llava-hf/llava-v1.6-mistral-7b-hf", "ray"),
("facebook/chameleon-7b", "ray"),
("llava-hf/llava-1.5-7b-hf", "mp"), ("llava-hf/llava-1.5-7b-hf", "mp"),
("llava-hf/llava-v1.6-mistral-7b-hf", "mp"), ("llava-hf/llava-v1.6-mistral-7b-hf", "mp"),
("facebook/chameleon-7b", "mp"),
]) ])
@fork_new_process_for_each_test @fork_new_process_for_each_test
def test_models(hf_runner, vllm_runner, image_assets, model: str, def test_models(hf_runner, vllm_runner, image_assets, model: str,
...@@ -34,6 +36,8 @@ def test_models(hf_runner, vllm_runner, image_assets, model: str, ...@@ -34,6 +36,8 @@ def test_models(hf_runner, vllm_runner, image_assets, model: str,
from ..models.test_llava import models, run_test from ..models.test_llava import models, run_test
elif model.startswith("llava-hf/llava-v1.6"): elif model.startswith("llava-hf/llava-v1.6"):
from ..models.test_llava_next import models, run_test from ..models.test_llava_next import models, run_test
elif model.startswith("facebook/chameleon"):
from ..models.test_chameleon import models, run_test
else: else:
raise NotImplementedError(f"Unsupported model: {model}") raise NotImplementedError(f"Unsupported model: {model}")
......
import sys import sys
import time import time
from typing import Optional
import torch import torch
from openai import OpenAI, OpenAIError from openai import OpenAI, OpenAIError
...@@ -17,8 +18,11 @@ assert chatml_jinja_path.exists() ...@@ -17,8 +18,11 @@ assert chatml_jinja_path.exists()
class MyOPTForCausalLM(OPTForCausalLM): class MyOPTForCausalLM(OPTForCausalLM):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
# this dummy model always predicts the first token # this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata) logits = super().compute_logits(hidden_states, sampling_metadata)
logits.zero_() logits.zero_()
......
import re
from typing import List, Optional, Type from typing import List, Optional, Type
import pytest import pytest
from transformers import BatchEncoding
from vllm.multimodal.utils import rescale_image_size from vllm.multimodal.utils import rescale_image_size
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from ..conftest import IMAGE_ASSETS, VllmRunner, _ImageAssets from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_outputs_equal
pytestmark = pytest.mark.vlm pytestmark = pytest.mark.vlm
...@@ -19,9 +21,8 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ ...@@ -19,9 +21,8 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
models = ["facebook/chameleon-7b"] models = ["facebook/chameleon-7b"]
#TODO (ywang96): Add correctness test when chameleon is
# available on transformers.
def run_test( def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner], vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets, image_assets: _ImageAssets,
model: str, model: str,
...@@ -29,13 +30,20 @@ def run_test( ...@@ -29,13 +30,20 @@ def run_test(
size_factors: List[float], size_factors: List[float],
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int, tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None, distributed_executor_backend: Optional[str] = None,
): ):
"""Test if the model can generate text given """Inference result should be the same between hf and vllm.
a batch of images and prompts.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
""" """
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
images = [asset.pil_image for asset in image_assets] images = [asset.pil_image for asset in image_assets]
inputs_per_image = [( inputs_per_image = [(
...@@ -50,35 +58,49 @@ def run_test( ...@@ -50,35 +58,49 @@ def run_test(
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model: enforce_eager=True) as vllm_model:
for prompts, images in inputs_per_image: vllm_outputs_per_image = [
vllm_outputs = vllm_model.generate_greedy(prompts, vllm_model.generate_greedy_logprobs(prompts,
max_tokens, max_tokens,
images=images) num_logprobs=num_logprobs,
for i in range(len(vllm_outputs)): images=images)
for prompts, images in inputs_per_image
# format prompt back to original ]
replacements = {
"<racm3:break>": "", def process(hf_inputs: BatchEncoding):
"<eoss>": "", hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \
"<reserved08706>": "" .to(torch_dtype) # type: ignore
} return hf_inputs
pattern = '|'.join(replacements.keys())
vllm_result = re.sub( with hf_runner(model,
pattern, dtype=dtype,
lambda match: replacements[match.group(0)], #noqa B023 postprocess_inputs=process,
vllm_outputs[i][1]) is_vision_model=True) as hf_model:
vllm_result = vllm_result.replace("<image>", "", 1023) hf_outputs_per_image = [
assert vllm_result[:len(prompts[i])] == prompts[i] hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
# assert at least 10 new characters are generated num_logprobs=num_logprobs,
# (to take stop token into account) images=images)
assert len(vllm_outputs[i][1]) - len(prompts[i]) > 10 for prompts, images in inputs_per_image
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
# HF Logprobs include image tokens, unlike vLLM, so we don't directly
# compare them
check_outputs_equal(
outputs_0_lst=[outputs[:2] for outputs in hf_outputs],
outputs_1_lst=[outputs[:2] for outputs in vllm_outputs],
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"size_factors", "size_factors",
[ [
# No image
[],
# Single-scale # Single-scale
[1.0], [1.0],
# Single-scale, batched # Single-scale, batched
...@@ -88,15 +110,18 @@ def run_test( ...@@ -88,15 +110,18 @@ def run_test(
], ],
) )
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [8])
def test_models(vllm_runner, image_assets, model, size_factors, dtype: str, @pytest.mark.parametrize("num_logprobs", [5])
max_tokens: int) -> None: def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
dtype, max_tokens, num_logprobs) -> None:
run_test( run_test(
hf_runner,
vllm_runner, vllm_runner,
image_assets, image_assets,
model, model,
size_factors=size_factors, size_factors=size_factors,
dtype=dtype, dtype=dtype,
max_tokens=max_tokens, max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1, tensor_parallel_size=1,
) )
from typing import List, Optional, Tuple, Type from typing import List, Optional, Tuple, Type
import pytest import pytest
from transformers import AutoConfig, AutoTokenizer from transformers import AutoConfig, AutoTokenizer, BatchEncoding
from vllm.multimodal.utils import rescale_image_size from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
...@@ -110,16 +110,21 @@ def run_test( ...@@ -110,16 +110,21 @@ def run_test(
for prompts, images in inputs_per_image for prompts, images in inputs_per_image
] ]
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model: if mantis_processor is not None:
if mantis_processor is not None:
def process(*args, **kwargs): def process(hf_inputs: BatchEncoding):
output = mantis_processor(*args, **kwargs) hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \
output["pixel_values"] = output["pixel_values"].to(torch_dtype) .to(torch_dtype) # type: ignore
return output return hf_inputs
else:
hf_model.processor = process def process(hf_inputs: BatchEncoding):
return hf_inputs
with hf_runner(model,
dtype=dtype,
postprocess_inputs=process,
is_vision_model=True) as hf_model:
hf_outputs_per_image = [ hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts, hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens, max_tokens,
......
from collections import UserDict
from typing import List, Optional, Tuple, Type from typing import List, Optional, Tuple, Type
import pytest import pytest
import torch import torch
import torch.types import torch.types
from transformers import BatchFeature from transformers import BatchEncoding
from vllm.multimodal.utils import rescale_image_size from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
...@@ -14,18 +13,6 @@ from .utils import check_logprobs_close ...@@ -14,18 +13,6 @@ from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm pytestmark = pytest.mark.vlm
class NestedInputs(UserDict):
def __init__(self, model_inputs: BatchFeature):
super().__init__({"model_inputs": model_inputs})
self.model_inputs = model_inputs
def to(self, device: torch.types.Device):
return NestedInputs(self.model_inputs.to(device))
# The image token is placed before "user" on purpose so that the test can pass # The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign": "stop_sign":
...@@ -41,6 +28,10 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ ...@@ -41,6 +28,10 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
models = ["openbmb/MiniCPM-Llama3-V-2_5"] models = ["openbmb/MiniCPM-Llama3-V-2_5"]
def _wrap_inputs(hf_inputs: BatchEncoding) -> BatchEncoding:
return BatchEncoding({"model_inputs": hf_inputs})
def trunc_hf_output(hf_output: Tuple[List[int], str, def trunc_hf_output(hf_output: Tuple[List[int], str,
Optional[SampleLogprobs]]): Optional[SampleLogprobs]]):
output_ids, output_str, out_logprobs = hf_output output_ids, output_str, out_logprobs = hf_output
...@@ -105,11 +96,8 @@ def run_test( ...@@ -105,11 +96,8 @@ def run_test(
for prompts, images in inputs_per_image for prompts, images in inputs_per_image
] ]
with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad(): hf_model = hf_runner(model, dtype=dtype, postprocess_inputs=_wrap_inputs)
hf_processor = hf_model.processor with hf_model, torch.no_grad():
hf_model.processor = lambda **kw: NestedInputs(
hf_processor(**kw) # type: ignore
)
hf_outputs_per_image = [ hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts, hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens, max_tokens,
...@@ -224,11 +212,8 @@ def run_multi_image_test( ...@@ -224,11 +212,8 @@ def run_multi_image_test(
for prompts, images in inputs_per_case for prompts, images in inputs_per_case
] ]
with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad(): hf_model = hf_runner(model, dtype=dtype, postprocess_inputs=_wrap_inputs)
hf_processor = hf_model.processor with hf_model, torch.no_grad():
hf_model.processor = lambda **kw: NestedInputs(
hf_processor(**kw) # type: ignore
)
hf_outputs_per_case = [ hf_outputs_per_case = [
hf_model.generate_greedy_logprobs_limit(prompts, hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens, max_tokens,
......
from typing import Optional
import torch import torch
from vllm import LLM, ModelRegistry, SamplingParams from vllm import LLM, ModelRegistry, SamplingParams
...@@ -7,8 +9,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -7,8 +9,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
class MyOPTForCausalLM(OPTForCausalLM): class MyOPTForCausalLM(OPTForCausalLM):
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
# this dummy model always predicts the first token # this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata) logits = super().compute_logits(hidden_states, sampling_metadata)
logits.zero_() logits.zero_()
......
...@@ -19,7 +19,7 @@ def tensor_model_parallel_all_gather(input_: torch.Tensor, ...@@ -19,7 +19,7 @@ def tensor_model_parallel_all_gather(input_: torch.Tensor,
def tensor_model_parallel_gather(input_: torch.Tensor, def tensor_model_parallel_gather(input_: torch.Tensor,
dst: int = 0, dst: int = 0,
dim: int = -1) -> torch.Tensor: dim: int = -1) -> Optional[torch.Tensor]:
"""Gather the input tensor across model parallel group.""" """Gather the input tensor across model parallel group."""
return get_tp_group().gather(input_, dst, dim) return get_tp_group().gather(input_, dst, dim)
......
...@@ -329,7 +329,7 @@ class GroupCoordinator: ...@@ -329,7 +329,7 @@ class GroupCoordinator:
def gather(self, def gather(self,
input_: torch.Tensor, input_: torch.Tensor,
dst: int = 0, dst: int = 0,
dim: int = -1) -> torch.Tensor: dim: int = -1) -> Optional[torch.Tensor]:
""" """
NOTE: We assume that the input tensor is on the same device across NOTE: We assume that the input tensor is on the same device across
all the ranks. all the ranks.
......
...@@ -50,7 +50,7 @@ class LogitsProcessor(nn.Module): ...@@ -50,7 +50,7 @@ class LogitsProcessor(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None, embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Optional[torch.Tensor]:
if self.logits_as_input: if self.logits_as_input:
logits = hidden_states logits = hidden_states
else: else:
...@@ -73,14 +73,18 @@ class LogitsProcessor(nn.Module): ...@@ -73,14 +73,18 @@ class LogitsProcessor(nn.Module):
return logits return logits
def _get_logits(self, hidden_states: torch.Tensor, def _get_logits(
lm_head: VocabParallelEmbedding, self,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens. # Get the logits for the next tokens.
logits = lm_head.linear_method.apply(lm_head, logits = lm_head.linear_method.apply(lm_head,
hidden_states, hidden_states,
bias=embedding_bias) bias=embedding_bias)
if self.use_gather: if self.use_gather:
# None may be returned for rank > 0
logits = tensor_model_parallel_gather(logits) logits = tensor_model_parallel_gather(logits)
else: else:
# Gather is not supported for some devices such as TPUs. # Gather is not supported for some devices such as TPUs.
......
...@@ -19,6 +19,7 @@ from tqdm.auto import tqdm ...@@ -19,6 +19,7 @@ from tqdm.auto import tqdm
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import LoadConfig, ModelConfig from vllm.config import LoadConfig, ModelConfig
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QuantizationConfig, from vllm.model_executor.layers.quantization import (QuantizationConfig,
get_quantization_config) get_quantization_config)
...@@ -514,8 +515,30 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: ...@@ -514,8 +515,30 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
def default_weight_loader(param: torch.Tensor, def default_weight_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None: loaded_weight: torch.Tensor) -> None:
"""Default weight loader.""" """Default weight loader."""
assert param.size() == loaded_weight.size() try:
param.data.copy_(loaded_weight) assert param.size() == loaded_weight.size(), (
f"Attempted to load weight ({loaded_weight.size()}) "
f"into parameter ({param.size()})")
param.data.copy_(loaded_weight)
except Exception:
# NOTE: This exception is added for the purpose of setting breakpoint to
# debug weight loading issues.
raise
def row_parallel_weight_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
"""Load weights that are row-parallelized."""
tp_rank = get_tensor_model_parallel_rank()
shard_dim = 0 if param.dim() != 1 else None
if shard_dim is not None:
shard_size = param.data.shape[shard_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size)
return default_weight_loader(param, loaded_weight)
def initialize_dummy_weights( def initialize_dummy_weights(
......
...@@ -433,8 +433,11 @@ class ArcticForCausalLM(nn.Module): ...@@ -433,8 +433,11 @@ class ArcticForCausalLM(nn.Module):
attn_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -346,8 +346,11 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA): ...@@ -346,8 +346,11 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
attn_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -872,8 +872,11 @@ class BartForConditionalGeneration(nn.Module): ...@@ -872,8 +872,11 @@ class BartForConditionalGeneration(nn.Module):
return self.model(input_ids, positions, encoder_input_ids, return self.model(input_ids, positions, encoder_input_ids,
encoder_positions, kv_caches, attn_metadata) encoder_positions, kv_caches, attn_metadata)
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -637,8 +637,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision): ...@@ -637,8 +637,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.get_lm_head(), hidden_states, logits = self.logits_processor(self.get_lm_head(), hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -292,8 +292,11 @@ class BloomForCausalLM(nn.Module): ...@@ -292,8 +292,11 @@ class BloomForCausalLM(nn.Module):
attn_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -25,8 +25,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -25,8 +25,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, row_parallel_weight_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import (cached_get_tokenizer, from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens) repeat_and_pad_image_tokens)
...@@ -141,6 +143,11 @@ class ChameleonLayerNorm(nn.LayerNorm): ...@@ -141,6 +143,11 @@ class ChameleonLayerNorm(nn.LayerNorm):
super().__init__(hidden_size, *args, **kwargs) super().__init__(hidden_size, *args, **kwargs)
self.normalized_shape = (hidden_size[-1], ) self.normalized_shape = (hidden_size[-1], )
set_weight_attrs(self.weight,
{"weight_loader": row_parallel_weight_loader})
set_weight_attrs(self.bias,
{"weight_loader": row_parallel_weight_loader})
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = F.layer_norm(hidden_states, hidden_states = F.layer_norm(hidden_states,
self.normalized_shape, self.normalized_shape,
...@@ -697,6 +704,8 @@ class ChameleonVQVAEEncoder(nn.Module): ...@@ -697,6 +704,8 @@ class ChameleonVQVAEEncoder(nn.Module):
) )
def forward(self, pixel_values: torch.Tensor): def forward(self, pixel_values: torch.Tensor):
pixel_values = pixel_values.to(self.conv_in.weight.dtype)
# downsampling # downsampling
hidden_states = [self.conv_in(pixel_values)] hidden_states = [self.conv_in(pixel_values)]
for i_level in range(self.num_resolutions): for i_level in range(self.num_resolutions):
...@@ -959,15 +968,19 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsVision): ...@@ -959,15 +968,19 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsVision):
attn_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
# Disallow image tokens which does not include special # Disallow image tokens which does not include special
# begin-image and end-image tokens # begin-image and end-image tokens
image_tokens = self.model.vocabulary_mapping.image_tokens if logits is not None:
logits[:, image_tokens] = torch.finfo(logits.dtype).min image_tokens = self.model.vocabulary_mapping.image_tokens
logits[:, image_tokens] = torch.finfo(logits.dtype).min
return logits return logits
......
...@@ -372,8 +372,11 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA): ...@@ -372,8 +372,11 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
attn_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -25,13 +25,11 @@ from typing import Iterable, List, Optional, Set, Tuple ...@@ -25,13 +25,11 @@ from typing import Iterable, List, Optional, Set, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn.parameter import Parameter
from transformers import CohereConfig from transformers import CohereConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import get_tensor_model_parallel_world_size
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -43,7 +41,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -43,7 +41,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, row_parallel_weight_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
...@@ -67,25 +66,14 @@ class LayerNorm(nn.Module): ...@@ -67,25 +66,14 @@ class LayerNorm(nn.Module):
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(param_shape)) self.weight = nn.Parameter(torch.ones(param_shape))
self.variance_epsilon = eps self.variance_epsilon = eps
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader}) set_weight_attrs(self.weight,
{"weight_loader": row_parallel_weight_loader})
def forward(self, hidden_states, residuals=None): def forward(self, hidden_states, residuals=None):
hidden_states = layer_norm_func(hidden_states, self.weight, hidden_states = layer_norm_func(hidden_states, self.weight,
self.variance_epsilon) self.variance_epsilon)
return hidden_states, residuals return hidden_states, residuals
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
shard_dim = 0 if param.dim() != 1 else None
param_data = param.data
if shard_dim is not None:
shard_size = param_data.shape[shard_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(shard_dim, start_idx,
shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere # Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class CohereMLP(nn.Module): class CohereMLP(nn.Module):
...@@ -359,8 +347,11 @@ class CohereForCausalLM(nn.Module): ...@@ -359,8 +347,11 @@ class CohereForCausalLM(nn.Module):
attn_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
is_not_lora = hasattr(self.model.embed_tokens, 'weight') is_not_lora = hasattr(self.model.embed_tokens, 'weight')
if is_not_lora: if is_not_lora:
logits = self.logits_processor(self.model.embed_tokens, logits = self.logits_processor(self.model.embed_tokens,
......
...@@ -388,8 +388,11 @@ class DbrxForCausalLM(nn.Module): ...@@ -388,8 +388,11 @@ class DbrxForCausalLM(nn.Module):
attn_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment