Commit 99b471c2 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.1

parents 1925d2e9 468d761b
......@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm._C import ops
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
......@@ -68,15 +68,18 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
def __init__(self, quant_config: SqueezeLLMConfig):
self.quant_config = quant_config
def create_weights(self, input_size_per_partition: int,
output_size_per_partition: int, input_size: int,
output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
if input_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
qweight = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
......@@ -103,17 +106,18 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
set_weight_attrs(lookup_table, {
"output_dim": 0,
})
return {
"qweight": qweight,
"lookup_table": lookup_table,
}
layer.register_parameter("qweight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("lookup_table", lookup_table)
set_weight_attrs(lookup_table, extra_weight_attrs)
def apply_weights(self,
weights: Dict[str, Any],
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["qweight"]
lookup_table = weights["lookup_table"]
qweight = layer.qweight
lookup_table = layer.lookup_table
out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
if is_hip():
......@@ -126,5 +130,5 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
if bias is not None:
out = out + bias
out.add_(bias)
return out.reshape(out_shape)
......@@ -144,6 +144,7 @@ class RejectionSampler(nn.Module):
recovered_probs = self._get_recovered_probs(
target_probs, draft_probs).reshape(batch_size * k, vocab_size)
# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids = _multinomial(recovered_probs,
num_samples=1).reshape(
batch_size, k)
......@@ -307,6 +308,12 @@ class RejectionSampler(nn.Module):
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
bonus_token_ids, -1)
# We disable bonus tokens because it causes corrupt KV cache for
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
output_with_bonus_tokens[:, -1] = -1
# Fill the recovered token ids.
output.mul_(~after_false_mask).add_(
recovered_token_ids.mul(after_false_mask))
......
......@@ -27,7 +27,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from vllm._C import ops
from vllm import _custom_ops as ops
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
......@@ -108,7 +108,8 @@ class RotaryEmbedding(nn.Module):
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
positions.device)
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
if offsets is not None else positions]
cos, sin = cos_sin.chunk(2, dim=-1)
......@@ -247,11 +248,12 @@ def _yarn_find_correction_dim(num_rotations: int,
# Find dim range bounds based on rotations
def _yarn_find_correction_range(low_rot: int,
high_rot: int,
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048) -> int:
def _yarn_find_correction_range(
low_rot: int,
high_rot: int,
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048) -> Tuple[int, int]:
low = math.floor(
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(
......@@ -293,8 +295,8 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: float = 32,
beta_slow: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
......
......@@ -27,8 +27,22 @@ class Sampler(nn.Module):
6. Sample the next tokens.
Here, each sequence group within the batch can have different sampling
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
The structure of the logits tensor is coupled with the seq_groups in
sampling_metadata. Typically, each sequence in each seq_group has one row in
logits for the next token to be sampled; however, for a seq_group with a
prompt request with the prompt_logprobs sampling parameter, there are rows
in logits for each token in the input prompt.
"""
def __init__(self):
super().__init__()
# Whether or not the SamplerOutput should have on-device tensors
# containing the sampled token ids and probabilities. This is used by
# speculative decoding.
self.include_gpu_probs_tensor = False
def forward(
self,
logits: torch.Tensor,
......@@ -73,13 +87,45 @@ class Sampler(nn.Module):
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens.
sample_results = _sample(probs, logprobs, sampling_metadata,
sampling_tensors)
sample_results, maybe_sampled_tokens_tensor = _sample(
probs,
logprobs,
sampling_metadata,
sampling_tensors,
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
)
if self.include_gpu_probs_tensor:
assert maybe_sampled_tokens_tensor is not None
sampled_tokens_tensor = maybe_sampled_tokens_tensor
on_device_tensors = (probs, sampled_tokens_tensor)
else:
on_device_tensors = None
# Get the logprobs query results.
prompt_logprobs, sample_logprobs = _get_logprobs(
logprobs, sampling_metadata, sample_results)
return _build_sampler_output(sample_results, sampling_metadata,
prompt_logprobs, sample_logprobs)
return _build_sampler_output(sample_results,
sampling_metadata,
prompt_logprobs,
sample_logprobs,
on_device_tensors=on_device_tensors)
@property
def _should_modify_greedy_probs_inplace(self) -> bool:
"""Whether or not the sampler should modify the probability distribution
of greedily-sampled tokens such that multinomial sampling would sample
the greedily-sampled token.
In other words, if True then we set the probability of the greedily-
sampled token to 1.
This is used by speculative decoding, which requires that the sampling
method be encoded into the probability distribution.
"""
# Modify greedy probs if include_gpu_probs_tensor is set.
return self.include_gpu_probs_tensor
def _get_bin_counts_and_mask(
......@@ -106,7 +152,16 @@ def _apply_min_tokens_penalty(
# list of indices in logits that will be set to -inf
logits_to_penalize = []
start_idx = 0
for seq_ids, sampling_params in sampling_metadata.seq_groups:
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
# handle prompt_logprobs by skipping rows in logits added for the prompt
# tokens (prompt logprobs are not penalized)
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
assert len(seq_ids) == 1
start_idx += sampling_metadata.prompt_lens[i] - 1
min_tokens = sampling_params.min_tokens
if min_tokens > 0:
seqs_to_penalize = []
......@@ -132,6 +187,8 @@ def _apply_min_tokens_penalty(
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
logits[tuple(zip(*logits_to_penalize))] = -float("inf")
# verifies that no rows in logits were missed unexpectedly
assert start_idx == logits.shape[0]
return logits
......@@ -342,7 +399,9 @@ def _sample_with_torch(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> List[Tuple[List[int], List[int]]]:
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
categorized_seq_group_ids = {t: [] for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
......@@ -354,6 +413,15 @@ def _sample_with_torch(
sample_metadata = {}
multinomial_samples = {}
# Create output tensor for sampled token ids.
if include_gpu_probs_tensor:
sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
1,
dtype=torch.long,
device=logprobs.device)
else:
sampled_token_ids_tensor = None
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
for sampling_type in SamplingType:
......@@ -366,9 +434,25 @@ def _sample_with_torch(
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
is_prompts, sample_indices)
long_sample_indices = sample_indices.long()
if sampling_type == SamplingType.GREEDY:
greedy_samples = torch.argmax(logprobs[sample_indices.long()],
greedy_samples = torch.argmax(logprobs[long_sample_indices],
dim=-1)
if include_gpu_probs_tensor:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor[
long_sample_indices] = greedy_samples.unsqueeze(-1)
if modify_greedy_probs:
# If required, modify the probabilities such that sampling from
# the modified distribution would always sample the argmax
# token id.
_modify_greedy_probs_inplace(logprobs, probs,
long_sample_indices,
greedy_samples)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
max_best_of_in_batch = 1
for seq_group, is_prompt in zip(seq_groups, is_prompts):
......@@ -380,15 +464,23 @@ def _sample_with_torch(
"seq_groups": seq_groups,
"generators": sampling_metadata.generators,
}
multinomial_samples[sampling_type] = _multinomial(
probs[sample_indices.long()], max_best_of_in_batch,
probs[long_sample_indices], max_best_of_in_batch,
**seeded_args)
if include_gpu_probs_tensor:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor[
long_sample_indices] = multinomial_samples[sampling_type]
elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
else:
raise ValueError(f"Unsupported sampling type: {sampling_type}")
# GPU<->CPU sync happens in the loop below.
# This also converts the sample output to Python objects.
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
......@@ -410,7 +502,7 @@ def _sample_with_torch(
sample_results_dict[i]
for i in range(len(sampling_metadata.seq_groups))
]
return sample_results
return sample_results, sampled_token_ids_tensor
def _sample_with_triton_kernel(
......@@ -494,12 +586,17 @@ def _sample_with_triton_kernel(
def _sample(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
) -> List[Tuple[List[int], List[int]]]:
return _sample_with_torch(probs, logprobs, sampling_metadata)
probs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
return _sample_with_torch(
probs,
logprobs,
sampling_metadata,
include_gpu_probs_tensor=include_gpu_probs_tensor,
modify_greedy_probs=modify_greedy_probs,
)
# TODO: Enable once Triton kernel & associated code is faster.
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
......@@ -663,12 +760,73 @@ def _get_logprobs(
return result_prompt_logprobs, result_sample_logprobs
def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
sample_indices: torch.Tensor,
greedy_samples: torch.Tensor) -> None:
"""Modify the probability distributions of the greedily-sampled tokens such
that each sampled token has a "probability" of 1.0. This is required by
speculative decoding, which depends on the sampling method being encoded
within the probability distribution for correctness.
# Why do we only need to do this for greedy sampling?
vLLM's sampler performs the following steps for greedy or multinomial
(random) sampling:
1. Get logits from model.
2. Modify logits according to per-sequence sampling parameters.
- Multiply by temperature, top-k and top-p masking, penalize tokens
according to their frequency, etc.
3. Sample a token.
- Random sampling simply samples from the modified probability
distribution.
- Greedy sampling performs `argmax` to obtain the token with the
highest likelihood.
Ignoring greedy sampling for a moment, we find that the computed probability
distribution has the following property: we can sample from it independently
and find that the token sampled by the Sampler has a frequency corresponding
to how often we see it in our sampling. In other words, for tokens sampled
with vLLM's random SamplingType, the computed probability distribution
encodes the sampling methodology completely.
Greedy sampling does not normally have this property. vLLM modifies logits
according to sampling params, then performs `argmax`, then returns the
sampled token and the computed probability distribution. If we sample from
the distribution, we'll find the likelihood of the greedily-sampled token
is not always 1.0.
Since lossless speculative decoding requires that the sampling methodology
be encoded within the probability distribution, we are motivated to modify
the probability distribution such that the sampled token has probability 1
when speculative decoding is used.
NOTE: Alternatively, we could use an extremely low temperature to achieve
greedy sampling using multinomial computation and unite the codepaths. This
has implications on the overall design of the sampler, e.g. how to record
accurate logprobs for the user, so this improvement is deferred to later.
"""
logprobs[sample_indices, :] = -float('inf')
logprobs[sample_indices, greedy_samples] = 0.0
probs[sample_indices, :] = 0
probs[sample_indices, greedy_samples] = 1.0
def _build_sampler_output(
sample_results: List[Tuple[List[int], List[int]]],
sampling_metadata: SamplingMetadata,
prompt_logprobs: List[Optional[PromptLogprobs]],
sample_logprobs: List[SampleLogprobs],
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor]],
) -> SamplerOutput:
"""Construct Python objects with the output of sampling.
Args:
on_device_tensors: Tuple containing on-device tensors with the
probabilities used in sampling and the sampled token ids. This
allows post-processing without copies to CPU/serialization, e.g. in
speculative decoding rejection sampling.
"""
sampler_output = []
for (seq_group, sample_result, group_prompt_logprobs,
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
......@@ -684,4 +842,15 @@ def _build_sampler_output(
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
sampler_output.append(
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
return SamplerOutput(outputs=sampler_output)
# If not specified, store None values in SamplerOutput.
if on_device_tensors is not None:
sampled_token_probs, sampled_token_ids = on_device_tensors
else:
sampled_token_probs, sampled_token_ids = (None, None)
return SamplerOutput(
outputs=sampler_output,
sampled_token_probs=sampled_token_probs,
sampled_token_ids=sampled_token_ids,
)
......@@ -4,11 +4,9 @@ import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.utils import divide
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.utils import set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE = 64
......
"""Utilities for selecting and loading models."""
import contextlib
from typing import Tuple, Type
import torch
import torch.nn as nn
from vllm.config import DeviceConfig, ModelConfig
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.llava import LlavaForConditionalGeneration
from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights)
import os
_VISION_MODEL_CLASSES = [
LlavaForConditionalGeneration,
]
@contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(old_dtype)
def _get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
if architectures == ['LlamaForCausalLM']:
os.environ['LLAMA_NN'] = '1'
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]
for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def get_architecture_class_name(model_config: ModelConfig) -> str:
return _get_model_architecture(model_config)[1]
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
**kwargs) -> nn.Module:
lora_config = kwargs.get("lora_config", None)
vision_language_config = kwargs.get("vision_language_config", None)
model_class = _get_model_architecture(model_config)[0]
# Get the (maybe quantized) linear method.
linear_method = None
if model_config.quantization is not None:
quant_config = get_quant_config(model_config)
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < quant_config.get_min_capability():
raise ValueError(
f"The quantization method {model_config.quantization} is not "
"supported for the current GPU. "
f"Minimum capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}.")
supported_dtypes = quant_config.get_supported_act_dtypes()
if model_config.dtype not in supported_dtypes:
raise ValueError(
f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}")
linear_method = quant_config.get_linear_method()
if linear_method != None:
os.environ['LLAMA_NN'] = '0'
with _set_default_torch_dtype(model_config.dtype):
# Create a model instance.
# The weights will be initialized as empty tensors.
with torch.device(device_config.device):
if hasattr(model_class, "supported_lora_modules"):
model = model_class(model_config.hf_config, linear_method,
lora_config)
elif lora_config:
raise ValueError(
f"Model {model_class.__name__} does not support LoRA, "
"but LoRA is enabled. Support for this model may "
"be added in the future. If this is important to you, "
"please open an issue on github.")
else:
if model_class not in _VISION_MODEL_CLASSES:
model = model_class(model_config.hf_config, linear_method)
else:
model = model_class(model_config.hf_config,
vision_language_config, linear_method)
if model_config.load_format == "dummy":
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
else:
# Load the weights from the cached or downloaded files.
model.load_weights(model_config.model, model_config.download_dir,
model_config.load_format, model_config.revision)
return model.eval()
from typing import Optional
from torch import nn
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.model_executor.model_loader.loader import (BaseModelLoader,
get_model_loader)
from vllm.model_executor.model_loader.utils import (
get_architecture_class_name, get_model_architecture)
def get_model(
*, model_config: ModelConfig, load_config: LoadConfig,
device_config: DeviceConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
loader = get_model_loader(load_config)
return loader.load_model(model_config=model_config,
device_config=device_config,
lora_config=lora_config,
vision_language_config=vision_language_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config)
__all__ = [
"get_model", "get_model_loader", "BaseModelLoader",
"get_architecture_class_name", "get_model_architecture"
]
# ruff: noqa: SIM117
import copy
import glob
import os
from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple,
Type)
import torch
from torch import nn
from vllm.config import (VLLM_USE_MODELSCOPE, DeviceConfig, LoadConfig,
LoadFormat, LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VisionLanguageConfig)
from vllm.logger import init_logger
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_serialized_tensorizer, load_with_tensorizer,
tensorizer_weights_iterator)
from vllm.model_executor.model_loader.utils import (get_model_architecture,
set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf, filter_files_not_needed_for_inference,
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.llava import LlavaForConditionalGeneration
if TYPE_CHECKING:
from vllm.model_executor.layers.linear import LinearMethodBase
_VISION_MODEL_CLASSES = [
LlavaForConditionalGeneration,
]
logger = init_logger(__name__)
def _get_linear_method(
model_config: ModelConfig,
load_config: LoadConfig) -> Optional["LinearMethodBase"]:
"""Get the (maybe quantized) linear method."""
linear_method = None
if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config)
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < quant_config.get_min_capability():
raise ValueError(
f"The quantization method {model_config.quantization} is not "
"supported for the current GPU. "
f"Minimum capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}.")
supported_dtypes = quant_config.get_supported_act_dtypes()
if model_config.dtype not in supported_dtypes:
raise ValueError(
f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}")
linear_method = quant_config.get_linear_method()
if linear_method != None:
os.environ['LLAMA_NN'] = '0'
return linear_method
def _get_model_initialization_kwargs(
model_class: Type[nn.Module], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig]
) -> Dict[str, Any]:
"""Get extra kwargs for model initialization."""
extra_kwargs = {}
if hasattr(model_class, "supported_lora_modules"):
extra_kwargs["lora_config"] = lora_config
elif lora_config:
raise ValueError(
f"Model {model_class.__name__} does not support LoRA, "
"but LoRA is enabled. Support for this model may "
"be added in the future. If this is important to you, "
"please open an issue on github.")
elif model_class in _VISION_MODEL_CLASSES:
extra_kwargs["vision_language_config"] = vision_language_config
return extra_kwargs
def _initialize_model(
model_config: ModelConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
"""Initialize a model with the given configurations."""
model_class = get_model_architecture(model_config)[0]
linear_method = _get_linear_method(model_config, load_config)
return model_class(config=model_config.hf_config,
linear_method=linear_method,
**_get_model_initialization_kwargs(
model_class, lora_config, vision_language_config))
class BaseModelLoader(ABC):
"""Base class for model loaders."""
def __init__(self, load_config: LoadConfig):
self.load_config = load_config
@abstractmethod
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
"""Load a model with the given configurations."""
...
class DefaultModelLoader(BaseModelLoader):
"""Model loader that can load different file types from disk."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def _maybe_download_from_modelscope(
self, model: str, revision: Optional[str]) -> Optional[str]:
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
Returns the path to the downloaded model, or None if the model is not
downloaded from ModelScope."""
if VLLM_USE_MODELSCOPE:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
# pylint: disable=C.
from modelscope.hub.snapshot_download import snapshot_download
if not os.path.exists(model):
model_path = snapshot_download(
model_id=model,
cache_dir=self.load_config.download_dir,
revision=revision)
else:
model_path = model
return model_path
return None
def _prepare_weights(self, model_name_or_path: str,
revision: Optional[str],
fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
model_name_or_path = self._maybe_download_from_modelscope(
model_name_or_path, revision) or model_name_or_path
is_local = os.path.isdir(model_name_or_path)
load_format = self.load_config.load_format
use_safetensors = False
# Some quantized models use .pt files for storing the weights.
if load_format == LoadFormat.AUTO:
allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == LoadFormat.SAFETENSORS:
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == LoadFormat.PT:
allow_patterns = ["*.pt"]
elif load_format == LoadFormat.NPCACHE:
allow_patterns = ["*.bin"]
else:
raise ValueError(f"Unknown load_format: {load_format}")
if fall_back_to_pt:
allow_patterns += ["*.pt"]
if not is_local:
hf_folder = download_weights_from_hf(model_name_or_path,
self.load_config.download_dir,
allow_patterns, revision)
else:
hf_folder = model_name_or_path
hf_weights_files: List[str] = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if len(hf_weights_files) > 0:
if pattern == "*.safetensors":
use_safetensors = True
break
if not use_safetensors:
hf_weights_files = filter_files_not_needed_for_inference(
hf_weights_files)
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`")
return hf_folder, hf_weights_files, use_safetensors
def _get_weights_iterator(
self, model_name_or_path: str, revision: Optional[str],
fall_back_to_pt: bool
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
model_name_or_path, revision, fall_back_to_pt)
if self.load_config.load_format == LoadFormat.NPCACHE:
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
return np_cache_weights_iterator(model_name_or_path,
self.load_config.download_dir,
hf_folder, hf_weights_files)
if use_safetensors:
return safetensors_weights_iterator(hf_weights_files)
return pt_weights_iterator(hf_weights_files)
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config)
model.load_weights(
self._get_weights_iterator(model_config.model,
model_config.revision,
fall_back_to_pt=getattr(
model,
"fall_back_to_pt_during_load",
True)), )
for _, module in model.named_modules():
linear_method = getattr(module, "linear_method", None)
if linear_method is not None:
linear_method.process_weights_after_loading(module)
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
return model.eval()
class DummyModelLoader(BaseModelLoader):
"""Model loader that will set model weights to random values."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
return model.eval()
class TensorizerLoader(BaseModelLoader):
"""Model loader using CoreWeave's tensorizer library."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
self.tensorizer_config = load_config.model_loader_extra_config
else:
self.tensorizer_config = TensorizerConfig(
**load_config.model_loader_extra_config)
def _verify_config(self, model_config: ModelConfig,
parallel_config: ParallelConfig):
self.tensorizer_config.verify_with_model_config(model_config)
self.tensorizer_config.verify_with_parallel_config(parallel_config)
def _get_weights_iterator(
self) -> Generator[Tuple[str, torch.Tensor], None, None]:
tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
return tensorizer_weights_iterator(tensorizer_args)
def _load_model_unserialized(
self, model_config: ModelConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig]
) -> nn.Module:
"""Load an unserialized model with tensorizer.
Unserialized here means "not serialized with tensorizer". This
should still be faster than default HuggingFace loading, but will
be slower than loading a tensorizer-serialized model.
"""
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config)
model.load_weights(self._get_weights_iterator())
return model.eval()
def _load_model_serialized(
self, model_config: ModelConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig]
) -> nn.Module:
"""Load a serialized model with tensorizer.
See the examples/tensorize_vllm_model.py example "
script for serializing vLLM models."""
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model_class = get_model_architecture(model_config)[0]
linear_method = _get_linear_method(model_config,
self.load_config)
extra_kwargs = _get_model_initialization_kwargs(
model_class, lora_config, vision_language_config)
extra_kwargs["linear_method"] = linear_method
tensorizer_config = copy.copy(self.tensorizer_config)
tensorizer_config.model_class = model_class
tensorizer_config.hf_config = model_config.hf_config
tensorizer_config.dtype = model_config.dtype
model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
return model.eval()
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
self._verify_config(model_config, parallel_config)
if is_vllm_serialized_tensorizer(self.tensorizer_config):
return self._load_model_serialized(model_config, device_config,
lora_config,
vision_language_config)
return self._load_model_unserialized(model_config, device_config,
lora_config,
vision_language_config)
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""
if isinstance(load_config.load_format, type):
return load_config.load_format(load_config)
if load_config.load_format == LoadFormat.DUMMY:
return DummyModelLoader(load_config)
if load_config.load_format == LoadFormat.TENSORIZER:
return TensorizerLoader(load_config)
return DefaultModelLoader(load_config)
"""Utilities for selecting and loading neuron models."""
import importlib
import os
from typing import Optional, Type
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
......@@ -27,7 +27,7 @@ TORCH_DTYPE_TO_NEURON_AMP = {
}
# Models supported by Neuron.
_NEURON_SUPPORTED_MODELS = {
_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
"LlamaForCausalLM": ("transformers_neuronx.llama.model",
"LlamaForSampling", "LlamaForCausalLM"),
"MistralForCausalLM": ("transformers_neuronx.mistral.model",
......@@ -43,11 +43,13 @@ class NeuronCasualLM(nn.Module):
) -> None:
super().__init__()
self.config = config
self.model = None
self.logits_processor = LogitsProcessor(config.vocab_size,
logits_as_input=True)
self.sampler = Sampler()
# Lazy initialized
self.model: nn.Module
def forward(
self,
input_ids: torch.Tensor,
......@@ -74,17 +76,17 @@ class NeuronCasualLM(nn.Module):
def load_weights(self, model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config)
neuronx_module_path, neuronx_model_cls, hf_model_cls = (
neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = (
_NEURON_SUPPORTED_MODELS[arch])
neuronx_module = importlib.import_module(neuronx_module_path)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
split_model_dir = f"{model_name_or_path}-split"
if os.path.isdir(os.path.join(model_name_or_path,
"pytorch_model.bin")):
split_model_dir = model_name_or_path
elif not os.path.exists(f"{model_name_or_path}-split"):
hf_model_cls = getattr(transformers, hf_model_cls)
hf_model_cls = getattr(transformers, hf_model_cls_name)
from transformers_neuronx.module import save_pretrained_split
hf_model = hf_model_cls.from_pretrained(model_name_or_path,
......@@ -96,7 +98,7 @@ class NeuronCasualLM(nn.Module):
self.model.to_neuron()
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
def _get_model_architecture(config: PretrainedConfig) -> str:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _NEURON_SUPPORTED_MODELS:
......
import argparse
import dataclasses
import io
import os
import time
import typing
from dataclasses import dataclass
from typing import Generator, Optional, Tuple, Type, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.config import ModelConfig, ParallelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
tensorizer_load_fail = None
try:
from tensorizer import (DecryptionParams, EncryptionParams,
TensorDeserializer, TensorSerializer)
from tensorizer.stream_io import open_stream
from tensorizer.utils import (convert_bytes, get_mem_usage,
no_init_or_tensor)
except ImportError as e:
tensorizer_load_fail = e
__all__ = [
'EncryptionParams', 'DecryptionParams', 'TensorDeserializer',
'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage',
'no_init_or_tensor', 'TensorizerConfig'
]
logger = init_logger(__name__)
@dataclass
class TensorizerConfig:
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
str, bytes, os.PathLike, int]
vllm_tensorized: bool
verify_hash: Optional[bool] = False
num_readers: Optional[int] = 1
encryption_keyfile: Optional[str] = None
s3_access_key_id: Optional[str] = None
s3_secret_access_key: Optional[str] = None
s3_endpoint: Optional[str] = None
model_class: Optional[Type[torch.nn.Module]] = None
hf_config: Optional[PretrainedConfig] = None
dtype: Optional[Union[str, torch.dtype]] = None
def _construct_tensorizer_args(self) -> "TensorizerArgs":
tensorizer_args = {
"tensorizer_uri": self.tensorizer_uri,
"vllm_tensorized": self.vllm_tensorized,
"verify_hash": self.verify_hash,
"num_readers": self.num_readers,
"encryption_keyfile": self.encryption_keyfile,
"s3_access_key_id": self.s3_access_key_id,
"s3_secret_access_key": self.s3_secret_access_key,
"s3_endpoint": self.s3_endpoint,
}
return TensorizerArgs(**tensorizer_args)
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
) -> None:
if (parallel_config.tensor_parallel_size > 1
and self.tensorizer_uri is not None):
raise ValueError(
"Loading to multiple GPUs is not currently supported with "
"vLLM-serialized models. Please set tensor_parallel_size=1."
" or use a non-vLLM-serialized model, such as a "
"serialized Hugging Face `PretrainedModel`.")
def verify_with_model_config(self, model_config: "ModelConfig") -> None:
if (model_config.quantization is not None
and self.tensorizer_uri is not None):
logger.warning(
"Loading a model using Tensorizer with quantization on vLLM"
" is unstable and may lead to errors.")
def load_with_tensorizer(tensorizer_config: TensorizerConfig,
**extra_kwargs) -> nn.Module:
tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs)
return tensorizer.deserialize()
def is_vllm_serialized_tensorizer(tensorizer_config: TensorizerConfig) -> bool:
if tensorizer_config is None:
return False
return tensorizer_config.vllm_tensorized
@dataclass
class TensorizerArgs:
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
str, bytes, os.PathLike, int]
vllm_tensorized: bool
verify_hash: Optional[bool] = False
num_readers: Optional[int] = 1
encryption_keyfile: Optional[str] = None
s3_access_key_id: Optional[str] = None
s3_secret_access_key: Optional[str] = None
s3_endpoint: Optional[str] = None
"""
Args for the TensorizerAgent class. These are used to configure the behavior
of the TensorDeserializer when loading tensors from a serialized model.
Args:
tensorizer_uri: Path to serialized model tensors. Can be a local file
path or a S3 URI.
vllm_tensorized: If True, indicates that the serialized model is a
vLLM model. This is used to determine the behavior of the
TensorDeserializer when loading tensors from a serialized model.
It is far faster to deserialize a vLLM model as it utilizes
tensorizer's optimized GPU loading.
verify_hash: If True, the hashes of each tensor will be verified against
the hashes stored in the metadata. A `HashMismatchError` will be
raised if any of the hashes do not match.
num_readers: Controls how many threads are allowed to read concurrently
from the source file. Default is 1. This greatly increases
performance.
encryption_keyfile: File path to a binary file containing a
binary key to use for decryption. `None` (the default) means
no decryption. See the example script in
examples/tensorize_vllm_model.py.
s3_access_key_id: The access key for the S3 bucket. Can also be set via
the S3_ACCESS_KEY_ID environment variable.
s3_secret_access_key: The secret access key for the S3 bucket. Can also
be set via the S3_SECRET_ACCESS_KEY environment variable.
s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
S3_ENDPOINT_URL environment variable.
"""
def __post_init__(self):
self.file_obj = self.tensorizer_uri
self.s3_access_key_id = (self.s3_access_key_id
or os.environ.get("S3_ACCESS_KEY_ID")) or None
self.s3_secret_access_key = (
self.s3_secret_access_key
or os.environ.get("S3_SECRET_ACCESS_KEY")) or None
self.s3_endpoint = (self.s3_endpoint
or os.environ.get("S3_ENDPOINT_URL")) or None
self.stream_params = {
"s3_access_key_id": self.s3_access_key_id,
"s3_secret_access_key": self.s3_secret_access_key,
"s3_endpoint": self.s3_endpoint,
}
self.deserializer_params = {
"verify_hash": self.verify_hash,
"encryption": self.encryption_keyfile,
"num_readers": self.num_readers
}
if self.encryption_keyfile:
with open_stream(
self.encryption_keyfile,
**self.stream_params,
) as stream:
key = stream.read()
decryption_params = DecryptionParams.from_key(key)
self.deserializer_params['encryption'] = decryption_params
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Tensorizer CLI arguments"""
# Tensorizer options arg group
group = parser.add_argument_group(
'tensorizer options',
description=('Options for configuring the behavior of the'
' tensorizer deserializer when '
'--load-format=tensorizer'))
group.add_argument(
"--tensorizer-uri",
help="Path to serialized model tensors. Can be a local file path,"
" or an HTTP(S) or S3 URI.",
)
group.add_argument(
"--verify-hash",
action="store_true",
help="If enabled, the hashes of each tensor will be verified"
" against the hashes stored in the file metadata. An exception"
" will be raised if any of the hashes do not match.",
)
group.add_argument(
"--encryption-keyfile",
default=None,
help="The file path to a binary file containing a binary key to "
"use for decryption. Can be a file path or S3 network URI.")
group.add_argument(
"--num-readers",
default=1,
type=int,
help="Controls how many threads are allowed to read concurrently "
"from the source file.")
group.add_argument(
"--s3-access-key-id",
default=None,
help="The access key for the S3 bucket. Can also be set via the "
"S3_ACCESS_KEY_ID environment variable.",
)
group.add_argument(
"--s3-secret-access-key",
default=None,
help="The secret access key for the S3 bucket. Can also be set via "
"the S3_SECRET_ACCESS_KEY environment variable.",
)
group.add_argument(
"--s3-endpoint",
default=None,
help="The endpoint for the S3 bucket. Can also be set via the "
"S3_ENDPOINT_URL environment variable.",
)
group.add_argument(
"--vllm-tensorized",
action="store_true",
help="If enabled, indicates that the serialized model is a vLLM "
"model. This is used to determine the behavior of the "
"TensorDeserializer when loading tensors from a "
"serialized model.")
return parser
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs":
attrs = [attr.name for attr in dataclasses.fields(cls)]
tensorizer_args = cls(**{
attr: getattr(args, attr)
for attr in attrs if hasattr(args, attr)
})
return tensorizer_args
class TensorizerAgent:
"""
A class for performing tensorizer deserializations specifically for
vLLM models using plaid_mode. Uses TensorizerArgs to configure the
behavior of the TensorDeserializer when loading tensors from a serialized
model. For deserializations of HuggingFace models, TensorDeserializer is
instead used as an iterator directly in the func hf_model_weights_iterator
in vllm/model_executor/model_loader/weight_utils.py
"""
def __init__(self, tensorizer_config: TensorizerConfig,
linear_method: LinearMethodBase, **extra_kwargs):
if tensorizer_load_fail is not None:
raise ImportError(
"Tensorizer is not installed. Please install tensorizer "
"to use this feature with `pip install vllm[tensorizer]`."
) from tensorizer_load_fail
self.tensorizer_config = tensorizer_config
self.tensorizer_args = (
self.tensorizer_config._construct_tensorizer_args())
self.extra_kwargs = extra_kwargs
if extra_kwargs.get("linear_method", None) is not None:
self.linear_method = extra_kwargs["linear_method"]
else:
self.linear_method = linear_method
self.model = self._init_model()
def _init_model(self):
model_args = self.tensorizer_config.hf_config
model_args.torch_dtype = self.tensorizer_config.dtype
with no_init_or_tensor():
return self.tensorizer_config.model_class(
config=model_args,
linear_method=self.linear_method,
**self.extra_kwargs)
def _resize_lora_embeddings(self):
"""Modify LoRA embedding layers to use bigger tensors
to allow for adapter added tokens."""
for child in self.model.modules():
if (isinstance(child, VocabParallelEmbedding)
and child.weight.shape[0] <
child.num_embeddings_per_partition):
new_weight = torch.empty(child.num_embeddings_per_partition,
child.embedding_dim,
dtype=child.weight.dtype,
device=child.weight.device)
new_weight[:child.weight.shape[0]].copy_(child.weight.data)
new_weight[child.weight.shape[0]:].fill_(0)
child.weight.data = new_weight
def _check_tensors_on_meta_device(self):
for tensor in self.model.state_dict().values():
if tensor.device.type == 'meta':
raise ValueError(
"The serialized model contains tensors on the meta device,"
" indicating that some tensors were not loaded properly."
" Please check that the parameters of the model being"
" specified match that of the serialized model, such as"
" its quantization.")
def deserialize(self):
"""
Deserialize the model using the TensorDeserializer. This method is
specifically for vLLM models using tensorizer's plaid_mode.
The deserializer makes use of tensorizer_args.stream_params
to configure the behavior of the stream when loading tensors from a
serialized model. The deserializer_params are used to configure the
behavior of the TensorDeserializer when loading tensors themselves.
Documentation on these params can be found in TensorizerArgs
Returns:
nn.Module: The deserialized model.
"""
before_mem = get_mem_usage()
start = time.perf_counter()
with open_stream(
self.tensorizer_args.tensorizer_uri,
mode="rb",
**self.tensorizer_args.stream_params,
) as stream, TensorDeserializer(
stream,
dtype=self.tensorizer_config.dtype,
**self.tensorizer_args.deserializer_params) as deserializer:
deserializer.load_into_module(self.model)
end = time.perf_counter()
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
duration = end - start
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
after_mem = get_mem_usage()
deserializer.close()
logger.info(f"Deserialized {total_bytes_str} in "
f"{end - start:0.2f}s, {per_second}/s")
logger.info(f"Memory usage before: {before_mem}")
logger.info(f"Memory usage after: {after_mem}")
self._check_tensors_on_meta_device()
self._resize_lora_embeddings()
return self.model.eval()
def tensorizer_weights_iterator(
tensorizer_args: "TensorizerArgs"
) -> Generator[Tuple[str, torch.Tensor], None, None]:
logger.warning(
"Deserializing HuggingFace models is not optimized for "
"loading on vLLM, as tensorizer is forced to load to CPU. "
"Consider deserializing a vLLM model instead for faster "
"load times. See the examples/tensorize_vllm_model.py example "
"script for serializing vLLM models.")
deserializer_args = tensorizer_args.deserializer_params
stream_params = tensorizer_args.stream_params
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
with TensorDeserializer(stream, **deserializer_args,
device="cpu") as state:
for name, param in state.items():
yield name, param
del state
"""Utilities for selecting and loading models."""
import contextlib
from typing import Tuple, Type
import torch
from torch import nn
from vllm.config import ModelConfig
from vllm.model_executor.models import ModelRegistry
import os
@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(old_dtype)
def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
if architectures == ['LlamaForCausalLM']:
os.environ['LLAMA_NN'] = '1'
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None
and model_config.quantization != "fp8"
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]
for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def get_architecture_class_name(model_config: ModelConfig) -> str:
return get_model_architecture(model_config)[1]
......@@ -4,20 +4,23 @@ import glob
import hashlib
import json
import os
import tempfile
from collections import defaultdict
from typing import Any, Iterator, List, Optional, Tuple
from typing import Any, Generator, Iterable, List, Optional, Tuple
import filelock
import huggingface_hub.constants
import numpy as np
import torch
from huggingface_hub import HfFileSystem, snapshot_download
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm
from vllm.config import ModelConfig
from vllm.config import LoadConfig, ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QuantizationConfig,
get_quantization_config)
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
logger = init_logger(__name__)
......@@ -25,11 +28,25 @@ logger = init_logger(__name__)
# can share the same lock without error.
# lock files in the temp directory will be automatically deleted when the
# system reboots, so users will not complain about annoying lock files
temp_dir = os.environ.get('TMPDIR') or os.environ.get(
'TEMP') or os.environ.get('TMP') or "/tmp/"
temp_dir = tempfile.gettempdir()
class Disabledtqdm(tqdm):
def enable_hf_transfer():
"""automatically activates hf_transfer
"""
if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
try:
# enable hf hub transfer if available
import hf_transfer # type: ignore # noqa
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
except ImportError:
pass
enable_hf_transfer()
class DisabledTqdm(tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, disable=True)
......@@ -97,7 +114,8 @@ def convert_bin_to_safetensor_file(
# TODO(woosuk): Move this to other place.
def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
def get_quant_config(model_config: ModelConfig,
load_config: LoadConfig) -> QuantizationConfig:
quant_cls = get_quantization_config(model_config.quantization)
# Read the quantization config from the HF model config, if available.
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
......@@ -108,19 +126,26 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
is_local = os.path.isdir(model_name_or_path)
if not is_local:
# Download the config files.
with get_lock(model_name_or_path, model_config.download_dir):
with get_lock(model_name_or_path, load_config.download_dir):
hf_folder = snapshot_download(model_name_or_path,
revision=model_config.revision,
allow_patterns="*.json",
cache_dir=model_config.download_dir,
tqdm_class=Disabledtqdm)
cache_dir=load_config.download_dir,
tqdm_class=DisabledTqdm)
else:
hf_folder = model_name_or_path
possible_config_filenames = quant_cls.get_config_filenames()
# If the quantization config is not found, use the default config.
if not possible_config_filenames:
return quant_cls()
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
quant_config_files = [
f for f in config_files if any(
f.endswith(x) for x in quant_cls.get_config_filenames())
f.endswith(x) for x in possible_config_filenames)
]
if len(quant_config_files) == 0:
raise ValueError(
......@@ -136,143 +161,167 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
return quant_cls.from_config(config)
def prepare_hf_model_weights(
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
fall_back_to_pt: bool = True,
revision: Optional[str] = None,
) -> Tuple[str, List[str], bool]:
# Download model weights from huggingface.
is_local = os.path.isdir(model_name_or_path)
use_safetensors = False
# Some quantized models use .pt files for storing the weights.
if load_format == "auto":
allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == "safetensors":
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == "pt":
allow_patterns = ["*.pt"]
elif load_format == "npcache":
allow_patterns = ["*.bin"]
else:
raise ValueError(f"Unknown load_format: {load_format}")
if fall_back_to_pt:
allow_patterns += ["*.pt"]
def download_weights_from_hf(model_name_or_path: str,
cache_dir: Optional[str],
allow_patterns: List[str],
revision: Optional[str] = None) -> str:
"""Download model weights from Hugging Face Hub.
Args:
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
weights. If None, will use HF defaults.
allow_patterns (List[str]): The allowed patterns for the
weight files. Files matched by any of the patterns will be
downloaded.
revision (Optional[str]): The revision of the model.
Returns:
str: The path to the downloaded model weights.
"""
# Before we download we look at that is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
if not is_local:
# Before we download we look at that is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
# depending on what is available we download different things
for pattern in allow_patterns:
matching = fnmatch.filter(file_list, pattern)
if len(matching) > 0:
allow_patterns = [pattern]
break
logger.info(f"Using model weights format {allow_patterns}")
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
hf_folder = snapshot_download(model_name_or_path,
allow_patterns=allow_patterns,
cache_dir=cache_dir,
tqdm_class=Disabledtqdm,
revision=revision)
else:
hf_folder = model_name_or_path
hf_weights_files: List[str] = []
# depending on what is available we download different things
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if len(hf_weights_files) > 0:
if pattern == "*.safetensors":
use_safetensors = True
matching = fnmatch.filter(file_list, pattern)
if len(matching) > 0:
allow_patterns = [pattern]
break
if not use_safetensors:
# Exclude files that are not needed for inference.
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
blacklist = [
"training_args.bin",
"optimizer.bin",
"optimizer.pt",
"scheduler.pt",
"scaler.pt",
]
hf_weights_files = [
f for f in hf_weights_files
if not any(f.endswith(x) for x in blacklist)
]
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`")
return hf_folder, hf_weights_files, use_safetensors
def hf_model_weights_iterator(
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
fall_back_to_pt: Optional[bool] = True,
) -> Iterator[Tuple[str, torch.Tensor]]:
hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
model_name_or_path,
cache_dir=cache_dir,
load_format=load_format,
fall_back_to_pt=fall_back_to_pt,
revision=revision)
if load_format == "npcache":
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
np_folder = os.path.join(hf_folder, "np")
os.makedirs(np_folder, exist_ok=True)
weight_names_file = os.path.join(np_folder, "weight_names.json")
# Use file lock to prevent multiple processes from
# dumping the same model weights to numpy at the same time.
with get_lock(model_name_or_path, cache_dir):
if not os.path.exists(weight_names_file):
weight_names = []
for bin_file in hf_weights_files:
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
param_path = os.path.join(np_folder, name)
with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy())
weight_names.append(name)
with open(weight_names_file, "w") as f:
json.dump(weight_names, f)
with open(weight_names_file, "r") as f:
weight_names = json.load(f)
for name in weight_names:
param_path = os.path.join(np_folder, name)
with open(param_path, "rb") as f:
param = np.load(f)
yield name, torch.from_numpy(param)
elif use_safetensors:
for st_file in hf_weights_files:
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
yield name, param
else:
for bin_file in hf_weights_files:
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
logger.info(f"Using model weights format {allow_patterns}")
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
hf_folder = snapshot_download(model_name_or_path,
allow_patterns=allow_patterns,
cache_dir=cache_dir,
tqdm_class=DisabledTqdm,
revision=revision)
return hf_folder
def filter_files_not_needed_for_inference(
hf_weights_files: List[str]) -> List[str]:
"""
Exclude files that are not needed for inference.
See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
"""
blacklist = [
"training_args.bin",
"optimizer.bin",
"optimizer.pt",
"scheduler.pt",
"scaler.pt",
]
hf_weights_files = [
f for f in hf_weights_files
if not any(f.endswith(x) for x in blacklist)
]
return hf_weights_files
def np_cache_weights_iterator(
model_name_or_path: str, cache_dir: Optional[str], hf_folder: str,
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model np files.
Will dump the model weights to numpy files if they are not already dumped.
"""
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
np_folder = os.path.join(hf_folder, "np")
os.makedirs(np_folder, exist_ok=True)
weight_names_file = os.path.join(np_folder, "weight_names.json")
# Use file lock to prevent multiple processes from
# dumping the same model weights to numpy at the same time.
with get_lock(model_name_or_path, cache_dir):
if not os.path.exists(weight_names_file):
weight_names = []
for bin_file in hf_weights_files:
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
param_path = os.path.join(np_folder, name)
with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy())
weight_names.append(name)
with open(weight_names_file, "w") as f:
json.dump(weight_names, f)
with open(weight_names_file, "r") as f:
weight_names = json.load(f)
for name in weight_names:
param_path = os.path.join(np_folder, name)
with open(param_path, "rb") as f:
param = np.load(f)
yield name, torch.from_numpy(param)
def safetensors_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
for st_file in hf_weights_files:
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
yield name, param
del state
torch.cuda.empty_cache()
def pt_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model bin/pt files."""
for bin_file in hf_weights_files:
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
yield name, param
del state
torch.cuda.empty_cache()
def kv_cache_scales_loader(
filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int,
model_type: Optional[str]) -> Iterable[Tuple[int, float]]:
"""
A simple utility to read in KV cache scaling factors that have been
previously serialized to disk. Used by the model to populate the appropriate
KV cache scaling factors. The serialization should represent a dictionary
whose keys are the TP ranks and values are another dictionary mapping layers
to their KV cache scaling factors.
Keep this function in sync with the output of examples/fp8/extract_scales.py
"""
try:
with open(filename) as f:
context = {
"model_type": model_type,
"num_hidden_layers": num_hidden_layers,
"tp_rank": tp_rank,
"tp_size": tp_size,
}
schema_dct = json.load(f)
schema = QuantParamSchema.model_validate(schema_dct,
context=context)
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
return layer_scales_map.items()
except FileNotFoundError:
logger.error(f"File or directory '{filename}' not found.")
except json.JSONDecodeError:
logger.error(f"Error decoding JSON in file '{filename}'.")
except Exception as e:
logger.error(f"An error occurred while reading '{filename}': {e}")
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
logger.warning("Defaulting to KV cache scaling factors = 1.0 "
f"for all layers in TP rank {tp_rank} "
"as an error occurred during loading.")
return []
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
......
import importlib
from typing import List, Optional, Type
from typing import Dict, List, Optional, Type
import torch.nn as nn
......@@ -41,6 +41,7 @@ _MODELS = {
# transformers's mpt class has lower case
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"OLMoForCausalLM": ("olmo", "OLMoForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
......@@ -55,6 +56,10 @@ _MODELS = {
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
}
# Architecture -> type.
# out of tree models
_OOT_MODELS: Dict[str, Type[nn.Module]] = {}
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS = []
......@@ -74,6 +79,8 @@ class ModelRegistry:
@staticmethod
def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch in _OOT_MODELS:
return _OOT_MODELS[model_arch]
if model_arch not in _MODELS:
return None
if is_hip():
......@@ -95,6 +102,16 @@ class ModelRegistry:
def get_supported_archs() -> List[str]:
return list(_MODELS.keys())
@staticmethod
def register_model(model_arch: str, model_cls: Type[nn.Module]):
if model_arch in _MODELS:
logger.warning(
f"Model architecture {model_arch} is already registered, "
"and will be overwritten by the new model "
f"class {model_cls.__name__}.")
global _OOT_MODELS
_OOT_MODELS[model_arch] = model_cls
__all__ = [
"ModelRegistry",
......
......@@ -19,7 +19,7 @@
# limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
import math
from typing import List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
......@@ -27,6 +27,8 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
......@@ -38,11 +40,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -340,19 +339,14 @@ class BaiChuanBaseForCausalLM(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if name == "lm_head.weight":
......
......@@ -17,13 +17,15 @@
# limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights."""
import math
from typing import List, Optional
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import BloomConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
......@@ -33,11 +35,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -298,14 +297,9 @@ class BloomForCausalLM(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
for name, loaded_weight in weights:
if name == "lm_head.weight":
continue
if not name.startswith("transformer."):
......
......@@ -2,7 +2,7 @@
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights."""
from typing import List, Optional
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
......@@ -10,6 +10,7 @@ from torch.nn import LayerNorm
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
......@@ -21,11 +22,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig
......@@ -371,14 +369,9 @@ class ChatGLMForCausalLM(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
for name, loaded_weight in weights:
if "rotary_pos_emb.inv_freq" in name:
continue
if "word_embeddings" in name:
......
......@@ -20,14 +20,17 @@
# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
from typing import List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn.parameter import Parameter
from transformers import CohereConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
......@@ -38,33 +41,48 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput
@torch.compile
def layer_norm_func(hidden_states, weight, variance_epsilon):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True)
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
variance_epsilon)
hidden_states = weight.to(torch.float32) * hidden_states
return hidden_states.to(input_dtype)
class LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5, bias=False):
def __init__(self, param_shape=None, eps=1e-5):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size)) if bias else None
self.weight = nn.Parameter(torch.ones(param_shape))
self.variance_epsilon = eps
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
def forward(self, hidden_states, residuals=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True)
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
hidden_states = (hidden_states -
mean) * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = self.weight.to(torch.float32) * hidden_states
if self.bias is not None:
hidden_states = hidden_states + self.bias.to(torch.float32)
return hidden_states.to(input_dtype), residuals
hidden_states = layer_norm_func(hidden_states, self.weight,
self.variance_epsilon)
return hidden_states, residuals
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
shard_dim = 0 if param.dim() != 1 else None
param_data = param.data
if shard_dim is not None:
shard_size = param_data.shape[shard_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(shard_dim, start_idx,
shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
......@@ -128,9 +146,12 @@ class CohereAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.max_position_embeddings = config.max_position_embeddings
self.max_position_embeddings = getattr(
config, "model_max_length", None) or getattr(
config, "max_position_embeddings", 8192)
self.rope_theta = config.rope_theta
self.rope_scaling = getattr(config, "rope_scaling", None)
self.use_qk_norm = getattr(config, "use_qk_norm", False)
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_dim,
......@@ -159,6 +180,22 @@ class CohereAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
)
if self.use_qk_norm:
self.q_norm = LayerNorm(param_shape=(self.num_heads,
self.head_dim),
eps=config.layer_norm_eps)
self.k_norm = LayerNorm(param_shape=(self.num_kv_heads,
self.head_dim),
eps=config.layer_norm_eps)
def _apply_qk_norm(self, q, k):
q = q.view(*q.shape[:-1], -1, self.head_dim)
k = k.view(*k.shape[:-1], -1, self.head_dim)
q, _ = self.q_norm(q)
k, _ = self.k_norm(k)
q = q.view(*q.shape[:-2], -1)
k = k.view(*k.shape[:-2], -1)
return q, k
def forward(
self,
......@@ -169,6 +206,8 @@ class CohereAttention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
......@@ -186,7 +225,7 @@ class CohereDecoderLayer(nn.Module):
self.self_attn = CohereAttention(config, linear_method=linear_method)
self.mlp = CohereMLP(config, linear_method=linear_method)
self.input_layernorm = LayerNorm(config.hidden_size,
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
eps=config.layer_norm_eps)
def forward(
......@@ -229,7 +268,8 @@ class CohereModel(nn.Module):
CohereDecoderLayer(config, linear_method=linear_method)
for _ in range(config.num_hidden_layers)
])
self.norm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.norm = LayerNorm(param_shape=(config.hidden_size),
eps=config.layer_norm_eps)
def forward(
self,
......@@ -294,13 +334,7 @@ class CohereForCausalLM(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -311,17 +345,26 @@ class CohereForCausalLM(nn.Module):
]
params_dict = dict(self.named_parameters())
loaded_params = set()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
for name, loaded_weight in weights:
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......
# coding=utf-8
from typing import List, Optional
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn as nn
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
......@@ -15,14 +18,9 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.dbrx import DbrxConfig
......@@ -392,20 +390,13 @@ class DbrxForCausalLM(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
expert_params_mapping = [(
"ws" if weight_name in ["w1", "v1"] else "w2s",
f"experts.mlp.{weight_name}",
) for weight_name in ["w1", "v1", "w2"]]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
for name, loaded_weight in weights:
for param_name, weight_name in expert_params_mapping:
if weight_name not in name:
continue
......
......@@ -23,16 +23,15 @@
# limitations under the License.
"""Inference-only DeciLM model compatible with HuggingFace weights."""
from typing import Optional
from typing import Iterable, Optional, Tuple
import torch
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
class DeciLMForCausalLM(LlamaForCausalLM):
......@@ -65,11 +64,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
linear_method=linear_method,
lora_config=lora_config)
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -79,8 +74,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......
......@@ -21,13 +21,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Deepseek model."""
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
......@@ -41,13 +44,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -317,6 +315,8 @@ class DeepseekDecoderLayer(nn.Module):
class DeepseekModel(nn.Module):
fall_back_to_pt_during_load = False
def __init__(
self,
config: PretrainedConfig,
......@@ -396,11 +396,7 @@ class DeepseekForCausalLM(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -411,12 +407,7 @@ class DeepseekForCausalLM(nn.Module):
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path,
cache_dir,
load_format,
revision,
fall_back_to_pt=False):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment