Unverified Commit 7025b11d authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Fix weight loading for Chameleon when TP>1 (#7410)

parent 5469146b
......@@ -4,7 +4,8 @@ import os
import sys
from collections import UserList
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union
from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict,
TypeVar, Union)
import pytest
import torch
......@@ -27,7 +28,7 @@ from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sequence import SampleLogprobs
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
is_cpu)
identity, is_cpu)
logger = init_logger(__name__)
......@@ -197,6 +198,8 @@ class HfRunner:
is_embedding_model: bool = False,
is_vision_model: bool = False,
is_encoder_decoder_model: bool = False,
postprocess_inputs: Callable[[BatchEncoding],
BatchEncoding] = identity,
) -> None:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
......@@ -242,12 +245,14 @@ class HfRunner:
torch_dtype=torch_dtype,
trust_remote_code=True,
)
except Exception:
except Exception as exc:
logger.warning(
"Unable to auto-load processor from HuggingFace for "
"model %s. Using tokenizer instead.", model_name)
"Unable to auto-load HuggingFace processor for model (%s). "
"Using tokenizer instead. Reason: %s", model_name, exc)
self.processor = self.tokenizer
self.postprocess_inputs = postprocess_inputs
def generate(
self,
prompts: List[str],
......@@ -267,6 +272,7 @@ class HfRunner:
processor_kwargs["images"] = images[i]
inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)
output_ids = self.model.generate(
**self.wrap_device(inputs),
......@@ -336,6 +342,7 @@ class HfRunner:
processor_kwargs["images"] = images[i]
inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)
output = self.model.generate(
**self.wrap_device(inputs),
......@@ -420,6 +427,7 @@ class HfRunner:
processor_kwargs["images"] = images[i]
inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)
output = self.model.generate(
**self.wrap_device(inputs),
......@@ -552,7 +560,8 @@ class VllmRunner:
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[List[Image.Image]] = None,
images: Optional[Union[List[Image.Image],
List[List[Image.Image]]]] = None,
) -> List[Tuple[List[List[int]], List[str]]]:
if images is not None:
assert len(prompts) == len(images)
......@@ -587,7 +596,7 @@ class VllmRunner:
for req_output in req_outputs:
for sample in req_output.outputs:
output_str = sample.text
output_ids = sample.token_ids
output_ids = list(sample.token_ids)
output_logprobs = sample.logprobs
outputs.append((output_ids, output_str, output_logprobs))
return outputs
......@@ -596,7 +605,8 @@ class VllmRunner:
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[List[Image.Image]] = None,
images: Optional[Union[List[Image.Image],
List[List[Image.Image]]]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
assert sampling_params.logprobs is not None
......
......@@ -18,8 +18,10 @@ from ..utils import fork_new_process_for_each_test
@pytest.mark.parametrize("model, distributed_executor_backend", [
("llava-hf/llava-1.5-7b-hf", "ray"),
("llava-hf/llava-v1.6-mistral-7b-hf", "ray"),
("facebook/chameleon-7b", "ray"),
("llava-hf/llava-1.5-7b-hf", "mp"),
("llava-hf/llava-v1.6-mistral-7b-hf", "mp"),
("facebook/chameleon-7b", "mp"),
])
@fork_new_process_for_each_test
def test_models(hf_runner, vllm_runner, image_assets, model: str,
......@@ -34,6 +36,8 @@ def test_models(hf_runner, vllm_runner, image_assets, model: str,
from ..models.test_llava import models, run_test
elif model.startswith("llava-hf/llava-v1.6"):
from ..models.test_llava_next import models, run_test
elif model.startswith("facebook/chameleon"):
from ..models.test_chameleon import models, run_test
else:
raise NotImplementedError(f"Unsupported model: {model}")
......
import sys
import time
from typing import Optional
import torch
from openai import OpenAI, OpenAIError
......@@ -17,8 +18,11 @@ assert chatml_jinja_path.exists()
class MyOPTForCausalLM(OPTForCausalLM):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
logits.zero_()
......
import re
from typing import List, Optional, Type
import pytest
from transformers import BatchEncoding
from vllm.multimodal.utils import rescale_image_size
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from ..conftest import IMAGE_ASSETS, VllmRunner, _ImageAssets
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_outputs_equal
pytestmark = pytest.mark.vlm
......@@ -19,9 +21,8 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
models = ["facebook/chameleon-7b"]
#TODO (ywang96): Add correctness test when chameleon is
# available on transformers.
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
......@@ -29,13 +30,20 @@ def run_test(
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Test if the model can generate text given
a batch of images and prompts.
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [(
......@@ -50,35 +58,49 @@ def run_test(
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
for prompts, images in inputs_per_image:
vllm_outputs = vllm_model.generate_greedy(prompts,
max_tokens,
images=images)
for i in range(len(vllm_outputs)):
# format prompt back to original
replacements = {
"<racm3:break>": "",
"<eoss>": "",
"<reserved08706>": ""
}
pattern = '|'.join(replacements.keys())
vllm_result = re.sub(
pattern,
lambda match: replacements[match.group(0)], #noqa B023
vllm_outputs[i][1])
vllm_result = vllm_result.replace("<image>", "", 1023)
assert vllm_result[:len(prompts[i])] == prompts[i]
# assert at least 10 new characters are generated
# (to take stop token into account)
assert len(vllm_outputs[i][1]) - len(prompts[i]) > 10
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]
def process(hf_inputs: BatchEncoding):
hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \
.to(torch_dtype) # type: ignore
return hf_inputs
with hf_runner(model,
dtype=dtype,
postprocess_inputs=process,
is_vision_model=True) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
# HF Logprobs include image tokens, unlike vLLM, so we don't directly
# compare them
check_outputs_equal(
outputs_0_lst=[outputs[:2] for outputs in hf_outputs],
outputs_1_lst=[outputs[:2] for outputs in vllm_outputs],
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
......@@ -88,15 +110,18 @@ def run_test(
],
)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(vllm_runner, image_assets, model, size_factors, dtype: str,
max_tokens: int) -> None:
@pytest.mark.parametrize("max_tokens", [8])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
dtype, max_tokens, num_logprobs) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
from typing import List, Optional, Tuple, Type
import pytest
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoConfig, AutoTokenizer, BatchEncoding
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
......@@ -110,16 +110,21 @@ def run_test(
for prompts, images in inputs_per_image
]
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
if mantis_processor is not None:
if mantis_processor is not None:
def process(*args, **kwargs):
output = mantis_processor(*args, **kwargs)
output["pixel_values"] = output["pixel_values"].to(torch_dtype)
return output
def process(hf_inputs: BatchEncoding):
hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \
.to(torch_dtype) # type: ignore
return hf_inputs
else:
hf_model.processor = process
def process(hf_inputs: BatchEncoding):
return hf_inputs
with hf_runner(model,
dtype=dtype,
postprocess_inputs=process,
is_vision_model=True) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
......
from collections import UserDict
from typing import List, Optional, Tuple, Type
import pytest
import torch
import torch.types
from transformers import BatchFeature
from transformers import BatchEncoding
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
......@@ -14,18 +13,6 @@ from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm
class NestedInputs(UserDict):
def __init__(self, model_inputs: BatchFeature):
super().__init__({"model_inputs": model_inputs})
self.model_inputs = model_inputs
def to(self, device: torch.types.Device):
return NestedInputs(self.model_inputs.to(device))
# The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
......@@ -41,6 +28,10 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
models = ["openbmb/MiniCPM-Llama3-V-2_5"]
def _wrap_inputs(hf_inputs: BatchEncoding) -> BatchEncoding:
return BatchEncoding({"model_inputs": hf_inputs})
def trunc_hf_output(hf_output: Tuple[List[int], str,
Optional[SampleLogprobs]]):
output_ids, output_str, out_logprobs = hf_output
......@@ -105,11 +96,8 @@ def run_test(
for prompts, images in inputs_per_image
]
with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad():
hf_processor = hf_model.processor
hf_model.processor = lambda **kw: NestedInputs(
hf_processor(**kw) # type: ignore
)
hf_model = hf_runner(model, dtype=dtype, postprocess_inputs=_wrap_inputs)
with hf_model, torch.no_grad():
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
......@@ -224,11 +212,8 @@ def run_multi_image_test(
for prompts, images in inputs_per_case
]
with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad():
hf_processor = hf_model.processor
hf_model.processor = lambda **kw: NestedInputs(
hf_processor(**kw) # type: ignore
)
hf_model = hf_runner(model, dtype=dtype, postprocess_inputs=_wrap_inputs)
with hf_model, torch.no_grad():
hf_outputs_per_case = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
......
from typing import Optional
import torch
from vllm import LLM, ModelRegistry, SamplingParams
......@@ -7,8 +9,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
class MyOPTForCausalLM(OPTForCausalLM):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
logits.zero_()
......
......@@ -19,7 +19,7 @@ def tensor_model_parallel_all_gather(input_: torch.Tensor,
def tensor_model_parallel_gather(input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> torch.Tensor:
dim: int = -1) -> Optional[torch.Tensor]:
"""Gather the input tensor across model parallel group."""
return get_tp_group().gather(input_, dst, dim)
......
......@@ -329,7 +329,7 @@ class GroupCoordinator:
def gather(self,
input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> torch.Tensor:
dim: int = -1) -> Optional[torch.Tensor]:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
......
......@@ -50,7 +50,7 @@ class LogitsProcessor(nn.Module):
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Optional[torch.Tensor]:
if self.logits_as_input:
logits = hidden_states
else:
......@@ -73,14 +73,18 @@ class LogitsProcessor(nn.Module):
return logits
def _get_logits(self, hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = lm_head.linear_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
if self.use_gather:
# None may be returned for rank > 0
logits = tensor_model_parallel_gather(logits)
else:
# Gather is not supported for some devices such as TPUs.
......
......@@ -19,6 +19,7 @@ from tqdm.auto import tqdm
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import LoadConfig, ModelConfig
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QuantizationConfig,
get_quantization_config)
......@@ -514,8 +515,30 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
def default_weight_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
"""Default weight loader."""
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)
try:
assert param.size() == loaded_weight.size(), (
f"Attempted to load weight ({loaded_weight.size()}) "
f"into parameter ({param.size()})")
param.data.copy_(loaded_weight)
except Exception:
# NOTE: This exception is added for the purpose of setting breakpoint to
# debug weight loading issues.
raise
def row_parallel_weight_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
"""Load weights that are row-parallelized."""
tp_rank = get_tensor_model_parallel_rank()
shard_dim = 0 if param.dim() != 1 else None
if shard_dim is not None:
shard_size = param.data.shape[shard_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size)
return default_weight_loader(param, loaded_weight)
def initialize_dummy_weights(
......
......@@ -433,8 +433,11 @@ class ArcticForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
......@@ -346,8 +346,11 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
......@@ -872,8 +872,11 @@ class BartForConditionalGeneration(nn.Module):
return self.model(input_ids, positions, encoder_input_ids,
encoder_positions, kv_caches, attn_metadata)
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
......@@ -637,8 +637,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.get_lm_head(), hidden_states,
sampling_metadata)
return logits
......
......@@ -292,8 +292,11 @@ class BloomForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
......@@ -25,8 +25,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, row_parallel_weight_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
......@@ -141,6 +143,11 @@ class ChameleonLayerNorm(nn.LayerNorm):
super().__init__(hidden_size, *args, **kwargs)
self.normalized_shape = (hidden_size[-1], )
set_weight_attrs(self.weight,
{"weight_loader": row_parallel_weight_loader})
set_weight_attrs(self.bias,
{"weight_loader": row_parallel_weight_loader})
def forward(self, hidden_states):
hidden_states = F.layer_norm(hidden_states,
self.normalized_shape,
......@@ -697,6 +704,8 @@ class ChameleonVQVAEEncoder(nn.Module):
)
def forward(self, pixel_values: torch.Tensor):
pixel_values = pixel_values.to(self.conv_in.weight.dtype)
# downsampling
hidden_states = [self.conv_in(pixel_values)]
for i_level in range(self.num_resolutions):
......@@ -959,15 +968,19 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsVision):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
# Disallow image tokens which does not include special
# begin-image and end-image tokens
image_tokens = self.model.vocabulary_mapping.image_tokens
logits[:, image_tokens] = torch.finfo(logits.dtype).min
if logits is not None:
image_tokens = self.model.vocabulary_mapping.image_tokens
logits[:, image_tokens] = torch.finfo(logits.dtype).min
return logits
......
......@@ -372,8 +372,11 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
......@@ -25,13 +25,11 @@ from typing import Iterable, List, Optional, Set, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn.parameter import Parameter
from transformers import CohereConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
......@@ -43,7 +41,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, row_parallel_weight_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors, SamplerOutput
......@@ -67,25 +66,14 @@ class LayerNorm(nn.Module):
super().__init__()
self.weight = nn.Parameter(torch.ones(param_shape))
self.variance_epsilon = eps
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
set_weight_attrs(self.weight,
{"weight_loader": row_parallel_weight_loader})
def forward(self, hidden_states, residuals=None):
hidden_states = layer_norm_func(hidden_states, self.weight,
self.variance_epsilon)
return hidden_states, residuals
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
shard_dim = 0 if param.dim() != 1 else None
param_data = param.data
if shard_dim is not None:
shard_size = param_data.shape[shard_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(shard_dim, start_idx,
shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class CohereMLP(nn.Module):
......@@ -359,8 +347,11 @@ class CohereForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
is_not_lora = hasattr(self.model.embed_tokens, 'weight')
if is_not_lora:
logits = self.logits_processor(self.model.embed_tokens,
......
......@@ -388,8 +388,11 @@ class DbrxForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment