Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b9e12416
Commit
b9e12416
authored
May 31, 2024
by
zhuwenwen
Browse files
merge v0.4.3
parents
e5d707db
e9d3aa04
Changes
345
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
313 additions
and
79 deletions
+313
-79
vllm/model_executor/layers/quantization/utils/quant_utils.py
vllm/model_executor/layers/quantization/utils/quant_utils.py
+146
-0
vllm/model_executor/layers/rejection_sampler.py
vllm/model_executor/layers/rejection_sampler.py
+10
-2
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+71
-16
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+75
-53
vllm/model_executor/model_loader/__init__.py
vllm/model_executor/model_loader/__init__.py
+11
-8
No files found.
Too many changes to show.
To preserve performance only
345 of 345+
files are displayed.
Plain diff
Email patch
vllm/model_executor/layers/quantization/utils/quant_utils.py
0 → 100644
View file @
b9e12416
"""This file is used for /tests and /benchmarks"""
import
numpy
import
torch
SUPPORTED_NUM_BITS
=
[
4
,
8
]
SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
def
get_pack_factor
(
num_bits
):
assert
num_bits
in
SUPPORTED_NUM_BITS
,
f
"Unsupported num_bits =
{
num_bits
}
"
return
32
//
num_bits
def
permute_rows
(
q_w
:
torch
.
Tensor
,
w_ref
:
torch
.
Tensor
,
group_size
:
int
):
assert
q_w
.
shape
==
w_ref
.
shape
orig_device
=
q_w
.
device
k_size
,
_
=
q_w
.
shape
g_idx
=
torch
.
zeros
((
k_size
,
),
dtype
=
torch
.
int32
)
for
i
in
range
(
k_size
):
g_idx
[
i
]
=
i
//
group_size
# Simulate act_order by doing a random permutation on K
rand_perm
=
torch
.
randperm
(
k_size
)
g_idx
=
g_idx
[
rand_perm
].
contiguous
()
q_w
=
q_w
[
rand_perm
,
:].
contiguous
()
w_ref
=
w_ref
[
rand_perm
,
:].
contiguous
()
return
(
w_ref
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
g_idx
.
to
(
device
=
orig_device
),
rand_perm
.
to
(
device
=
orig_device
),
)
def
quantize_weights
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
act_order
:
bool
):
orig_device
=
w
.
device
size_k
,
size_n
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
assert
num_bits
in
SUPPORTED_NUM_BITS
,
f
"Unsupported num_bits =
{
num_bits
}
"
assert
group_size
in
SUPPORTED_GROUP_SIZES
+
[
size_k
],
f
"Unsupported groupsize =
{
group_size
}
"
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
max_q_val
=
2
**
num_bits
-
1
half_q_val
=
(
max_q_val
+
1
)
//
2
# Reshape to [groupsize, -1]
if
group_size
<
size_k
:
w
=
w
.
reshape
((
-
1
,
group_size
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
group_size
,
-
1
))
# Compute scale for each group
s
=
torch
.
max
(
torch
.
abs
(
w
),
0
,
keepdim
=
True
)[
0
]
s
*=
2
/
max_q_val
# 2 => symmetric
# Quantize
q_w
=
torch
.
round
(
w
/
s
).
int
()
q_w
+=
half_q_val
q_w
=
torch
.
clamp
(
q_w
,
0
,
max_q_val
)
# Compute ref (dequantized)
w_ref
=
(
q_w
-
half_q_val
).
half
()
*
s
# Restore original shapes
if
group_size
<
size_k
:
def
reshape_w
(
w
):
w
=
w
.
reshape
((
group_size
,
-
1
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
size_k
,
size_n
)).
contiguous
()
return
w
q_w
=
reshape_w
(
q_w
)
w_ref
=
reshape_w
(
w_ref
)
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
# Apply act_order
g_idx
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
rand_perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
if
act_order
:
assert
(
group_size
<
size_k
),
"For act_order, groupsize = {} must be less than size_k = {}"
.
format
(
group_size
,
size_k
)
w_ref
,
q_w
,
g_idx
,
rand_perm
=
permute_rows
(
q_w
,
w_ref
,
group_size
)
return
(
w_ref
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
s
.
to
(
device
=
orig_device
),
g_idx
.
to
(
device
=
orig_device
),
rand_perm
.
to
(
device
=
orig_device
),
)
def
sort_weights
(
q_w
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
):
orig_device
=
q_w
.
device
sort_indices
=
torch
.
argsort
(
g_idx
).
to
(
dtype
=
torch
.
int32
)
# Sort based on g_idx
g_idx
=
g_idx
[
sort_indices
].
contiguous
()
q_w
=
q_w
[
sort_indices
,
:].
contiguous
()
return
(
q_w
.
to
(
device
=
orig_device
),
g_idx
.
to
(
device
=
orig_device
),
sort_indices
.
to
(
device
=
orig_device
),
)
def
gptq_pack
(
q_w
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
pack_factor
=
get_pack_factor
(
num_bits
)
assert
size_k
%
pack_factor
==
0
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_res
=
numpy
.
zeros
((
size_k
//
pack_factor
,
size_n
),
dtype
=
numpy
.
uint32
)
for
i
in
range
(
pack_factor
):
q_res
|=
q_w
[
i
::
pack_factor
,
:]
<<
num_bits
*
i
q_res
=
torch
.
from_numpy
(
q_res
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
return
q_res
vllm/model_executor/layers/rejection_sampler.py
View file @
b9e12416
...
@@ -12,15 +12,21 @@ class RejectionSampler(nn.Module):
...
@@ -12,15 +12,21 @@ class RejectionSampler(nn.Module):
https://arxiv.org/pdf/2302.01318.pdf.
https://arxiv.org/pdf/2302.01318.pdf.
"""
"""
def
__init__
(
self
,
strict_mode
:
bool
=
False
):
def
__init__
(
self
,
disable_bonus_tokens
:
bool
=
True
,
strict_mode
:
bool
=
False
):
"""Create a rejection sampler.
"""Create a rejection sampler.
Args:
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
during sampling. This catches correctness issues but adds
nontrivial latency.
nontrivial latency.
"""
"""
super
().
__init__
()
super
().
__init__
()
self
.
_disable_bonus_tokens
=
disable_bonus_tokens
self
.
_strict_mode
=
strict_mode
self
.
_strict_mode
=
strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are
# NOTE: A "bonus token" is accepted iff all proposal tokens are
...
@@ -116,6 +122,7 @@ class RejectionSampler(nn.Module):
...
@@ -116,6 +122,7 @@ class RejectionSampler(nn.Module):
draft_token_ids
,
draft_token_ids
,
bonus_token_ids
,
bonus_token_ids
,
)
)
return
output_token_ids
return
output_token_ids
def
_batch_modified_rejection_sampling
(
def
_batch_modified_rejection_sampling
(
...
@@ -312,7 +319,8 @@ class RejectionSampler(nn.Module):
...
@@ -312,7 +319,8 @@ class RejectionSampler(nn.Module):
# proposal methods that require KV cache. We can fix it by "prefilling"
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
# https://github.com/vllm-project/vllm/issues/4212
output_with_bonus_tokens
[:,
-
1
]
=
-
1
if
self
.
_disable_bonus_tokens
:
output_with_bonus_tokens
[:,
-
1
]
=
-
1
# Fill the recovered token ids.
# Fill the recovered token ids.
output
.
mul_
(
~
after_false_mask
).
add_
(
output
.
mul_
(
~
after_false_mask
).
add_
(
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
b9e12416
...
@@ -53,6 +53,7 @@ class RotaryEmbedding(nn.Module):
...
@@ -53,6 +53,7 @@ class RotaryEmbedding(nn.Module):
max_position_embeddings
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
base
:
int
,
is_neox_style
:
bool
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
head_size
=
head_size
self
.
head_size
=
head_size
...
@@ -60,9 +61,10 @@ class RotaryEmbedding(nn.Module):
...
@@ -60,9 +61,10 @@ class RotaryEmbedding(nn.Module):
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
self
.
base
=
base
self
.
base
=
base
self
.
is_neox_style
=
is_neox_style
self
.
is_neox_style
=
is_neox_style
self
.
dtype
=
dtype
cache
=
self
.
_compute_cos_sin_cache
()
cache
=
self
.
_compute_cos_sin_cache
()
cache
=
cache
.
to
(
torch
.
get_default_
dtype
()
)
cache
=
cache
.
to
(
dtype
)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
...
@@ -109,7 +111,7 @@ class RotaryEmbedding(nn.Module):
...
@@ -109,7 +111,7 @@ class RotaryEmbedding(nn.Module):
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
self
.
cos_sin_cache
:
torch
.
Tensor
=
self
.
cos_sin_cache
.
to
(
self
.
cos_sin_cache
:
torch
.
Tensor
=
self
.
cos_sin_cache
.
to
(
positions
.
device
)
positions
.
device
,
dtype
=
query
.
dtype
)
cos_sin
=
self
.
cos_sin_cache
[
torch
.
add
(
positions
,
offsets
)
cos_sin
=
self
.
cos_sin_cache
[
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
]
if
offsets
is
not
None
else
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
...
@@ -143,7 +145,8 @@ class RotaryEmbedding(nn.Module):
...
@@ -143,7 +145,8 @@ class RotaryEmbedding(nn.Module):
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
positions
.
device
)
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
positions
.
device
,
dtype
=
query
.
dtype
)
# ops.rotary_embedding()/batched_rotary_embedding()
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
# are in-place operations that update the query and key tensors.
if
offsets
is
not
None
:
if
offsets
is
not
None
:
...
@@ -166,6 +169,29 @@ class RotaryEmbedding(nn.Module):
...
@@ -166,6 +169,29 @@ class RotaryEmbedding(nn.Module):
class
LinearScalingRotaryEmbedding
(
RotaryEmbedding
):
class
LinearScalingRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with linear scaling.
"""RotaryEmbedding extended with linear scaling.
It supports multiple scaling factors. Since multiple LoRA adapters may have
different scaling factors, we need multiple cos/sin caches. In this way,
instead of running rotary embedding kernel per lora, we can run multiple
lora in a batched way.
In addition to that, we also keep the cos/sin cache for the scaling factor
of 1 (default) at all times.
Exemplary for two scaling factors x=1, y and z with embeddings
[[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and
[[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and
[[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]],
we construct the cos/sin cache as follows:
[[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p],
...
[xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]]
We then use offsets to index into the cos/sin cache for
the respective scaling factors.
The offset to cache can be accessed via `scaling_factor_to_offset` API.
Credits to the Reddit user /u/kaiokendev
Credits to the Reddit user /u/kaiokendev
"""
"""
...
@@ -177,16 +203,22 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
...
@@ -177,16 +203,22 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
base
:
int
,
base
:
int
,
is_neox_style
:
bool
,
is_neox_style
:
bool
,
scaling_factors
:
Union
[
List
[
float
],
float
],
scaling_factors
:
Union
[
List
[
float
],
float
],
dtype
:
torch
.
dtype
,
)
->
None
:
)
->
None
:
if
isinstance
(
scaling_factors
,
float
):
if
isinstance
(
scaling_factors
,
float
):
scaling_factors
=
[
scaling_factors
]
scaling_factors
=
[
scaling_factors
]
self
.
scaling_factors
=
scaling_factors
self
.
scaling_factors
:
List
[
float
]
=
scaling_factors
# noqa
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
)
is_neox_style
,
dtype
)
# Lazy initialized.
self
.
_scaling_factor_to_offset
:
Dict
[
float
,
int
]
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
cache_list
=
[]
cache_list
:
List
[
torch
.
Tensor
]
=
[]
# offsets to the next cache in a tensor.
# Each offset corresponds to the same index in scaling_factors.
offsets
:
List
[
int
]
=
[]
for
scaling_factor
in
self
.
scaling_factors
:
for
scaling_factor
in
self
.
scaling_factors
:
# NOTE(woosuk): self.max_position_embeddings is the original
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
# maximum length before applying the rope scaling.
...
@@ -200,9 +232,25 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
...
@@ -200,9 +232,25 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
cos
=
freqs
.
cos
()
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
if
not
cache_list
:
offset
=
0
else
:
last_offset
=
offsets
[
-
1
]
next_max_len
=
cache_list
[
-
1
].
shape
[
0
]
offset
=
last_offset
+
next_max_len
offsets
.
append
(
offset
)
cache_list
.
append
(
cache
)
cache_list
.
append
(
cache
)
self
.
_scaling_factor_to_offset
=
{
float
(
scaling_factor
):
offsets
[
i
]
for
i
,
scaling_factor
in
enumerate
(
self
.
scaling_factors
)
}
assert
len
(
self
.
scaling_factors
)
==
len
(
offsets
)
return
torch
.
cat
(
cache_list
,
dim
=
0
)
return
torch
.
cat
(
cache_list
,
dim
=
0
)
@
property
def
scaling_factor_to_offset
(
self
)
->
Dict
[
float
,
int
]:
return
self
.
_scaling_factor_to_offset
class
DynamicNTKScalingRotaryEmbedding
(
RotaryEmbedding
):
class
DynamicNTKScalingRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with Dynamic NTK scaling.
"""RotaryEmbedding extended with Dynamic NTK scaling.
...
@@ -218,10 +266,11 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
...
@@ -218,10 +266,11 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
base
:
int
,
base
:
int
,
is_neox_style
:
bool
,
is_neox_style
:
bool
,
scaling_factor
:
float
,
scaling_factor
:
float
,
dtype
:
torch
.
dtype
,
)
->
None
:
)
->
None
:
self
.
scaling_factor
=
scaling_factor
self
.
scaling_factor
=
scaling_factor
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
)
is_neox_style
,
dtype
)
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
# NOTE(woosuk): self.max_position_embeddings is the original
# NOTE(woosuk): self.max_position_embeddings is the original
...
@@ -298,6 +347,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
...
@@ -298,6 +347,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
base
:
int
,
base
:
int
,
is_neox_style
:
bool
,
is_neox_style
:
bool
,
scaling_factor
:
float
,
scaling_factor
:
float
,
dtype
:
torch
.
dtype
,
*
,
*
,
extrapolation_factor
:
float
=
1
,
extrapolation_factor
:
float
=
1
,
attn_factor
:
float
=
1
,
attn_factor
:
float
=
1
,
...
@@ -313,7 +363,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
...
@@ -313,7 +363,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
self
.
mscale
=
float
(
self
.
mscale
=
float
(
_yarn_get_mscale
(
self
.
scaling_factor
)
*
attn_factor
)
_yarn_get_mscale
(
self
.
scaling_factor
)
*
attn_factor
)
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
)
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
scaling_factor
:
float
)
->
torch
.
Tensor
:
def
_compute_inv_freq
(
self
,
scaling_factor
:
float
)
->
torch
.
Tensor
:
pos_freqs
=
self
.
base
**
(
pos_freqs
=
self
.
base
**
(
...
@@ -358,6 +408,7 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
...
@@ -358,6 +408,7 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
original_max_position_embeddings
:
int
,
original_max_position_embeddings
:
int
,
base
:
int
,
base
:
int
,
is_neox_style
:
bool
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
short_factor
:
List
[
float
],
short_factor
:
List
[
float
],
long_factor
:
List
[
float
],
long_factor
:
List
[
float
],
short_mscale
:
float
=
1.1
,
short_mscale
:
float
=
1.1
,
...
@@ -384,14 +435,14 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
...
@@ -384,14 +435,14 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
short_cache
=
self
.
_compute_cos_sin_cache
(
short_cache
=
self
.
_compute_cos_sin_cache
(
original_max_position_embeddings
,
short_factor
,
short_mscale
)
original_max_position_embeddings
,
short_factor
,
short_mscale
)
short_cache
=
short_cache
.
to
(
torch
.
get_default_
dtype
()
)
short_cache
=
short_cache
.
to
(
dtype
)
self
.
register_buffer
(
"short_cos_sin_cache"
,
self
.
register_buffer
(
"short_cos_sin_cache"
,
short_cache
,
short_cache
,
persistent
=
False
)
persistent
=
False
)
long_cache
=
self
.
_compute_cos_sin_cache
(
max_position_embeddings
,
long_cache
=
self
.
_compute_cos_sin_cache
(
max_position_embeddings
,
long_factor
,
long_mscale
)
long_factor
,
long_mscale
)
long_cache
=
long_cache
.
to
(
torch
.
get_default_
dtype
()
)
long_cache
=
long_cache
.
to
(
dtype
)
self
.
register_buffer
(
"long_cos_sin_cache"
,
self
.
register_buffer
(
"long_cos_sin_cache"
,
long_cache
,
long_cache
,
persistent
=
False
)
persistent
=
False
)
...
@@ -462,7 +513,10 @@ def get_rope(
...
@@ -462,7 +513,10 @@ def get_rope(
base
:
int
,
base
:
int
,
is_neox_style
:
bool
=
True
,
is_neox_style
:
bool
=
True
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
)
->
RotaryEmbedding
:
)
->
RotaryEmbedding
:
if
dtype
is
None
:
dtype
=
torch
.
get_default_dtype
()
if
rope_scaling
is
not
None
:
if
rope_scaling
is
not
None
:
# Transforms every value that is a list into a tuple for caching calls
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple
=
{
rope_scaling_tuple
=
{
...
@@ -473,12 +527,12 @@ def get_rope(
...
@@ -473,12 +527,12 @@ def get_rope(
else
:
else
:
rope_scaling_args
=
None
rope_scaling_args
=
None
key
=
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
key
=
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
rope_scaling_args
)
rope_scaling_args
,
dtype
)
if
key
in
_ROPE_DICT
:
if
key
in
_ROPE_DICT
:
return
_ROPE_DICT
[
key
]
return
_ROPE_DICT
[
key
]
if
rope_scaling
is
None
:
if
rope_scaling
is
None
:
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
)
is_neox_style
,
dtype
)
else
:
else
:
scaling_type
=
rope_scaling
[
"type"
]
scaling_type
=
rope_scaling
[
"type"
]
if
scaling_type
!=
"su"
:
if
scaling_type
!=
"su"
:
...
@@ -487,11 +541,11 @@ def get_rope(
...
@@ -487,11 +541,11 @@ def get_rope(
rotary_emb
=
LinearScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
rotary_emb
=
LinearScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
max_position
,
base
,
is_neox_style
,
is_neox_style
,
scaling_factor
)
scaling_factor
,
dtype
)
elif
scaling_type
==
"dynamic"
:
elif
scaling_type
==
"dynamic"
:
rotary_emb
=
DynamicNTKScalingRotaryEmbedding
(
rotary_emb
=
DynamicNTKScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_factor
)
scaling_factor
,
dtype
)
elif
scaling_type
==
"yarn"
:
elif
scaling_type
==
"yarn"
:
original_max_position
=
rope_scaling
[
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
"original_max_position_embeddings"
]
...
@@ -504,7 +558,7 @@ def get_rope(
...
@@ -504,7 +558,7 @@ def get_rope(
rotary_emb
=
YaRNScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
rotary_emb
=
YaRNScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
original_max_position
,
original_max_position
,
base
,
is_neox_style
,
base
,
is_neox_style
,
scaling_factor
,
scaling_factor
,
dtype
,
**
extra_kwargs
)
**
extra_kwargs
)
elif
scaling_type
==
"su"
:
elif
scaling_type
==
"su"
:
short_factor
=
rope_scaling
[
"short_factor"
]
short_factor
=
rope_scaling
[
"short_factor"
]
...
@@ -518,7 +572,8 @@ def get_rope(
...
@@ -518,7 +572,8 @@ def get_rope(
}
}
rotary_emb
=
Phi3SuScaledRotaryEmbedding
(
rotary_emb
=
Phi3SuScaledRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
original_max_position
,
head_size
,
rotary_dim
,
max_position
,
original_max_position
,
base
,
is_neox_style
,
short_factor
,
long_factor
,
**
extra_kwargs
)
base
,
is_neox_style
,
dtype
,
short_factor
,
long_factor
,
**
extra_kwargs
)
else
:
else
:
raise
ValueError
(
f
"Unknown RoPE scaling type
{
scaling_type
}
"
)
raise
ValueError
(
f
"Unknown RoPE scaling type
{
scaling_type
}
"
)
_ROPE_DICT
[
key
]
=
rotary_emb
_ROPE_DICT
[
key
]
=
rotary_emb
...
...
vllm/model_executor/layers/sampler.py
View file @
b9e12416
...
@@ -10,8 +10,9 @@ from vllm.model_executor.sampling_metadata import (SamplingMetadata,
...
@@ -10,8 +10,9 @@ from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors
,
SamplingTensors
,
SequenceGroupToSample
)
SequenceGroupToSample
)
from
vllm.sampling_params
import
SamplingType
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
(
Logprob
,
PromptLogprobs
,
SampleLogprobs
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
SamplerOutput
,
SequenceGroupOutput
,
SequenceOutput
)
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
SequenceOutput
)
# (num_token_ids, num_parent_ids) per sequence group.
# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType
=
List
[
Tuple
[
List
[
int
],
List
[
int
]]]
SampleResultType
=
List
[
Tuple
[
List
[
int
],
List
[
int
]]]
...
@@ -680,7 +681,9 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
...
@@ -680,7 +681,9 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
"""
"""
vals
=
x
[
torch
.
arange
(
0
,
len
(
x
),
device
=
x
.
device
,
dtype
=
indices
.
dtype
),
vals
=
x
[
torch
.
arange
(
0
,
len
(
x
),
device
=
x
.
device
,
dtype
=
indices
.
dtype
),
indices
]
indices
]
return
(
x
>
vals
[:,
None
]).
long
().
sum
(
1
).
add_
(
1
)
result
=
(
x
>
vals
[:,
None
])
del
vals
return
result
.
sum
(
1
).
add_
(
1
)
def
_get_logprobs
(
def
_get_logprobs
(
...
@@ -782,13 +785,14 @@ def _get_logprobs(
...
@@ -782,13 +785,14 @@ def _get_logprobs(
top_logprobs
,
top_token_ids
=
torch
.
topk
(
logprobs
,
top_logprobs
,
top_token_ids
=
torch
.
topk
(
logprobs
,
largest_num_logprobs
,
largest_num_logprobs
,
dim
=-
1
)
dim
=-
1
)
top_logprobs
=
top_logprobs
.
cpu
()
top_token_ids
=
top_token_ids
.
cpu
()
else
:
else
:
top_logprobs
,
top_token_ids
=
None
,
None
top_logprobs
,
top_token_ids
=
None
,
None
selected_logprobs
=
selected_logprobs
.
cpu
()
selected_logprobs
=
selected_logprobs
.
to
(
'cpu'
)
ranks
=
ranks
.
cpu
()
ranks
=
ranks
.
to
(
'cpu'
)
if
top_logprobs
is
not
None
and
top_token_ids
is
not
None
:
top_logprobs
=
top_logprobs
.
to
(
'cpu'
)
top_token_ids
=
top_token_ids
.
to
(
'cpu'
)
# Find prompt/sample logprobs.
# Find prompt/sample logprobs.
prompt_logprobs_per_seq_group
:
List
[
Optional
[
PromptLogprobs
]]
=
[]
prompt_logprobs_per_seq_group
:
List
[
Optional
[
PromptLogprobs
]]
=
[]
...
@@ -828,37 +832,48 @@ def _get_prompt_logprob_if_needed(
...
@@ -828,37 +832,48 @@ def _get_prompt_logprob_if_needed(
# Find prompt logprobs
# Find prompt logprobs
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
if
(
is_prompt
and
sampling_params
.
prompt_logprobs
is
not
None
)
:
if
is_prompt
and
sampling_params
.
prompt_logprobs
is
not
None
:
prompt_logprobs
=
[]
prompt_logprobs
=
[]
num_logprobs
=
sampling_params
.
prompt_logprobs
num_logprobs
=
sampling_params
.
prompt_logprobs
next_prompt_tokens
=
_get_next_prompt_tokens
(
seq_group
)
next_prompt_tokens
=
_get_next_prompt_tokens
(
seq_group
)
for
token_id
in
next_prompt_tokens
:
# Pre-select indexes and create a list. It is faster than calling .item
# repetitively.
selected_logprob_items
=
selected_logprobs
[
selected_logprobs_idx
:
selected_logprobs_idx
+
len
(
next_prompt_tokens
)].
tolist
()
rank_items
=
ranks
[
selected_logprobs_idx
:
selected_logprobs_idx
+
len
(
next_prompt_tokens
)].
tolist
()
for
idx
,
token_id
in
enumerate
(
next_prompt_tokens
):
# Calculate the prompt logprob of the real prompt tokens.
# Calculate the prompt logprob of the real prompt tokens.
# Use tuple here for performance (to use to_list()).
# {token_id: (logprob, rank_from_vocab)}
# {token_id: (logprob, rank_from_vocab)}
prompt_logprobs_dict
:
Dict
[
int
,
Tuple
[
float
,
int
]]
=
{
prompt_logprobs_dict
:
Dict
[
int
,
Tuple
[
float
,
int
]]
=
{
token_id
:
(
selected_logprobs
[
selected_logprobs_idx
].
item
(),
token_id
:
(
selected_logprob_items
[
idx
],
rank_items
[
idx
])
ranks
[
selected_logprobs_idx
].
item
())
}
}
# Add top K prompt logprobs along with its rank.
# Add top K prompt logprobs along with its rank.
if
num_logprobs
>
0
:
if
num_logprobs
>
0
:
prompt_logprobs_dict
.
update
(
top_ids
=
top_token_ids
[
zip
(
top_logprob_idx
,
:
num_logprobs
].
tolist
()
top_token_ids
[
top_logprob_idx
,
:
num_logprobs
].
tolist
(),
top_probs
=
top_logprobs
[
zip
(
top_logprob_idx
,
:
num_logprobs
].
tolist
()
top_logprobs
[
# Top K is already sorted by rank, so we can use 1 ~
top_logprob_idx
,
:
num_logprobs
].
tolist
(),
# num_logprobs + 1 for rank.
# This is ranks. Since top_logprob is sorted,
top_ranks
=
range
(
1
,
num_logprobs
+
1
)
# we can just use a range here.
prompt_logprobs_dict
.
update
({
range
(
1
,
num_logprobs
+
1
))))
top_id
:
(
top_prob
,
rank
)
for
top_id
,
top_prob
,
rank
in
zip
(
top_ids
,
top_probs
,
top_ranks
)
})
prompt_logprobs
.
append
({
prompt_logprobs
.
append
({
token_id
:
Logprob
(
*
logprob_and_rank
)
token_id
:
Logprob
(
*
logprob_and_rank
)
for
token_id
,
logprob_and_rank
in
prompt_logprobs_dict
.
items
()
for
token_id
,
logprob_and_rank
in
prompt_logprobs_dict
.
items
()
})
})
# + 1 to go to the next prompt token.
# + 1 to go to the next prompt token.
top_logprob_idx
+=
1
top_logprob_idx
+=
1
selected_logprobs_idx
+=
1
# + len(next_prompt_tokens) to go to the next prompt.
selected_logprobs_idx
+=
len
(
next_prompt_tokens
)
return
prompt_logprobs
,
top_logprob_idx
,
selected_logprobs_idx
return
prompt_logprobs
,
top_logprob_idx
,
selected_logprobs_idx
...
@@ -874,47 +889,54 @@ def _get_sampled_logprob_if_needed(
...
@@ -874,47 +889,54 @@ def _get_sampled_logprob_if_needed(
):
):
"""Compute the sample logprob if needed."""
"""Compute the sample logprob if needed."""
seq_ids
=
seq_group
.
seq_ids
seq_ids
=
seq_group
.
seq_ids
num_logprobs
=
seq_group
.
sampling_params
.
logprobs
num_logprobs
=
seq_group
.
sampling_params
.
logprobs
or
0
if
num_logprobs
is
None
:
num_logprobs
=
0
sampled_logprobs
:
SampleLogprobs
=
[]
sampled_logprobs
:
SampleLogprobs
=
[]
next_token_ids
,
parent_seq_ids
=
sample_result
next_token_ids
,
parent_seq_ids
=
sample_result
if
seq_group
.
do_sample
:
if
seq_group
.
do_sample
:
assert
len
(
next_token_ids
)
>
0
assert
len
(
next_token_ids
)
>
0
for
(
next_token_id
,
parent_id
)
in
zip
(
next_token_ids
,
parent_seq_ids
):
# Pre-select items from tensor. tolist() is faster than repetitive
# Calculate the sample logprob of the real sampled tokens.
# `.item()` calls.
# Use tuple here for performance (to use to_list()).
selected_logprob_items
=
selected_logprobs
[
# token_id: (logprob, rank_from_vocab)
selected_logprobs_idx
:
selected_logprobs_idx
+
sampled_logprobs_dict
:
Dict
[
int
,
Tuple
[
float
,
int
]]
=
{
len
(
next_token_ids
)].
tolist
()
next_token_id
:
rank_items
=
ranks
[
selected_logprobs_idx
:
selected_logprobs_idx
+
(
selected_logprobs
[
selected_logprobs_idx
].
item
(),
len
(
next_token_ids
)].
tolist
()
ranks
[
selected_logprobs_idx
].
item
())
for
idx
,
(
next_token_id
,
parent_id
)
in
enumerate
(
zip
(
next_token_ids
,
parent_seq_ids
)):
# Get the logprob of a sampled token.
sampled_logprobs_dict
=
{
next_token_id
:
(
selected_logprob_items
[
idx
],
rank_items
[
idx
])
}
}
# +1 to go to the next sampled token. Note that
# Get top K logprobs.
# selected_logprobs can contain duplicates unlike top_logprobs
if
num_logprobs
>
0
:
# when beam search is enabled.
top_ids
=
top_token_ids
[
top_logprob_idx
+
selected_logprobs_idx
+=
1
parent_id
,
:
num_logprobs
].
tolist
()
top_probs
=
top_logprobs
[
top_logprob_idx
+
# Second, add top K logprobs along with its rank.
parent_id
,
:
num_logprobs
].
tolist
()
if
num_logprobs
>=
0
:
# Top K is already sorted by rank, so we can use 1 ~
sampled_logprobs_dict
.
update
(
# num_logprobs + 1 for rank.
zip
(
top_ranks
=
range
(
1
,
num_logprobs
+
1
)
top_token_ids
[
top_logprob_idx
+
sampled_logprobs_dict
.
update
({
parent_id
,
:
num_logprobs
].
tolist
(),
top_id
:
(
top_prob
,
rank
)
zip
(
for
top_id
,
top_prob
,
rank
in
zip
(
top_ids
,
top_probs
,
top_logprobs
[
top_logprob_idx
+
top_ranks
)
parent_id
,
:
num_logprobs
].
tolist
(),
})
# This is rank. Since top_logprob is sorted, we
# can just use a range here.
range
(
1
,
num_logprobs
+
1
))))
sampled_logprobs
.
append
({
sampled_logprobs
.
append
({
token_id
:
Logprob
(
*
logprob_and_rank
)
token_id
:
Logprob
(
*
logprob_and_rank
)
for
token_id
,
logprob_and_rank
in
for
token_id
,
logprob_and_rank
in
sampled_logprobs_dict
.
items
()
sampled_logprobs_dict
.
items
()
})
})
# There are len(seq_ids) number of sampled tokens for the current
# sequence group in top_logprobs. Jump to the next seq_group.
# NOTE: This part of code is not intuitive. `selected_logprobs` include
# logprobs for the current step, which has len(next_token_ids) tokens
# per sequence group. `logprobs` includes logprobs from the previous
# steps, which has len(seq_ids) tokens per sequence group.
# Iterate to the next sequence group in a batch.
selected_logprobs_idx
+=
len
(
next_token_ids
)
# Iterate to the next sequence group in a batch.
top_logprob_idx
+=
len
(
seq_ids
)
top_logprob_idx
+=
len
(
seq_ids
)
return
sampled_logprobs
,
top_logprob_idx
,
selected_logprobs_idx
return
sampled_logprobs
,
top_logprob_idx
,
selected_logprobs_idx
...
@@ -1000,7 +1022,7 @@ def _build_sampler_output(
...
@@ -1000,7 +1022,7 @@ def _build_sampler_output(
seq_outputs
.
append
(
seq_outputs
.
append
(
SequenceOutput
(
seq_ids
[
parent_id
],
next_token_id
,
logprobs
))
SequenceOutput
(
seq_ids
[
parent_id
],
next_token_id
,
logprobs
))
sampler_output
.
append
(
sampler_output
.
append
(
SequenceGroupOutput
(
seq_outputs
,
group_prompt_logprobs
))
Completion
SequenceGroupOutput
(
seq_outputs
,
group_prompt_logprobs
))
# If not specified, store None values in SamplerOutput.
# If not specified, store None values in SamplerOutput.
if
on_device_tensors
is
not
None
:
if
on_device_tensors
is
not
None
:
...
...
vllm/model_executor/model_loader/__init__.py
View file @
b9e12416
...
@@ -2,26 +2,29 @@ from typing import Optional
...
@@ -2,26 +2,29 @@ from typing import Optional
from
torch
import
nn
from
torch
import
nn
from
vllm.config
import
(
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.model_executor.model_loader.loader
import
(
BaseModelLoader
,
from
vllm.model_executor.model_loader.loader
import
(
BaseModelLoader
,
get_model_loader
)
get_model_loader
)
from
vllm.model_executor.model_loader.utils
import
(
from
vllm.model_executor.model_loader.utils
import
(
get_architecture_class_name
,
get_model_architecture
)
get_architecture_class_name
,
get_model_architecture
)
def
get_model
(
def
get_model
(
*
,
model_config
:
ModelConfig
,
load_config
:
LoadConfig
,
*
,
model_config
:
ModelConfig
,
load_config
:
LoadConfig
,
device_config
:
DeviceConfig
,
parallel_config
:
ParallelConfig
,
device_config
:
DeviceConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
scheduler_config
:
SchedulerConfig
,
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
])
->
nn
.
Module
:
vision_language_config
:
Optional
[
VisionLanguageConfig
],
cache_config
:
CacheConfig
)
->
nn
.
Module
:
loader
=
get_model_loader
(
load_config
)
loader
=
get_model_loader
(
load_config
)
return
loader
.
load_model
(
model_config
=
model_config
,
return
loader
.
load_model
(
model_config
=
model_config
,
device_config
=
device_config
,
device_config
=
device_config
,
lora_config
=
lora_config
,
lora_config
=
lora_config
,
vision_language_config
=
vision_language_config
,
vision_language_config
=
vision_language_config
,
parallel_config
=
parallel_config
,
parallel_config
=
parallel_config
,
scheduler_config
=
scheduler_config
)
scheduler_config
=
scheduler_config
,
cache_config
=
cache_config
)
__all__
=
[
__all__
=
[
...
...
Prev
1
…
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment