Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fcfc474d
Commit
fcfc474d
authored
Apr 09, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.3' into v0.8.3-dev
parents
bb94d2e5
296c6572
Changes
503
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
980 additions
and
211 deletions
+980
-211
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+5
-0
vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py
...xecutor/layers/quantization/kernels/scaled_mm/__init__.py
+3
-1
vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py
...l_executor/layers/quantization/kernels/scaled_mm/aiter.py
+119
-0
vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py
...del_executor/layers/quantization/kernels/scaled_mm/xla.py
+2
-1
vllm/model_executor/layers/quantization/moe_wna16.py
vllm/model_executor/layers/quantization/moe_wna16.py
+19
-17
vllm/model_executor/layers/quantization/quark/quark_moe.py
vllm/model_executor/layers/quantization/quark/quark_moe.py
+17
-13
vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
...cutor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
+2
-0
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+28
-9
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+68
-0
vllm/model_executor/layers/vocab_parallel_embedding.py
vllm/model_executor/layers/vocab_parallel_embedding.py
+1
-1
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+2
-0
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+7
-11
vllm/model_executor/models/adapters.py
vllm/model_executor/models/adapters.py
+5
-4
vllm/model_executor/models/aya_vision.py
vllm/model_executor/models/aya_vision.py
+527
-0
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+104
-95
vllm/model_executor/models/bamba.py
vllm/model_executor/models/bamba.py
+55
-50
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+6
-4
vllm/model_executor/models/blip.py
vllm/model_executor/models/blip.py
+4
-1
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+4
-2
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+2
-2
No files found.
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
fcfc474d
...
@@ -592,9 +592,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -592,9 +592,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
if
apply_router_weight_on_input
is
not
None
:
raise
NotImplementedError
(
"Apply router weight on input is not supported for"
"fused Marlin MoE method."
)
# The input must currently be float16
# The input must currently be float16
orig_dtype
=
x
.
dtype
orig_dtype
=
x
.
dtype
...
...
vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py
View file @
fcfc474d
...
@@ -3,6 +3,8 @@
...
@@ -3,6 +3,8 @@
import
os
import
os
from
typing
import
Dict
,
List
,
Optional
,
Type
from
typing
import
Dict
,
List
,
Optional
,
Type
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter
import
(
AiterScaledMMLinearKernel
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass
import
(
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass
import
(
CutlassScaledMMLinearKernel
)
CutlassScaledMMLinearKernel
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel
import
(
# noqa: E501
...
@@ -17,7 +19,7 @@ from vllm.platforms import PlatformEnum, current_platform
...
@@ -17,7 +19,7 @@ from vllm.platforms import PlatformEnum, current_platform
_POSSIBLE_KERNELS
:
Dict
[
PlatformEnum
,
List
[
Type
[
ScaledMMLinearKernel
]]]
=
{
_POSSIBLE_KERNELS
:
Dict
[
PlatformEnum
,
List
[
Type
[
ScaledMMLinearKernel
]]]
=
{
PlatformEnum
.
CPU
:
[
CutlassScaledMMLinearKernel
],
PlatformEnum
.
CPU
:
[
CutlassScaledMMLinearKernel
],
PlatformEnum
.
CUDA
:
[
CutlassScaledMMLinearKernel
],
PlatformEnum
.
CUDA
:
[
CutlassScaledMMLinearKernel
],
PlatformEnum
.
ROCM
:
[
TritonScaledMMLinearKernel
],
PlatformEnum
.
ROCM
:
[
AiterScaledMMLinearKernel
,
TritonScaledMMLinearKernel
],
PlatformEnum
.
TPU
:
[
XLAScaledMMLinearKernel
],
PlatformEnum
.
TPU
:
[
XLAScaledMMLinearKernel
],
}
}
...
...
vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py
0 → 100644
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
,
Tuple
import
torch
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
.cutlass
import
CutlassScaledMMLinearKernel
from
.ScaledMMLinearKernel
import
ScaledMMLinearLayerConfig
class
AiterScaledMMLinearKernel
(
CutlassScaledMMLinearKernel
):
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
90
@
classmethod
def
can_implement
(
cls
,
c
:
ScaledMMLinearLayerConfig
)
->
Tuple
[
bool
,
Optional
[
str
]]:
if
not
current_platform
.
is_rocm
():
return
(
False
,
"AiterScaledMMLinearKernel requires `aiter` which is not "
+
"currently supported on non-ROCm platform."
)
try
:
import
aiter
# noqa: F401 # deliberately attempt to import aiter
except
Exception
:
return
(
False
,
"AiterScaledMMLinearKernel requires `aiter` which is not "
+
"installed on ROCm."
)
# Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
if
not
(
envs
.
VLLM_ROCM_USE_AITER_LINEAR
\
and
envs
.
VLLM_ROCM_USE_AITER
):
return
(
False
,
"AiterScaledMMLinearKernel is disabled. "
+
"Enable by setting `VLLM_ROCM_USE_AITER=1` "
+
"and `VLLM_ROCM_USE_AITER_LINEAR=1`. "
+
"`VLLM_ROCM_USE_AITER_LINEAR` default is True."
)
if
not
c
.
input_symmetric
:
return
(
False
,
"AiterScaledMMLinearKernel only supports symmetric "
+
"quantization."
)
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
super
().
process_weights_after_loading
(
layer
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
`AiterScaledMMLinearKernel` implements a fused version of
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting.
Currently only support per-tensor-per-tensor GEMM
and per-token-per-channel GEMM through AITER
w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support
ATIER block scaled GEMM and mix-precision GEMM.
"""
w_q
,
w_s
,
i_s
,
i_zp
,
azp_adj
=
self
.
_get_weight_params
(
layer
)
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
symmetric
=
azp_adj
is
None
assert
symmetric
,
(
"AiterScaledMMLinearKernel only supports"
" symmetric quantization."
)
x_q
,
x_s
,
x_zp
=
ops
.
scaled_int8_quant
(
x
,
i_s
,
i_zp
,
symmetric
=
symmetric
)
assert
x_zp
is
None
,
(
"AiterScaledMMLinearKernel only supports"
" symmetric quantization."
)
out_dtype
=
x
.
dtype
assert
(
w_q
.
shape
[
0
]
%
16
==
0
and
w_q
.
shape
[
1
]
%
16
==
0
)
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
assert
bias
is
None
or
bias
.
shape
[
0
]
==
w_q
.
shape
[
1
]
and
bias
.
dtype
==
out_dtype
m
=
x_q
.
shape
[
0
]
# a
n
=
w_q
.
shape
[
1
]
# b
per_tensor_scale_a
=
(
x_s
.
numel
()
==
1
)
per_tensor_scale_b
=
(
w_s
.
numel
()
==
1
)
per_token_scale_a
=
(
x_s
.
numel
()
==
m
)
per_channel_scale_b
=
(
w_s
.
numel
()
==
n
)
# @TODO:
# Maybe broadcast the per-tensor-scale into per-channel-scale
# if one of the scale is a per-channel-scale.
# For now, it only supports:
# - per-tensor-per-tensor a8w8 scaled GEMM, and
# - per-token-per-channel a8w8 scaled GEMM
assert
((
per_tensor_scale_a
and
per_tensor_scale_b
)
or
(
per_token_scale_a
and
per_channel_scale_b
)),
(
"Currently only support per-tensor-per-tensor GEMM "
+
" and per-token-per-channel GEMM through AITER"
" w8a8 scaled gemm. `AiterScaledMMLinearKernel` "
+
"does not support AITER block scaled GEMM."
)
from
aiter
import
gemm_a8w8_CK
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return
gemm_a8w8_CK
(
x_q
,
w_q
.
t
(),
x_s
,
w_s
,
bias
).
to
(
out_dtype
)
vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py
View file @
fcfc474d
...
@@ -97,7 +97,8 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
...
@@ -97,7 +97,8 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
block_size
=-
1
,
block_size
=-
1
,
int4_weight
=
False
,
int4_weight
=
False
,
quantize_activation
=
True
)
quantize_activation
=
True
)
# `quantized_matmul` output is fp32, cast it down to bf16 for perf
out
=
out
.
to
(
x
.
dtype
)
# Explicitly capture control flow to make dynamo happy.
# Explicitly capture control flow to make dynamo happy.
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
return
cond
(
bias
is
None
,
self
.
no_add_bias
,
self
.
add_bias
,
[
out
,
bias
])
return
cond
(
bias
is
None
,
self
.
no_add_bias
,
self
.
add_bias
,
[
out
,
bias
])
vllm/model_executor/layers/quantization/moe_wna16.py
View file @
fcfc474d
...
@@ -293,6 +293,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
...
@@ -293,6 +293,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
moe_ep_size
:
Optional
[
int
]
=
None
,
moe_ep_size
:
Optional
[
int
]
=
None
,
...
@@ -316,23 +317,24 @@ class MoeWNA16Method(FusedMoEMethodBase):
...
@@ -316,23 +317,24 @@ class MoeWNA16Method(FusedMoEMethodBase):
weight_bits
=
self
.
quant_config
.
weight_bits
weight_bits
=
self
.
quant_config
.
weight_bits
has_zp
=
self
.
quant_config
.
has_zp
has_zp
=
self
.
quant_config
.
has_zp
return
fused_experts
(
x
,
return
fused_experts
(
layer
.
w13_qweight
,
x
,
layer
.
w2_qweight
,
layer
.
w13_qweight
,
topk_weights
=
topk_weights
,
layer
.
w2_qweight
,
topk_ids
=
topk_ids
,
topk_weights
=
topk_weights
,
inplace
=
True
,
topk_ids
=
topk_ids
,
use_int4_w4a16
=
weight_bits
==
4
,
inplace
=
True
,
use_int8_w8a16
=
weight_bits
==
8
,
use_int4_w4a16
=
weight_bits
==
4
,
global_num_experts
=
global_num_experts
,
use_int8_w8a16
=
weight_bits
==
8
,
expert_map
=
expert_map
,
global_num_experts
=
global_num_experts
,
w1_scale
=
layer
.
w13_scales
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
w2_scale
=
layer
.
w2_scales
,
expert_map
=
expert_map
,
w1_zp
=
layer
.
w13_qzeros
if
has_zp
else
None
,
w1_scale
=
layer
.
w13_scales
,
w2_zp
=
layer
.
w2_qzeros
if
has_zp
else
None
,
w2_scale
=
layer
.
w2_scales
,
block_shape
=
[
0
,
layer
.
group_size
],
w1_zp
=
layer
.
w13_qzeros
if
has_zp
else
None
,
use_nn_moe
=
False
,
w2_zp
=
layer
.
w2_qzeros
if
has_zp
else
None
,
)
block_shape
=
[
0
,
layer
.
group_size
],
use_nn_moe
=
False
)
@
staticmethod
@
staticmethod
def
get_weight_loader
(
layer
,
weight_loader
):
def
get_weight_loader
(
layer
,
weight_loader
):
...
...
vllm/model_executor/layers/quantization/quark/quark_moe.py
View file @
fcfc474d
...
@@ -202,6 +202,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
...
@@ -202,6 +202,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
@@ -217,16 +219,18 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
...
@@ -217,16 +219,18 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
e_score_correction_bias
=
e_score_correction_bias
)
return
fused_experts
(
x
,
return
fused_experts
(
layer
.
w13_weight
,
x
,
layer
.
w2_weight
,
layer
.
w13_weight
,
topk_weights
=
topk_weights
,
layer
.
w2_weight
,
topk_ids
=
topk_ids
,
topk_weights
=
topk_weights
,
inplace
=
True
,
topk_ids
=
topk_ids
,
use_fp8_w8a8
=
True
,
inplace
=
True
,
global_num_experts
=
global_num_experts
,
use_fp8_w8a8
=
True
,
expert_map
=
expert_map
,
global_num_experts
=
global_num_experts
,
w1_scale
=
layer
.
w13_weight_scale
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
w2_scale
=
layer
.
w2_weight_scale
,
expert_map
=
expert_map
,
a1_scale
=
layer
.
w13_input_scale
,
w1_scale
=
layer
.
w13_weight_scale
,
a2_scale
=
layer
.
w2_input_scale
)
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
)
vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
View file @
fcfc474d
...
@@ -22,6 +22,7 @@ class QuarkW8A8Fp8(QuarkScheme):
...
@@ -22,6 +22,7 @@ class QuarkW8A8Fp8(QuarkScheme):
self
.
qscheme
=
qscheme
self
.
qscheme
=
qscheme
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
fp8_linear
=
Fp8LinearOp
(
use_per_token_if_dynamic
=
True
)
self
.
fp8_linear
=
Fp8LinearOp
(
use_per_token_if_dynamic
=
True
)
self
.
out_dtype
=
torch
.
get_default_dtype
()
@
classmethod
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
...
@@ -134,5 +135,6 @@ class QuarkW8A8Fp8(QuarkScheme):
...
@@ -134,5 +135,6 @@ class QuarkW8A8Fp8(QuarkScheme):
return
self
.
fp8_linear
.
apply
(
input
=
x
,
return
self
.
fp8_linear
.
apply
(
input
=
x
,
weight
=
layer
.
weight
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
weight_scale
=
layer
.
weight_scale
,
out_dtype
=
self
.
out_dtype
,
input_scale
=
layer
.
input_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
)
bias
=
bias
)
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
fcfc474d
...
@@ -51,6 +51,16 @@ def cutlass_block_fp8_supported() -> bool:
...
@@ -51,6 +51,16 @@ def cutlass_block_fp8_supported() -> bool:
return
ops
.
cutlass_scaled_mm_supports_block_fp8
(
capability
)
return
ops
.
cutlass_scaled_mm_supports_block_fp8
(
capability
)
def
cutlass_group_gemm_supported
()
->
bool
:
if
not
current_platform
.
is_cuda
():
return
False
capability_tuple
=
current_platform
.
get_device_capability
()
capability
=
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
()
return
ops
.
cutlass_group_gemm_supported
(
capability
)
CUTLASS_FP8_SUPPORTED
=
cutlass_fp8_supported
()
CUTLASS_FP8_SUPPORTED
=
cutlass_fp8_supported
()
CUTLASS_BLOCK_FP8_SUPPORTED
=
cutlass_block_fp8_supported
()
CUTLASS_BLOCK_FP8_SUPPORTED
=
cutlass_block_fp8_supported
()
...
@@ -154,6 +164,7 @@ class Fp8LinearOp:
...
@@ -154,6 +164,7 @@ class Fp8LinearOp:
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
out_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
input_scale_ub
:
Optional
[
torch
.
Tensor
]
=
None
,
input_scale_ub
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -173,8 +184,13 @@ class Fp8LinearOp:
...
@@ -173,8 +184,13 @@ class Fp8LinearOp:
if
use_per_token_if_dynamic
is
None
:
if
use_per_token_if_dynamic
is
None
:
use_per_token_if_dynamic
=
self
.
use_per_token_if_dynamic
use_per_token_if_dynamic
=
self
.
use_per_token_if_dynamic
if
out_dtype
is
None
:
out_dtype
=
input
.
dtype
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if
self
.
cutlass_fp8_supported
:
if
self
.
cutlass_fp8_supported
:
assert
input
.
dtype
!=
current_platform
.
fp8_dtype
(
),
"FP8 input to cutlass is not currently implemented"
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input_2d
,
input_2d
,
input_scale
,
input_scale
,
...
@@ -184,7 +200,7 @@ class Fp8LinearOp:
...
@@ -184,7 +200,7 @@ class Fp8LinearOp:
# Fused GEMM_DQ
# Fused GEMM_DQ
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
weight
,
weight
,
out_dtype
=
inp
ut
.
dtype
,
out_dtype
=
o
ut
_
dtype
,
scale_a
=
x_scale
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
scale_b
=
weight_scale
,
bias
=
bias
)
bias
=
bias
)
...
@@ -193,12 +209,15 @@ class Fp8LinearOp:
...
@@ -193,12 +209,15 @@ class Fp8LinearOp:
# torch.scaled_mm supports per tensor weights + activations only
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
# so fallback to naive if per channel or per token
else
:
else
:
# Maybe apply padding to output, see comment in __init__
if
input
.
dtype
!=
current_platform
.
fp8_dtype
():
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
# Maybe apply padding to output, see comment in __init__
input_2d
,
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input_scale
,
input_2d
,
num_token_padding
=
self
.
output_padding
,
input_scale
,
use_per_token_if_dynamic
=
use_per_token_if_dynamic
)
num_token_padding
=
self
.
output_padding
,
use_per_token_if_dynamic
=
use_per_token_if_dynamic
)
else
:
qinput
,
x_scale
=
input_2d
,
input_scale
per_tensor_weights
=
(
weight_scale
.
numel
()
==
1
)
per_tensor_weights
=
(
weight_scale
.
numel
()
==
1
)
per_tensor_activations
=
(
x_scale
.
numel
()
==
1
)
per_tensor_activations
=
(
x_scale
.
numel
()
==
1
)
...
@@ -207,7 +226,7 @@ class Fp8LinearOp:
...
@@ -207,7 +226,7 @@ class Fp8LinearOp:
# Fused GEMM_DQ
# Fused GEMM_DQ
output
=
torch
.
_scaled_mm
(
qinput
,
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
weight
,
out_dtype
=
inp
ut
.
dtype
,
out_dtype
=
o
ut
_
dtype
,
scale_a
=
x_scale
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
scale_b
=
weight_scale
,
bias
=
bias
)
bias
=
bias
)
...
@@ -231,7 +250,7 @@ class Fp8LinearOp:
...
@@ -231,7 +250,7 @@ class Fp8LinearOp:
# Fused GEMM_DQ Rowwise GEMM
# Fused GEMM_DQ Rowwise GEMM
output
=
torch
.
_scaled_mm
(
qinput
,
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
weight
,
out_dtype
=
inp
ut
.
dtype
,
out_dtype
=
o
ut
_
dtype
,
scale_a
=
x_scale
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
.
t
(),
scale_b
=
weight_scale
.
t
(),
bias
=
bias
)
bias
=
bias
)
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
fcfc474d
...
@@ -851,6 +851,70 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
...
@@ -851,6 +851,70 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
return
new_freqs
return
new_freqs
class
Llama4VisionRotaryEmbedding
(
RotaryEmbedding
):
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
):
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
inv_freqs
=
super
().
_compute_inv_freq
(
base
)
inv_freqs
=
inv_freqs
[:(
self
.
rotary_dim
//
2
)]
return
inv_freqs
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
# self.max_position_embeddings here is number of image patches
# i.e. (image_size // patch_size) ** 2
num_patches
=
self
.
max_position_embeddings
img_idx
=
torch
.
arange
(
num_patches
,
dtype
=
torch
.
int32
)
\
.
reshape
(
num_patches
,
1
)
img_idx
=
torch
.
cat
([
img_idx
,
img_idx
[:
1
]],
dim
=
0
)
img_idx
[
-
1
,
-
1
]
=
-
2
# set to ID_CLS_TOKEN
num_patches_single_dim
=
int
(
math
.
sqrt
(
num_patches
))
frequencies_x
=
img_idx
%
num_patches_single_dim
frequencies_y
=
img_idx
//
num_patches_single_dim
freqs_x
=
((
frequencies_x
+
1
)[...,
None
]
*
inv_freq
[
None
,
None
,
:]).
repeat_interleave
(
2
,
dim
=-
1
)
freqs_y
=
((
frequencies_y
+
1
)[...,
None
]
*
inv_freq
[
None
,
None
,
:]).
repeat_interleave
(
2
,
dim
=-
1
)
freqs
=
torch
.
cat
([
freqs_x
,
freqs_y
],
dim
=-
1
).
float
().
contiguous
()[...,
::
2
]
freqs
=
freqs
.
masked_fill
(
img_idx
.
reshape
(
-
1
,
1
,
1
)
<
0
,
0
)
cache
=
torch
.
view_as_complex
(
torch
.
stack
([
torch
.
cos
(
freqs
),
torch
.
sin
(
freqs
)],
dim
=-
1
))
return
cache
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
self
.
cos_sin_cache
:
torch
.
Tensor
=
self
.
cos_sin_cache
.
to
(
query
.
device
)
query_
=
torch
.
view_as_complex
(
query
.
float
().
reshape
(
*
query
.
shape
[:
-
1
],
-
1
,
2
))
key_
=
torch
.
view_as_complex
(
key
.
float
().
reshape
(
*
key
.
shape
[:
-
1
],
-
1
,
2
))
broadcast_shape
=
[
d
if
i
==
1
or
i
==
(
query_
.
ndim
-
1
)
else
1
for
i
,
d
in
enumerate
(
query_
.
shape
)
]
freqs_ci
=
self
.
cos_sin_cache
.
view
(
*
broadcast_shape
)
query_out
=
torch
.
view_as_real
(
query_
*
freqs_ci
).
flatten
(
3
)
key_out
=
torch
.
view_as_real
(
key_
*
freqs_ci
).
flatten
(
3
)
return
query_out
.
type_as
(
query
),
key_out
.
type_as
(
key
)
class
MRotaryEmbedding
(
RotaryEmbedding
):
class
MRotaryEmbedding
(
RotaryEmbedding
):
"""Rotary Embedding with Multimodal Sections."""
"""Rotary Embedding with Multimodal Sections."""
...
@@ -1130,6 +1194,10 @@ def get_rope(
...
@@ -1130,6 +1194,10 @@ def get_rope(
scaling_factor
,
low_freq_factor
,
scaling_factor
,
low_freq_factor
,
high_freq_factor
,
high_freq_factor
,
original_max_position
)
original_max_position
)
elif
scaling_type
==
"mllama4"
:
rotary_emb
=
Llama4VisionRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
)
elif
scaling_type
==
"default"
:
elif
scaling_type
==
"default"
:
if
"mrope_section"
in
rope_scaling
:
if
"mrope_section"
in
rope_scaling
:
rotary_emb
=
MRotaryEmbedding
(
rotary_emb
=
MRotaryEmbedding
(
...
...
vllm/model_executor/layers/vocab_parallel_embedding.py
View file @
fcfc474d
...
@@ -250,7 +250,7 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -250,7 +250,7 @@ class VocabParallelEmbedding(torch.nn.Module):
# If we are making an embedding layer, then our quantization linear
# If we are making an embedding layer, then our quantization linear
# method must implement the embedding operation. If we are another
# method must implement the embedding operation. If we are another
# layer type like ParallelLMHead, this is not important.
# layer type like ParallelLMHead, this is not important.
is_embedding_layer
=
type
(
self
.
__class__
)
is
VocabParallelEmbedding
is_embedding_layer
=
type
(
self
)
is
VocabParallelEmbedding
quant_method_implements_embedding
=
method_has_implemented_embedding
(
quant_method_implements_embedding
=
method_has_implemented_embedding
(
type
(
quant_method
))
type
(
quant_method
))
if
is_embedding_layer
and
not
quant_method_implements_embedding
:
if
is_embedding_layer
and
not
quant_method_implements_embedding
:
...
...
vllm/model_executor/model_loader/loader.py
View file @
fcfc474d
...
@@ -1261,6 +1261,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -1261,6 +1261,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
pack_ratio
)
pack_ratio
)
offsets
=
np
.
concatenate
(([
0
],
np
.
cumsum
(
num_elements
)))
offsets
=
np
.
concatenate
(([
0
],
np
.
cumsum
(
num_elements
)))
# Make torch infer_schema happy
offsets
=
torch
.
tensor
(
offsets
).
cpu
()
set_weight_attrs
(
param
,
{
"bnb_shard_offsets"
:
offsets
})
set_weight_attrs
(
param
,
{
"bnb_shard_offsets"
:
offsets
})
if
load_8bit
:
if
load_8bit
:
...
...
vllm/model_executor/model_loader/utils.py
View file @
fcfc474d
...
@@ -37,16 +37,13 @@ def is_transformers_impl_compatible(
...
@@ -37,16 +37,13 @@ def is_transformers_impl_compatible(
mod
=
module
or
getattr
(
transformers
,
arch
,
None
)
mod
=
module
or
getattr
(
transformers
,
arch
,
None
)
if
mod
is
None
:
if
mod
is
None
:
return
False
return
False
if
hasattr
(
mod
,
"supports_backend"
):
return
mod
.
is_backend_compatible
()
return
mod
.
is_backend_compatible
()
else
:
return
mod
.
_supports_flex_attn
def
resolve_transformers_
fallback
(
model_config
:
ModelConfig
,
def
resolve_transformers_
arch
(
model_config
:
ModelConfig
,
architectures
:
list
[
str
]):
architectures
:
list
[
str
]):
for
i
,
arch
in
enumerate
(
architectures
):
for
i
,
arch
in
enumerate
(
architectures
):
if
arch
==
"Transformers
Model
"
:
if
arch
==
"Transformers
ForCausalLM
"
:
continue
continue
auto_map
:
dict
[
str
,
str
]
=
getattr
(
model_config
.
hf_config
,
"auto_map"
,
auto_map
:
dict
[
str
,
str
]
=
getattr
(
model_config
.
hf_config
,
"auto_map"
,
None
)
or
dict
()
None
)
or
dict
()
...
@@ -70,7 +67,7 @@ def resolve_transformers_fallback(model_config: ModelConfig,
...
@@ -70,7 +67,7 @@ def resolve_transformers_fallback(model_config: ModelConfig,
raise
ValueError
(
raise
ValueError
(
f
"The Transformers implementation of
{
arch
}
is not "
f
"The Transformers implementation of
{
arch
}
is not "
"compatible with vLLM."
)
"compatible with vLLM."
)
architectures
[
i
]
=
"Transformers
Model
"
architectures
[
i
]
=
"Transformers
ForCausalLM
"
if
model_config
.
model_impl
==
ModelImpl
.
AUTO
:
if
model_config
.
model_impl
==
ModelImpl
.
AUTO
:
if
not
is_transformers_impl_compatible
(
arch
,
custom_model_module
):
if
not
is_transformers_impl_compatible
(
arch
,
custom_model_module
):
raise
ValueError
(
raise
ValueError
(
...
@@ -81,7 +78,7 @@ def resolve_transformers_fallback(model_config: ModelConfig,
...
@@ -81,7 +78,7 @@ def resolve_transformers_fallback(model_config: ModelConfig,
"%s has no vLLM implementation, falling back to Transformers "
"%s has no vLLM implementation, falling back to Transformers "
"implementation. Some features may not be supported and "
"implementation. Some features may not be supported and "
"performance may not be optimal."
,
arch
)
"performance may not be optimal."
,
arch
)
architectures
[
i
]
=
"Transformers
Model
"
architectures
[
i
]
=
"Transformers
ForCausalLM
"
return
architectures
return
architectures
...
@@ -140,8 +137,7 @@ def get_model_architecture(
...
@@ -140,8 +137,7 @@ def get_model_architecture(
for
arch
in
architectures
)
for
arch
in
architectures
)
if
(
not
is_vllm_supported
if
(
not
is_vllm_supported
or
model_config
.
model_impl
==
ModelImpl
.
TRANSFORMERS
):
or
model_config
.
model_impl
==
ModelImpl
.
TRANSFORMERS
):
architectures
=
resolve_transformers_fallback
(
model_config
,
architectures
=
resolve_transformers_arch
(
model_config
,
architectures
)
architectures
)
model_cls
,
arch
=
ModelRegistry
.
resolve_model_cls
(
architectures
)
model_cls
,
arch
=
ModelRegistry
.
resolve_model_cls
(
architectures
)
if
model_config
.
task
==
"embed"
:
if
model_config
.
task
==
"embed"
:
...
...
vllm/model_executor/models/adapters.py
View file @
fcfc474d
...
@@ -99,16 +99,17 @@ def _create_pooling_model_cls(
...
@@ -99,16 +99,17 @@ def _create_pooling_model_cls(
mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"model."
:
""
})
mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"model."
:
""
})
weights
=
mapper
.
apply
(
weights
)
weights
=
mapper
.
apply
(
weights
)
self
.
model
.
load_weights
(
weights
)
loaded_params
=
self
.
model
.
load_weights
(
weights
)
return
loaded_params
=
{
f
"model.
{
name
}
"
for
name
in
loaded_params
}
return
loaded_params
# For most other models
# For most other models
if
hasattr
(
orig_cls
,
"load_weights"
):
if
hasattr
(
orig_cls
,
"load_weights"
):
orig_cls
.
load_weights
(
self
,
weights
)
# type: ignore
return
orig_cls
.
load_weights
(
self
,
weights
)
# type: ignore
# Fallback
# Fallback
else
:
else
:
loader
=
AutoWeightsLoader
(
self
)
loader
=
AutoWeightsLoader
(
self
)
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
return
ModelForPooling
# type: ignore
return
ModelForPooling
# type: ignore
...
...
vllm/model_executor/models/aya_vision.py
0 → 100644
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0 Adapted from
# https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision
from
functools
import
cached_property
from
typing
import
(
Iterable
,
Literal
,
Mapping
,
Optional
,
Sequence
,
Set
,
Tuple
,
TypedDict
,
Union
,
cast
)
import
torch
from
torch
import
nn
from
transformers
import
BatchFeature
,
GotOcr2ImageProcessor
from
transformers.activations
import
ACT2FN
from
transformers.image_processing_utils
import
get_size_dict
from
transformers.models.aya_vision
import
AyaVisionConfig
from
transformers.models.aya_vision.processing_aya_vision
import
(
AyaVisionProcessor
)
from
transformers.models.got_ocr2.image_processing_got_ocr2
import
(
get_optimal_tiled_canvas
)
from
vllm.config
import
VllmConfig
from
vllm.jsontree
import
json_map_leaves
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
MultiModalFieldConfig
,
PromptReplacement
,
PromptUpdate
,
encode_tokens
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.siglip
import
SiglipVisionModel
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
scatter_patch_features
,
select_patch_features
class
AyaVisionImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
pixel_values
:
torch
.
Tensor
"""
Shape: `(num_patches_total, num_channels, height, width)`
`num_patches_total` is the total number of patches over each image over each
prompt in the batch.
"""
num_patches
:
torch
.
Tensor
"""Shape: `(batch_size * num_images)`"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class
AyaVisionMultiModalProjector
(
nn
.
Module
):
def
__init__
(
self
,
config
:
AyaVisionConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
downsample_factor
=
config
.
downsample_factor
self
.
alignment_intermediate_size
=
getattr
(
config
,
"alignment_intermediate_size"
,
config
.
text_config
.
hidden_size
)
self
.
layernorm
=
nn
.
LayerNorm
(
config
.
vision_config
.
hidden_size
*
(
config
.
downsample_factor
**
2
),
eps
=
config
.
adapter_layer_norm_eps
)
self
.
linear_1
=
nn
.
Linear
(
config
.
vision_config
.
hidden_size
*
(
config
.
downsample_factor
**
2
),
self
.
alignment_intermediate_size
,
bias
=
True
,
)
self
.
act
=
ACT2FN
[
"silu"
]
# SwiGLU uses SiLU activation
# For SwiGLU, project down to half size since we split intermediate dim
self
.
linear_2
=
nn
.
Linear
(
self
.
alignment_intermediate_size
//
2
,
config
.
text_config
.
hidden_size
,
bias
=
True
)
def
forward
(
self
,
image_features
:
torch
.
Tensor
)
->
torch
.
Tensor
:
image_features
=
self
.
pixel_shuffle
(
image_features
)
image_features
=
self
.
layernorm
(
image_features
)
hidden_states
=
self
.
linear_1
(
image_features
)
# Split along last dimension and apply SwiGLU
x
,
gate
=
hidden_states
.
chunk
(
2
,
dim
=-
1
)
hidden_states
=
self
.
act
(
gate
)
*
x
hidden_states
=
self
.
linear_2
(
hidden_states
)
return
hidden_states
def
pixel_shuffle
(
self
,
image_features
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# B, S, D
batch_size
,
seq_length
,
_
=
image_features
.
shape
height
=
width
=
int
(
seq_length
**
0.5
)
image_features
=
image_features
.
reshape
(
image_features
.
shape
[
0
],
width
,
height
,
-
1
)
channels
=
image_features
.
shape
[
-
1
]
image_features
=
image_features
.
reshape
(
batch_size
,
width
,
int
(
height
/
self
.
downsample_factor
),
int
(
channels
*
self
.
downsample_factor
))
image_features
=
image_features
.
permute
(
0
,
2
,
1
,
3
)
image_features
=
image_features
.
reshape
(
batch_size
,
int
(
height
/
self
.
downsample_factor
),
int
(
width
/
self
.
downsample_factor
),
-
1
)
image_features
=
image_features
.
permute
(
0
,
2
,
1
,
3
)
return
image_features
class
AyaVisionProcessingInfo
(
BaseProcessingInfo
):
def
get_hf_config
(
self
)
->
AyaVisionConfig
:
return
self
.
ctx
.
get_hf_config
(
AyaVisionConfig
)
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
AyaVisionProcessor
:
return
self
.
ctx
.
get_hf_processor
(
AyaVisionProcessor
,
**
kwargs
)
def
get_image_processor
(
self
)
->
GotOcr2ImageProcessor
:
return
self
.
get_hf_processor
().
image_processor
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
self
.
get_max_image_tokens
()}
def
get_max_image_tokens
(
self
)
->
int
:
hf_processor
=
self
.
get_hf_processor
()
image_processor
=
hf_processor
.
image_processor
image_size
=
self
.
get_image_size_with_most_features
()
tokenizer
=
hf_processor
.
tokenizer
num_patches
=
self
.
get_num_patches
(
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
size
=
image_processor
.
size
,
min_patches
=
image_processor
.
min_patches
,
max_patches
=
image_processor
.
max_patches
)
image_string
=
hf_processor
.
_prompt_split_image
(
num_patches
)
x
=
encode_tokens
(
tokenizer
,
image_string
,
add_special_tokens
=
False
,
)
return
len
(
x
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
}
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
image_processor
=
self
.
get_image_processor
()
height
=
image_processor
.
size
[
'height'
]
width
=
image_processor
.
size
[
'width'
]
max_patches
=
image_processor
.
max_patches
return
ImageSize
(
height
=
height
*
max_patches
,
width
=
width
*
max_patches
)
def
get_num_patches
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
size
:
dict
,
min_patches
:
int
,
max_patches
:
int
)
->
int
:
"""
Calculate the number of patches needed for a given image based on size
constraints. This method replicates and adjusts the logic from:
transformers/models/got_ocr2/image_processing_got_ocr2
"""
size
=
get_size_dict
(
size
,
default_to_square
=
False
)
num_columns
,
num_rows
=
get_optimal_tiled_canvas
(
(
image_height
,
image_width
),
(
size
[
"height"
],
size
[
"width"
]),
min_patches
,
max_patches
)
num_blocks
=
num_columns
*
num_rows
return
num_blocks
if
num_blocks
==
1
else
num_blocks
+
1
class
AyaVisionDummyInputsBuilder
(
BaseDummyInputsBuilder
[
AyaVisionProcessingInfo
]):
def
get_dummy_processor_inputs
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
processor
=
self
.
info
.
get_hf_processor
()
image_token
=
processor
.
image_token
num_images
=
mm_counts
.
get
(
"image"
,
0
)
image_size
=
\
self
.
info
.
get_image_size_with_most_features
()
mm_data
=
{
"image"
:
self
.
_get_dummy_images
(
width
=
image_size
.
width
,
height
=
image_size
.
height
,
num_images
=
num_images
)
}
return
ProcessorInputs
(
prompt_text
=
image_token
*
num_images
,
mm_data
=
mm_data
,
)
class
AyaVisionMultiModalProcessor
(
BaseMultiModalProcessor
[
AyaVisionProcessingInfo
]):
def
_call_hf_processor
(
self
,
prompt
:
str
,
mm_data
:
Mapping
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
processed_outputs
=
super
().
_call_hf_processor
(
prompt
,
mm_data
,
mm_kwargs
,
)
hf_processor
=
self
.
info
.
get_hf_processor
(
**
mm_kwargs
)
image_processor
=
hf_processor
.
image_processor
hf_config
=
self
.
info
.
get_hf_config
()
# HF processor pops the `num_patches` kwarg, which is needed by vLLM
if
(
images
:
=
mm_data
.
get
(
"images"
))
is
not
None
and
'<image>'
in
prompt
:
assert
isinstance
(
images
,
list
)
parsed_images
=
(
self
.
_get_data_parser
().
parse_mm_data
({
"image"
:
images
}).
get_items
(
"image"
,
ImageProcessorItems
))
image_sizes
=
[
parsed_images
.
get_image_size
(
i
)
for
i
in
range
(
len
(
parsed_images
))
]
num_patches
=
[
self
.
info
.
get_num_patches
(
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
size
=
image_processor
.
size
,
min_patches
=
image_processor
.
min_patches
,
max_patches
=
image_processor
.
max_patches
)
for
image_size
in
image_sizes
]
image_tokens_list
=
[
hf_processor
.
_prompt_split_image
(
num_patch
)
for
num_patch
in
num_patches
]
tokenizer
=
self
.
info
.
get_tokenizer
()
image_token_ids
=
[
tokenizer
.
encode
(
image_tokens
,
add_special_tokens
=
False
)
for
image_tokens
in
image_tokens_list
]
embed_is_patch
=
[
torch
.
tensor
(
image_repl_tokens
)
==
hf_config
.
image_token_index
for
image_repl_tokens
in
image_token_ids
]
processed_outputs
[
"embed_is_patch"
]
=
embed_is_patch
processed_outputs
[
"num_patches"
]
=
torch
.
tensor
(
num_patches
)
return
processed_outputs
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
num_patches
=
hf_inputs
.
get
(
"num_patches"
,
torch
.
empty
(
0
))
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
num_patches
),
num_patches
=
MultiModalFieldConfig
.
batched
(
"image"
),
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
def
_get_prompt_updates
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargs
,
)
->
Sequence
[
PromptUpdate
]:
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
image_token
=
hf_processor
.
image_token
image_processor
=
hf_processor
.
image_processor
def
get_replacement
(
item_idx
:
int
):
images
:
ImageProcessorItems
=
mm_items
.
get
(
"image"
,
ImageProcessorItems
)
image_size
:
ImageSize
=
images
.
get_image_size
(
item_idx
)
num_patches
=
self
.
info
.
get_num_patches
(
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
size
=
image_processor
.
size
,
min_patches
=
image_processor
.
min_patches
,
max_patches
=
image_processor
.
max_patches
)
return
hf_processor
.
_prompt_split_image
(
num_patches
=
num_patches
)
return
[
PromptReplacement
(
modality
=
"image"
,
target
=
image_token
,
replacement
=
get_replacement
,
)
]
def
_get_num_hidden_layers
(
hf_config
:
AyaVisionConfig
)
->
int
:
feature_layers
=
hf_config
.
vision_feature_layer
num_hidden_layers
=
hf_config
.
vision_config
.
num_hidden_layers
# If we have one feature layer, initialize up to that layer
if
isinstance
(
feature_layers
,
int
):
return
_get_layer_index
(
feature_layers
,
num_hidden_layers
)
# If we have multiple feature layers, initialize up to the deepest m
elif
isinstance
(
feature_layers
,
(
list
,
tuple
)):
return
max
(
_get_layer_index
(
idx
,
num_hidden_layers
)
for
idx
in
feature_layers
)
raise
TypeError
(
f
"vision_layer_feature type:
{
type
(
feature_layers
)
}
"
" is not supported"
)
def
_get_layer_index
(
feature_layer_index
:
int
,
num_hidden_layers
:
int
)
->
int
:
if
feature_layer_index
<
0
:
return
num_hidden_layers
+
feature_layer_index
+
1
return
feature_layer_index
@
MULTIMODAL_REGISTRY
.
register_processor
(
AyaVisionMultiModalProcessor
,
info
=
AyaVisionProcessingInfo
,
dummy_inputs
=
AyaVisionDummyInputsBuilder
)
class
AyaVisionForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
:
AyaVisionConfig
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
num_hidden_layers
=
_get_num_hidden_layers
(
config
)
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
multimodal_config
=
multimodal_config
self
.
vision_tower
=
SiglipVisionModel
(
config
.
vision_config
,
quant_config
,
num_hidden_layers_override
=
num_hidden_layers
,
prefix
=
maybe_prefix
(
prefix
,
"vision_model"
))
self
.
vocab_size
=
config
.
text_config
.
vocab_size
self
.
multi_modal_projector
=
AyaVisionMultiModalProjector
(
config
)
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
),
# Cohere2ForCausalLM and CohereForCausalLM are the same on vllm
architectures
=
[
"Cohere2ForCausalLM"
])
@
property
def
dtype
(
self
):
return
next
(
self
.
parameters
()).
dtype
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
def
_image_pixels_to_features
(
self
,
vision_tower
:
SiglipVisionModel
,
pixel_values
:
torch
.
Tensor
,
**
kwargs
)
->
torch
.
Tensor
:
target_dtype
=
vision_tower
.
get_input_embeddings
().
weight
.
dtype
image_features
=
vision_tower
(
pixel_values
.
to
(
dtype
=
target_dtype
),
**
kwargs
)
def
select_features
(
leaf
:
torch
.
Tensor
):
return
self
.
_select_image_features
(
leaf
,
strategy
=
self
.
config
.
vision_feature_select_strategy
,
)
return
cast
(
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]],
json_map_leaves
(
select_features
,
image_features
),
)
def
_select_image_features
(
self
,
image_features
:
torch
.
Tensor
,
*
,
strategy
:
str
)
->
torch
.
Tensor
:
if
strategy
==
"default"
:
return
image_features
[:,
1
:]
elif
strategy
==
"full"
:
return
image_features
raise
ValueError
(
f
"Unexpected select feature strategy:
{
strategy
}
"
)
def
_process_image_input
(
self
,
image_input
:
AyaVisionImagePixelInputs
,
**
kwargs
)
->
list
[
torch
.
Tensor
]:
assert
self
.
vision_tower
is
not
None
pixel_values
=
image_input
[
"pixel_values"
]
num_patches
=
image_input
[
"num_patches"
]
image_features
=
self
.
_image_pixels_to_features
(
self
.
vision_tower
,
pixel_values
=
pixel_values
)
image_embeds
=
self
.
multi_modal_projector
(
image_features
)
return
[
e
.
flatten
(
0
,
2
)
for
e
in
image_embeds
.
split
(
num_patches
.
tolist
())
]
def
_validate_pixel_values
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
h
=
w
=
self
.
config
.
vision_config
.
image_size
expected_dims
=
(
3
,
h
,
w
)
def
_validate_shape
(
d
:
torch
.
Tensor
):
if
d
.
shape
!=
expected_dims
:
raise
ValueError
(
"The expected shape of pixel values per image per batch "
f
"is
{
expected_dims
}
. You supplied
{
tuple
(
d
.
shape
)
}
."
)
for
d
in
data
:
_validate_shape
(
d
)
return
data
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
AyaVisionImagePixelInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
num_patches
=
kwargs
.
pop
(
"num_patches"
,
None
)
embed_is_patch
=
kwargs
.
pop
(
"embed_is_patch"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
assert
image_embeds
is
None
,
"Aya Vision does not support image_embeds."
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
if
num_patches
is
not
None
and
not
isinstance
(
num_patches
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of num_patches. "
f
"Got type:
{
type
(
num_patches
)
}
"
)
if
not
isinstance
(
embed_is_patch
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of embed_is_patch. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
)
num_patches
=
flatten_bn
(
num_patches
,
concat
=
True
)
embed_is_patch
=
flatten_bn
(
embed_is_patch
)
return
AyaVisionImagePixelInputs
(
type
=
"pixel_values"
,
pixel_values
=
self
.
_validate_pixel_values
(
pixel_values
),
num_patches
=
num_patches
,
embed_is_patch
=
embed_is_patch
,
)
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
image_features
=
self
.
_process_image_input
(
image_input
,
**
kwargs
)
return
scatter_patch_features
(
image_features
,
image_input
[
"embed_is_patch"
],
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
,
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
=
input_ids
,
inputs_embeds
=
inputs_embeds
,
multimodal_embeddings
=
select_patch_features
(
multimodal_embeddings
),
placeholder_token_id
=
self
.
config
.
image_token_index
)
return
inputs_embeds
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
intermediate_tensors
is
not
None
:
inputs_embeds
=
None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif
inputs_embeds
is
None
:
vision_embeddings
=
self
.
get_multimodal_embeddings
(
**
kwargs
)
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
,
vision_embeddings
)
input_ids
=
None
hidden_states
=
self
.
language_model
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
@
cached_property
def
sampler
(
self
):
if
hasattr
(
self
.
language_model
,
"sampler"
):
return
self
.
language_model
.
sampler
return
get_sampler
()
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
vllm/model_executor/models/baichuan.py
View file @
fcfc474d
...
@@ -51,7 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -51,7 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
,
SupportsQuant
from
.interfaces
import
SupportsLoRA
,
SupportsPP
,
SupportsQuant
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
)
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
...
@@ -301,6 +301,16 @@ class BaiChuanModel(nn.Module):
...
@@ -301,6 +301,16 @@ class BaiChuanModel(nn.Module):
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
return
self
.
embed_tokens
(
input_ids
)
...
@@ -336,86 +346,6 @@ class BaiChuanModel(nn.Module):
...
@@ -336,86 +346,6 @@ class BaiChuanModel(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
class
BaiChuanBaseForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
,
SupportsQuant
):
packed_modules_mapping
=
{
"W_pack"
:
[
"W_pack"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
position_embedding
:
str
=
"ROPE"
,
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
model
=
BaiChuanModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
position_embedding
=
position_embedding
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
@@ -428,17 +358,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
...
@@ -428,17 +358,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
if
name
==
"lm_head.weight"
:
# Unlike Baichuan, Baichuan2 normalizes the head weights.
# Refer to:
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
# Distinguish between Baichuan and Baichuan2 by checking the
# vocab size. This is suggested by
# https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
is_baichuan2
=
self
.
config
.
vocab_size
==
125696
if
is_baichuan2
:
loaded_weight
=
torch
.
nn
.
functional
.
normalize
(
loaded_weight
)
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
...
@@ -464,7 +383,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
...
@@ -464,7 +383,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
loaded_params
.
add
(
name
)
if
self
.
use_llama_nn
and
self
.
quant_method
is
None
:
if
self
.
use_llama_nn
and
self
.
quant_method
is
None
:
lay_key_words
=
[
lay_key_words
=
[
"self_attn.W_pack.weight"
,
"self_attn.W_pack.weight"
,
...
@@ -540,11 +459,101 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
...
@@ -540,11 +459,101 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
zeros_and_scalse_pad
=
torch
.
zeros
(
dim_n
,
pad_group
,
dtype
=
torch
.
int32
).
cuda
()
zeros_and_scalse_pad
=
torch
.
zeros
(
dim_n
,
pad_group
,
dtype
=
torch
.
int32
).
cuda
()
zeros_and_scalse
.
data
=
torch
.
cat
((
zeros_and_scalse
.
data
,
zeros_and_scalse_pad
),
dim
=
1
).
contiguous
()
zeros_and_scalse
.
data
=
torch
.
cat
((
zeros_and_scalse
.
data
,
zeros_and_scalse_pad
),
dim
=
1
).
contiguous
()
qweight_pad
=
torch
.
zeros
(
dim_n
,
int
(
group_size
//
4
),
dtype
=
torch
.
int32
).
cuda
()
qweight_pad
=
torch
.
zeros
(
dim_n
,
int
(
group_size
//
4
),
dtype
=
torch
.
int32
).
cuda
()
qweight
.
data
=
torch
.
cat
((
qweight
.
data
,
qweight_pad
),
dim
=
1
).
contiguous
()
qweight
.
data
=
torch
.
cat
((
qweight
.
data
,
qweight_pad
),
dim
=
1
).
contiguous
()
return
loaded_params
return
loaded_params
class
BaiChuanBaseForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
,
SupportsQuant
):
packed_modules_mapping
=
{
"W_pack"
:
[
"W_pack"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
position_embedding
:
str
=
"ROPE"
,
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
model
=
BaiChuanModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
position_embedding
=
position_embedding
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
self
.
lm_head
.
weight
.
weight_loader
=
self
.
lm_head_weight_loader
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
def
lm_head_weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
):
# Unlike Baichuan, Baichuan2 normalizes the head weights.
# Refer to:
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
# Distinguish between Baichuan and Baichuan2 by checking the
# vocab size. This is suggested by
# https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
is_baichuan2
=
self
.
config
.
vocab_size
==
125696
if
is_baichuan2
:
loaded_weight
=
torch
.
nn
.
functional
.
normalize
(
loaded_weight
)
default_weight_loader
(
param
,
loaded_weight
)
class
BaichuanForCausalLM
(
BaiChuanBaseForCausalLM
):
class
BaichuanForCausalLM
(
BaiChuanBaseForCausalLM
):
"""Baichuan 13B and Baichuan2 7B/13B.
"""Baichuan 13B and Baichuan2 7B/13B.
NOTE: the class name has a lower case 'c'.
NOTE: the class name has a lower case 'c'.
...
...
vllm/model_executor/models/bamba.py
View file @
fcfc474d
...
@@ -34,7 +34,7 @@ from vllm.utils import LayerBlockType
...
@@ -34,7 +34,7 @@ from vllm.utils import LayerBlockType
from
.interfaces
import
(
HasInnerState
,
IsHybrid
,
SupportsLoRA
,
SupportsPP
,
from
.interfaces
import
(
HasInnerState
,
IsHybrid
,
SupportsLoRA
,
SupportsPP
,
SupportsQuant
,
SupportsV0Only
)
SupportsQuant
,
SupportsV0Only
)
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
...
@@ -363,6 +363,58 @@ class BambaModel(nn.Module):
...
@@ -363,6 +363,58 @@ class BambaModel(nn.Module):
hidden_states
,
_
=
self
.
final_layernorm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
final_layernorm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
"A_log"
in
name
:
name
=
name
.
replace
(
"A_log"
,
"A"
)
if
".self_attn."
in
name
:
name
=
name
.
replace
(
".self_attn"
,
""
)
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
BambaForCausalLM
(
nn
.
Module
,
HasInnerState
,
SupportsLoRA
,
SupportsPP
,
class
BambaForCausalLM
(
nn
.
Module
,
HasInnerState
,
SupportsLoRA
,
SupportsPP
,
IsHybrid
,
SupportsV0Only
,
SupportsQuant
):
IsHybrid
,
SupportsV0Only
,
SupportsQuant
):
...
@@ -502,52 +554,5 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
...
@@ -502,52 +554,5 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
loader
=
AutoWeightsLoader
(
self
)
# (param_name, shard_name, shard_id)
return
loader
.
load_weights
(
weights
)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
\ No newline at end of file
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
"A_log"
in
name
:
name
=
name
.
replace
(
"A_log"
,
"A"
)
if
".self_attn."
in
name
:
name
=
name
.
replace
(
".self_attn"
,
""
)
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/bert.py
View file @
fcfc474d
...
@@ -26,7 +26,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
...
@@ -26,7 +26,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from
vllm.transformers_utils.config
import
(
from
vllm.transformers_utils.config
import
(
get_cross_encoder_activation_function
)
get_cross_encoder_activation_function
)
from
.interfaces
import
SupportsCrossEncoding
,
SupportsV0Only
from
.interfaces
import
SupportsCrossEncoding
,
SupportsQuant
,
SupportsV0Only
from
.utils
import
WeightsMapper
,
maybe_prefix
from
.utils
import
WeightsMapper
,
maybe_prefix
...
@@ -313,7 +313,8 @@ class BertOutput(nn.Module):
...
@@ -313,7 +313,8 @@ class BertOutput(nn.Module):
return
hidden_states
return
hidden_states
class
BertModel
(
nn
.
Module
):
class
BertModel
(
nn
.
Module
,
SupportsQuant
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"query"
,
"key"
,
"value"
]}
def
__init__
(
self
,
def
__init__
(
self
,
*
,
*
,
...
@@ -385,7 +386,7 @@ class BertModel(nn.Module):
...
@@ -385,7 +386,7 @@ class BertModel(nn.Module):
return
loaded_params
return
loaded_params
class
BertEmbeddingModel
(
nn
.
Module
,
SupportsV0Only
):
class
BertEmbeddingModel
(
nn
.
Module
,
SupportsV0Only
,
SupportsQuant
):
"""A model that uses Bert to provide embedding functionalities.
"""A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
This class encapsulates the BertModel and provides an interface for
...
@@ -443,7 +444,8 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only):
...
@@ -443,7 +444,8 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only):
softmax
=
False
)
softmax
=
False
)
class
BertForSequenceClassification
(
nn
.
Module
,
SupportsCrossEncoding
):
class
BertForSequenceClassification
(
nn
.
Module
,
SupportsCrossEncoding
,
SupportsQuant
):
"""A model that uses Bert to provide embedding functionalities.
"""A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
This class encapsulates the BertModel and provides an interface for
...
...
vllm/model_executor/models/blip.py
View file @
fcfc474d
...
@@ -16,6 +16,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -16,6 +16,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
.interfaces
import
SupportsQuant
def
get_blip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
def
get_blip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
assert
image_size
%
patch_size
==
0
assert
image_size
%
patch_size
==
0
...
@@ -243,9 +245,10 @@ class BlipEncoder(nn.Module):
...
@@ -243,9 +245,10 @@ class BlipEncoder(nn.Module):
return
hidden_states
return
hidden_states
class
BlipVisionModel
(
nn
.
Module
):
class
BlipVisionModel
(
nn
.
Module
,
SupportsQuant
):
config_class
=
BlipVisionConfig
config_class
=
BlipVisionConfig
main_input_name
=
"pixel_values"
main_input_name
=
"pixel_values"
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
]}
def
__init__
(
def
__init__
(
self
,
self
,
...
...
vllm/model_executor/models/blip2.py
View file @
fcfc474d
...
@@ -24,7 +24,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
...
@@ -24,7 +24,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.blip
import
BlipVisionModel
from
.blip
import
BlipVisionModel
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
,
SupportsQuant
)
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
...
@@ -498,7 +499,8 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
...
@@ -498,7 +499,8 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
@
MULTIMODAL_REGISTRY
.
register_processor
(
Blip2MultiModalProcessor
,
@
MULTIMODAL_REGISTRY
.
register_processor
(
Blip2MultiModalProcessor
,
info
=
Blip2ProcessingInfo
,
info
=
Blip2ProcessingInfo
,
dummy_inputs
=
Blip2DummyInputsBuilder
)
dummy_inputs
=
Blip2DummyInputsBuilder
)
class
Blip2ForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
class
Blip2ForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
,
SupportsQuant
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
vllm/model_executor/models/bloom.py
View file @
fcfc474d
...
@@ -46,7 +46,7 @@ from vllm.sequence import IntermediateTensors
...
@@ -46,7 +46,7 @@ from vllm.sequence import IntermediateTensors
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
.interfaces
import
SupportsPP
,
SupportsV0Only
from
.interfaces
import
SupportsPP
,
SupportsQuant
,
SupportsV0Only
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
...
@@ -290,7 +290,7 @@ class BloomModel(nn.Module):
...
@@ -290,7 +290,7 @@ class BloomModel(nn.Module):
return
hidden_states
return
hidden_states
class
BloomForCausalLM
(
nn
.
Module
,
SupportsPP
,
SupportsV0Only
):
class
BloomForCausalLM
(
nn
.
Module
,
SupportsPP
,
SupportsV0Only
,
SupportsQuant
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
...
...
Prev
1
…
15
16
17
18
19
20
21
22
23
…
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment