Commit 4eabe123 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/releases/v0.9.0' into v0.9.0-ori

parents 45840cd2 58738772
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
...@@ -74,7 +75,7 @@ class AutoRoundConfig(QuantizationConfig): ...@@ -74,7 +75,7 @@ class AutoRoundConfig(QuantizationConfig):
f"group_size={self.group_size}, sym={self.sym})") f"group_size={self.group_size}, sym={self.sym})")
@classmethod @classmethod
def get_name(cls): ## use str will trigger preci issue def get_name(cls) -> QuantizationMethods:
return "auto-round" return "auto-round"
@classmethod @classmethod
...@@ -142,18 +143,18 @@ class AutoRoundConfig(QuantizationConfig): ...@@ -142,18 +143,18 @@ class AutoRoundConfig(QuantizationConfig):
prefix, layer.__class__.__name__, weight_bits, group_size, prefix, layer.__class__.__name__, weight_bits, group_size,
sym) sym)
if backend == "auto" or "marlin" in backend: if backend == "auto" or "marlin" in backend:
AWQ_TYPE_MAP = {
4: scalar_types.uint4,
8: scalar_types.uint8,
}
use_marlin = (weight_bits
in AWQ_TYPE_MAP) and check_marlin_supported(
AWQ_TYPE_MAP[weight_bits], group_size, not sym)
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
use_marlin = check_moe_marlin_supports_layer(layer, group_size) use_marlin = use_marlin and check_moe_marlin_supports_layer(
else: layer, group_size)
AWQ_TYPE_MAP = {
4: scalar_types.uint4,
8: scalar_types.uint8,
}
use_marlin = ((weight_bits, sym) in AWQ_TYPE_MAP
and check_marlin_supported(
AWQ_TYPE_MAP[(weight_bits)], group_size,
not sym))
else: else:
use_marlin = False use_marlin = False
if use_marlin: if use_marlin:
...@@ -180,10 +181,11 @@ class AutoRoundConfig(QuantizationConfig): ...@@ -180,10 +181,11 @@ class AutoRoundConfig(QuantizationConfig):
from vllm.model_executor.layers.quantization.moe_wna16 import ( from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config) MoeWNA16Config)
config = { config = {
"linear_quant_method": "awq", "quant_method": "awq",
"weight_bits": weight_bits, "bits": weight_bits,
"group_size": group_size, "group_size": group_size,
"zero_point": not sym, "zero_point": not sym,
"lm_head": False,
} }
return MoeWNA16Config.from_config(config).get_quant_method( return MoeWNA16Config.from_config(config).get_quant_method(
layer, prefix) layer, prefix)
...@@ -213,18 +215,18 @@ class AutoRoundConfig(QuantizationConfig): ...@@ -213,18 +215,18 @@ class AutoRoundConfig(QuantizationConfig):
prefix, layer.__class__.__name__, weight_bits, group_size, prefix, layer.__class__.__name__, weight_bits, group_size,
sym) sym)
if backend == "auto" or "marlin" in backend: if backend == "auto" or "marlin" in backend:
GPTQ_TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}
use_marlin = ((weight_bits, sym) in GPTQ_TYPE_MAP
and check_marlin_supported(
GPTQ_TYPE_MAP[(weight_bits, sym)],
group_size,
has_zp=not sym))
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
use_marlin = check_moe_marlin_supports_layer(layer, group_size) use_marlin = use_marlin and check_moe_marlin_supports_layer(
else: layer, group_size)
GPTQ_TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}
use_marlin = ((weight_bits, sym) in GPTQ_TYPE_MAP
and check_marlin_supported(
GPTQ_TYPE_MAP[(weight_bits, sym)],
group_size,
has_zp=not sym))
else: else:
use_marlin = False use_marlin = False
if use_marlin: if use_marlin:
...@@ -251,11 +253,11 @@ class AutoRoundConfig(QuantizationConfig): ...@@ -251,11 +253,11 @@ class AutoRoundConfig(QuantizationConfig):
from vllm.model_executor.layers.quantization.moe_wna16 import ( from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config) MoeWNA16Config)
config = { config = {
"linear_quant_method": "gptq", "quant_method": "gptq",
"weight_bits": weight_bits, "bits": weight_bits,
"group_size": group_size, "group_size": group_size,
"sym": sym, "sym": sym,
"lm_head_quantized": False, "lm_head": False,
} }
return MoeWNA16Config.from_config(config).get_quant_method( return MoeWNA16Config.from_config(config).get_quant_method(
layer, prefix) layer, prefix)
......
...@@ -286,9 +286,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -286,9 +286,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
rocm_aiter_fused_experts, shuffle_weights) rocm_aiter_fused_experts, shuffle_weights)
# reshaping weights is required for aiter moe kernel. # reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight.data, shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w2_weight.data, layer.w13_weight.data, layer.w2_weight.data)
layout=(16, 16))
layer.w13_weight = torch.nn.Parameter(shuffled_w13, layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False) requires_grad=False)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import re
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from types import MappingProxyType from types import MappingProxyType
from typing import Optional from typing import Optional
import regex as re
from compressed_tensors import CompressionFormat from compressed_tensors import CompressionFormat
from torch.nn import Module from torch.nn import Module
......
...@@ -10,7 +10,6 @@ from torch.nn import Module ...@@ -10,7 +10,6 @@ from torch.nn import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -63,10 +62,9 @@ class Fp8Config(QuantizationConfig): ...@@ -63,10 +62,9 @@ class Fp8Config(QuantizationConfig):
weight_block_size: Optional[list[int]] = None, weight_block_size: Optional[list[int]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
logger.warning("Detected fp8 checkpoint. Please note that the "
"format is experimental and subject to change.")
if activation_scheme not in ACTIVATION_SCHEMES: if activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError( raise ValueError(
f"Unsupported activation scheme {activation_scheme}") f"Unsupported activation scheme {activation_scheme}")
...@@ -461,7 +459,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -461,7 +459,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logger.warning_once( logger.warning_once(
"DeepGemm not supported on the current platform.") "DeepGemm not supported on the current platform.")
self.fused_experts = functools.partial( self.fused_experts = functools.partial( # type: ignore
fused_experts, fused_experts,
block_shape=self.quant_config.weight_block_size, block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm) allow_deep_gemm=self.allow_deep_gemm)
...@@ -597,7 +595,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -597,7 +595,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
# Lazy import to avoid importing triton too early. # Lazy import to avoid importing triton too early.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
expand_weights, is_rocm_aiter_moe_enabled, shuffle_weights) is_rocm_aiter_moe_enabled, shuffle_weights)
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
...@@ -629,9 +627,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -629,9 +627,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel. # reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights( shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w13_weight.data, layer.w2_weight.data)
layer.w2_weight.data,
layout=(16, 16))
layer.w13_weight = torch.nn.Parameter(shuffled_w13, layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False) requires_grad=False)
...@@ -677,20 +673,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -677,20 +673,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
requires_grad=False) requires_grad=False)
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel. # reshaping weights is required for aiter moe kernel.
w13_scales, w2_scales = expand_weights( shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight_scale.data, layer.w13_weight, layer.w2_weight)
layer.w2_weight_scale.data,
expansion_dims=[
layer.w13_weight.shape[1], layer.w2_weight.shape[1]
])
layer.w13_weight_scale = torch.nn.Parameter(
w13_scales.contiguous(), requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(
w2_scales.contiguous(), requires_grad=False)
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight,
layer.w2_weight,
layout=(16, 16))
layer.w13_weight = torch.nn.Parameter(shuffled_w13, layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False) requires_grad=False)
...@@ -762,20 +746,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -762,20 +746,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
start += shard_size start += shard_size
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = shuffle_weights(
expansion_dims = [ layer.w13_weight, layer.w2_weight)
layer.w13_weight.shape[1], layer.w2_weight.shape[1]
]
max_w13_scales, w2_scales = expand_weights(
max_w13_scales,
layer.w2_weight_scale.data,
expansion_dims=expansion_dims)
layer.w2_weight_scale = torch.nn.Parameter(
w2_scales.contiguous(), requires_grad=False)
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight,
layer.w2_weight,
layout=(32, 32))
layer.w13_weight = torch.nn.Parameter(shuffled_w13, layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False) requires_grad=False)
...@@ -791,17 +763,12 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -791,17 +763,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale del layer.w13_input_scale
del layer.w2_input_scale del layer.w2_input_scale
def set_prepare_finalize( def select_gemm_impl(self, prepare_finalize):
self,
dp_size: int,
world_size: int,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
) -> bool:
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts) TritonOrDeepGemmExperts)
if self.use_marlin or self.rocm_aiter_moe_enabled: assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
return False "Marlin and ROCm AITER are not supported with all2all yet.")
experts = TritonOrDeepGemmExperts( experts = TritonOrDeepGemmExperts(
use_fp8_w8a8=True, use_fp8_w8a8=True,
...@@ -809,12 +776,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -809,12 +776,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
allow_deep_gemm=self.allow_deep_gemm, allow_deep_gemm=self.allow_deep_gemm,
) )
self.fused_experts = mk.FusedMoEModularKernel( return experts
prepare_finalize,
experts,
)
return True
def apply( def apply(
self, self,
......
...@@ -9,7 +9,6 @@ from torch.nn.parameter import Parameter, UninitializedParameter ...@@ -9,7 +9,6 @@ from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase) FusedMoEMethodBase)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
...@@ -19,6 +18,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -19,6 +18,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -96,8 +96,8 @@ MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES ...@@ -96,8 +96,8 @@ MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor, def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
qweight_type: int) -> torch.Tensor: qweight_type: int) -> torch.Tensor:
# HACK: when doing chunked prefill we don't generate output tokens # HACK: when doing chunked prefill we don't generate output tokens
# so input to logits generator is empty which causes invalid parameter # so input to logits generator is empty which causes invalid parameter
if x.shape[0] == 0: if x.shape[0] == 0:
...@@ -130,6 +130,30 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor, ...@@ -130,6 +130,30 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
return y return y
def _fused_mul_mat_gguf_fake(
x: torch.Tensor,
qweight: torch.Tensor,
qweight_type: int,
) -> torch.Tensor:
return torch.empty(x.shape[0],
qweight.shape[0],
dtype=x.dtype,
device=x.device)
try:
direct_register_custom_op(
op_name="_fused_mul_mat_gguf",
op_func=_fused_mul_mat_gguf,
mutates_args=[],
fake_impl=_fused_mul_mat_gguf_fake,
)
fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf
except AttributeError as error:
raise error
def _fused_moe_gguf( def _fused_moe_gguf(
x: torch.Tensor, x: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
...@@ -138,8 +162,21 @@ def _fused_moe_gguf( ...@@ -138,8 +162,21 @@ def _fused_moe_gguf(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
qweight_type: int, qweight_type: int,
qweight_type2: int, qweight_type2: int,
act, activation: str,
) -> torch.Tensor: ) -> torch.Tensor:
def act(x: torch.Tensor):
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if activation == "silu":
torch.ops._C.silu_and_mul(out, x)
elif activation == "gelu":
torch.ops._C.gelu_and_mul(out, x)
else:
raise ValueError(f"Unsupported activation: {activation}")
return out
# lazy import to avoid triggering triton import in CPU backend # lazy import to avoid triggering triton import in CPU backend
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
moe_align_block_size) moe_align_block_size)
...@@ -189,12 +226,12 @@ def _fused_moe_gguf( ...@@ -189,12 +226,12 @@ def _fused_moe_gguf(
for ww, ii in zip(w, idx): for ww, ii in zip(w, idx):
expert_up = w1[ii] expert_up = w1[ii]
out = _fuse_mul_mat(inp, expert_up, qweight_type) out = fused_mul_mat_gguf(inp, expert_up, qweight_type)
out = act(out) out = act(out)
expert_down = w2[ii] expert_down = w2[ii]
current_state = _fuse_mul_mat(out, expert_down, current_state = fused_mul_mat_gguf(out, expert_down,
qweight_type2).mul_(ww) qweight_type2).mul_(ww)
if current_hidden_state is None: if current_hidden_state is None:
current_hidden_state = current_state current_hidden_state = current_state
else: else:
...@@ -203,6 +240,78 @@ def _fused_moe_gguf( ...@@ -203,6 +240,78 @@ def _fused_moe_gguf(
return out_hidden_states return out_hidden_states
def _fused_moe_gguf_fake(
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
qweight_type: int,
qweight_type2: int,
activation: str,
) -> torch.Tensor:
return torch.empty_like(x)
try:
direct_register_custom_op(
op_name="_fused_moe_gguf",
op_func=_fused_moe_gguf,
mutates_args=[],
fake_impl=_fused_moe_gguf_fake,
)
fused_moe_gguf = torch.ops.vllm._fused_moe_gguf
except AttributeError as error:
raise error
def _apply_gguf_embedding(
x: torch.Tensor,
qweight: torch.Tensor,
qweight_type: int,
hidden_size: int,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
if qweight_type in UNQUANTIZED_TYPES:
return torch.embedding(qweight, x)
elif qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
x_flat = x.flatten()
assert (hidden_size == qweight.shape[1] // type_size * block_size)
quant = torch.index_select(qweight, dim=0, index=x_flat)
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
x_flat.shape[0], dtype)
return dequant.view(*x.shape, hidden_size)
else:
qweight_type = WeightType(qweight_type)
raise NotImplementedError(
f"Unsupported GGUF quantization type: {qweight_type}")
def _apply_gguf_embedding_fake(
x: torch.Tensor,
qweight: torch.Tensor,
qweight_type: int,
hidden_size: int,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
return torch.empty(x.shape[0], hidden_size, dtype=dtype, device=x.device)
try:
direct_register_custom_op(
op_name="_apply_gguf_embedding",
op_func=_apply_gguf_embedding,
mutates_args=[],
fake_impl=_apply_gguf_embedding_fake,
)
apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding
except AttributeError as error:
raise error
class GGUFLinearMethod(LinearMethodBase): class GGUFLinearMethod(LinearMethodBase):
"""Linear method for GGUF. """Linear method for GGUF.
...@@ -249,26 +358,76 @@ class GGUFLinearMethod(LinearMethodBase): ...@@ -249,26 +358,76 @@ class GGUFLinearMethod(LinearMethodBase):
set_weight_attrs(qweight_type, extra_weight_attrs) set_weight_attrs(qweight_type, extra_weight_attrs)
layer.register_parameter("qweight_type", qweight_type) layer.register_parameter("qweight_type", qweight_type)
def process_weights_after_loading(self, layer: torch.nn.Module):
qweight_type = layer.qweight_type.weight_type
if not (qweight_type in UNQUANTIZED_TYPES
or qweight_type in DEQUANT_TYPES):
qweight_type = WeightType(qweight_type)
raise ValueError(
f"Unsupported GGUF quantization type {qweight_type} in "
f"layer {layer}.")
# For MergedColumnParallelLinear and QKVParallelLinear, we need to
# materialize the padded weight parameter for CUDA Graph compatibility.
self._create_padded_weight_param(layer)
def _create_padded_weight_param(self, layer: torch.nn.Module):
"""Create padded weight parameter for GGUF MergedLinear layer."""
qweight = layer.qweight
shard_id_map = qweight.shard_id_map
shard_id = qweight.shard_id
if len(data_container := qweight.data_container) > 1:
dtype = {data.dtype for data in data_container}
assert len(dtype) == 1, ValueError(
f"Data container has mixed dtypes: {dtype}")
dtype = next(iter(dtype))
# concat dim0 and pad dim1
padded_side = max(x.size(1) for x in data_container)
concat_side = sum(x.size(0) for x in data_container)
# Pad the quantized weights to dense tensor, and create a map
# with the location of each shard in the padded tensor.
padded_data = torch.zeros((concat_side, padded_side),
dtype=dtype,
device=qweight.device)
# (dim0_start, dim0_end, dim1_size)
shard_offset_map = dict[str, tuple[int, int, int]]()
for idx in shard_id:
id_in_container = shard_id_map[idx]
start = sum(
x.size(0) for x in data_container[:id_in_container])
end = start + data_container[id_in_container].size(0)
size = data_container[id_in_container].size(1)
padded_data[start:end, :size] = data_container[id_in_container]
shard_offset_map[idx] = (start, end, size)
qweight.data_container.clear()
padded_param = Parameter(padded_data, requires_grad=False)
set_weight_attrs(padded_param, vars(qweight))
set_weight_attrs(padded_param,
{"shard_offset_map": shard_offset_map})
layer.register_parameter("qweight", padded_param)
def apply(self, def apply(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
shard_id = getattr(layer.qweight, "shard_id", None) shard_id = layer.qweight.shard_id
if shard_id: if shard_id:
# dequantize shard weights respectively # dequantize shard weights respectively
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
qweight = layer.qweight.unbind(0) qweight = layer.qweight
result = [] result = []
for idx in shard_id: for idx in shard_id:
q_idx = layer.qweight.shard_id_map[idx] start, end, offset = layer.qweight.shard_offset_map[idx]
qweight_type = layer.qweight_type.shard_weight_type[idx] qweight_type = layer.qweight_type.shard_weight_type[idx]
result.append(_fuse_mul_mat(x, qweight[q_idx], qweight_type)) result.append(
fused_mul_mat_gguf(
x, qweight[start:end, :offset].contiguous(),
qweight_type))
out = torch.cat(result, axis=1) out = torch.cat(result, axis=1)
else: else:
qweight = layer.qweight qweight = layer.qweight
qweight_type = layer.qweight_type.weight_type qweight_type = layer.qweight_type.weight_type
out = _fuse_mul_mat(x, qweight, qweight_type) out = fused_mul_mat_gguf(x, qweight, qweight_type)
if bias is not None: if bias is not None:
out.add_(bias) out.add_(bias)
return out return out
...@@ -338,7 +497,6 @@ class GGUFMoEMethod(FusedMoEMethodBase): ...@@ -338,7 +497,6 @@ class GGUFMoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_qweight_type, extra_weight_attrs) set_weight_attrs(w2_qweight_type, extra_weight_attrs)
layer.register_parameter("w2_qweight_type", w2_qweight_type) layer.register_parameter("w2_qweight_type", w2_qweight_type)
self.act = SiluAndMul()
def apply( def apply(
self, self,
...@@ -375,10 +533,10 @@ class GGUFMoEMethod(FusedMoEMethodBase): ...@@ -375,10 +533,10 @@ class GGUFMoEMethod(FusedMoEMethodBase):
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias)
return _fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight, return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
topk_weights, topk_ids, topk_weights, topk_ids,
layer.w13_qweight_type.weight_type, layer.w13_qweight_type.weight_type,
layer.w2_qweight_type.weight_type, self.act) layer.w2_qweight_type.weight_type, activation)
class GGUFEmbeddingMethod(GGUFLinearMethod): class GGUFEmbeddingMethod(GGUFLinearMethod):
...@@ -392,34 +550,15 @@ class GGUFEmbeddingMethod(GGUFLinearMethod): ...@@ -392,34 +550,15 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
x: torch.Tensor) -> torch.Tensor: x: torch.Tensor) -> torch.Tensor:
qweight = layer.qweight qweight = layer.qweight
qweight_type = layer.qweight_type.weight_type qweight_type = layer.qweight_type.weight_type
hidden_size = qweight.tensor_shape[1]
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] return apply_gguf_embedding(x,
hidden_size = qweight.shape[1] // type_size * block_size qweight,
if qweight_type < 2: qweight_type,
return torch.embedding(qweight, x) hidden_size,
x_flat = x.flatten() dtype=self.params_dtype)
quant = torch.index_select(qweight, dim=0, index=x_flat)
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
x_flat.shape[0], self.params_dtype)
return dequant.view(*x.shape, hidden_size)
class GGUFUninitializedParameter(UninitializedParameter): class GGUFUninitializedParameter(UninitializedParameter):
cls_to_become = Parameter cls_to_become = Parameter
data_container: list[torch.Tensor] data_container: list[torch.Tensor]
def materialize_nested(self) -> Parameter:
dtype = {data.dtype for data in self.data_container}
assert len(dtype) == 1, ValueError(
f"Data container has mixed dtypes: {dtype}")
dtype = next(iter(dtype))
nested_data = torch.nested.nested_tensor(self.data_container,
device=self.device,
dtype=dtype)
self.data_container.clear()
param = torch.Tensor._make_subclass(self.cls_to_become,
nested_data,
require_grad=False)
for k, v in self.__dict__.items():
setattr(param, k, v)
return param
...@@ -14,7 +14,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -14,7 +14,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.platforms import current_platform from vllm.platforms import current_platform
MIN_IPEX_VERSION = "2.5.0" MIN_IPEX_VERSION = "2.7.0"
class IPEXConfig(QuantizationConfig): class IPEXConfig(QuantizationConfig):
...@@ -181,8 +181,6 @@ class IPEXGPTQLinearMethod(GPTQLinearMethod): ...@@ -181,8 +181,6 @@ class IPEXGPTQLinearMethod(GPTQLinearMethod):
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1]) reshaped_x = x.reshape(-1, x.shape[-1])
out = layer.ipex_qlinear(reshaped_x) out = layer.ipex_qlinear(reshaped_x)
if bias is not None:
out.add_(bias)
return out.reshape(x.shape[:-1] + (layer.ipex_output_size, )) return out.reshape(x.shape[:-1] + (layer.ipex_output_size, ))
......
...@@ -192,7 +192,7 @@ class ModelOptNvFp4Config(QuantizationConfig): ...@@ -192,7 +192,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
@classmethod @classmethod
def get_name(cls) -> QuantizationMethods: def get_name(cls) -> QuantizationMethods:
return "nvfp4" return "modelopt_fp4"
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]: def get_supported_act_dtypes(cls) -> list[torch.dtype]:
...@@ -228,7 +228,7 @@ class ModelOptNvFp4Config(QuantizationConfig): ...@@ -228,7 +228,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
exclude_modules, group_size) exclude_modules, group_size)
def is_layer_excluded(self, prefix: str, exclude_modules: list): def is_layer_excluded(self, prefix: str, exclude_modules: list):
import re import regex as re
for pattern in exclude_modules: for pattern in exclude_modules:
regex_str = pattern.replace('.', r'\.').replace('*', r'.*') regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
if re.fullmatch(regex_str, prefix): if re.fullmatch(regex_str, prefix):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import re
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from types import MappingProxyType from types import MappingProxyType
from typing import Any, Optional from typing import Any, Optional
import regex as re
def deep_compare(dict1: Any, dict2: Any) -> bool: def deep_compare(dict1: Any, dict2: Any) -> bool:
if type(dict1) is not type(dict2): if type(dict1) is not type(dict2):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import re
from copy import deepcopy from copy import deepcopy
from typing import Optional, Union from typing import Optional, Union
import regex as re
import torch import torch
from vllm.config import QuantizationConfig from vllm.config import QuantizationConfig
......
...@@ -262,16 +262,16 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler): ...@@ -262,16 +262,16 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
True, then a token can be accepted, else it should be True, then a token can be accepted, else it should be
rejected. rejected.
Given {math}`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of Given $q(\hat{x}_{n+1}|x_1, \dots, x_n)$, the probability of
{math}`\hat{x}_{n+1}` given context {math}`x_1, \dots, x_n` according $\hat{x}_{n+1}$ given context $x_1, \dots, x_n$ according
to the target model, and {math}`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the to the target model, and $p(\hat{x}_{n+1}|x_1, \dots, x_n)$, the
same conditional probability according to the draft model, the token same conditional probability according to the draft model, the token
is accepted with probability: is accepted with probability:
:::{math} $$
\min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)} \min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
{p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right) {p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
::: $$
This implementation does not apply causality. When using the output, This implementation does not apply causality. When using the output,
if a token is rejected, subsequent tokens should not be used. if a token is rejected, subsequent tokens should not be used.
...@@ -314,30 +314,31 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler): ...@@ -314,30 +314,31 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target model is recovered (within hardware numerics). target model is recovered (within hardware numerics).
The probability distribution used in this rejection case is constructed The probability distribution used in this rejection case is constructed
as follows. Given {math}`q(x|x_1, \dots, x_n)`, the probability of as follows. Given $q(x|x_1, \dots, x_n)$, the probability of
{math}`x` given context {math}`x_1, \dots, x_n` according to the target $x$ given context $x_1, \dots, x_n$ according to the target
model and {math}`p(x|x_1, \dots, x_n)`, the same conditional probability model and $p(x|x_1, \dots, x_n)$, the same conditional probability
according to the draft model: according to the draft model:
:::{math} $$
x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+ x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
::: $$
where {math}`(f(x))_+` is defined as: where $(f(x))_+$ is defined as:
:::{math} $$
(f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))} (f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
::: $$
See https://github.com/vllm-project/vllm/pull/2336 for a visualization See https://github.com/vllm-project/vllm/pull/2336 for a visualization
of the draft, target, and recovered probability distributions. of the draft, target, and recovered probability distributions.
Returns a tensor of shape [batch_size, k, vocab_size]. Returns a tensor of shape [batch_size, k, vocab_size].
Note: This batches operations on GPU and thus constructs the recovered Note:
distribution for all tokens, even if they are accepted. This causes This batches operations on GPU and thus constructs the recovered
division-by-zero errors, so we use self._smallest_positive_value to distribution for all tokens, even if they are accepted. This causes
avoid that. This introduces some drift to the distribution. division-by-zero errors, so we use self._smallest_positive_value to
avoid that. This introduces some drift to the distribution.
""" """
_, k, _ = draft_probs.shape _, k, _ = draft_probs.shape
......
...@@ -228,17 +228,19 @@ class Sampler(nn.Module): ...@@ -228,17 +228,19 @@ class Sampler(nn.Module):
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
""" """
Single-step scheduling: Single-step scheduling:
* Perform GPU-side sampling computation & compute * Perform GPU-side sampling computation & compute
GPU-side logprobs tensor GPU-side logprobs tensor
* Pythonize sampling result & logprobs tensor * Pythonize sampling result & logprobs tensor
Multi-step scheduling: Multi-step scheduling:
* Perform GPU-side sampling computation & compute * Perform GPU-side sampling computation & compute
GPU-side logprobs tensor GPU-side logprobs tensor
* Defer Pythonization of sampling result & logprobs * Defer Pythonization of sampling result & logprobs
tensor tensor
* Encapsulate arguments required for deferred Pythonization * Encapsulate arguments required for deferred Pythonization
in the {class}`SamplerOutput` structure in the
[`SamplerOutput`][vllm.model_executor.layers.sampler.SamplerOutput]
structure
Args: Args:
logits: (num_tokens, vocab_size). logits: (num_tokens, vocab_size).
......
...@@ -93,29 +93,27 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -93,29 +93,27 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
Evaluates and returns a mask of accepted tokens based on the Evaluates and returns a mask of accepted tokens based on the
posterior probabilities. posterior probabilities.
Parameters: Args:
---------- target_probs (torch.Tensor): A tensor of shape
target_probs : torch.Tensor (batch_size, k, vocab_size) representing the probabilities of
A tensor of shape (batch_size, k, vocab_size) representing each token in the vocabulary for each position in the proposed
the probabilities of each token in the vocabulary for each sequence. This is the distribution generated by the target
position in the proposed sequence. This is the distribution model.
generated by the target model. draft_token_ids (torch.Tensor): A tensor of shape (batch_size, k)
draft_token_ids : torch.Tensor representing the proposed token ids.
A tensor of shape (batch_size, k) representing the proposed
token ids.
A draft token_id x_{n+k} is accepted if it satisfies the A draft token_id x_{n+k} is accepted if it satisfies the
following condition following condition
:::{math} $$
p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) > p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) >
\min \left( \epsilon, \delta * \exp \left( \min \left( \epsilon, \delta * \exp \left(
-H(p_{\text{original}}( -H(p_{\text{original}}(
\cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right) \cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right)
::: $$
where {math}`p_{\text{original}}` corresponds to target_probs where $p_{\text{original}}$ corresponds to target_probs
and {math}`\epsilon` and {math}`\delta` correspond to hyperparameters and $\epsilon$ and $\delta$ correspond to hyperparameters
specified using self._posterior_threshold and self._posterior_alpha specified using self._posterior_threshold and self._posterior_alpha
This method computes the posterior probabilities for the given This method computes the posterior probabilities for the given
...@@ -126,13 +124,10 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -126,13 +124,10 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
returns a boolean mask indicating which tokens can be accepted. returns a boolean mask indicating which tokens can be accepted.
Returns: Returns:
------- torch.Tensor: A boolean tensor of shape (batch_size, k) where each
torch.Tensor element indicates whether the corresponding draft token has
A boolean tensor of shape (batch_size, k) where each element been accepted or rejected. True indicates acceptance and false
indicates whether the corresponding draft token has been accepted indicates rejection.
or rejected. True indicates acceptance and false indicates
rejection.
""" """
device = target_probs.device device = target_probs.device
candidates_prob = torch.gather( candidates_prob = torch.gather(
...@@ -156,17 +151,14 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -156,17 +151,14 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
The recovered token ids will fill the first unmatched token The recovered token ids will fill the first unmatched token
by the target token. by the target token.
Parameters Args:
---------- target_probs (torch.Tensor): A tensor of shape
target_probs : torch.Tensor (batch_size, k, vocab_size) containing the target probability
A tensor of shape (batch_size, k, vocab_size) containing distribution.
the target probability distribution
Returns:
Returns torch.Tensor: A tensor of shape (batch_size, k) with the recovered
------- token ids which are selected from target probs.
torch.Tensor
A tensor of shape (batch_size, k) with the recovered token
ids which are selected from target probs.
""" """
max_indices = torch.argmax(target_probs, dim=-1) max_indices = torch.argmax(target_probs, dim=-1)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional
from torch import nn from torch import nn
from vllm.config import LoadConfig, LoadFormat, VllmConfig from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig
from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.bitsandbytes_loader import ( from vllm.model_executor.model_loader.bitsandbytes_loader import (
BitsAndBytesModelLoader) BitsAndBytesModelLoader)
...@@ -47,9 +49,14 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: ...@@ -47,9 +49,14 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
return DefaultModelLoader(load_config) return DefaultModelLoader(load_config)
def get_model(*, vllm_config: VllmConfig) -> nn.Module: def get_model(*,
vllm_config: VllmConfig,
model_config: Optional[ModelConfig] = None) -> nn.Module:
loader = get_model_loader(vllm_config.load_config) loader = get_model_loader(vllm_config.load_config)
return loader.load_model(vllm_config=vllm_config) if model_config is None:
model_config = vllm_config.model_config
return loader.load_model(vllm_config=vllm_config,
model_config=model_config)
__all__ = [ __all__ = [
......
...@@ -18,6 +18,7 @@ class BaseModelLoader(ABC): ...@@ -18,6 +18,7 @@ class BaseModelLoader(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def load_model(self, *, vllm_config: VllmConfig) -> nn.Module: def load_model(self, *, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
"""Load a model with the given configurations.""" """Load a model with the given configurations."""
raise NotImplementedError raise NotImplementedError
...@@ -569,10 +569,9 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -569,10 +569,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision) self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, vllm_config: VllmConfig) -> nn.Module: def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
device_config = vllm_config.device_config device_config = vllm_config.device_config
model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
......
...@@ -11,8 +11,8 @@ import torch ...@@ -11,8 +11,8 @@ import torch
from torch import nn from torch import nn
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm import envs
from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import ( from vllm.model_executor.model_loader.utils import (
...@@ -64,7 +64,7 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -64,7 +64,7 @@ class DefaultModelLoader(BaseModelLoader):
Returns the path to the downloaded model, or None if the model is not Returns the path to the downloaded model, or None if the model is not
downloaded from ModelScope.""" downloaded from ModelScope."""
if VLLM_USE_MODELSCOPE: if envs.VLLM_USE_MODELSCOPE:
# download model from ModelScope hub, # download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use. # lazy import so that modelscope is not required for normal use.
# pylint: disable=C. # pylint: disable=C.
...@@ -264,13 +264,14 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -264,13 +264,14 @@ class DefaultModelLoader(BaseModelLoader):
fall_back_to_pt=True, fall_back_to_pt=True,
allow_patterns_overrides=None) allow_patterns_overrides=None)
def load_model(self, vllm_config: VllmConfig) -> nn.Module: def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
device_config = vllm_config.device_config device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device) target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with target_device: with target_device:
model = initialize_model(vllm_config=vllm_config) model = initialize_model(vllm_config=vllm_config,
model_config=model_config)
weights_to_load = {name for name, _ in model.named_parameters()} weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights( loaded_weights = model.load_weights(
......
...@@ -22,9 +22,9 @@ class DummyModelLoader(BaseModelLoader): ...@@ -22,9 +22,9 @@ class DummyModelLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
pass # Nothing to download pass # Nothing to download
def load_model(self, vllm_config: VllmConfig) -> nn.Module: def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
device_config = vllm_config.device_config device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device) target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with target_device: with target_device:
......
...@@ -92,9 +92,9 @@ class GGUFModelLoader(BaseModelLoader): ...@@ -92,9 +92,9 @@ class GGUFModelLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model) self._prepare_weights(model_config.model)
def load_model(self, vllm_config: VllmConfig) -> nn.Module: def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
device_config = vllm_config.device_config device_config = vllm_config.device_config
model_config = vllm_config.model_config
local_model_path = self._prepare_weights(model_config.model) local_model_path = self._prepare_weights(model_config.model)
gguf_weights_map = self._get_gguf_weights_map(model_config) gguf_weights_map = self._get_gguf_weights_map(model_config)
# we can only know if tie word embeddings after mapping weights # we can only know if tie word embeddings after mapping weights
......
...@@ -87,16 +87,29 @@ class NeuronCausalLM(nn.Module): ...@@ -87,16 +87,29 @@ class NeuronCausalLM(nn.Module):
input_block_ids: torch.Tensor, input_block_ids: torch.Tensor,
sampling_params: torch.Tensor, sampling_params: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# sort block ids sequentially for perf/neuron support reasons
sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids)
input_ids = torch.index_select(input_ids, 0, sorted_indices)
positions = torch.index_select(positions, 0, sorted_indices)
sampling_params = torch.index_select(sampling_params, 0,
sorted_indices)
output = self.model(input_ids, output = self.model(input_ids,
attention_mask=None, attention_mask=None,
position_ids=positions, position_ids=positions,
seq_ids=input_block_ids, seq_ids=sorted_input_block_ids,
sampling_params=sampling_params) sampling_params=sampling_params)
# on-device sampling # on-device sampling
if self.config.neuron_config.on_device_sampling_config: if self.config.neuron_config.on_device_sampling_config:
return output.hidden_states output = output.hidden_states
else: else:
return output.logits[:, -1, :] output = output.logits[:, -1, :]
restored_indices = torch.argsort(sorted_indices)
if input_block_ids.shape[0] != 1:
output = torch.index_select(output, 0, restored_indices)
return output
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
...@@ -340,14 +353,26 @@ class NeuronSpeculationCausalLM(nn.Module): ...@@ -340,14 +353,26 @@ class NeuronSpeculationCausalLM(nn.Module):
input_block_ids: torch.Tensor, input_block_ids: torch.Tensor,
sampling_params: torch.Tensor, sampling_params: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# sort block ids sequentially for perf/neuron support reasons
sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids)
input_ids = torch.index_select(input_ids, 0, sorted_indices)
positions = torch.index_select(positions, 0, sorted_indices)
sampling_params = torch.index_select(sampling_params, 0,
sorted_indices)
output = self.model(input_ids, output = self.model(input_ids,
attention_mask=None, attention_mask=None,
position_ids=positions, position_ids=positions,
seq_ids=input_block_ids, seq_ids=sorted_input_block_ids,
sampling_params=sampling_params) sampling_params=sampling_params)
restored_indices = torch.argsort(sorted_indices)
# CTX encoding # CTX encoding
if (positions[:, 0]).sum().item() == 0: if (positions[:, 0]).sum().item() == 0:
return output.fused_outputs[0][:, 0:1] output = output.fused_outputs[0][:, 0:1]
if input_block_ids.shape[0] != 1:
output = torch.index_select(output, 0, restored_indices)
return output
# Fused Spec (Generation) # Fused Spec (Generation)
accepted_tokens_with_padding = output.fused_outputs[0] accepted_tokens_with_padding = output.fused_outputs[0]
...@@ -362,6 +387,10 @@ class NeuronSpeculationCausalLM(nn.Module): ...@@ -362,6 +387,10 @@ class NeuronSpeculationCausalLM(nn.Module):
-1) >= generated_token_counts -1) >= generated_token_counts
accepted_tokens_with_padding[mask] = -1 accepted_tokens_with_padding[mask] = -1
if input_block_ids.shape[0] != 1:
accepted_tokens_with_padding = torch.index_select(
accepted_tokens_with_padding, 0, restored_indices)
return accepted_tokens_with_padding return accepted_tokens_with_padding
def sample( def sample(
...@@ -416,6 +445,10 @@ class NeuronSpeculationCausalLM(nn.Module): ...@@ -416,6 +445,10 @@ class NeuronSpeculationCausalLM(nn.Module):
draft_neuron_config.speculation_length = 0 draft_neuron_config.speculation_length = 0
draft_neuron_config.trace_tokengen_model = True draft_neuron_config.trace_tokengen_model = True
draft_neuron_config.enable_fused_speculation = False draft_neuron_config.enable_fused_speculation = False
if getattr(config.neuron_config, "draft_model_modules_to_not_convert",
None):
draft_neuron_config.modules_to_not_convert = (
draft_neuron_config.draft_model_modules_to_not_convert)
if config.neuron_config.enable_eagle_speculation: if config.neuron_config.enable_eagle_speculation:
draft_neuron_config.is_eagle_draft = True draft_neuron_config.is_eagle_draft = True
draft_neuron_config.sequence_parallel_enabled = False draft_neuron_config.sequence_parallel_enabled = False
...@@ -502,7 +535,7 @@ def _get_default_neuron_config(model_config: ModelConfig, ...@@ -502,7 +535,7 @@ def _get_default_neuron_config(model_config: ModelConfig,
max_context_length=scheduler_config.max_model_len, max_context_length=scheduler_config.max_model_len,
seq_len=scheduler_config.max_model_len, seq_len=scheduler_config.max_model_len,
enable_bucketing=True, enable_bucketing=True,
is_continuous_batching=(batch_size > 1), is_continuous_batching=True,
quantized=False, quantized=False,
torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
padding_side="right", padding_side="right",
...@@ -520,6 +553,7 @@ def _get_default_speculation_config(model_config: ModelConfig, ...@@ -520,6 +553,7 @@ def _get_default_speculation_config(model_config: ModelConfig,
args.""" args."""
neuron_config = dict( neuron_config = dict(
tp_degree=parallel_config.tensor_parallel_size, tp_degree=parallel_config.tensor_parallel_size,
ctx_batch_size=1,
batch_size=scheduler_config.max_num_seqs, batch_size=scheduler_config.max_num_seqs,
max_context_length=scheduler_config.max_model_len, max_context_length=scheduler_config.max_model_len,
seq_len=scheduler_config.max_model_len, seq_len=scheduler_config.max_model_len,
...@@ -527,6 +561,7 @@ def _get_default_speculation_config(model_config: ModelConfig, ...@@ -527,6 +561,7 @@ def _get_default_speculation_config(model_config: ModelConfig,
trace_tokengen_model=False, trace_tokengen_model=False,
enable_fused_speculation=True, enable_fused_speculation=True,
enable_bucketing=True, enable_bucketing=True,
is_continuous_batching=True,
quantized=False, quantized=False,
torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
on_device_sampling_config=dict( on_device_sampling_config=dict(
......
...@@ -100,11 +100,10 @@ class RunaiModelStreamerLoader(BaseModelLoader): ...@@ -100,11 +100,10 @@ class RunaiModelStreamerLoader(BaseModelLoader):
"""Download model if necessary""" """Download model if necessary"""
self._prepare_weights(model_config.model, model_config.revision) self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, vllm_config: VllmConfig) -> nn.Module: def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
"""Perform streaming of the model to destination""" """Perform streaming of the model to destination"""
device_config = vllm_config.device_config device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device) target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with target_device: with target_device:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment