Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
081057de
Commit
081057de
authored
Apr 29, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.5' into v0.8.5-ori
parents
7cf5d5c4
ba41cc90
Changes
554
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1136 additions
and
327 deletions
+1136
-327
vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py
...tor/layers/quantization/kernels/mixed_precision/marlin.py
+24
-27
vllm/model_executor/layers/quantization/kv_cache.py
vllm/model_executor/layers/quantization/kv_cache.py
+36
-0
vllm/model_executor/layers/quantization/quark/quark.py
vllm/model_executor/layers/quantization/quark/quark.py
+15
-20
vllm/model_executor/layers/quantization/utils/bitblas_utils.py
...model_executor/layers/quantization/utils/bitblas_utils.py
+198
-0
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+15
-1
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+174
-92
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+339
-19
vllm/model_executor/layers/utils.py
vllm/model_executor/layers/utils.py
+46
-5
vllm/model_executor/layers/vocab_parallel_embedding.py
vllm/model_executor/layers/vocab_parallel_embedding.py
+2
-1
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+52
-28
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+20
-17
vllm/model_executor/models/arctic.py
vllm/model_executor/models/arctic.py
+0
-10
vllm/model_executor/models/aria.py
vllm/model_executor/models/aria.py
+1
-11
vllm/model_executor/models/aya_vision.py
vllm/model_executor/models/aya_vision.py
+0
-16
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+0
-10
vllm/model_executor/models/bamba.py
vllm/model_executor/models/bamba.py
+0
-10
vllm/model_executor/models/bart.py
vllm/model_executor/models/bart.py
+0
-10
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+189
-12
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+25
-28
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+0
-10
No files found.
Too many changes to show.
To preserve performance only
554 of 554+
files are displayed.
Plain diff
Email patch
vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py
View file @
081057de
...
...
@@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES
,
apply_gptq_marlin_linear
,
check_marlin_supports_shape
,
marlin_is_k_full
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
marlin_sort_g_idx
,
query_marlin_supported_quant_types
)
marlin_zero_points
,
query_marlin_supported_quant_types
,
unpack_cols
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
permute_param_layout_
)
...
...
@@ -25,10 +25,6 @@ class MarlinLinearKernel(MPLinearKernel):
@
classmethod
def
can_implement
(
cls
,
c
:
MPLinearLayerConfig
)
->
Tuple
[
bool
,
Optional
[
str
]]:
if
c
.
zero_points
:
return
False
,
"Zero points currently not supported by "
\
" MarlinLinearKernel. Will be added when AWQMarlin "
\
"is migrated over to using MPLinearKernel backend"
quant_types
=
query_marlin_supported_quant_types
(
c
.
zero_points
)
if
c
.
weight_type
not
in
quant_types
:
...
...
@@ -67,28 +63,6 @@ class MarlinLinearKernel(MPLinearKernel):
if
self
.
w_zp_name
is
None
:
self
.
w_zp_name
=
"w_zp"
if
c
.
has_g_idx
:
g_idx
,
g_idx_sort_indices
=
marlin_sort_g_idx
(
getattr
(
layer
,
self
.
w_gidx_name
))
self
.
_transform_param
(
layer
,
self
.
w_gidx_name
,
lambda
_
:
g_idx
)
layer
.
g_idx_sort_indices
=
g_idx_sort_indices
else
:
setattr
(
layer
,
self
.
w_gidx_name
,
marlin_make_empty_g_idx
(
device
))
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
if
c
.
zero_points
:
pass
# TODO (lucas): add the following when AWQMarlin is migrated over to
# using MPLinearKernel backend
# self._transform_param(layer, self.w_zp_name, lambda x: \
# marlin_zero_points(
# x,
# size_k=c.partition_weight_shape[0],
# size_n=c.partition_weight_shape[1],
# num_bits=c.weight_type.size_bits))
else
:
setattr
(
layer
,
self
.
w_zp_name
,
marlin_make_empty_g_idx
(
device
))
def
transform_w_q
(
x
):
assert
isinstance
(
x
,
BasevLLMParameter
)
permute_param_layout_
(
x
,
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
0
)
...
...
@@ -108,6 +82,28 @@ class MarlinLinearKernel(MPLinearKernel):
group_size
=
c
.
group_size
)
return
x
if
c
.
has_g_idx
:
g_idx
,
g_idx_sort_indices
=
marlin_sort_g_idx
(
getattr
(
layer
,
self
.
w_gidx_name
))
self
.
_transform_param
(
layer
,
self
.
w_gidx_name
,
lambda
_
:
g_idx
)
layer
.
g_idx_sort_indices
=
g_idx_sort_indices
else
:
setattr
(
layer
,
self
.
w_gidx_name
,
marlin_make_empty_g_idx
(
device
))
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
if
c
.
zero_points
:
grouped_k
=
(
c
.
partition_weight_shape
[
0
]
//
c
.
group_size
if
c
.
group_size
!=
-
1
else
1
)
self
.
_transform_param
(
layer
,
self
.
w_zp_name
,
lambda
x
:
\
marlin_zero_points
(
unpack_cols
(
x
.
t
(),
c
.
weight_type
.
size_bits
,
grouped_k
,
c
.
partition_weight_shape
[
1
]),
size_k
=
grouped_k
,
size_n
=
c
.
partition_weight_shape
[
1
],
num_bits
=
c
.
weight_type
.
size_bits
))
else
:
setattr
(
layer
,
self
.
w_zp_name
,
marlin_make_empty_g_idx
(
device
))
self
.
_transform_param
(
layer
,
self
.
w_q_name
,
transform_w_q
)
self
.
_transform_param
(
layer
,
self
.
w_s_name
,
transform_w_s
)
...
...
@@ -131,5 +127,6 @@ class MarlinLinearKernel(MPLinearKernel):
wtype
=
c
.
weight_type
,
input_size_per_partition
=
c
.
partition_weight_shape
[
0
],
output_size_per_partition
=
c
.
partition_weight_shape
[
1
],
has_zp
=
self
.
config
.
zero_points
,
is_k_full
=
self
.
is_k_full
,
bias
=
bias
)
vllm/model_executor/layers/quantization/kv_cache.py
View file @
081057de
...
...
@@ -38,6 +38,9 @@ class BaseKVCacheMethod(QuantizeMethodBase):
requires_grad
=
False
)
layer
.
v_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
-
1.0
),
requires_grad
=
False
)
# Initialize P = softmax(QK^T) scales
layer
.
prob_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
-
1.0
),
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
)
->
torch
.
Tensor
:
raise
RuntimeError
(
...
...
@@ -97,5 +100,38 @@ class BaseKVCacheMethod(QuantizeMethodBase):
"may cause accuracy issues. Please make sure k/v_scale "
"scaling factors are available in the fp8 checkpoint."
)
if
layer
.
q_scale
>
0.0
:
q_scale
=
layer
.
q_scale
if
current_platform
.
is_fp8_fnuz
():
q_scale
*=
2
layer
.
calculate_kv_scales
=
False
else
:
q_scale
=
1.0
if
layer
.
prob_scale
>
0.0
:
prob_scale
=
layer
.
prob_scale
if
current_platform
.
is_fp8_fnuz
():
prob_scale
*=
2
else
:
prob_scale
=
1.0
is_singleton_float
=
lambda
x
:
isinstance
(
x
,
float
)
or
isinstance
(
x
,
torch
.
Tensor
)
and
x
.
numel
()
==
1
and
x
.
is_floating_point
()
if
not
is_singleton_float
(
q_scale
)
or
not
is_singleton_float
(
prob_scale
):
raise
ValueError
(
"Only support per-tensor scaling factor"
"for fp8-quantized Q/prob"
)
# These are used in the final Attention.forward()
layer
.
_q_scale
.
copy_
(
q_scale
)
layer
.
_prob_scale
.
copy_
(
prob_scale
)
if
q_scale
==
1.0
or
prob_scale
==
1.0
:
logger
.
warning_once
(
f
"Using Q scale
{
q_scale
}
and prob scale
{
prob_scale
}
"
"with fp8 attention. This may cause accuracy issues. "
"Please make sure Q/prob scaling factors are "
"available in the fp8 checkpoint."
)
del
layer
.
k_scale
del
layer
.
v_scale
del
layer
.
q_scale
del
layer
.
prob_scale
vllm/model_executor/layers/quantization/quark/quark.py
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
import
fnmatch
import
re
from
typing
import
Any
,
Dict
,
List
,
Optional
,
cast
import
torch
...
...
@@ -125,6 +124,13 @@ class QuarkConfig(QuantizationConfig):
for
q_config
in
q_configs
:
q_config
[
"output_tensors"
]
=
None
# In case q_proj output is also quantized, remove the configuration
# to keep qkv consistency.
q_proj_q_config
=
cast
(
Dict
[
str
,
Any
],
layer_quant_config
.
get
(
"*q_proj"
))
if
q_proj_q_config
is
not
None
:
q_proj_q_config
[
"output_tensors"
]
=
None
return
cls
(
quant_config
=
config
,
kv_cache_group
=
kv_cache_group
,
kv_cache_config
=
kv_cache_config
,
...
...
@@ -289,25 +295,14 @@ class QuarkConfig(QuantizationConfig):
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if
self
.
kv_cache_group
is
None
or
len
(
self
.
kv_cache_group
)
==
0
:
return
None
kv_proj_names
=
[
re
.
split
(
r
"[*.]"
,
kv_cache
)[
-
1
]
for
kv_cache
in
self
.
kv_cache_group
]
if
name
.
endswith
(
".output_scale"
):
if
len
(
kv_proj_names
)
==
1
and
kv_proj_names
[
0
]
in
name
:
kv_output_scale_name
=
"."
+
kv_proj_names
[
0
]
+
".output_scale"
return
name
.
replace
(
kv_output_scale_name
,
".attn.k_scale"
)
elif
len
(
kv_proj_names
)
==
2
:
for
kv_proj_name
in
kv_proj_names
:
if
kv_proj_name
in
name
and
kv_proj_name
==
"k_proj"
:
return
name
.
replace
(
".k_proj.output_scale"
,
".attn.k_scale"
)
elif
kv_proj_name
in
name
and
kv_proj_name
==
"v_proj"
:
return
name
.
replace
(
".v_proj.output_scale"
,
".attn.v_scale"
)
if
name
.
endswith
(
".output_scale"
)
and
".k_proj"
in
name
:
return
name
.
replace
(
".k_proj.output_scale"
,
".attn.k_scale"
)
if
name
.
endswith
(
".output_scale"
)
and
".v_proj"
in
name
:
return
name
.
replace
(
".v_proj.output_scale"
,
".attn.v_scale"
)
if
name
.
endswith
(
".output_scale"
)
and
".q_proj"
in
name
:
return
name
.
replace
(
".q_proj.output_scale"
,
".attn.q_scale"
)
if
name
.
endswith
(
"self_attn.prob_output_scale"
):
return
name
.
replace
(
".prob_output_scale"
,
".attn.prob_scale"
)
# If no matches, return None
return
None
...
...
vllm/model_executor/layers/quantization/utils/bitblas_utils.py
0 → 100644
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
,
Tuple
import
torch
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
ScalarType
,
scalar_types
MINIMUM_BITBLAS_VERSION
=
"0.1.0"
BITBLAS_MIN_WEIGHT_SIZE_N
=
16
BITBLAS_MIN_WEIGHT_SIZE_K
=
16
GPTQ_BITBLAS_MAX_PARALLEL
=
16
BITBLAS_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
# For dynamic shape code generation
BITBLAS_OPTIMIZE_FEATURES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
]
# If want to enable high performance for contiguous batching
# Please use the following values
BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS
=
[
16
,
32
,
64
,
128
,
256
,
512
,
1024
]
BITBLAS_SUPPORTED_NUM_BITS
=
[
1
,
2
,
4
,
8
]
BITBLAS_SUPPORTED_SYM
=
[
False
,
True
]
# Determines the supported quantization types for BitBLAS based on the
# device's capability and whether zero-point (zp) is used.
def
query_bitblas_supported_quant_types
(
has_zp
:
bool
,
device_capability
:
Optional
[
int
]
=
None
):
if
device_capability
is
None
:
capability_tuple
=
current_platform
.
get_device_capability
()
device_capability
=
(
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
())
if
device_capability
<
70
:
return
[]
if
has_zp
:
# AWQ style, unsigned + runtime zero-point
return
[
scalar_types
.
uint4
,
scalar_types
.
uint8
]
else
:
# GPTQ style, unsigned + symmetric bias
# TODO: once fp8_bitblas is merged into "gptq_bitblas" we should be able
# to add `scalar_types.float8_e4m3fn` here
return
[
scalar_types
.
uint4b8
,
scalar_types
.
uint8b128
]
def
_check_bitblas_supported
(
quant_type
:
ScalarType
,
group_size
:
Optional
[
int
],
has_zp
:
bool
,
device_capability
:
Optional
[
int
]
=
None
)
->
Tuple
[
bool
,
Optional
[
str
]]:
if
device_capability
is
None
:
capability_tuple
=
current_platform
.
get_device_capability
()
device_capability
=
(
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
())
supported_types
=
query_bitblas_supported_quant_types
(
has_zp
,
device_capability
)
if
quant_type
not
in
supported_types
:
return
(
False
,
f
"BitBLAS does not support weight_bits =
{
quant_type
}
. "
f
"Only types =
{
supported_types
}
"
f
"are supported (for group_size =
{
group_size
}
, "
f
"device_capability =
{
device_capability
}
, zp =
{
has_zp
}
)."
)
if
(
group_size
is
None
or
group_size
not
in
BITBLAS_SUPPORTED_GROUP_SIZES
):
return
(
False
,
f
"BitBLAS does not support group_size =
{
group_size
}
. "
f
"Only group_sizes =
{
BITBLAS_SUPPORTED_GROUP_SIZES
}
"
"are supported."
)
return
True
,
None
def
check_bitblas_supported
(
quant_type
:
ScalarType
,
group_size
:
int
,
has_zp
:
bool
=
False
,
device_capability
:
Optional
[
int
]
=
None
)
->
bool
:
cond
,
_
=
_check_bitblas_supported
(
quant_type
,
group_size
,
has_zp
,
device_capability
)
return
cond
def
verify_bitblas_supported
(
quant_type
:
ScalarType
,
group_size
:
int
,
has_zp
:
bool
=
False
)
->
None
:
cond
,
err_msg
=
_check_bitblas_supported
(
quant_type
,
group_size
,
has_zp
)
if
not
cond
:
assert
err_msg
is
not
None
raise
ValueError
(
err_msg
)
def
verify_bitblas_supports_shape
(
output_size_per_partition
:
int
,
input_size_per_partition
:
int
,
input_size
:
int
,
group_size
:
int
)
->
None
:
# Validate output_size_per_partition
if
output_size_per_partition
%
BITBLAS_MIN_WEIGHT_SIZE_N
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by "
f
" min_thread_n =
{
BITBLAS_MIN_WEIGHT_SIZE_N
}
. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq."
)
# Validate input_size_per_partition
if
input_size_per_partition
%
BITBLAS_MIN_WEIGHT_SIZE_K
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible "
f
"by min_thread_k =
{
BITBLAS_MIN_WEIGHT_SIZE_K
}
. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq."
)
if
(
group_size
<
input_size
and
input_size_per_partition
%
group_size
!=
0
):
raise
ValueError
(
f
"Weight input_size_per_partition =
{
input_size_per_partition
}
"
f
" is not divisible by group_size =
{
group_size
}
."
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq."
)
def
check_bitblas_supports_shape
(
output_size_per_partition
:
int
,
input_size_per_partition
:
int
,
input_size
:
int
,
group_size
:
int
)
\
->
Tuple
[
bool
,
Optional
[
str
]]:
try
:
verify_bitblas_supports_shape
(
output_size_per_partition
,
input_size_per_partition
,
input_size
,
group_size
)
except
ValueError
as
e
:
return
False
,
e
.
__str__
()
return
True
,
None
def
bitblas_is_k_full
(
act_order
:
bool
,
is_row_parallel
:
bool
)
->
bool
:
return
(
not
act_order
)
or
(
act_order
and
not
is_row_parallel
)
def
bitblas_repeat_scales_on_all_ranks
(
act_order
:
bool
,
group_size
:
int
,
is_row_parallel
:
bool
)
->
bool
:
# Need to repeat scales on every rank if act_ordering or
# channelwise and RowParallelLinear
is_channelwise
=
group_size
==
-
1
return
act_order
or
(
is_channelwise
and
is_row_parallel
)
def
bitblas_make_empty_g_idx
(
device
:
torch
.
device
)
->
torch
.
Tensor
:
return
torch
.
nn
.
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
),
requires_grad
=
False
)
def
bitblas_make_empty_zp
(
device
:
torch
.
device
)
->
torch
.
Tensor
:
return
torch
.
nn
.
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
),
requires_grad
=
False
)
def
bitblas_sort_g_idx
(
g_idx
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
g_idx_sort_indices
=
torch
.
argsort
(
g_idx
).
to
(
torch
.
int
)
return
g_idx
[
g_idx_sort_indices
],
g_idx_sort_indices
def
unpack_gptq_qzeros
(
qzeros
,
bits
,
is_gptq_v2
=
False
)
->
torch
.
Tensor
:
qzeros
=
qzeros
.
view
(
torch
.
int32
)
elems_per_int32
=
32
//
bits
unpacked_zeros
=
torch
.
zeros
(
(
qzeros
.
shape
[
0
],
qzeros
.
shape
[
1
]
*
elems_per_int32
),
dtype
=
torch
.
int8
,
device
=
qzeros
.
device
,
requires_grad
=
False
,
)
for
col
in
range
(
unpacked_zeros
.
shape
[
1
]):
i
=
col
%
elems_per_int32
unpacked_zeros
[:,
col
]
=
(
qzeros
[:,
col
//
elems_per_int32
]
>>
(
bits
*
i
))
&
0xF
if
not
is_gptq_v2
:
return
unpacked_zeros
+
1
return
unpacked_zeros
def
unpack_gptq_qweight
(
qweight
,
bits
):
qweight
=
qweight
.
view
(
torch
.
int8
)
elems_per_int8
=
8
//
bits
unpacked_weight
=
torch
.
zeros
(
(
qweight
.
shape
[
0
],
qweight
.
shape
[
1
]
*
elems_per_int8
),
dtype
=
torch
.
int8
,
device
=
qweight
.
device
,
requires_grad
=
False
,
)
for
col
in
range
(
unpacked_weight
.
shape
[
1
]):
i
=
col
%
elems_per_int8
unpacked_weight
[:,
col
]
=
(
qweight
[:,
col
//
elems_per_int8
]
>>
(
bits
*
i
))
return
torch
.
bitwise_and
(
unpacked_weight
,
2
**
bits
-
1
)
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
081057de
...
...
@@ -151,6 +151,19 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) \
group_size
=
group_size
)[
0
]
def
check_moe_marlin_supports_layer
(
layer
:
LinearBase
,
group_size
:
int
)
\
->
bool
:
hidden_size
=
layer
.
hidden_size
intermediate_size_per_partition
=
layer
.
intermediate_size_per_partition
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
# moe marlin requires n % 128 == 0 and k % 64 == 0
return
hidden_size
%
128
==
0
and
\
intermediate_size_per_partition
%
max
(
64
,
group_size
)
==
0
and
\
group_size
in
[
-
1
,
32
,
64
,
128
]
def
marlin_make_workspace
(
output_size_per_partition
:
int
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
max_workspace_size
=
(
output_size_per_partition
//
...
...
@@ -319,6 +332,7 @@ def apply_gptq_marlin_linear(
wtype
:
ScalarType
,
output_size_per_partition
:
int
,
input_size_per_partition
:
int
,
has_zp
:
bool
,
is_k_full
:
bool
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
use_fp32_reduce
:
bool
=
USE_FP32_REDUCE_DEFAULT
)
->
torch
.
Tensor
:
...
...
@@ -343,8 +357,8 @@ def apply_gptq_marlin_linear(
size_n
=
output_size_per_partition
,
size_k
=
input_size_per_partition
,
is_k_full
=
is_k_full
,
has_zp
=
False
,
use_atomic_add
=
use_atomic_add
,
has_zp
=
has_zp
,
use_fp32_reduce
=
use_fp32_reduce
,
is_zp_float
=
False
)
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
envs
from
vllm.config
import
CompilationLevel
,
get_current_vllm_config
from
vllm.platforms
import
current_platform
...
...
@@ -17,6 +18,7 @@ TORCH_DEVICE_IDENTITY = None
# The condition is determined once as the operations
# are time consuming.
USE_ROWWISE_TORCH_SCALED_MM
=
(
current_platform
.
is_rocm
()
and
torch
.
__version__
[
0
:
3
]
>=
"2.7"
and
current_platform
.
has_device_capability
(
94
))
...
...
@@ -131,6 +133,160 @@ def maybe_create_device_identity():
TORCH_DEVICE_IDENTITY
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
)
def
cutlass_w8a8_scaled_mm
(
*
,
qinput
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
output_shape
:
List
,
**
kwargs
)
->
torch
.
Tensor
:
# Fused GEMM_DQ
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
weight
,
out_dtype
=
out_dtype
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
bias
=
bias
)
return
output
.
view
(
*
output_shape
)
def
rocm_per_tensor_w8a8_scaled_mm
(
*
,
qinput
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
input_2d
:
torch
.
Tensor
,
output_shape
:
List
)
->
torch
.
Tensor
:
from
vllm.platforms.rocm
import
on_mi250_mi300
if
envs
.
VLLM_ROCM_USE_SKINNY_GEMM
and
on_mi250_mi300
(
)
and
qinput
.
shape
[
0
]
==
1
and
qinput
.
shape
[
1
]
%
16
==
0
:
output
=
ops
.
wvSplitKQ
(
weight
.
t
(),
qinput
,
out_dtype
,
scale_a
,
scale_b
,
current_platform
.
get_cu_count
())
else
:
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
out_dtype
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
bias
=
bias
)
return
torch
.
narrow
(
output
,
0
,
0
,
input_2d
.
shape
[
0
]).
view
(
*
output_shape
)
def
torch_per_tensor_w8a8_scaled_mm
(
*
,
qinput
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
input_2d
:
torch
.
Tensor
,
output_shape
:
List
)
->
torch
.
Tensor
:
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
out_dtype
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
bias
=
bias
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
output
=
output
[
0
]
return
torch
.
narrow
(
output
,
0
,
0
,
input_2d
.
shape
[
0
]).
view
(
*
output_shape
)
def
torch_per_token_w8a8_scaled_mm
(
*
,
qinput
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
input_2d
:
torch
.
Tensor
,
output_shape
:
List
)
->
torch
.
Tensor
:
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
# when using it.
# For now it has only been validated on ROCm platform.
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
#
# For CUDA platform please validate if the torch._scaled_mm supports
# rowwise scaled GEMM before using it
# Fused GEMM_DQ Rowwise GEMM
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
out_dtype
,
scale_a
=
scale_a
,
scale_b
=
scale_b
.
t
(),
bias
=
bias
)
output
=
torch
.
narrow
(
output
,
0
,
0
,
input_2d
.
shape
[
0
])
output
=
output
.
view
(
*
output_shape
)
return
output
def
torch_channelwise_w8a8_scaled_mm
(
*
,
qinput
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
input_2d
:
torch
.
Tensor
,
output_shape
:
List
,
**
kwargs
)
->
torch
.
Tensor
:
# Use unfused DQ due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
scale_a
=
TORCH_DEVICE_IDENTITY
,
scale_b
=
TORCH_DEVICE_IDENTITY
,
out_dtype
=
torch
.
float32
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
output
=
output
[
0
]
# Unpad (undo num_token_padding)
output
=
torch
.
narrow
(
output
,
0
,
0
,
input_2d
.
shape
[
0
])
x_scale
=
torch
.
narrow
(
scale_a
,
0
,
0
,
input_2d
.
shape
[
0
])
# DQ
# C = sw * sx * (X * W) + bias
output
=
output
*
x_scale
*
scale_b
.
t
()
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
to
(
out_dtype
).
view
(
*
output_shape
)
def
dispatch_w8a8_scaled_mm
(
cutlass_fp8_supported
:
bool
,
per_tensor_weights
:
bool
,
per_tensor_activations
:
bool
,
use_per_token_if_dynamic
:
Optional
[
bool
]
)
->
Callable
[...,
torch
.
Tensor
]:
if
cutlass_fp8_supported
:
return
cutlass_w8a8_scaled_mm
if
per_tensor_weights
and
per_tensor_activations
:
if
current_platform
.
is_rocm
():
return
rocm_per_tensor_w8a8_scaled_mm
return
torch_per_tensor_w8a8_scaled_mm
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
if
(
use_per_token_if_dynamic
and
not
per_tensor_weights
and
not
per_tensor_activations
and
USE_ROWWISE_TORCH_SCALED_MM
):
return
torch_per_token_w8a8_scaled_mm
return
torch_channelwise_w8a8_scaled_mm
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
# https://github.com/vllm-project/vllm/issues/14397
class
Fp8LinearOp
:
...
...
@@ -156,7 +312,8 @@ class Fp8LinearOp:
if
pad_output
is
None
:
config
=
get_current_vllm_config
().
compilation_config
pad_output
=
config
.
level
<
CompilationLevel
.
PIECEWISE
self
.
output_padding
=
17
if
pad_output
else
None
self
.
output_padding
=
17
if
(
pad_output
and
not
current_platform
.
is_rocm
())
else
None
def
apply
(
self
,
...
...
@@ -195,18 +352,6 @@ class Fp8LinearOp:
input_scale
,
scale_ub
=
input_scale_ub
,
use_per_token_if_dynamic
=
use_per_token_if_dynamic
)
# Fused GEMM_DQ
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
weight
,
out_dtype
=
out_dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
bias
=
bias
)
return
output
.
view
(
*
output_shape
)
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
else
:
if
input
.
dtype
!=
current_platform
.
fp8_dtype
():
# Maybe apply padding to output, see comment in __init__
...
...
@@ -218,84 +363,21 @@ class Fp8LinearOp:
else
:
qinput
,
x_scale
=
input_2d
,
input_scale
per_tensor_weights
=
(
weight_scale
.
numel
()
==
1
)
per_tensor_activations
=
(
x_scale
.
numel
()
==
1
)
if
per_tensor_weights
and
per_tensor_activations
:
# Fused GEMM_DQ
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
out_dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
bias
=
bias
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
output
=
output
[
0
]
return
torch
.
narrow
(
output
,
0
,
0
,
input_2d
.
shape
[
0
]).
view
(
*
output_shape
)
elif
(
use_per_token_if_dynamic
and
not
per_tensor_weights
and
not
per_tensor_activations
and
USE_ROWWISE_TORCH_SCALED_MM
):
# For now validated on ROCm platform
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
# and ROCm 6.3, which only exists in torch 2.7 and above.
# For CUDA platform please validate if the
# torch._scaled_mm support rowwise scaled GEMM
# Fused GEMM_DQ Rowwise GEMM
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
out_dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
.
t
(),
bias
=
bias
)
output
=
torch
.
narrow
(
output
,
0
,
0
,
input_2d
.
shape
[
0
])
output
=
output
.
view
(
*
output_shape
)
return
output
else
:
# Fallback for channelwise case, where we use unfused DQ
# due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
scale_a
=
TORCH_DEVICE_IDENTITY
,
scale_b
=
TORCH_DEVICE_IDENTITY
,
out_dtype
=
torch
.
float32
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
output
=
output
[
0
]
# Unpad (undo num_token_padding)
output
=
torch
.
narrow
(
output
,
0
,
0
,
input_2d
.
shape
[
0
])
x_scale
=
torch
.
narrow
(
x_scale
,
0
,
0
,
input_2d
.
shape
[
0
])
# DQ
# C = sw * sx * (X * W) + bias
output
=
output
*
x_scale
*
weight_scale
.
t
()
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
to
(
dtype
=
input
.
dtype
).
view
(
*
output_shape
)
per_tensor_weights
=
(
weight_scale
.
numel
()
==
1
)
per_tensor_activations
=
(
x_scale
.
numel
()
==
1
)
w8a8_scaled_mm_func
=
dispatch_w8a8_scaled_mm
(
self
.
cutlass_fp8_supported
,
per_tensor_weights
,
per_tensor_activations
,
use_per_token_if_dynamic
)
return
w8a8_scaled_mm_func
(
qinput
=
qinput
,
weight
=
weight
,
out_dtype
=
out_dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
bias
=
bias
,
input_2d
=
input_2d
,
output_shape
=
output_shape
)
def
normalize_e4m3fn_to_e4m3fnuz
(
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
081057de
...
...
@@ -46,20 +46,12 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
return
x
.
flatten
(
-
2
)
def
_apply_rotary_emb
(
def
_apply_rotary_emb
_torch
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
is_neox_style
:
bool
,
)
->
torch
.
Tensor
:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos
=
cos
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
sin
=
sin
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
if
is_neox_style
:
...
...
@@ -75,6 +67,24 @@ def _apply_rotary_emb(
return
torch
.
stack
((
o1
,
o2
),
dim
=-
1
).
flatten
(
-
2
)
def
_apply_rotary_emb
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
is_neox_style
:
bool
)
->
torch
.
Tensor
:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
if
current_platform
.
is_cuda_alike
():
from
vllm.vllm_flash_attn.layers.rotary
import
apply_rotary_emb
return
apply_rotary_emb
(
x
.
unsqueeze
(
0
),
cos
,
sin
,
not
is_neox_style
).
squeeze
(
0
)
else
:
return
_apply_rotary_emb_torch
(
x
,
cos
,
sin
,
is_neox_style
)
@
CustomOp
.
register
(
"rotary_embedding"
)
class
RotaryEmbedding
(
CustomOp
):
"""Original rotary positional embedding."""
...
...
@@ -141,14 +151,16 @@ class RotaryEmbedding(CustomOp):
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
_apply_rotary_emb
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query_rot
=
_apply_rotary_emb_torch
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
_apply_rotary_emb
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key_rot
=
_apply_rotary_emb_torch
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
...
...
@@ -988,8 +1000,9 @@ class MRotaryEmbedding(RotaryEmbedding):
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
@
static
method
@
class
method
def
get_input_positions
(
cls
,
input_tokens
:
List
[
int
],
hf_config
:
PretrainedConfig
,
image_grid_thw
:
Optional
[
Union
[
List
[
List
[
int
]],
torch
.
Tensor
]],
...
...
@@ -997,6 +1010,8 @@ class MRotaryEmbedding(RotaryEmbedding):
second_per_grid_ts
:
Optional
[
List
[
float
]],
context_len
:
int
=
0
,
seq_len
:
Optional
[
int
]
=
None
,
audio_feature_lengths
:
Optional
[
torch
.
Tensor
]
=
None
,
use_audio_in_video
:
bool
=
False
,
)
->
Tuple
[
List
[
List
[
int
]],
int
]:
"""Get mrope input positions and delta value."""
...
...
@@ -1006,7 +1021,7 @@ class MRotaryEmbedding(RotaryEmbedding):
second_per_grid_ts
llm_positions
,
mrope_position_delta
=
\
MRotaryEmbedding
.
get_input_positions_tensor
(
cls
.
get_input_positions_tensor
(
input_tokens
=
input_tokens
,
hf_config
=
hf_config
,
image_grid_thw
=
image_grid_thw
,
...
...
@@ -1014,12 +1029,52 @@ class MRotaryEmbedding(RotaryEmbedding):
second_per_grid_ts
=
second_per_grid_ts
,
context_len
=
context_len
,
seq_len
=
seq_len
,
audio_feature_lengths
=
audio_feature_lengths
,
use_audio_in_video
=
use_audio_in_video
,
)
return
llm_positions
.
tolist
(),
mrope_position_delta
@
static
method
@
class
method
def
get_input_positions_tensor
(
cls
,
input_tokens
:
List
[
int
],
hf_config
:
PretrainedConfig
,
image_grid_thw
:
Union
[
List
[
List
[
int
]],
torch
.
Tensor
],
video_grid_thw
:
Union
[
List
[
List
[
int
]],
torch
.
Tensor
],
second_per_grid_ts
:
List
[
float
],
context_len
:
int
=
0
,
seq_len
:
Optional
[
int
]
=
None
,
audio_feature_lengths
:
Optional
[
torch
.
Tensor
]
=
None
,
use_audio_in_video
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
int
]:
from
vllm.transformers_utils.config
import
thinker_uses_mrope
if
thinker_uses_mrope
(
hf_config
):
return
cls
.
_omni_get_input_positions_tensor
(
input_tokens
=
input_tokens
,
hf_config
=
hf_config
,
image_grid_thw
=
image_grid_thw
,
video_grid_thw
=
video_grid_thw
,
second_per_grid_ts
=
second_per_grid_ts
,
context_len
=
context_len
,
seq_len
=
seq_len
,
audio_feature_lengths
=
audio_feature_lengths
,
use_audio_in_video
=
use_audio_in_video
,
)
else
:
return
cls
.
_vl_get_input_positions_tensor
(
input_tokens
=
input_tokens
,
hf_config
=
hf_config
,
image_grid_thw
=
image_grid_thw
,
video_grid_thw
=
video_grid_thw
,
second_per_grid_ts
=
second_per_grid_ts
,
context_len
=
context_len
,
seq_len
=
seq_len
,
)
@
classmethod
def
_vl_get_input_positions_tensor
(
cls
,
input_tokens
:
List
[
int
],
hf_config
:
PretrainedConfig
,
image_grid_thw
:
Union
[
List
[
List
[
int
]],
torch
.
Tensor
],
...
...
@@ -1037,11 +1092,6 @@ class MRotaryEmbedding(RotaryEmbedding):
tokens_per_second
=
getattr
(
hf_config
.
vision_config
,
"tokens_per_second"
,
1.0
)
if
isinstance
(
image_grid_thw
,
torch
.
Tensor
):
image_grid_thw
=
image_grid_thw
.
tolist
()
if
isinstance
(
video_grid_thw
,
torch
.
Tensor
):
video_grid_thw
=
video_grid_thw
.
tolist
()
input_tokens_tensor
=
torch
.
tensor
(
input_tokens
)
vision_start_indices
=
torch
.
argwhere
(
input_tokens_tensor
==
vision_start_token_id
).
squeeze
(
1
)
...
...
@@ -1121,6 +1171,224 @@ class MRotaryEmbedding(RotaryEmbedding):
return
llm_positions
,
mrope_position_delta
@
classmethod
def
_omni_get_input_positions_tensor
(
cls
,
input_tokens
:
List
[
int
],
hf_config
:
PretrainedConfig
,
image_grid_thw
:
Union
[
List
[
List
[
int
]],
torch
.
Tensor
],
video_grid_thw
:
Union
[
List
[
List
[
int
]],
torch
.
Tensor
],
second_per_grid_ts
:
Optional
[
List
[
float
]]
=
None
,
context_len
:
int
=
0
,
seq_len
:
Optional
[
int
]
=
None
,
audio_feature_lengths
:
Optional
[
torch
.
Tensor
]
=
None
,
use_audio_in_video
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
int
]:
"""Get mrope input positions and delta value (Qwen2.5-Omni version).
Differences from MRotaryEmbedding:
1. Add audio support (and related `audio_feature_lengths`).
2. Add `use_audio_in_video` option to read audio from video inputs.
In this case, audio and vision position ids will be split into
chunks and interleaved.
Example:
(V_i are vision position ids, A_i are audio position ids)
|V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|...
|vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |...
"""
# TODO(fyabc): refactor and share more code with
# _vl_get_input_positions_tensor.
thinker_config
=
hf_config
.
thinker_config
audio_token_id
=
thinker_config
.
audio_token_index
image_token_id
=
thinker_config
.
image_token_index
video_token_id
=
thinker_config
.
video_token_index
audio_start_token_id
=
thinker_config
.
audio_start_token_id
audio_end_token_id
=
thinker_config
.
audio_end_token_id
vision_start_token_id
=
thinker_config
.
vision_start_token_id
vision_end_token_id
=
thinker_config
.
vision_end_token_id
seconds_per_chunk
=
thinker_config
.
seconds_per_chunk
spatial_merge_size
=
thinker_config
.
vision_config
.
spatial_merge_size
tokens_per_second
=
getattr
(
thinker_config
.
vision_config
,
"tokens_per_second"
,
25
)
if
isinstance
(
image_grid_thw
,
list
):
image_grid_thw
=
torch
.
tensor
(
image_grid_thw
)
if
isinstance
(
video_grid_thw
,
list
):
video_grid_thw
=
torch
.
tensor
(
video_grid_thw
)
src_item
=
input_tokens
audio_seqlens
=
audio_feature_lengths
if
not
second_per_grid_ts
:
second_per_grid_ts
=
[
1
]
*
video_grid_thw
.
shape
[
0
]
audio_idx
=
0
video_idx
=
0
image_idx
=
0
new_src_item
:
list
[
int
]
=
[]
llm_pos_ids_list
:
list
[
torch
.
Tensor
]
=
[]
idx
=
0
while
idx
<
len
(
src_item
):
new_src_item_len
=
len
(
new_src_item
)
start_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
if
src_item
[
idx
]
not
in
[
audio_token_id
,
video_token_id
,
image_token_id
]:
if
use_audio_in_video
and
idx
>
0
:
if
src_item
[
idx
]
==
vision_end_token_id
and
\
src_item
[
idx
-
1
]
==
audio_end_token_id
:
# processing the <|audio_eos|> before <|vision_eos|>
start_idx
-=
1
elif
src_item
[
idx
]
==
audio_start_token_id
and
\
src_item
[
idx
-
1
]
==
vision_start_token_id
:
# processing the <|audio_bos|> after <|vision_eos|>
start_idx
-=
1
new_src_item
.
append
(
src_item
[
idx
])
llm_pos_ids
=
torch
.
tensor
([
start_idx
],
dtype
=
torch
.
long
).
expand
(
3
,
-
1
)
llm_pos_ids_list
.
append
(
llm_pos_ids
)
elif
src_item
[
idx
]
==
audio_token_id
:
assert
audio_seqlens
is
not
None
audio_seqlen
=
audio_seqlens
[
audio_idx
]
place_num
=
(((
audio_seqlen
-
1
)
//
2
+
1
-
2
)
//
2
+
1
)
new_src_item
.
extend
([
audio_token_id
]
*
place_num
)
llm_pos_ids
=
torch
.
arange
(
place_num
).
expand
(
3
,
-
1
)
+
start_idx
llm_pos_ids_list
.
append
(
llm_pos_ids
)
audio_idx
+=
1
elif
src_item
[
idx
]
==
image_token_id
:
grid_t
=
image_grid_thw
[
image_idx
][
0
]
grid_hs
=
image_grid_thw
[:,
1
]
grid_ws
=
image_grid_thw
[:,
2
]
t_index
=
(
torch
.
arange
(
grid_t
)
*
1
*
tokens_per_second
).
long
()
llm_pos_ids
=
cls
.
_get_llm_pos_ids_for_vision
(
start_idx
,
image_idx
,
spatial_merge_size
,
t_index
,
grid_hs
,
grid_ws
)
llm_pos_ids_list
.
append
(
llm_pos_ids
)
vision_seqlen
=
image_grid_thw
[
image_idx
].
prod
()
//
(
spatial_merge_size
**
2
)
new_src_item
.
extend
([
image_token_id
]
*
vision_seqlen
)
image_idx
+=
1
elif
src_item
[
idx
]
==
video_token_id
and
not
use_audio_in_video
:
grid_t
=
video_grid_thw
[
video_idx
][
0
]
grid_hs
=
video_grid_thw
[:,
1
]
grid_ws
=
video_grid_thw
[:,
2
]
t_index
=
(
torch
.
arange
(
grid_t
)
*
second_per_grid_ts
[
video_idx
]
*
tokens_per_second
).
long
()
llm_pos_ids
=
cls
.
_get_llm_pos_ids_for_vision
(
start_idx
,
video_idx
,
spatial_merge_size
,
t_index
,
grid_hs
,
grid_ws
)
llm_pos_ids_list
.
append
(
llm_pos_ids
)
vision_seqlen
=
video_grid_thw
[
video_idx
].
prod
()
//
(
spatial_merge_size
**
2
)
new_src_item
.
extend
([
video_token_id
]
*
vision_seqlen
)
video_idx
+=
1
else
:
# read audio from video
assert
audio_seqlens
is
not
None
audio_seqlen
=
audio_seqlens
[
audio_idx
]
vision_seqlen
=
video_grid_thw
[
video_idx
].
prod
()
//
(
spatial_merge_size
**
2
)
grid_t
=
video_grid_thw
[
video_idx
][
0
]
grid_h
=
video_grid_thw
[
video_idx
][
1
]
grid_w
=
video_grid_thw
[
video_idx
][
2
]
grid_hs
=
video_grid_thw
[:,
1
]
grid_ws
=
video_grid_thw
[:,
2
]
t_ntoken_per_chunk
=
int
(
tokens_per_second
*
seconds_per_chunk
)
t_index
=
(
torch
.
arange
(
grid_t
)
*
second_per_grid_ts
[
video_idx
]
*
tokens_per_second
).
long
()
t_index_split_chunk
=
cls
.
_split_list_into_ranges
(
t_index
,
t_ntoken_per_chunk
)
place_num
=
(((
audio_seqlen
-
1
)
//
2
+
1
-
2
)
//
2
+
1
)
+
2
pure_audio_len
=
place_num
-
2
added_audio_len
=
0
audio_llm_pos_ids_list
:
List
[
torch
.
Tensor
]
=
[]
for
t_chunk
in
t_index_split_chunk
:
vision_ntoken_per_chunk
=
len
(
t_chunk
)
*
grid_h
*
grid_w
//
(
spatial_merge_size
**
2
)
new_src_item
.
extend
([
video_token_id
]
*
vision_ntoken_per_chunk
)
vision_llm_pos_ids_list
=
cls
.
_get_llm_pos_ids_for_vision
(
start_idx
,
video_idx
,
spatial_merge_size
,
t_chunk
,
grid_hs
,
grid_ws
).
split
(
1
,
dim
=
1
)
llm_pos_ids_list
.
extend
(
vision_llm_pos_ids_list
)
new_src_item
.
extend
(
min
(
t_ntoken_per_chunk
,
pure_audio_len
-
added_audio_len
)
*
[
audio_token_id
])
audio_start_idx
=
start_idx
if
len
(
audio_llm_pos_ids_list
)
==
0
else
audio_llm_pos_ids_list
[
-
1
][
0
].
item
()
+
1
if
min
(
t_ntoken_per_chunk
,
pure_audio_len
-
added_audio_len
)
>
0
:
audio_llm_pos_ids_list
=
(
torch
.
arange
(
min
(
t_ntoken_per_chunk
,
pure_audio_len
-
added_audio_len
)).
expand
(
3
,
-
1
)
+
audio_start_idx
).
split
(
1
,
dim
=
1
)
else
:
audio_llm_pos_ids_list
=
[]
added_audio_len
+=
min
(
t_ntoken_per_chunk
,
pure_audio_len
-
added_audio_len
)
llm_pos_ids_list
.
extend
(
audio_llm_pos_ids_list
)
if
added_audio_len
<
pure_audio_len
:
new_src_item
.
extend
(
(
pure_audio_len
-
added_audio_len
)
*
[
audio_token_id
])
audio_llm_pos_ids_list
=
(
torch
.
arange
(
pure_audio_len
-
added_audio_len
).
expand
(
3
,
-
1
)
+
llm_pos_ids_list
[
-
1
].
max
()
+
1
).
split
(
1
,
dim
=
1
)
llm_pos_ids_list
.
extend
(
audio_llm_pos_ids_list
)
audio_idx
+=
1
video_idx
+=
1
# move to the next token
idx
+=
len
(
new_src_item
)
-
new_src_item_len
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
)
mrope_position_delta
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
max
()
+
1
-
len
(
src_item
)
llm_positions
=
llm_positions
[:,
context_len
:
seq_len
]
return
llm_positions
,
mrope_position_delta
@
staticmethod
def
_get_llm_pos_ids_for_vision
(
start_idx
:
int
,
vision_idx
:
int
,
spatial_merge_size
:
int
,
t_index
:
List
[
int
],
grid_hs
:
torch
.
Tensor
,
grid_ws
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
llm_pos_ids_list
=
[]
llm_grid_h
=
grid_hs
[
vision_idx
]
//
spatial_merge_size
llm_grid_w
=
grid_ws
[
vision_idx
]
//
spatial_merge_size
h_index
=
(
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
len
(
t_index
),
-
1
,
llm_grid_w
).
flatten
())
w_index
=
(
torch
.
arange
(
llm_grid_w
).
view
(
1
,
1
,
-
1
).
expand
(
len
(
t_index
),
llm_grid_h
,
-
1
).
flatten
())
t_index_tensor
=
torch
.
Tensor
(
t_index
).
to
(
llm_grid_h
.
device
).
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
).
long
().
flatten
()
_llm_pos_ids
=
torch
.
stack
([
t_index_tensor
,
h_index
,
w_index
])
llm_pos_ids_list
.
append
(
_llm_pos_ids
+
start_idx
)
llm_pos_ids
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
)
return
llm_pos_ids
@
staticmethod
def
_split_list_into_ranges
(
lst
:
torch
.
Tensor
,
interval
:
int
)
->
List
[
List
[
int
]]:
ranges
:
List
[
List
[
int
]]
=
[[]
for
_
in
range
((
max
(
lst
)
//
interval
)
+
1
)]
for
num
in
lst
:
index
=
num
//
interval
ranges
[
index
].
append
(
num
)
return
ranges
@
staticmethod
def
get_next_input_positions
(
mrope_position_delta
:
int
,
...
...
@@ -1144,6 +1412,58 @@ class MRotaryEmbedding(RotaryEmbedding):
mrope_position_delta
+
seq_len
,
).
expand
(
3
,
-
1
)
@
classmethod
def
omni_get_updates_use_audio_in_video
(
cls
,
thinker_config
:
PretrainedConfig
,
audio_len
:
int
,
video_grid_thw
:
Union
[
List
[
int
],
torch
.
Tensor
],
video_second_per_grid_t
:
float
,
)
->
List
[
int
]:
"""Get video prompt updates when `use_audio_in_video` is True.
In this case, audio and vision update ids will be split into
chunks and interleaved (details in `_omni_get_input_positions_tensor`).
<|video_bos|><|VIDEO|><|video_eos|> =>
<|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|>
"""
audio_token_id
=
thinker_config
.
audio_token_index
video_token_id
=
thinker_config
.
video_token_index
audio_start_token_id
=
thinker_config
.
audio_start_token_id
audio_end_token_id
=
thinker_config
.
audio_end_token_id
seconds_per_chunk
=
thinker_config
.
seconds_per_chunk
spatial_merge_size
=
thinker_config
.
vision_config
.
spatial_merge_size
tokens_per_second
=
getattr
(
thinker_config
.
vision_config
,
"tokens_per_second"
,
25
)
grid_t
=
video_grid_thw
[
0
]
grid_h
=
video_grid_thw
[
1
]
grid_w
=
video_grid_thw
[
2
]
t_ntoken_per_chunk
=
int
(
tokens_per_second
*
seconds_per_chunk
)
t_index
=
(
torch
.
arange
(
grid_t
)
*
video_second_per_grid_t
*
tokens_per_second
).
long
()
t_index_split_chunk
=
cls
.
_split_list_into_ranges
(
t_index
,
t_ntoken_per_chunk
)
updates
=
[
audio_start_token_id
]
added_audio_len
=
0
for
t_chunk
in
t_index_split_chunk
:
vision_ntoken_per_chunk
=
len
(
t_chunk
)
*
grid_h
*
grid_w
//
(
spatial_merge_size
**
2
)
updates
.
extend
([
video_token_id
]
*
vision_ntoken_per_chunk
)
audio_chunk_size
=
min
(
t_ntoken_per_chunk
,
audio_len
-
added_audio_len
)
updates
.
extend
(
audio_chunk_size
*
[
audio_token_id
])
added_audio_len
+=
audio_chunk_size
if
added_audio_len
<
audio_len
:
updates
.
extend
((
audio_len
-
added_audio_len
)
*
[
audio_token_id
])
updates
.
extend
([
audio_end_token_id
])
return
updates
_ROPE_DICT
:
Dict
[
Tuple
,
RotaryEmbedding
]
=
{}
...
...
vllm/model_executor/layers/utils.py
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
"""Utility methods for model layers."""
from
typing
import
Tuple
from
typing
import
Callable
,
Optional
,
Tuple
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
envs
from
vllm.platforms
import
current_platform
def
get_token_bin_counts_and_mask
(
tokens
:
torch
.
Tensor
,
...
...
@@ -47,12 +51,49 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
output_tokens_tensor
,
vocab_size
,
num_seqs
)
repetition_penalties
=
repetition_penalties
.
unsqueeze
(
dim
=
1
).
repeat
(
1
,
vocab_size
)
logits
[
logits
>
0
]
/=
torch
.
where
(
prompt_mask
|
output_mask
,
repetition_penalties
,
1.0
)[
logits
>
0
]
logits
[
logits
<=
0
]
*=
torch
.
where
(
prompt_mask
|
output_mask
,
repetition_penalties
,
1.0
)[
logits
<=
0
]
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
penalties
=
torch
.
where
(
prompt_mask
|
output_mask
,
repetition_penalties
,
1.0
)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
scaling
=
torch
.
where
(
logits
>
0
,
1.0
/
penalties
,
penalties
)
logits
*=
scaling
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits
-=
frequency_penalties
.
unsqueeze
(
dim
=
1
)
*
output_bin_counts
logits
-=
presence_penalties
.
unsqueeze
(
dim
=
1
)
*
output_mask
return
logits
def
rocm_unquantized_gemm
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
):
from
vllm.platforms.rocm
import
on_mi250_mi300
k
=
weight
.
shape
[
1
]
use_skinny
=
(
envs
.
VLLM_ROCM_USE_SKINNY_GEMM
and
on_mi250_mi300
()
and
\
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
\
and
k
%
8
==
0
and
bias
is
None
)
if
use_skinny
is
not
True
:
return
torch
.
nn
.
functional
.
linear
(
x
,
weight
,
bias
)
x_view
=
x
.
view
(
-
1
,
x
.
size
(
-
1
))
n
=
x_view
.
shape
[
0
]
m
=
weight
.
shape
[
0
]
cu_count
=
current_platform
.
get_cu_count
()
if
m
>
8
and
0
<
n
<
4
:
out
=
ops
.
wvSplitK
(
weight
,
x_view
,
cu_count
)
return
out
.
view
(
*
x
.
shape
[:
-
1
],
weight
.
shape
[
0
])
elif
m
%
4
==
0
and
n
==
1
and
k
<=
8192
:
out
=
ops
.
LLMM1
(
weight
,
x_view
,
4
)
return
out
.
view
(
*
x
.
shape
[:
-
1
],
weight
.
shape
[
0
])
return
torch
.
nn
.
functional
.
linear
(
x
,
weight
,
bias
)
def
dispatch_unquantized_gemm
()
->
Callable
[...,
torch
.
Tensor
]:
if
current_platform
.
is_rocm
():
return
rocm_unquantized_gemm
return
torch
.
nn
.
functional
.
linear
vllm/model_executor/layers/vocab_parallel_embedding.py
View file @
081057de
...
...
@@ -12,6 +12,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
method_has_implemented_embedding
)
from
vllm.model_executor.layers.utils
import
dispatch_unquantized_gemm
from
vllm.model_executor.parameter
import
BasevLLMParameter
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
...
...
@@ -40,7 +41,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
F
.
linear
(
x
,
layer
.
weight
,
bias
)
return
dispatch_unquantized_gemm
()
(
x
,
layer
.
weight
,
bias
)
def
embedding
(
self
,
layer
:
torch
.
nn
.
Module
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
vllm/model_executor/model_loader/loader.py
View file @
081057de
...
...
@@ -611,8 +611,12 @@ class ShardedStateLoader(BaseModelLoader):
DEFAULT_PATTERN
=
"model-rank-{rank}-part-{part}.safetensors"
def
__init__
(
self
,
load_config
:
LoadConfig
):
def
__init__
(
self
,
load_config
:
LoadConfig
,
runai_model_streamer
:
bool
=
False
):
super
().
__init__
(
load_config
)
self
.
runai_model_streamer
=
runai_model_streamer
extra_config
=
({}
if
load_config
.
model_loader_extra_config
is
None
else
load_config
.
model_loader_extra_config
.
copy
())
self
.
pattern
=
extra_config
.
pop
(
"pattern"
,
self
.
DEFAULT_PATTERN
)
...
...
@@ -659,7 +663,7 @@ class ShardedStateLoader(BaseModelLoader):
def
_prepare_weights
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
]):
if
os
.
path
.
isdir
(
model_name_or_path
):
if
is_s3
(
model_name_or_path
)
or
os
.
path
.
isdir
(
model_name_or_path
):
return
model_name_or_path
else
:
allow_patterns
=
[
"*.safetensors"
]
...
...
@@ -678,12 +682,13 @@ class ShardedStateLoader(BaseModelLoader):
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
target_device
=
torch
.
device
(
device_config
.
device
)
from
safetensors.torch
import
safe_open
from
vllm.distributed
import
get_tensor_model_parallel_rank
local_model_path
=
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
model_weights
=
model_config
.
model
if
hasattr
(
model_config
,
"model_weights"
):
model_weights
=
model_config
.
model_weights
local_model_path
=
model_weights
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
...
...
@@ -695,40 +700,56 @@ class ShardedStateLoader(BaseModelLoader):
local_model_path
,
self
.
pattern
.
format
(
rank
=
rank
,
part
=
"*"
),
)
filepaths
=
glob
.
glob
(
pattern
)
filepaths
=
[]
if
is_s3
(
local_model_path
):
file_pattern
=
f
"*
{
self
.
pattern
.
format
(
rank
=
rank
,
part
=
' * '
)
}
"
filepaths
=
s3_glob
(
path
=
local_model_path
,
allow_pattern
=
[
file_pattern
])
else
:
filepaths
=
glob
.
glob
(
pattern
)
if
not
filepaths
:
# TODO: support un-sharded checkpoints too
raise
ValueError
(
f
"Could not find checkpoint files '
{
pattern
}
', only "
f
"pre-sharded checkpoints are currently supported!"
)
state_dict
=
self
.
_filter_subtensors
(
model
.
state_dict
())
for
path
in
filepaths
:
with
safe_open
(
path
,
framework
=
"pt"
)
as
f
:
for
key
in
f
.
keys
():
# noqa: SIM118
tensor
=
f
.
get_tensor
(
key
)
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
# narrowed view of the parameter data.
param_data
=
state_dict
[
key
].
data
param_shape
=
state_dict
[
key
].
shape
for
dim
,
size
in
enumerate
(
tensor
.
shape
):
if
size
<
param_shape
[
dim
]:
param_data
=
param_data
.
narrow
(
dim
,
0
,
size
)
if
tensor
.
shape
!=
param_shape
:
logger
.
warning
(
"loading tensor of shape %s into "
"parameter '%s' of shape %s"
,
tensor
.
shape
,
key
,
param_shape
,
)
param_data
.
copy_
(
tensor
)
state_dict
.
pop
(
key
)
for
key
,
tensor
in
self
.
iterate_over_files
(
filepaths
):
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
# narrowed view of the parameter data.
param_data
=
state_dict
[
key
].
data
param_shape
=
state_dict
[
key
].
shape
for
dim
,
size
in
enumerate
(
tensor
.
shape
):
if
size
<
param_shape
[
dim
]:
param_data
=
param_data
.
narrow
(
dim
,
0
,
size
)
if
tensor
.
shape
!=
param_shape
:
logger
.
warning
(
"loading tensor of shape %s into "
"parameter '%s' of shape %s"
,
tensor
.
shape
,
key
,
param_shape
,
)
param_data
.
copy_
(
tensor
)
state_dict
.
pop
(
key
)
if
state_dict
:
raise
ValueError
(
f
"Missing keys
{
tuple
(
state_dict
)
}
in loaded state!"
)
return
model
.
eval
()
def
iterate_over_files
(
self
,
paths
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
if
self
.
runai_model_streamer
:
yield
from
runai_safetensors_weights_iterator
(
paths
,
True
)
else
:
from
safetensors.torch
import
safe_open
for
path
in
paths
:
with
safe_open
(
path
,
framework
=
"pt"
)
as
f
:
for
key
in
f
.
keys
():
# noqa: SIM118
tensor
=
f
.
get_tensor
(
key
)
yield
key
,
tensor
@
staticmethod
def
save_model
(
model
:
torch
.
nn
.
Module
,
...
...
@@ -1515,4 +1536,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if
load_config
.
load_format
==
LoadFormat
.
RUNAI_STREAMER
:
return
RunaiModelStreamerLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
RUNAI_STREAMER_SHARDED
:
return
ShardedStateLoader
(
load_config
,
runai_model_streamer
=
True
)
return
DefaultModelLoader
(
load_config
)
vllm/model_executor/model_loader/utils.py
View file @
081057de
...
...
@@ -30,15 +30,6 @@ def set_default_torch_dtype(dtype: torch.dtype):
torch
.
set_default_dtype
(
old_dtype
)
def
is_transformers_impl_compatible
(
arch
:
str
,
module
:
Optional
[
"transformers.PreTrainedModel"
]
=
None
)
->
bool
:
mod
=
module
or
getattr
(
transformers
,
arch
,
None
)
if
mod
is
None
:
return
False
return
mod
.
is_backend_compatible
()
def
resolve_transformers_arch
(
model_config
:
ModelConfig
,
architectures
:
list
[
str
]):
for
i
,
arch
in
enumerate
(
architectures
):
...
...
@@ -55,20 +46,32 @@ def resolve_transformers_arch(model_config: ModelConfig,
# "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
# },
auto_modules
=
{
name
:
get_class_from_dynamic_module
(
module
,
model_config
.
model
)
name
:
get_class_from_dynamic_module
(
module
,
model_config
.
model
,
revision
=
model_config
.
revision
)
for
name
,
module
in
sorted
(
auto_map
.
items
(),
key
=
lambda
x
:
x
[
0
])
}
custom_model_module
=
auto_modules
.
get
(
"AutoModel"
)
model_module
=
getattr
(
transformers
,
arch
,
None
)
if
model_module
is
None
:
if
"AutoModel"
not
in
auto_map
:
raise
ValueError
(
f
"Cannot find model module. '
{
arch
}
' is not a registered "
"model in the Transformers library (only relevant if the "
"model is meant to be in Transformers) and 'AutoModel' is "
"not present in the model config's 'auto_map' (relevant "
"if the model is custom)."
)
model_module
=
auto_modules
[
"AutoModel"
]
# TODO(Isotr0py): Further clean up these raises.
# perhaps handled them in _ModelRegistry._raise_for_unsupported?
if
model_config
.
model_impl
==
ModelImpl
.
TRANSFORMERS
:
if
not
is_transformers_impl_compatible
(
arch
,
custom_model_module
):
if
not
model_module
.
is_backend_compatible
(
):
raise
ValueError
(
f
"The Transformers implementation of
{
arch
}
is not "
"compatible with vLLM."
)
architectures
[
i
]
=
"TransformersForCausalLM"
if
model_config
.
model_impl
==
ModelImpl
.
AUTO
:
if
not
is_transformers_impl_compatible
(
arch
,
custom_model_module
):
if
not
model_module
.
is_backend_compatible
(
):
raise
ValueError
(
f
"
{
arch
}
has no vLLM implementation and the Transformers "
"implementation is not compatible with vLLM. Try setting "
...
...
@@ -97,10 +100,10 @@ def get_model_architecture(
architectures
=
[
"QuantMixtralForCausalLM"
]
vllm_supported_archs
=
ModelRegistry
.
get_supported_archs
()
is_
vllm_supported
=
any
(
arch
in
vllm_supported_archs
for
arch
in
architectures
)
if
(
not
is_vllm_supported
or
model_config
.
model_impl
=
=
ModelImpl
.
TRANSFORMERS
):
vllm_
not_
supported
=
not
any
(
arch
in
vllm_supported_archs
for
arch
in
architectures
)
if
(
model_config
.
model_impl
==
ModelImpl
.
TRANSFORMERS
or
model_config
.
model_impl
!
=
ModelImpl
.
VLLM
and
vllm_not_supported
):
architectures
=
resolve_transformers_arch
(
model_config
,
architectures
)
model_cls
,
arch
=
ModelRegistry
.
resolve_model_cls
(
architectures
)
...
...
vllm/model_executor/models/arctic.py
View file @
081057de
...
...
@@ -24,7 +24,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from
vllm.model_executor.layers.quantization.deepspeedfp
import
(
DeepSpeedFPConfig
,
DeepSpeedFPParameter
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -435,7 +434,6 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
...
...
@@ -462,14 +460,6 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
...
...
vllm/model_executor/models/aria.py
View file @
081057de
...
...
@@ -15,11 +15,10 @@ from vllm.model_executor.layers.fused_moe import FusedMoE
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
(
SamplerOutput
,
SamplingMetadata
,
get_sampler
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargs
)
...
...
@@ -527,7 +526,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
vocab_size
,
logit_scale
)
self
.
sampler
=
get_sampler
()
def
_validate_image_sizes
(
self
,
images
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
...
...
@@ -653,14 +651,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
loader
=
AutoWeightsLoader
(
self
)
loader
.
load_weights
(
weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
vllm/model_executor/models/aya_vision.py
View file @
081057de
# SPDX-License-Identifier: Apache-2.0 Adapted from
# https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision
from
functools
import
cached_property
from
typing
import
(
Iterable
,
Literal
,
Mapping
,
Optional
,
Sequence
,
Set
,
Tuple
,
TypedDict
,
Union
,
cast
)
...
...
@@ -17,7 +16,6 @@ from transformers.models.got_ocr2.image_processing_got_ocr2 import (
from
vllm.config
import
VllmConfig
from
vllm.jsontree
import
json_map_leaves
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalDataDict
,
MultiModalKwargs
...
...
@@ -461,17 +459,3 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
@
cached_property
def
sampler
(
self
):
if
hasattr
(
self
.
language_model
,
"sampler"
):
return
self
.
language_model
.
sampler
return
get_sampler
()
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
vllm/model_executor/models/baichuan.py
View file @
081057de
...
...
@@ -39,7 +39,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -396,7 +395,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
...
...
@@ -423,14 +421,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
...
...
vllm/model_executor/models/bamba.py
View file @
081057de
...
...
@@ -24,7 +24,6 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2
,
extra_groups_for_head_shards
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -462,7 +461,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
...
...
@@ -538,14 +536,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
...
...
vllm/model_executor/models/bart.py
View file @
081057de
...
...
@@ -37,7 +37,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -791,7 +790,6 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
def
forward
(
self
,
...
...
@@ -828,14 +826,6 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
stacked_params_mapping
=
{
"q_proj"
:
{
"param_name"
:
"qkv_proj"
,
...
...
vllm/model_executor/models/bert.py
View file @
081057de
...
...
@@ -11,8 +11,10 @@ from vllm.compilation.decorators import support_torch_compile
from
vllm.config
import
CacheConfig
,
PoolerConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
(
get_act_and_mul_fn
,
get_act_fn
)
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.pooler
import
(
CrossEncodingPooler
,
Pooler
,
...
...
@@ -108,6 +110,7 @@ class BertEncoder(nn.Module):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
bias
:
bool
=
True
,
rotary_kwargs
:
Optional
[
dict
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
...
...
@@ -118,6 +121,7 @@ class BertEncoder(nn.Module):
BertLayer
(
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
bias
=
bias
,
rotary_kwargs
=
rotary_kwargs
,
prefix
=
f
"
{
prefix
}
.layer.
{
layer_idx
}
"
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
...
...
@@ -139,6 +143,7 @@ class BertLayer(nn.Module):
config
:
BertConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
True
,
rotary_kwargs
:
Optional
[
dict
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
...
...
@@ -149,19 +154,31 @@ class BertLayer(nn.Module):
layer_norm_eps
=
config
.
layer_norm_eps
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
bias
=
bias
,
rotary_kwargs
=
rotary_kwargs
,
prefix
=
f
"
{
prefix
}
.attention"
)
self
.
intermediate
=
BertIntermediate
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.intermediate"
)
if
config
.
hidden_act
in
[
"silu"
,
"gelu_and_mul"
]:
self
.
intermediate
=
BertGatedIntermediate
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.intermediate"
)
else
:
self
.
intermediate
=
BertIntermediate
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.intermediate"
)
self
.
output
=
BertOutput
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
layer_norm_eps
=
config
.
layer_norm_eps
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.output"
)
...
...
@@ -181,6 +198,7 @@ class BertAttention(nn.Module):
layer_norm_eps
:
float
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
True
,
rotary_kwargs
:
Optional
[
dict
]
=
None
,
prefix
:
str
=
""
,
):
...
...
@@ -190,11 +208,13 @@ class BertAttention(nn.Module):
num_attention_heads
=
num_attention_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
bias
=
bias
,
rotary_kwargs
=
rotary_kwargs
,
prefix
=
f
"
{
prefix
}
.output"
)
self
.
output
=
BertSelfOutput
(
hidden_size
=
hidden_size
,
layer_norm_eps
=
layer_norm_eps
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.output"
)
...
...
@@ -215,6 +235,7 @@ class BertSelfAttention(nn.Module):
num_attention_heads
:
int
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
True
,
rotary_kwargs
:
Optional
[
dict
]
=
None
,
prefix
:
str
=
""
,
):
...
...
@@ -240,7 +261,7 @@ class BertSelfAttention(nn.Module):
head_size
=
self
.
head_dim
,
total_num_heads
=
self
.
total_num_heads
,
total_num_kv_heads
=
self
.
total_num_kv_heads
,
bias
=
True
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
)
...
...
@@ -278,12 +299,13 @@ class BertSelfOutput(nn.Module):
def
__init__
(
self
,
hidden_size
:
int
,
layer_norm_eps
:
float
,
bias
:
bool
=
True
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
dense
=
RowParallelLinear
(
input_size
=
hidden_size
,
output_size
=
hidden_size
,
bias
=
True
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.dense"
)
self
.
LayerNorm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
layer_norm_eps
)
...
...
@@ -301,12 +323,13 @@ class BertIntermediate(nn.Module):
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
bias
:
bool
=
True
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
dense
=
ColumnParallelLinear
(
input_size
=
hidden_size
,
output_size
=
intermediate_size
,
bias
=
True
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.dense"
)
self
.
intermediate_act_fn
=
get_act_fn
(
hidden_act
)
...
...
@@ -317,19 +340,46 @@ class BertIntermediate(nn.Module):
return
hidden_states
class
BertGatedIntermediate
(
nn
.
Module
):
# for NomciBert and GteModel
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
bias
:
bool
=
True
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
act_fn
=
get_act_and_mul_fn
(
hidden_act
)
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
gate_up
,
_
=
self
.
gate_up_proj
(
hidden_states
)
hidden_states
=
self
.
act_fn
(
gate_up
)
return
hidden_states
class
BertOutput
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
layer_norm_eps
:
float
,
bias
:
bool
=
True
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
dense
=
RowParallelLinear
(
input_size
=
intermediate_size
,
output_size
=
hidden_size
,
bias
=
True
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.dense"
)
...
...
@@ -343,19 +393,32 @@ class BertOutput(nn.Module):
class
BertModel
(
nn
.
Module
,
SupportsQuant
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"query"
,
"key"
,
"value"
]}
packed_modules_mapping
=
{
"qkv_proj"
:
[
"query"
,
"key"
,
"value"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
embedding_class
:
type
=
BertEmbedding
,
bias
:
bool
=
True
,
rotary_kwargs
:
Optional
[
dict
]
=
None
,
add_pooling_layer
:
bool
=
False
):
super
().
__init__
()
"""
For BertModel, all linear layers have bias.
For NomicBertModel, all linear layers do not have bias.
"""
config
=
vllm_config
.
model_config
.
hf_config
self
.
embeddings
=
embedding_class
(
config
)
self
.
encoder
=
BertEncoder
(
vllm_config
=
vllm_config
,
bias
=
bias
,
rotary_kwargs
=
rotary_kwargs
,
prefix
=
f
"
{
prefix
}
.encoder"
)
self
.
pooler
=
BertPooler
(
config
)
if
add_pooling_layer
else
None
...
...
@@ -387,6 +450,8 @@ class BertModel(nn.Module, SupportsQuant):
(
"qkv_proj"
,
"query"
,
"q"
),
(
"qkv_proj"
,
"key"
,
"k"
),
(
"qkv_proj"
,
"value"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
...
...
@@ -546,3 +611,115 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
,
token_type_ids
=
token_type_ids
)
class
NomicBertEmbeddingModel
(
BertEmbeddingModel
):
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_substr
=
{
"emb_ln"
:
"embeddings.LayerNorm"
,
"layers"
:
"layer"
,
"attn.Wqkv"
:
"attention.self.qkv_proj"
,
"attn.out_proj"
:
"attention.output.dense"
,
'norm1'
:
"attention.output.LayerNorm"
,
'mlp.fc11'
:
"intermediate.up_proj"
,
'mlp.fc12'
:
"intermediate.gate_proj"
,
'mlp.fc2'
:
"output.dense"
,
'norm2'
:
"output.LayerNorm"
,
})
def
_build_model
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
BertModel
:
config
=
vllm_config
.
model_config
.
hf_config
assert
config
.
__class__
.
__name__
==
"NomicBertConfig"
assert
config
.
activation_function
==
"swiglu"
# Assume NomicBertModel all linear layers do not have bias
assert
not
config
.
mlp_fc1_bias
assert
not
config
.
mlp_fc2_bias
assert
not
config
.
qkv_proj_bias
config
.
layer_norm_eps
=
config
.
layer_norm_epsilon
config
.
position_embedding_type
=
"rotary"
config
.
intermediate_size
=
config
.
n_inner
config
.
hidden_act
=
"silu"
config
.
hidden_size
=
config
.
n_embd
config
.
num_hidden_layers
=
config
.
n_layer
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
rotary_kwargs
=
{
"head_size"
:
head_dim
,
"rotary_dim"
:
getattr
(
config
,
"rotary_emb_dim"
,
head_dim
),
"max_position"
:
config
.
max_trained_positions
,
"base"
:
config
.
rotary_emb_base
,
"rope_scaling"
:
{
"rope_type"
:
"dynamic"
,
"factor"
:
config
.
rotary_scaling_factor
}
}
return
BertModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
bias
=
False
,
rotary_kwargs
=
rotary_kwargs
,
embedding_class
=
BertEmbedding
)
class
GteEmbeddingModel
(
BertEmbeddingModel
):
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_substr
=
{
"attention.qkv_proj"
:
"attention.self.qkv_proj"
,
"attention.o_proj"
:
"attention.output.dense"
,
'attn_ln'
:
"attention.output.LayerNorm"
,
'mlp.down_proj'
:
"output.dense"
,
'mlp_ln'
:
"output.LayerNorm"
,
})
def
_build_model
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
BertModel
:
config
=
vllm_config
.
model_config
.
hf_config
assert
config
.
__class__
.
__name__
==
"GteConfig"
assert
config
.
position_embedding_type
==
"rope"
assert
config
.
hidden_act
==
"gelu"
config
.
position_embedding_type
=
"rotary"
config
.
hidden_act
=
"gelu_and_mul"
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
rotary_kwargs
=
{
"head_size"
:
head_dim
,
"rotary_dim"
:
getattr
(
config
,
"rotary_emb_dim"
,
head_dim
),
"max_position"
:
config
.
max_position_embeddings
,
"base"
:
config
.
rope_theta
,
}
model
=
BertModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
rotary_kwargs
=
rotary_kwargs
,
embedding_class
=
BertEmbedding
)
# GteModel only gate_up_proj does not have bias.
# Hack method learned from vllm/model_executor/models/glm.py
for
layer
in
model
.
encoder
.
layer
:
layer
.
intermediate
.
gate_up_proj
.
bias
=
None
layer
.
intermediate
.
skip_bias_add
=
True
return
model
def
split_up_gate_proj
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
n
=
"mlp.up_gate_proj"
for
name
,
weight
in
weights
:
if
n
in
name
:
up
,
gate
=
weight
.
chunk
(
2
,
dim
=
0
)
yield
name
.
replace
(
n
,
"intermediate.up_proj"
),
up
yield
name
.
replace
(
n
,
"intermediate.gate_proj"
),
gate
else
:
yield
name
,
weight
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
weights
=
self
.
hf_to_vllm_mapper
.
apply
(
weights
)
weights
=
self
.
split_up_gate_proj
(
weights
)
self
.
model
.
load_weights
(
weights
)
vllm/model_executor/models/blip2.py
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
functools
import
cached_property
from
typing
import
Literal
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
import
torch
...
...
@@ -12,7 +11,6 @@ from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig,
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
...
...
@@ -62,6 +60,7 @@ class Blip2QFormerMultiHeadAttention(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
],
cache_config
:
Optional
[
CacheConfig
],
is_cross_attention
:
bool
=
False
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
...
...
@@ -141,7 +140,7 @@ class Blip2QFormerMultiHeadAttention(nn.Module):
class
Blip2QFormerSelfOutput
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Blip2QFormerConfig
)
->
None
:
def
__init__
(
self
,
config
:
Blip2QFormerConfig
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
...
...
@@ -169,6 +168,7 @@ class Blip2QFormerAttention(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
],
cache_config
:
Optional
[
CacheConfig
],
is_cross_attention
:
bool
=
False
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
...
...
@@ -177,9 +177,10 @@ class Blip2QFormerAttention(nn.Module):
quant_config
=
quant_config
,
cache_config
=
cache_config
,
is_cross_attention
=
is_cross_attention
,
prefix
=
f
"
{
prefix
}
.attention"
,
)
self
.
output
=
Blip2QFormerSelfOutput
(
config
)
self
.
output
=
Blip2QFormerSelfOutput
(
config
,
prefix
=
f
"
{
prefix
}
.output"
)
def
forward
(
self
,
...
...
@@ -197,7 +198,7 @@ class Blip2QFormerAttention(nn.Module):
class
Blip2QFormerIntermediate
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Blip2QFormerConfig
)
->
None
:
def
__init__
(
self
,
config
:
Blip2QFormerConfig
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
...
...
@@ -211,7 +212,7 @@ class Blip2QFormerIntermediate(nn.Module):
class
Blip2QFormerOutput
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Blip2QFormerConfig
)
->
None
:
def
__init__
(
self
,
config
:
Blip2QFormerConfig
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
intermediate_size
,
config
.
hidden_size
)
...
...
@@ -239,6 +240,7 @@ class Blip2QFormerLayer(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
],
cache_config
:
Optional
[
CacheConfig
],
layer_idx
:
int
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
...
...
@@ -246,7 +248,8 @@ class Blip2QFormerLayer(nn.Module):
self
.
seq_len_dim
=
1
self
.
attention
=
Blip2QFormerAttention
(
config
,
quant_config
=
quant_config
,
cache_config
=
cache_config
)
cache_config
=
cache_config
,
prefix
=
f
"
{
prefix
}
.attention"
)
self
.
layer_idx
=
layer_idx
...
...
@@ -255,13 +258,16 @@ class Blip2QFormerLayer(nn.Module):
config
,
quant_config
=
quant_config
,
cache_config
=
cache_config
,
is_cross_attention
=
True
)
is_cross_attention
=
True
,
prefix
=
f
"
{
prefix
}
.crossattention"
)
self
.
has_cross_attention
=
True
else
:
self
.
has_cross_attention
=
False
self
.
intermediate_query
=
Blip2QFormerIntermediate
(
config
)
self
.
output_query
=
Blip2QFormerOutput
(
config
)
self
.
intermediate_query
=
Blip2QFormerIntermediate
(
config
,
prefix
=
f
"
{
prefix
}
.intermediate_query"
)
self
.
output_query
=
Blip2QFormerOutput
(
config
,
prefix
=
f
"
{
prefix
}
.output_query"
)
def
forward
(
self
,
...
...
@@ -327,6 +333,7 @@ class Blip2QFormerEncoder(nn.Module):
*
,
quant_config
:
Optional
[
QuantizationConfig
],
cache_config
:
Optional
[
CacheConfig
],
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
...
...
@@ -336,7 +343,8 @@ class Blip2QFormerEncoder(nn.Module):
Blip2QFormerLayer
(
config
,
quant_config
=
quant_config
,
cache_config
=
cache_config
,
layer_idx
=
layer_idx
)
layer_idx
=
layer_idx
,
prefix
=
f
"
{
prefix
}
.layer.
{
layer_idx
}
"
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
])
...
...
@@ -367,6 +375,7 @@ class Blip2QFormerModel(nn.Module):
*
,
quant_config
:
Optional
[
QuantizationConfig
],
cache_config
:
Optional
[
CacheConfig
],
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
...
...
@@ -378,7 +387,8 @@ class Blip2QFormerModel(nn.Module):
self
.
encoder
=
Blip2QFormerEncoder
(
config
,
quant_config
=
quant_config
,
cache_config
=
cache_config
)
cache_config
=
cache_config
,
prefix
=
f
"
{
prefix
}
.encoder"
)
def
forward
(
self
,
...
...
@@ -513,7 +523,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self
.
qformer
=
Blip2QFormerModel
(
config
.
qformer_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qformer"
)
self
.
language_projection
=
nn
.
Linear
(
config
.
qformer_config
.
hidden_size
,
...
...
@@ -530,13 +541,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
@
cached_property
def
sampler
(
self
):
if
hasattr
(
self
.
language_model
,
"sampler"
):
return
self
.
language_model
.
sampler
return
get_sampler
()
def
_validate_pixel_values
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
h
=
w
=
self
.
config
.
vision_config
.
image_size
expected_dims
=
(
3
,
h
,
w
)
...
...
@@ -649,7 +653,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
)
->
Union
[
SamplerOutput
,
IntermediateTensors
]
:
)
->
IntermediateTensors
:
"""Run forward pass for BLIP-2.
One key thing to understand is the `input_ids` already accounts for the
...
...
@@ -707,13 +711,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
...
...
vllm/model_executor/models/bloom.py
View file @
081057de
...
...
@@ -35,7 +35,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -297,7 +296,6 @@ class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant):
self
.
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
transformer
.
make_empty_intermediate_tensors
)
...
...
@@ -324,14 +322,6 @@ class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant):
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
...
...
Prev
1
…
21
22
23
24
25
26
27
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment