Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
500b93c8
Commit
500b93c8
authored
Jul 25, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.3.post1' into v0.5.3.post1-dtk24.04.1
parents
99426767
38c4b7e8
Changes
282
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1991 additions
and
261 deletions
+1991
-261
vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
...l_executor/layers/quantization/utils/marlin_utils_test.py
+41
-10
vllm/model_executor/layers/quantization/utils/quant_utils.py
vllm/model_executor/layers/quantization/utils/quant_utils.py
+149
-1
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+67
-21
vllm/model_executor/layers/rejection_sampler.py
vllm/model_executor/layers/rejection_sampler.py
+86
-20
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+38
-3
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+100
-47
vllm/model_executor/layers/spec_decode_base_sampler.py
vllm/model_executor/layers/spec_decode_base_sampler.py
+34
-11
vllm/model_executor/layers/typical_acceptance_sampler.py
vllm/model_executor/layers/typical_acceptance_sampler.py
+2
-2
vllm/model_executor/layers/vocab_parallel_embedding.py
vllm/model_executor/layers/vocab_parallel_embedding.py
+8
-4
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+41
-24
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+101
-11
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+28
-7
vllm/model_executor/models/chameleon.py
vllm/model_executor/models/chameleon.py
+1045
-0
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+23
-4
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+47
-1
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+25
-11
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+78
-48
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+2
-1
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+2
-1
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+74
-34
No files found.
vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
View file @
500b93c8
...
@@ -2,11 +2,13 @@
...
@@ -2,11 +2,13 @@
from
typing
import
List
from
typing
import
List
import
numpy
import
numpy
as
np
import
torch
import
torch
from
.marlin_utils
import
GPTQ_MARLIN_TILE
,
marlin_permute_scales
from
.marlin_utils
import
(
GPTQ_MARLIN_TILE
,
marlin_permute_scales
,
from
.quant_utils
import
get_pack_factor
,
quantize_weights
,
sort_weights
marlin_zero_points
)
from
.quant_utils
import
(
get_pack_factor
,
quantize_weights
,
quantize_weights_with_zp
,
sort_weights
)
class
MarlinWorkspace
:
class
MarlinWorkspace
:
...
@@ -46,14 +48,14 @@ def marlin_weights(q_w, size_k, size_n, num_bits, perm):
...
@@ -46,14 +48,14 @@ def marlin_weights(q_w, size_k, size_n, num_bits, perm):
pack_factor
=
get_pack_factor
(
num_bits
)
pack_factor
=
get_pack_factor
(
num_bits
)
orig_device
=
q_w
.
device
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
n
umpy
.
uint32
)
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
n
p
.
uint32
)
q_packed
=
n
umpy
.
zeros
((
q_w
.
shape
[
0
],
q_w
.
shape
[
1
]
//
pack_factor
),
q_packed
=
n
p
.
zeros
((
q_w
.
shape
[
0
],
q_w
.
shape
[
1
]
//
pack_factor
),
dtype
=
n
umpy
.
uint32
)
dtype
=
n
p
.
uint32
)
for
i
in
range
(
pack_factor
):
for
i
in
range
(
pack_factor
):
q_packed
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_packed
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_packed
=
torch
.
from_numpy
(
q_packed
.
astype
(
n
umpy
.
int32
)).
to
(
orig_device
)
q_packed
=
torch
.
from_numpy
(
q_packed
.
astype
(
n
p
.
int32
)).
to
(
orig_device
)
return
q_packed
return
q_packed
...
@@ -74,12 +76,12 @@ def get_weight_perm(num_bits: int):
...
@@ -74,12 +76,12 @@ def get_weight_perm(num_bits: int):
for
j
in
range
(
4
):
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
256
*
j
for
p
in
perm1
])
perm_list
.
extend
([
p
+
256
*
j
for
p
in
perm1
])
perm
=
n
umpy
.
array
(
perm_list
)
perm
=
n
p
.
array
(
perm_list
)
if
num_bits
==
4
:
if
num_bits
==
4
:
interleave
=
n
umpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
interleave
=
n
p
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
elif
num_bits
==
8
:
interleave
=
n
umpy
.
array
([
0
,
2
,
1
,
3
])
interleave
=
n
p
.
array
([
0
,
2
,
1
,
3
])
else
:
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
...
@@ -118,3 +120,32 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
...
@@ -118,3 +120,32 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
return
res_list
def
awq_marlin_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
# Detect num groups
assert
size_k
%
group_size
==
0
num_groups
=
size_k
//
group_size
# Quantize with zp
w_ref
,
q_w
,
s
,
zp
=
quantize_weights_with_zp
(
w
,
num_bits
,
group_size
)
# Reformat to marlin
weight_perm
=
get_weight_perm
(
num_bits
)
marlin_q_w
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
weight_perm
)
marlin_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
)
marlin_zp
=
marlin_zero_points
(
zp
,
num_groups
,
size_n
,
num_bits
)
# Create result
res_list
=
[
w_ref
,
marlin_q_w
,
marlin_s
,
marlin_zp
]
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
vllm/model_executor/layers/quantization/utils/quant_utils.py
View file @
500b93c8
...
@@ -106,6 +106,67 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
...
@@ -106,6 +106,67 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
)
)
def
quantize_weights_with_zp
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
):
orig_device
=
w
.
device
size_k
,
size_n
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
assert
num_bits
in
SUPPORTED_NUM_BITS
,
f
"Unsupported num_bits =
{
num_bits
}
"
assert
group_size
in
SUPPORTED_GROUP_SIZES
+
[
size_k
],
f
"Unsupported groupsize =
{
group_size
}
"
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
max_q_val
=
2
**
num_bits
-
1
min_q_val
=
0
# Reshape to [groupsize, -1]
if
group_size
<
size_k
:
w
=
w
.
reshape
((
-
1
,
group_size
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
group_size
,
-
1
))
# Compute scale for each group
max
=
torch
.
max
(
w
,
0
,
keepdim
=
True
)[
0
]
min
=
torch
.
min
(
w
,
0
,
keepdim
=
True
)[
0
]
s
=
(
max
-
min
).
clamp
(
min
=
1e-5
)
/
max_q_val
# Compute zero-point for each group
zp
=
(
-
torch
.
round
(
min
/
s
)).
clamp
(
min_q_val
,
max_q_val
).
int
()
# Quantize
q_w
=
torch
.
round
(
w
/
s
).
int
()
+
zp
q_w
=
torch
.
clamp
(
q_w
,
min_q_val
,
max_q_val
)
# Compute ref (dequantized)
w_ref
=
(
q_w
-
zp
).
half
()
*
s
# Restore original shapes
if
group_size
<
size_k
:
def
reshape_w
(
w
):
w
=
w
.
reshape
((
group_size
,
-
1
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
size_k
,
size_n
)).
contiguous
()
return
w
q_w
=
reshape_w
(
q_w
)
w_ref
=
reshape_w
(
w_ref
)
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
zp
=
zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
(
w_ref
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
s
.
to
(
device
=
orig_device
),
zp
.
to
(
device
=
orig_device
),
)
def
sort_weights
(
q_w
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
):
def
sort_weights
(
q_w
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
):
orig_device
=
q_w
.
device
orig_device
=
q_w
.
device
...
@@ -122,7 +183,7 @@ def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
...
@@ -122,7 +183,7 @@ def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
)
)
def
gptq_pack
(
def
pack_rows
(
q_w
:
torch
.
Tensor
,
q_w
:
torch
.
Tensor
,
num_bits
:
int
,
num_bits
:
int
,
size_k
:
int
,
size_k
:
int
,
...
@@ -144,3 +205,90 @@ def gptq_pack(
...
@@ -144,3 +205,90 @@ def gptq_pack(
q_res
=
torch
.
from_numpy
(
q_res
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
q_res
=
torch
.
from_numpy
(
q_res
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
return
q_res
return
q_res
def
pack_cols
(
q_w
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
pack_factor
=
get_pack_factor
(
num_bits
)
assert
size_n
%
pack_factor
==
0
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_res
=
numpy
.
zeros
((
size_k
,
size_n
//
pack_factor
),
dtype
=
numpy
.
uint32
)
for
i
in
range
(
pack_factor
):
q_res
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_res
=
torch
.
from_numpy
(
q_res
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
q_res
=
q_res
.
contiguous
()
return
q_res
def
unpack_cols
(
packed_q_w
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
pack_factor
=
get_pack_factor
(
num_bits
)
assert
size_n
%
pack_factor
==
0
assert
packed_q_w
.
shape
==
(
size_k
,
size_n
//
pack_factor
),
"packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}"
.
format
(
packed_q_w
.
shape
,
size_k
,
size_n
,
pack_factor
)
orig_device
=
packed_q_w
.
device
packed_q_w_cpu
=
packed_q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_res
=
numpy
.
zeros
((
size_k
,
size_n
),
dtype
=
numpy
.
uint32
)
mask
=
(
1
<<
num_bits
)
-
1
for
i
in
range
(
pack_factor
):
vals
=
packed_q_w_cpu
&
mask
packed_q_w_cpu
>>=
num_bits
q_res
[:,
i
::
pack_factor
]
=
vals
q_res
=
torch
.
from_numpy
(
q_res
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
q_res
=
q_res
.
contiguous
()
return
q_res
def
gptq_pack
(
q_w
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
return
pack_rows
(
q_w
,
num_bits
,
size_k
,
size_n
)
def
awq_pack
(
q_w
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
# Interleave column dim (for the dequantize code) and pack it to int32
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
q_w
=
q_w
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
q_w
=
q_w
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
pack_cols
(
q_w
,
num_bits
,
size_k
,
size_n
)
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
500b93c8
...
@@ -104,49 +104,95 @@ def apply_fp8_linear(
...
@@ -104,49 +104,95 @@ def apply_fp8_linear(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
,
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
input_scale_ub
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
cutlass_fp8_supported
:
bool
=
True
,
cutlass_fp8_supported
:
bool
=
True
,
use_per_token_if_dynamic
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
# If static, layer.input_scale is scalar and x_scale is input_scale.
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if
cutlass_fp8_supported
:
if
cutlass_fp8_supported
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input
,
input_scale
)
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input
,
input_scale
,
scale_ub
=
input_scale_ub
,
use_per_token_if_dynamic
=
use_per_token_if_dynamic
)
# Fused GEMM_DQ
# Fused GEMM_DQ
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
return
ops
.
cutlass_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
bias
=
bias
)
else
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input
,
input_scale
,
batch_dim_padding
=
17
)
# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
weight
,
weight
,
out_dtype
=
input
.
dtype
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
scale_b
=
weight_scale
,
bias
=
bias
)
bias
=
bias
)
return
torch
.
narrow
(
output
,
0
,
0
,
input
.
shape
[
0
])
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
else
:
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input
,
input_scale
,
batch_dim_padding
=
17
,
use_per_token_if_dynamic
=
use_per_token_if_dynamic
)
per_tensor_weights
=
(
weight_scale
.
numel
()
==
1
)
per_tensor_activations
=
(
x_scale
.
numel
()
==
1
)
if
per_tensor_weights
and
per_tensor_activations
:
# Fused GEMM_DQ
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
bias
=
bias
)
return
torch
.
narrow
(
output
,
0
,
0
,
input
.
shape
[
0
])
else
:
# Fallback for channelwise case, where we use unfused DQ
# due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
torch
.
float32
)
# Unpad (undo batch_dim_padding)
output
=
torch
.
narrow
(
output
,
0
,
0
,
input
.
shape
[
0
])
# DQ
# C = sw * sx * (X * W) + bias
output
=
output
*
x_scale
*
weight_scale
.
t
()
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
to
(
dtype
=
input
.
dtype
)
def
apply_int8_linear
(
def
apply_int8_linear
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
,
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
# ops.scaled_int8_quant supports both dynamic and static quant.
# ops.scaled_int8_quant supports both dynamic and static quant.
...
...
vllm/model_executor/layers/rejection_sampler.py
View file @
500b93c8
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
import
torch.jit
import
torch.jit
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
SpecDecodeBaseSampler
)
SpecDecode
Stochastic
BaseSampler
)
class
RejectionSampler
(
SpecDecodeBaseSampler
):
class
RejectionSampler
(
SpecDecode
Stochastic
BaseSampler
):
"""Apply modified rejection sampling as described in "Accelerating Large
"""Apply modified rejection sampling as described in "Accelerating Large
Language Model Decoding with Speculative Sampling"
Language Model Decoding with Speculative Sampling"
https://arxiv.org/pdf/2302.01318.pdf.
https://arxiv.org/pdf/2302.01318.pdf.
...
@@ -36,6 +36,7 @@ class RejectionSampler(SpecDecodeBaseSampler):
...
@@ -36,6 +36,7 @@ class RejectionSampler(SpecDecodeBaseSampler):
bonus_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
generators
:
List
[
Optional
[
torch
.
Generator
]],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Sample token ids using rejection sampling. This accepts or rejects
"""Sample token ids using rejection sampling. This accepts or rejects
tokens proposed by the draft model using the probability of each token
tokens proposed by the draft model using the probability of each token
...
@@ -82,6 +83,7 @@ class RejectionSampler(SpecDecodeBaseSampler):
...
@@ -82,6 +83,7 @@ class RejectionSampler(SpecDecodeBaseSampler):
target_probs
,
target_probs
,
draft_probs
,
draft_probs
,
draft_token_ids
,
draft_token_ids
,
generators
,
))
))
output_token_ids
=
self
.
_create_output
(
output_token_ids
=
self
.
_create_output
(
...
@@ -94,10 +96,11 @@ class RejectionSampler(SpecDecodeBaseSampler):
...
@@ -94,10 +96,11 @@ class RejectionSampler(SpecDecodeBaseSampler):
return
output_token_ids
return
output_token_ids
def
_batch_modified_rejection_sampling
(
def
_batch_modified_rejection_sampling
(
self
,
self
,
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
generators
:
List
[
Optional
[
torch
.
Generator
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Perform modified rejection sampling on each sequence.
"""Perform modified rejection sampling on each sequence.
...
@@ -114,22 +117,33 @@ class RejectionSampler(SpecDecodeBaseSampler):
...
@@ -114,22 +117,33 @@ class RejectionSampler(SpecDecodeBaseSampler):
# shape [batch_size, k]
# shape [batch_size, k]
accepted
=
self
.
_get_accepted
(
target_probs
,
draft_probs
,
accepted
=
self
.
_get_accepted
(
target_probs
,
draft_probs
,
draft_token_ids
)
draft_token_ids
,
generators
)
recovered_probs
=
self
.
_get_recovered_probs
(
recovered_probs
=
self
.
_get_recovered_probs
(
target_probs
,
draft_probs
).
reshape
(
batch_size
*
k
,
vocab_size
)
target_probs
,
draft_probs
).
reshape
(
batch_size
*
k
,
vocab_size
)
seed_indices
,
non_seed_indices
=
self
.
_split_batch_by_seeded
(
generators
,
k
=
k
)
# NOTE: the recovered_probs are overwritten by this method.
# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids
=
_multinomial
(
recovered_probs
,
recovered_token_ids
=
_multinomial
(
num_samples
=
1
).
reshape
(
recovered_probs
,
batch_size
,
k
)
num_samples
=
1
,
k
=
k
,
generators
=
generators
,
seed_indices
=
seed_indices
,
# this arg is unused when None but torch.jit requires a list
non_seed_indices
=
non_seed_indices
or
[],
).
reshape
(
batch_size
,
k
)
return
accepted
,
recovered_token_ids
return
accepted
,
recovered_token_ids
def
_get_accepted
(
def
_get_accepted
(
self
,
self
,
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
generators
:
List
[
Optional
[
torch
.
Generator
]],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
r
"""Create bool matrix over the proposed draft tokens. If
r
"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be
True, then a token can be accepted, else it should be
...
@@ -164,10 +178,28 @@ class RejectionSampler(SpecDecodeBaseSampler):
...
@@ -164,10 +178,28 @@ class RejectionSampler(SpecDecodeBaseSampler):
selected_target_probs
=
target_probs
[
batch_indices
,
probs_indicies
,
selected_target_probs
=
target_probs
[
batch_indices
,
probs_indicies
,
draft_token_ids
]
draft_token_ids
]
uniform_rand
=
torch
.
rand
(
batch_size
,
seed_indices
,
non_seed_indices
=
self
.
_split_batch_by_seeded
(
k
,
generators
)
dtype
=
self
.
probs_dtype
,
device
=
target_probs
.
device
)
if
len
(
seed_indices
)
==
0
:
uniform_rand
=
torch
.
rand_like
(
selected_target_probs
)
else
:
uniform_rand
=
torch
.
empty_like
(
selected_target_probs
)
for
idx
in
seed_indices
:
uniform_rand
[
idx
,
:]
=
torch
.
rand
(
1
,
k
,
dtype
=
self
.
probs_dtype
,
device
=
target_probs
.
device
,
generator
=
generators
[
idx
])
if
non_seed_indices
:
uniform_rand
[
non_seed_indices
,
:]
=
torch
.
rand
(
len
(
non_seed_indices
),
k
,
dtype
=
self
.
probs_dtype
,
device
=
target_probs
.
device
)
capped_ratio
=
torch
.
minimum
(
capped_ratio
=
torch
.
minimum
(
selected_target_probs
/
selected_draft_probs
,
selected_target_probs
/
selected_draft_probs
,
torch
.
full
((
1
,
),
1
,
device
=
target_probs
.
device
))
torch
.
full
((
1
,
),
1
,
device
=
target_probs
.
device
))
...
@@ -240,6 +272,27 @@ class RejectionSampler(SpecDecodeBaseSampler):
...
@@ -240,6 +272,27 @@ class RejectionSampler(SpecDecodeBaseSampler):
"""
"""
return
torch
.
finfo
(
self
.
probs_dtype
).
tiny
return
torch
.
finfo
(
self
.
probs_dtype
).
tiny
# partition batch into indices for which a generator is provided
# and indicies for which no generator is provided
@
staticmethod
def
_split_batch_by_seeded
(
generators
:
List
[
Optional
[
torch
.
Generator
]],
k
:
int
=
1
,
)
->
Tuple
[
List
[
int
],
Optional
[
List
[
int
]]]:
if
all
(
generator
is
None
for
generator
in
generators
):
seed_indices
:
List
[
int
]
=
[]
non_seed_indices
:
Optional
[
List
[
int
]]
=
None
else
:
seed_indices
,
non_seed_indices
=
[],
[]
for
i
,
generator
in
enumerate
(
generators
):
if
generator
is
None
:
non_seed_indices
.
extend
(
range
(
k
*
i
,
k
*
(
i
+
1
)))
else
:
seed_indices
.
extend
(
range
(
k
*
i
,
k
*
(
i
+
1
)))
return
seed_indices
,
non_seed_indices
# torch.multinomial forces a GPU<->CPU sync.
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
# Therefore, we use an optimized implementation instead that skips the sync.
...
@@ -250,12 +303,25 @@ class RejectionSampler(SpecDecodeBaseSampler):
...
@@ -250,12 +303,25 @@ class RejectionSampler(SpecDecodeBaseSampler):
def
_multinomial
(
def
_multinomial
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
num_samples
:
int
,
num_samples
:
int
,
k
:
int
,
generators
:
List
[
Optional
[
torch
.
Generator
]],
seed_indices
:
List
[
int
],
non_seed_indices
:
List
[
int
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
num_samples
>
1
:
if
num_samples
>
1
:
# This is equivalent to torch.repeat_interleaved (which also
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
# forces a GPU<->CPU sync).
probs
=
probs
[:,
None
,
:].
expand
(
probs
.
shape
[
0
],
num_samples
,
probs
=
probs
[:,
None
,
:].
expand
(
probs
.
shape
[
0
],
num_samples
,
probs
.
shape
[
1
]).
contiguous
().
view
(
probs
.
shape
[
1
]).
contiguous
().
view
(
-
1
,
probs
.
shape
[
1
])
-
1
,
probs
.
shape
[
1
])
q
=
torch
.
empty_like
(
probs
).
exponential_
(
1.0
)
q
=
torch
.
empty_like
(
probs
)
if
len
(
seed_indices
)
==
0
:
q
.
exponential_
(
1.0
)
else
:
q
[
non_seed_indices
].
exponential_
(
1.0
)
for
idx
in
seed_indices
:
q
[
idx
].
exponential_
(
1.0
,
generator
=
generators
[
idx
//
k
])
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
vllm/model_executor/layers/rotary_embedding.py
View file @
500b93c8
...
@@ -733,6 +733,36 @@ class GemmaRotaryEmbedding(RotaryEmbedding):
...
@@ -733,6 +733,36 @@ class GemmaRotaryEmbedding(RotaryEmbedding):
return
inv_freq
return
inv_freq
class
ExtendedRotaryEmbedding
(
RotaryEmbedding
):
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
inv_freqs
=
super
().
_compute_inv_freq
(
base
)
return
self
.
apply_scaling
(
inv_freqs
)
def
apply_scaling
(
self
,
freqs
:
torch
.
Tensor
):
scale_factor
=
8
low_freq_factor
=
1
high_freq_factor
=
4
old_context_len
=
8192
low_freq_wavelen
=
old_context_len
/
low_freq_factor
high_freq_wavelen
=
old_context_len
/
high_freq_factor
new_freqs
=
[]
for
freq
in
freqs
:
wavelen
=
2
*
math
.
pi
/
freq
if
wavelen
<
high_freq_wavelen
:
new_freqs
.
append
(
freq
)
elif
wavelen
>
low_freq_wavelen
:
new_freqs
.
append
(
freq
/
scale_factor
)
else
:
assert
low_freq_wavelen
!=
high_freq_wavelen
smooth
=
(
old_context_len
/
wavelen
-
low_freq_factor
)
/
(
high_freq_factor
-
low_freq_factor
)
new_freqs
.
append
((
1
-
smooth
)
*
freq
/
scale_factor
+
smooth
*
freq
)
return
torch
.
tensor
(
new_freqs
,
dtype
=
freqs
.
dtype
,
device
=
freqs
.
device
)
_ROPE_DICT
:
Dict
[
Tuple
,
RotaryEmbedding
]
=
{}
_ROPE_DICT
:
Dict
[
Tuple
,
RotaryEmbedding
]
=
{}
...
@@ -764,12 +794,17 @@ def get_rope(
...
@@ -764,12 +794,17 @@ def get_rope(
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
)
is_neox_style
,
dtype
)
else
:
else
:
scaling_type
=
rope_scaling
[
"type"
]
scaling_type
=
rope_scaling
[
"type"
]
if
"type"
in
rope_scaling
else
rope_scaling
[
"rope_type"
]
# The correct one should be "longrope" but keep "su" here
# The correct one should be "longrope" but keep "su" here
# for backward compatible
# for backward compatible
if
scaling_type
!=
"su"
and
scaling_type
!=
"longrope"
:
if
scaling_type
not
in
{
"su"
,
"longrope"
,
"llama3"
}
:
scaling_factor
=
rope_scaling
[
"factor"
]
scaling_factor
=
rope_scaling
[
"factor"
]
if
scaling_type
==
"linear"
:
if
scaling_type
==
"llama3"
:
rotary_emb
=
ExtendedRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
)
elif
scaling_type
==
"linear"
:
rotary_emb
=
LinearScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
rotary_emb
=
LinearScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
max_position
,
base
,
is_neox_style
,
is_neox_style
,
...
...
vllm/model_executor/layers/sampler.py
View file @
500b93c8
...
@@ -47,6 +47,32 @@ class Sampler(nn.Module):
...
@@ -47,6 +47,32 @@ class Sampler(nn.Module):
# speculative decoding.
# speculative decoding.
self
.
include_gpu_probs_tensor
=
False
self
.
include_gpu_probs_tensor
=
False
def
_init_sampling_tensors
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
):
"""The goal here is to reuse sampling tensors between similar decode
runs. This is possible because sampling logic does not change between
decodes of the same sequences.
"""
_
,
vocab_size
=
logits
.
shape
# First free any existing stored sampling tensors.
# This is necessary because some sampling tensors may
# have pinned memory.
self
.
_sampling_tensors
=
None
# Initialize new sampling tensors
(
sampling_tensors
,
do_penalties
,
do_top_p_top_k
,
do_min_p
)
=
SamplingTensors
.
from_sampling_metadata
(
sampling_metadata
,
vocab_size
,
logits
.
device
,
logits
.
dtype
)
self
.
_sampling_tensors
=
sampling_tensors
self
.
_do_penalties
=
do_penalties
self
.
_do_top_p_top_k
=
do_top_p_top_k
self
.
_do_min_p
=
do_min_p
def
forward
(
def
forward
(
self
,
self
,
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
...
@@ -60,12 +86,23 @@ class Sampler(nn.Module):
...
@@ -60,12 +86,23 @@ class Sampler(nn.Module):
assert
logits
is
not
None
assert
logits
is
not
None
_
,
vocab_size
=
logits
.
shape
_
,
vocab_size
=
logits
.
shape
logits
=
_apply_min_tokens_penalty
(
logits
,
sampling_metadata
)
# Prepare sampling tensors with pinned memory to avoid blocking.
# Prepare sampling tensors with pinned memory to avoid blocking.
(
sampling_tensors
,
do_penalties
,
do_top_p_top_k
,
if
not
sampling_metadata
.
reuse_sampling_tensors
:
do_min_p
)
=
SamplingTensors
.
from_sampling_metadata
(
self
.
_init_sampling_tensors
(
logits
,
sampling_metadata
)
sampling_metadata
,
vocab_size
,
logits
.
device
,
logits
.
dtype
)
elif
self
.
_do_penalties
:
# In this case, the sampling tensors logic depends on
# "output_tokens" of a sequence. As a result, we cannot
# reuse sampling tensors, since "output_tokens" changes
# between decode runs.
self
.
_init_sampling_tensors
(
logits
,
sampling_metadata
)
assert
self
.
_sampling_tensors
is
not
None
sampling_tensors
=
self
.
_sampling_tensors
do_penalties
=
self
.
_do_penalties
do_top_p_top_k
=
self
.
_do_top_p_top_k
do_min_p
=
self
.
_do_min_p
logits
=
_apply_min_tokens_penalty
(
logits
,
sampling_metadata
)
# Apply presence and frequency penalties.
# Apply presence and frequency penalties.
if
do_penalties
:
if
do_penalties
:
...
@@ -77,7 +114,7 @@ class Sampler(nn.Module):
...
@@ -77,7 +114,7 @@ class Sampler(nn.Module):
# Apply temperature scaling.
# Apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
# Use in-place division to avoid creating a new tensor.
logits
.
div_
(
sampling_tensors
.
temperatures
.
unsqueeze
_
(
dim
=
1
))
logits
.
div_
(
sampling_tensors
.
temperatures
.
unsqueeze
(
dim
=
1
))
if
do_top_p_top_k
:
if
do_top_p_top_k
:
logits
=
_apply_top_k_top_p
(
logits
,
sampling_tensors
.
top_ps
,
logits
=
_apply_top_k_top_p
(
logits
,
sampling_tensors
.
top_ps
,
...
@@ -109,13 +146,19 @@ class Sampler(nn.Module):
...
@@ -109,13 +146,19 @@ class Sampler(nn.Module):
on_device_tensors
=
None
on_device_tensors
=
None
# Get the logprobs query results.
# Get the logprobs query results.
prompt_logprobs
,
sample_logprobs
=
_get_logprobs
(
prompt_logprobs
=
None
logprobs
,
sampling_metadata
,
sample_results
)
sample_logprobs
=
None
return
_build_sampler_output
(
sample_results
,
if
not
sampling_metadata
.
skip_sampler_cpu_output
:
sampling_metadata
,
prompt_logprobs
,
sample_logprobs
=
_get_logprobs
(
prompt_logprobs
,
logprobs
,
sampling_metadata
,
sample_results
)
sample_logprobs
,
on_device_tensors
=
on_device_tensors
)
return
_build_sampler_output
(
sample_results
,
sampling_metadata
,
prompt_logprobs
,
sample_logprobs
,
on_device_tensors
=
on_device_tensors
,
skip_sampler_cpu_output
=
sampling_metadata
.
skip_sampler_cpu_output
)
@
property
@
property
def
_should_modify_greedy_probs_inplace
(
self
)
->
bool
:
def
_should_modify_greedy_probs_inplace
(
self
)
->
bool
:
...
@@ -535,24 +578,29 @@ def _sample_with_torch(
...
@@ -535,24 +578,29 @@ def _sample_with_torch(
# GPU<->CPU sync happens in the loop below.
# GPU<->CPU sync happens in the loop below.
# This also converts the sample output to Python objects.
# This also converts the sample output to Python objects.
for
sampling_type
in
SamplingType
:
if
not
sampling_metadata
.
skip_sampler_cpu_output
:
if
sampling_type
not
in
sample_metadata
:
for
sampling_type
in
SamplingType
:
continue
if
sampling_type
not
in
sample_metadata
:
(
seq_group_id
,
seq_groups
)
=
sample_metadata
[
sampling_type
]
continue
if
sampling_type
==
SamplingType
.
GREEDY
:
(
seq_group_id
,
seq_groups
)
=
sample_metadata
[
sampling_type
]
sample_results
=
_greedy_sample
(
seq_groups
,
greedy_samples
)
if
sampling_type
==
SamplingType
.
GREEDY
:
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
sample_results
=
_greedy_sample
(
seq_groups
,
greedy_samples
)
sample_results
=
_random_sample
(
seq_groups
,
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
multinomial_samples
[
sampling_type
])
SamplingType
.
RANDOM_SEED
):
elif
sampling_type
==
SamplingType
.
BEAM
:
sample_results
=
_random_sample
(
sample_results
=
_beam_search_sample
(
seq_groups
,
seq_groups
,
multinomial_samples
[
sampling_type
])
beam_search_logprobs
)
elif
sampling_type
==
SamplingType
.
BEAM
:
sample_results_dict
.
update
(
zip
(
seq_group_id
,
sample_results
))
sample_results
=
_beam_search_sample
(
seq_groups
,
beam_search_logprobs
)
sample_results_dict
.
update
(
zip
(
seq_group_id
,
sample_results
))
sample_results
=
[
sample_results_dict
.
get
(
i
,
([],
[]))
for
i
in
range
(
len
(
sampling_metadata
.
seq_groups
))
]
else
:
sample_results
=
[]
sample_results
=
[
sample_results_dict
.
get
(
i
,
([],
[]))
for
i
in
range
(
len
(
sampling_metadata
.
seq_groups
))
]
return
sample_results
,
sampled_token_ids_tensor
return
sample_results
,
sampled_token_ids_tensor
...
@@ -997,10 +1045,11 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
...
@@ -997,10 +1045,11 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
def
_build_sampler_output
(
def
_build_sampler_output
(
sample_results
:
SampleResultType
,
sample_results
:
SampleResultType
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
prompt_logprobs
:
List
[
Optional
[
PromptLogprobs
]],
prompt_logprobs
:
Optional
[
List
[
Optional
[
PromptLogprobs
]]
]
,
sample_logprobs
:
List
[
SampleLogprobs
],
sample_logprobs
:
Optional
[
List
[
SampleLogprobs
]
]
,
on_device_tensors
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
on_device_tensors
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]],
torch
.
Tensor
]],
skip_sampler_cpu_output
:
bool
=
False
,
)
->
SamplerOutput
:
)
->
SamplerOutput
:
"""Construct Python objects with the output of sampling.
"""Construct Python objects with the output of sampling.
...
@@ -1010,22 +1059,26 @@ def _build_sampler_output(
...
@@ -1010,22 +1059,26 @@ def _build_sampler_output(
allows post-processing without copies to CPU/serialization, e.g. in
allows post-processing without copies to CPU/serialization, e.g. in
speculative decoding rejection sampling.
speculative decoding rejection sampling.
"""
"""
sampler_output
:
List
[
CompletionSequenceGroupOutput
]
=
[]
sampler_output
:
List
[
CompletionSequenceGroupOutput
]
=
[]
for
(
seq_group
,
sample_result
,
group_prompt_logprobs
,
if
not
skip_sampler_cpu_output
:
group_sample_logprobs
)
in
zip
(
sampling_metadata
.
seq_groups
,
assert
prompt_logprobs
is
not
None
sample_results
,
prompt_logprobs
,
assert
sample_logprobs
is
not
None
sample_logprobs
):
seq_ids
=
seq_group
.
seq_ids
for
(
seq_group
,
sample_result
,
group_prompt_logprobs
,
next_token_ids
,
parent_ids
=
sample_result
group_sample_logprobs
)
in
zip
(
sampling_metadata
.
seq_groups
,
seq_outputs
:
List
[
SequenceOutput
]
=
[]
sample_results
,
prompt_logprobs
,
for
parent_id
,
next_token_id
,
logprobs
in
zip
(
parent_ids
,
sample_logprobs
):
next_token_ids
,
seq_ids
=
seq_group
.
seq_ids
group_sample_logprobs
):
next_token_ids
,
parent_ids
=
sample_result
seq_outputs
.
append
(
seq_outputs
:
List
[
SequenceOutput
]
=
[]
SequenceOutput
(
seq_ids
[
parent_id
],
next_token_id
,
logprobs
))
for
parent_id
,
next_token_id
,
logprobs
in
zip
(
sampler_output
.
append
(
parent_ids
,
next_token_ids
,
group_sample_logprobs
):
CompletionSequenceGroupOutput
(
seq_outputs
,
group_prompt_logprobs
))
seq_outputs
.
append
(
SequenceOutput
(
seq_ids
[
parent_id
],
next_token_id
,
logprobs
))
sampler_output
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
group_prompt_logprobs
))
# If not specified, store None values in SamplerOutput.
# If not specified, store None values in SamplerOutput.
if
on_device_tensors
is
not
None
:
if
on_device_tensors
is
not
None
:
...
...
vllm/model_executor/layers/spec_decode_base_sampler.py
View file @
500b93c8
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
typing
import
Optional
from
typing
import
List
,
Optional
import
torch
import
torch
import
torch.jit
import
torch.jit
...
@@ -54,16 +54,6 @@ class SpecDecodeBaseSampler(nn.Module):
...
@@ -54,16 +54,6 @@ class SpecDecodeBaseSampler(nn.Module):
def
token_id_dtype
(
self
):
def
token_id_dtype
(
self
):
return
torch
.
int64
return
torch
.
int64
@
abstractmethod
def
forward
(
self
,
target_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
_create_output
(
def
_create_output
(
self
,
self
,
accepted
:
torch
.
Tensor
,
# [batch_size, k]
accepted
:
torch
.
Tensor
,
# [batch_size, k]
...
@@ -217,3 +207,36 @@ class SpecDecodeBaseSampler(nn.Module):
...
@@ -217,3 +207,36 @@ class SpecDecodeBaseSampler(nn.Module):
assert
torch
.
all
(
bonus_token_ids
>=
0
)
assert
torch
.
all
(
bonus_token_ids
>=
0
)
assert
torch
.
all
(
draft_token_ids
<
vocab_size
)
assert
torch
.
all
(
draft_token_ids
<
vocab_size
)
assert
torch
.
all
(
draft_token_ids
>=
0
)
assert
torch
.
all
(
draft_token_ids
>=
0
)
class
SpecDecodeDeterministicBaseSampler
(
SpecDecodeBaseSampler
):
"""Base class for samplers used for Speculative Decoding verification
step which are deterministic.
"""
@
abstractmethod
def
forward
(
self
,
target_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
class
SpecDecodeStochasticBaseSampler
(
SpecDecodeBaseSampler
):
"""Base class for samplers used for Speculative Decoding verification
step which are stochastic
"""
@
abstractmethod
def
forward
(
self
,
target_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
generators
:
List
[
Optional
[
torch
.
Generator
]],
)
->
torch
.
Tensor
:
raise
NotImplementedError
vllm/model_executor/layers/typical_acceptance_sampler.py
View file @
500b93c8
...
@@ -2,10 +2,10 @@ import torch
...
@@ -2,10 +2,10 @@ import torch
import
torch.jit
import
torch.jit
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
SpecDecodeBaseSampler
)
SpecDecode
Deterministic
BaseSampler
)
class
TypicalAcceptanceSampler
(
SpecDecodeBaseSampler
):
class
TypicalAcceptanceSampler
(
SpecDecode
Deterministic
BaseSampler
):
"""Apply typical acceptance sampling as described in section 3.3.1 in
"""Apply typical acceptance sampling as described in section 3.3.1 in
"MEDUSA: Simple LLM Inference Acceleration Framework with
"MEDUSA: Simple LLM Inference Acceleration Framework with
Multiple Decoding Heads"
Multiple Decoding Heads"
...
...
vllm/model_executor/layers/vocab_parallel_embedding.py
View file @
500b93c8
...
@@ -161,6 +161,7 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -161,6 +161,7 @@ class VocabParallelEmbedding(torch.nn.Module):
org_num_embeddings: original vocabulary size (without LoRA).
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
padding_size: padding size for the vocabulary.
quant_config: quant config for the layer
quant_config: quant config for the layer
prefix: full name of the layer in the state dict
"""
# noqa: E501
"""
# noqa: E501
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -169,7 +170,8 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -169,7 +170,8 @@ class VocabParallelEmbedding(torch.nn.Module):
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
org_num_embeddings
:
Optional
[
int
]
=
None
,
org_num_embeddings
:
Optional
[
int
]
=
None
,
padding_size
:
int
=
DEFAULT_VOCAB_PADDING_SIZE
,
padding_size
:
int
=
DEFAULT_VOCAB_PADDING_SIZE
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
# Keep the input dimensions.
# Keep the input dimensions.
...
@@ -195,7 +197,7 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -195,7 +197,7 @@ class VocabParallelEmbedding(torch.nn.Module):
linear_method
=
None
linear_method
=
None
if
quant_config
is
not
None
:
if
quant_config
is
not
None
:
linear_method
=
quant_config
.
get_quant_method
(
self
)
linear_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
if
linear_method
is
None
:
if
linear_method
is
None
:
linear_method
=
UnquantizedLinearMethod
()
linear_method
=
UnquantizedLinearMethod
()
self
.
linear_method
:
QuantizeMethodBase
=
linear_method
self
.
linear_method
:
QuantizeMethodBase
=
linear_method
...
@@ -382,9 +384,11 @@ class ParallelLMHead(VocabParallelEmbedding):
...
@@ -382,9 +384,11 @@ class ParallelLMHead(VocabParallelEmbedding):
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
org_num_embeddings
:
Optional
[
int
]
=
None
,
org_num_embeddings
:
Optional
[
int
]
=
None
,
padding_size
:
int
=
DEFAULT_VOCAB_PADDING_SIZE
,
padding_size
:
int
=
DEFAULT_VOCAB_PADDING_SIZE
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
(
num_embeddings
,
embedding_dim
,
params_dtype
,
super
().
__init__
(
num_embeddings
,
embedding_dim
,
params_dtype
,
org_num_embeddings
,
padding_size
,
quant_config
)
org_num_embeddings
,
padding_size
,
quant_config
,
prefix
)
if
bias
:
if
bias
:
self
.
bias
=
Parameter
(
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
num_embeddings_per_partition
,
torch
.
empty
(
self
.
num_embeddings_per_partition
,
...
...
vllm/model_executor/model_loader/loader.py
View file @
500b93c8
...
@@ -32,7 +32,8 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -32,7 +32,8 @@ from vllm.model_executor.model_loader.weight_utils import (
filter_duplicate_safetensors_files
,
filter_files_not_needed_for_inference
,
filter_duplicate_safetensors_files
,
filter_files_not_needed_for_inference
,
get_quant_config
,
initialize_dummy_weights
,
np_cache_weights_iterator
,
get_quant_config
,
initialize_dummy_weights
,
np_cache_weights_iterator
,
pt_weights_iterator
,
safetensors_weights_iterator
)
pt_weights_iterator
,
safetensors_weights_iterator
)
from
vllm.model_executor.models.interfaces
import
(
supports_lora
,
from
vllm.model_executor.models.interfaces
import
(
has_inner_state
,
supports_lora
,
supports_vision
)
supports_vision
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -69,10 +70,10 @@ def _get_quantization_config(
...
@@ -69,10 +70,10 @@ def _get_quantization_config(
def
_get_model_initialization_kwargs
(
def
_get_model_initialization_kwargs
(
model_class
:
Type
[
nn
.
Module
],
model_class
:
Type
[
nn
.
Module
],
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
)
->
Dict
[
str
,
Any
]:
scheduler_config
:
Optional
[
SchedulerConfig
]
=
None
)
->
Dict
[
str
,
Any
]:
"""Get extra kwargs for model initialization."""
"""Get extra kwargs for model initialization."""
extra_kwargs
:
Dict
[
str
,
Any
]
=
{}
extra_kwargs
:
Dict
[
str
,
Any
]
=
{}
...
@@ -93,13 +94,19 @@ def _get_model_initialization_kwargs(
...
@@ -93,13 +94,19 @@ def _get_model_initialization_kwargs(
extra_kwargs
[
"multimodal_config"
]
=
multimodal_config
extra_kwargs
[
"multimodal_config"
]
=
multimodal_config
if
has_inner_state
(
model_class
)
and
scheduler_config
:
extra_kwargs
[
"scheduler_config"
]
=
scheduler_config
return
extra_kwargs
return
extra_kwargs
def
_initialize_model
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
,
def
_initialize_model
(
lora_config
:
Optional
[
LoRAConfig
],
model_config
:
ModelConfig
,
multimodal_config
:
Optional
[
MultiModalConfig
],
load_config
:
LoadConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
cache_config
:
CacheConfig
,
scheduler_config
:
Optional
[
SchedulerConfig
]
=
None
)
->
nn
.
Module
:
"""Initialize a model with the given configurations."""
"""Initialize a model with the given configurations."""
model_class
=
get_model_architecture
(
model_config
)[
0
]
model_class
=
get_model_architecture
(
model_config
)[
0
]
quant_config
=
_get_quantization_config
(
model_config
,
load_config
)
quant_config
=
_get_quantization_config
(
model_config
,
load_config
)
...
@@ -108,7 +115,8 @@ def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
...
@@ -108,7 +115,8 @@ def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
**
_get_model_initialization_kwargs
(
**
_get_model_initialization_kwargs
(
model_class
,
lora_config
,
multimodal_config
))
model_class
,
lora_config
,
multimodal_config
,
scheduler_config
))
class
BaseModelLoader
(
ABC
):
class
BaseModelLoader
(
ABC
):
...
@@ -156,6 +164,7 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -156,6 +164,7 @@ class DefaultModelLoader(BaseModelLoader):
cache_dir
=
self
.
load_config
.
download_dir
,
cache_dir
=
self
.
load_config
.
download_dir
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
revision
=
revision
,
revision
=
revision
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
)
else
:
else
:
model_path
=
model
model_path
=
model
...
@@ -191,9 +200,13 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -191,9 +200,13 @@ class DefaultModelLoader(BaseModelLoader):
allow_patterns
+=
[
"*.pt"
]
allow_patterns
+=
[
"*.pt"
]
if
not
is_local
:
if
not
is_local
:
hf_folder
=
download_weights_from_hf
(
model_name_or_path
,
hf_folder
=
download_weights_from_hf
(
self
.
load_config
.
download_dir
,
model_name_or_path
,
allow_patterns
,
revision
)
self
.
load_config
.
download_dir
,
allow_patterns
,
revision
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
else
:
else
:
hf_folder
=
model_name_or_path
hf_folder
=
model_name_or_path
...
@@ -269,7 +282,7 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -269,7 +282,7 @@ class DefaultModelLoader(BaseModelLoader):
with
torch
.
device
(
device_config
.
device
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
multimodal_config
,
lora_config
,
multimodal_config
,
cache_config
)
cache_config
,
scheduler_config
)
model
.
load_weights
(
model
.
load_weights
(
self
.
_get_weights_iterator
(
model_config
.
model
,
self
.
_get_weights_iterator
(
model_config
.
model
,
model_config
.
revision
,
model_config
.
revision
,
...
@@ -282,10 +295,6 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -282,10 +295,6 @@ class DefaultModelLoader(BaseModelLoader):
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
if
quant_method
is
not
None
:
quant_method
.
process_weights_after_loading
(
module
)
quant_method
.
process_weights_after_loading
(
module
)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if
hasattr
(
module
,
"process_weights_after_loading"
):
module
.
process_weights_after_loading
()
return
model
.
eval
()
return
model
.
eval
()
...
@@ -309,7 +318,7 @@ class DummyModelLoader(BaseModelLoader):
...
@@ -309,7 +318,7 @@ class DummyModelLoader(BaseModelLoader):
with
torch
.
device
(
device_config
.
device
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
multimodal_config
,
lora_config
,
multimodal_config
,
cache_config
)
cache_config
,
scheduler_config
)
# NOTE(woosuk): For accurate performance evaluation, we assign
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
# random values to the weights.
initialize_dummy_weights
(
model
)
initialize_dummy_weights
(
model
)
...
@@ -488,9 +497,13 @@ class ShardedStateLoader(BaseModelLoader):
...
@@ -488,9 +497,13 @@ class ShardedStateLoader(BaseModelLoader):
return
model_name_or_path
return
model_name_or_path
else
:
else
:
allow_patterns
=
[
"*.safetensors"
]
allow_patterns
=
[
"*.safetensors"
]
return
download_weights_from_hf
(
model_name_or_path
,
return
download_weights_from_hf
(
self
.
load_config
.
download_dir
,
model_name_or_path
,
allow_patterns
,
revision
)
self
.
load_config
.
download_dir
,
allow_patterns
,
revision
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
...
@@ -662,8 +675,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -662,8 +675,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
matching_files
=
fnmatch
.
filter
(
repo_files
,
pattern
)
matching_files
=
fnmatch
.
filter
(
repo_files
,
pattern
)
if
matching_files
:
if
matching_files
:
hf_folder
=
download_weights_from_hf
(
hf_folder
=
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
model_name_or_path
,
[
pattern
],
revision
)
self
.
load_config
.
download_dir
,
[
pattern
],
revision
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
return
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
)),
pattern
return
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
)),
pattern
raise
RuntimeError
(
raise
RuntimeError
(
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
500b93c8
...
@@ -6,7 +6,7 @@ import json
...
@@ -6,7 +6,7 @@ import json
import
os
import
os
import
tempfile
import
tempfile
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
Any
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
filelock
import
filelock
import
huggingface_hub.constants
import
huggingface_hub.constants
...
@@ -22,6 +22,7 @@ from vllm.logger import init_logger
...
@@ -22,6 +22,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.quantization
import
(
QuantizationConfig
,
from
vllm.model_executor.layers.quantization
import
(
QuantizationConfig
,
get_quantization_config
)
get_quantization_config
)
from
vllm.model_executor.layers.quantization.schema
import
QuantParamSchema
from
vllm.model_executor.layers.quantization.schema
import
QuantParamSchema
from
vllm.utils
import
print_warning_once
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -188,6 +189,7 @@ def download_weights_from_hf(
...
@@ -188,6 +189,7 @@ def download_weights_from_hf(
cache_dir
:
Optional
[
str
],
cache_dir
:
Optional
[
str
],
allow_patterns
:
List
[
str
],
allow_patterns
:
List
[
str
],
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
ignore_patterns
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
)
->
str
:
)
->
str
:
"""Download model weights from Hugging Face Hub.
"""Download model weights from Hugging Face Hub.
...
@@ -199,6 +201,9 @@ def download_weights_from_hf(
...
@@ -199,6 +201,9 @@ def download_weights_from_hf(
weight files. Files matched by any of the patterns will be
weight files. Files matched by any of the patterns will be
downloaded.
downloaded.
revision (Optional[str]): The revision of the model.
revision (Optional[str]): The revision of the model.
ignore_patterns (Optional[Union[str, List[str]]]): The patterns to
filter out the weight files. Files matched by any of the patterns
will be ignored.
Returns:
Returns:
str: The path to the downloaded model weights.
str: The path to the downloaded model weights.
...
@@ -222,6 +227,7 @@ def download_weights_from_hf(
...
@@ -222,6 +227,7 @@ def download_weights_from_hf(
hf_folder
=
snapshot_download
(
hf_folder
=
snapshot_download
(
model_name_or_path
,
model_name_or_path
,
allow_patterns
=
allow_patterns
,
allow_patterns
=
allow_patterns
,
ignore_patterns
=
ignore_patterns
,
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
tqdm_class
=
DisabledTqdm
,
tqdm_class
=
DisabledTqdm
,
revision
=
revision
,
revision
=
revision
,
...
@@ -312,6 +318,13 @@ def filter_files_not_needed_for_inference(
...
@@ -312,6 +318,13 @@ def filter_files_not_needed_for_inference(
return
hf_weights_files
return
hf_weights_files
# explicitly use pure text format, with a newline at the end
# this makes it impossible to see the animation in the progress bar
# but will avoid messing up with ray or multiprocessing, which wraps
# each line of output with some prefix.
_BAR_FORMAT
=
"{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]
\n
"
# noqa: E501
def
np_cache_weights_iterator
(
def
np_cache_weights_iterator
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
],
hf_folder
:
str
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
],
hf_folder
:
str
,
hf_weights_files
:
List
[
str
]
hf_weights_files
:
List
[
str
]
...
@@ -320,6 +333,8 @@ def np_cache_weights_iterator(
...
@@ -320,6 +333,8 @@ def np_cache_weights_iterator(
Will dump the model weights to numpy files if they are not already dumped.
Will dump the model weights to numpy files if they are not already dumped.
"""
"""
enable_tqdm
=
not
torch
.
distributed
.
is_initialized
(
)
or
torch
.
distributed
.
get_rank
()
==
0
# Convert the model weights from torch tensors to numpy arrays for
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
# faster loading.
np_folder
=
os
.
path
.
join
(
hf_folder
,
"np"
)
np_folder
=
os
.
path
.
join
(
hf_folder
,
"np"
)
...
@@ -330,7 +345,12 @@ def np_cache_weights_iterator(
...
@@ -330,7 +345,12 @@ def np_cache_weights_iterator(
with
get_lock
(
model_name_or_path
,
cache_dir
):
with
get_lock
(
model_name_or_path
,
cache_dir
):
if
not
os
.
path
.
exists
(
weight_names_file
):
if
not
os
.
path
.
exists
(
weight_names_file
):
weight_names
:
List
[
str
]
=
[]
weight_names
:
List
[
str
]
=
[]
for
bin_file
in
hf_weights_files
:
for
bin_file
in
tqdm
(
hf_weights_files
,
desc
=
"Loading np_cache checkpoint shards"
,
disable
=
not
enable_tqdm
,
bar_format
=
_BAR_FORMAT
,
):
state
=
torch
.
load
(
bin_file
,
map_location
=
"cpu"
)
state
=
torch
.
load
(
bin_file
,
map_location
=
"cpu"
)
for
name
,
param
in
state
.
items
():
for
name
,
param
in
state
.
items
():
param_path
=
os
.
path
.
join
(
np_folder
,
name
)
param_path
=
os
.
path
.
join
(
np_folder
,
name
)
...
@@ -354,7 +374,14 @@ def safetensors_weights_iterator(
...
@@ -354,7 +374,14 @@ def safetensors_weights_iterator(
hf_weights_files
:
List
[
str
]
hf_weights_files
:
List
[
str
]
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Iterate over the weights in the model safetensor files."""
"""Iterate over the weights in the model safetensor files."""
for
st_file
in
hf_weights_files
:
enable_tqdm
=
not
torch
.
distributed
.
is_initialized
(
)
or
torch
.
distributed
.
get_rank
()
==
0
for
st_file
in
tqdm
(
hf_weights_files
,
desc
=
"Loading safetensors checkpoint shards"
,
disable
=
not
enable_tqdm
,
bar_format
=
_BAR_FORMAT
,
):
with
safe_open
(
st_file
,
framework
=
"pt"
)
as
f
:
with
safe_open
(
st_file
,
framework
=
"pt"
)
as
f
:
for
name
in
f
.
keys
():
# noqa: SIM118
for
name
in
f
.
keys
():
# noqa: SIM118
param
=
f
.
get_tensor
(
name
)
param
=
f
.
get_tensor
(
name
)
...
@@ -365,7 +392,14 @@ def pt_weights_iterator(
...
@@ -365,7 +392,14 @@ def pt_weights_iterator(
hf_weights_files
:
List
[
str
]
hf_weights_files
:
List
[
str
]
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Iterate over the weights in the model bin/pt files."""
"""Iterate over the weights in the model bin/pt files."""
for
bin_file
in
hf_weights_files
:
enable_tqdm
=
not
torch
.
distributed
.
is_initialized
(
)
or
torch
.
distributed
.
get_rank
()
==
0
for
bin_file
in
tqdm
(
hf_weights_files
,
desc
=
"Loading pt checkpoint shards"
,
disable
=
not
enable_tqdm
,
bar_format
=
_BAR_FORMAT
,
):
state
=
torch
.
load
(
bin_file
,
map_location
=
"cpu"
)
state
=
torch
.
load
(
bin_file
,
map_location
=
"cpu"
)
for
name
,
param
in
state
.
items
():
for
name
,
param
in
state
.
items
():
yield
name
,
param
yield
name
,
param
...
@@ -431,11 +465,6 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
...
@@ -431,11 +465,6 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
def
default_weight_loader
(
param
:
torch
.
Tensor
,
def
default_weight_loader
(
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
loaded_weight
:
torch
.
Tensor
)
->
None
:
"""Default weight loader."""
"""Default weight loader."""
# If the weight on disk does not have a shape, give it one
# (such scales for AutoFp8).
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
assert
param
.
size
()
==
loaded_weight
.
size
()
assert
param
.
size
()
==
loaded_weight
.
size
()
param
.
data
.
copy_
(
loaded_weight
)
param
.
data
.
copy_
(
loaded_weight
)
...
@@ -444,6 +473,7 @@ def initialize_dummy_weights(
...
@@ -444,6 +473,7 @@ def initialize_dummy_weights(
model
:
torch
.
nn
.
Module
,
model
:
torch
.
nn
.
Module
,
low
:
float
=
-
1e-3
,
low
:
float
=
-
1e-3
,
high
:
float
=
1e-3
,
high
:
float
=
1e-3
,
seed
:
int
=
1234
,
)
->
None
:
)
->
None
:
"""Initialize model weights with random values.
"""Initialize model weights with random values.
...
@@ -451,14 +481,74 @@ def initialize_dummy_weights(
...
@@ -451,14 +481,74 @@ def initialize_dummy_weights(
measurements. Additionally, the model weights should not cause NaNs in the
measurements. Additionally, the model weights should not cause NaNs in the
forward pass. We empirically found that initializing the weights with
forward pass. We empirically found that initializing the weights with
values between -1e-3 and 1e-3 works well for most models.
values between -1e-3 and 1e-3 works well for most models.
We use per-parameter random seed, so that dummy weights are consistent,
even if the model is partitioned across multiple devices. When the seed
is fixed, the random values generated by this function only depends on
the parameter's number of elements and its data type.
"""
"""
for
param
in
model
.
state_dict
().
values
():
for
param
in
model
.
state_dict
().
values
():
if
torch
.
is_floating_point
(
param
):
if
torch
.
is_floating_point
(
param
):
generator
=
torch
.
Generator
(
device
=
param
.
data
.
device
)
generator
.
manual_seed
(
seed
)
if
torch
.
finfo
(
param
.
data
.
dtype
).
bits
<
16
:
if
torch
.
finfo
(
param
.
data
.
dtype
).
bits
<
16
:
# uniform_ doesn't support < 16-bit datatypes (FP8)
# uniform_ doesn't support < 16-bit datatypes (FP8)
dtype
=
param
.
data
.
dtype
dtype
=
param
.
data
.
dtype
tmp_param
=
param
.
data
.
to
(
torch
.
float16
)
tmp_param
=
param
.
data
.
to
(
torch
.
float16
)
tmp_param
=
tmp_param
.
uniform_
(
low
,
high
).
to
(
dtype
)
tmp_param
=
tmp_param
.
uniform_
(
low
,
high
,
generator
=
generator
).
to
(
dtype
)
param
.
data
.
copy_
(
tmp_param
)
param
.
data
.
copy_
(
tmp_param
)
else
:
else
:
param
.
uniform_
(
low
,
high
)
param
.
uniform_
(
low
,
high
,
generator
=
generator
)
def
maybe_remap_kv_scale_name
(
name
:
str
,
params_dict
:
dict
)
->
Optional
[
str
]:
"""Remap the name of FP8 k/v_scale parameters.
This function handles the remapping of FP8 k/v_scale parameter names.
It detects if the given name ends with a suffix and attempts to remap
it to the expected name format in the model. If the remapped name is not
found in the params_dict, a warning is printed and None is returned.
Args:
name (str): The original loaded checkpoint parameter name.
params_dict (dict): Dictionary containing the model's named parameters.
Returns:
str: The remapped parameter name if successful, or the original name
if no remapping is needed.
None: If the remapped name is not found in params_dict.
"""
if
name
.
endswith
(
".kv_scale"
):
print_warning_once
(
"DEPRECATED. Found kv_scale in the checkpoint. "
"This format is deprecated in favor of separate k_scale and "
"v_scale tensors and will be removed in a future release. "
"Functionally, we will remap kv_scale to k_scale and duplicate "
"k_scale to v_scale"
)
# NOTE: we remap the deprecated kv_scale to k_scale
remapped_name
=
name
.
replace
(
".kv_scale"
,
".attn.k_scale"
)
if
remapped_name
not
in
params_dict
:
print_warning_once
(
f
"Found kv_scale in the checkpoint (e.g.
{
name
}
), "
"but not found the expected name in the model "
f
"(e.g.
{
remapped_name
}
). kv_scale is "
"not loaded."
)
return
None
return
remapped_name
possible_scale_names
=
[
".k_scale"
,
".v_scale"
]
for
scale_name
in
possible_scale_names
:
if
name
.
endswith
(
scale_name
):
remapped_name
=
name
.
replace
(
scale_name
,
f
".attn
{
scale_name
}
"
)
if
remapped_name
not
in
params_dict
:
print_warning_once
(
f
"Found
{
scale_name
}
in the checkpoint (e.g.
{
name
}
), "
"but not found the expected name in the model "
f
"(e.g.
{
remapped_name
}
).
{
scale_name
}
is "
"not loaded."
)
return
None
return
remapped_name
# If there were no matches, return the untouched param name
return
name
vllm/model_executor/models/__init__.py
View file @
500b93c8
import
functools
import
importlib
import
importlib
from
typing
import
Dict
,
List
,
Optional
,
Type
from
typing
import
Dict
,
List
,
Optional
,
Type
...
@@ -15,6 +16,10 @@ _GENERATION_MODELS = {
...
@@ -15,6 +16,10 @@ _GENERATION_MODELS = {
"BaiChuanForCausalLM"
:
(
"baichuan"
,
"BaiChuanForCausalLM"
),
# baichuan-7b
"BaiChuanForCausalLM"
:
(
"baichuan"
,
"BaiChuanForCausalLM"
),
# baichuan-7b
"BaichuanForCausalLM"
:
(
"baichuan"
,
"BaichuanForCausalLM"
),
# baichuan-13b
"BaichuanForCausalLM"
:
(
"baichuan"
,
"BaichuanForCausalLM"
),
# baichuan-13b
"BloomForCausalLM"
:
(
"bloom"
,
"BloomForCausalLM"
),
"BloomForCausalLM"
:
(
"bloom"
,
"BloomForCausalLM"
),
#TODO(ywang96): remove this when huggingface fixes the model repo
"ChameleonForCausalLM"
:
(
"chameleon"
,
"ChameleonForConditionalGeneration"
),
"ChameleonForConditionalGeneration"
:
(
"chameleon"
,
"ChameleonForConditionalGeneration"
),
"ChatGLMModel"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
"ChatGLMModel"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
"ChatGLMForConditionalGeneration"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
"ChatGLMForConditionalGeneration"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
"CohereForCausalLM"
:
(
"commandr"
,
"CohereForCausalLM"
),
"CohereForCausalLM"
:
(
"commandr"
,
"CohereForCausalLM"
),
...
@@ -86,18 +91,37 @@ _ROCM_UNSUPPORTED_MODELS: List[str] = []
...
@@ -86,18 +91,37 @@ _ROCM_UNSUPPORTED_MODELS: List[str] = []
# Models partially supported by ROCm.
# Models partially supported by ROCm.
# Architecture -> Reason.
# Architecture -> Reason.
_ROCM_SWA_REASON
=
(
"Sliding window attention (SWA) is not yet supported in "
"Triton flash attention. For half-precision SWA support, "
"please use CK flash attention by setting "
"`VLLM_USE_TRITON_FLASH_ATTN=0`"
)
_ROCM_PARTIALLY_SUPPORTED_MODELS
:
Dict
[
str
,
str
]
=
{
_ROCM_PARTIALLY_SUPPORTED_MODELS
:
Dict
[
str
,
str
]
=
{
"Qwen2ForCausalLM"
:
"Qwen2ForCausalLM"
:
"Sliding window attention is not yet supported in ROCm's flash attention"
,
_ROCM_SWA_REASON
,
"MistralForCausalLM"
:
"MistralForCausalLM"
:
"Sliding window attention is not yet supported in ROCm's flash attention"
,
_ROCM_SWA_REASON
,
"MixtralForCausalLM"
:
"MixtralForCausalLM"
:
"Sliding window attention is not yet supported in ROCm's flash attention"
,
_ROCM_SWA_REASON
,
"PaliGemmaForConditionalGeneration"
:
(
"ROCm flash attention does not yet "
"fully support 32-bit precision on PaliGemma"
),
"Phi3VForCausalLM"
:
(
"ROCm Triton flash attention may run into compilation errors due to "
"excessive use of shared memory. If this happens, disable Triton FA "
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`"
)
}
}
class
ModelRegistry
:
class
ModelRegistry
:
@
staticmethod
@
functools
.
lru_cache
(
maxsize
=
128
)
def
_get_model
(
model_arch
:
str
):
module_name
,
model_cls_name
=
_MODELS
[
model_arch
]
module
=
importlib
.
import_module
(
f
"vllm.model_executor.models.
{
module_name
}
"
)
return
getattr
(
module
,
model_cls_name
,
None
)
@
staticmethod
@
staticmethod
def
load_model_cls
(
model_arch
:
str
)
->
Optional
[
Type
[
nn
.
Module
]]:
def
load_model_cls
(
model_arch
:
str
)
->
Optional
[
Type
[
nn
.
Module
]]:
if
model_arch
in
_OOT_MODELS
:
if
model_arch
in
_OOT_MODELS
:
...
@@ -114,10 +138,7 @@ class ModelRegistry:
...
@@ -114,10 +138,7 @@ class ModelRegistry:
"Model architecture %s is partially supported by ROCm: %s"
,
"Model architecture %s is partially supported by ROCm: %s"
,
model_arch
,
_ROCM_PARTIALLY_SUPPORTED_MODELS
[
model_arch
])
model_arch
,
_ROCM_PARTIALLY_SUPPORTED_MODELS
[
model_arch
])
module_name
,
model_cls_name
=
_MODELS
[
model_arch
]
return
ModelRegistry
.
_get_model
(
model_arch
)
module
=
importlib
.
import_module
(
f
"vllm.model_executor.models.
{
module_name
}
"
)
return
getattr
(
module
,
model_cls_name
,
None
)
@
staticmethod
@
staticmethod
def
get_supported_archs
()
->
List
[
str
]:
def
get_supported_archs
()
->
List
[
str
]:
...
...
vllm/model_executor/models/chameleon.py
0 → 100644
View file @
500b93c8
from
functools
import
cached_property
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
)
import
torch
import
torch.nn.functional
as
F
from
PIL
import
Image
from
torch
import
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.image
import
(
cached_get_tokenizer
,
repeat_and_pad_image_tokens
)
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
,
SequenceData
from
vllm.transformers_utils.configs
import
(
ChameleonConfig
,
ChameleonVQVAEConfig
)
from
vllm.utils
import
print_warning_once
from
.interfaces
import
SupportsVision
logger
=
init_logger
(
__name__
)
# These configs are not part of the model config but the preprocessor
# and processor files, so we hardcode them in the model file for now.
CHAMELEON_CROP_SIZE_HEIGHT
=
CHAMELEON_CROP_SIZE_WIDTH
=
512
CHAMELEON_IMAGE_SEQ_LENGTH
=
1024
CHAMELEON_IMAGE_TOKEN_ID
=
8711
CHAMELEON_IMAGE_START_TOKEN_ID
=
8197
CHAMELEON_IMAGE_END_TOKEN_ID
=
8196
CHAMELEON_SEP_TOKEN_ID
=
8710
class
ChameleonImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
data
:
torch
.
Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
def
get_max_chameleon_image_tokens
(
ctx
:
InputContext
):
return
CHAMELEON_IMAGE_SEQ_LENGTH
def
dummy_seq_data_for_chameleon
(
seq_len
:
int
,
*
,
image_token_id
:
int
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
):
if
image_feature_size_override
is
None
:
image_feature_size
=
CHAMELEON_IMAGE_SEQ_LENGTH
else
:
image_feature_size
=
image_feature_size_override
token_ids
=
[
image_token_id
]
*
image_feature_size
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
)
return
SequenceData
(
token_ids
)
def
dummy_image_for_chameleon
(
image_width_override
:
Optional
[
int
]
=
None
,
image_height_override
:
Optional
[
int
]
=
None
,
):
width
=
CHAMELEON_CROP_SIZE_WIDTH
height
=
CHAMELEON_CROP_SIZE_HEIGHT
if
image_width_override
is
not
None
:
width
=
image_width_override
if
image_height_override
is
not
None
:
height
=
image_height_override
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
return
{
"image"
:
image
}
def
dummy_data_for_chameleon
(
ctx
:
InputContext
,
seq_len
:
int
):
seq_data
=
dummy_seq_data_for_chameleon
(
seq_len
,
image_token_id
=
CHAMELEON_IMAGE_TOKEN_ID
,
)
mm_data
=
dummy_image_for_chameleon
()
return
seq_data
,
mm_data
def
input_processor_for_chameleon
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
):
"""
Processing input prompt to insert required tokens for image placeholder.
See https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/models/chameleon/processing_chameleon.py#L58
"""
# noqa
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
return
llm_inputs
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
)
new_prompt
,
new_token_ids
=
repeat_and_pad_image_tokens
(
tokenizer
,
llm_inputs
.
get
(
"prompt"
),
llm_inputs
[
"prompt_token_ids"
],
image_token_id
=
CHAMELEON_IMAGE_TOKEN_ID
,
repeat_count
=
CHAMELEON_IMAGE_SEQ_LENGTH
,
pad_token_left
=
CHAMELEON_IMAGE_START_TOKEN_ID
,
pad_token_right
=
CHAMELEON_IMAGE_END_TOKEN_ID
,
)
# Appending sep token for chat mode to follow default processor
# behavior
new_prompt
+=
tokenizer
.
sep_token
new_token_ids
+=
[
CHAMELEON_SEP_TOKEN_ID
]
# NOTE: Create a defensive copy of the original inputs
return
LLMInputs
(
prompt_token_ids
=
new_token_ids
,
prompt
=
new_prompt
,
multi_modal_data
=
multi_modal_data
)
class
ChameleonLayerNorm
(
nn
.
LayerNorm
):
def
__init__
(
self
,
hidden_size
,
*
args
,
**
kwargs
):
super
().
__init__
(
hidden_size
,
*
args
,
**
kwargs
)
self
.
normalized_shape
=
(
hidden_size
[
-
1
],
)
def
forward
(
self
,
hidden_states
):
hidden_states
=
F
.
layer_norm
(
hidden_states
,
self
.
normalized_shape
,
None
,
None
,
eps
=
1e-5
)
hidden_states
=
hidden_states
*
self
.
weight
+
self
.
bias
return
hidden_states
# Copied from vllm.model_executor.models.llama.LlamaMLP -> ChameleonMLP
class
ChameleonMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
input_size
=
hidden_size
,
output_sizes
=
[
intermediate_size
]
*
2
,
bias
=
bias
,
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
input_size
=
intermediate_size
,
output_size
=
hidden_size
,
bias
=
bias
,
quant_config
=
quant_config
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
# Modified from vllm.model_executor.models.llama.LlamaAttention -> ChameleonAttention #noqa
class
ChameleonAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
4096
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
=
hidden_size
,
head_size
=
self
.
head_dim
,
total_num_heads
=
self
.
total_num_heads
,
total_num_kv_heads
=
self
.
total_num_kv_heads
,
bias
=
bias
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
RowParallelLinear
(
input_size
=
self
.
total_num_heads
*
self
.
head_dim
,
output_size
=
hidden_size
,
bias
=
bias
,
quant_config
=
quant_config
,
)
self
.
q_norm
=
ChameleonLayerNorm
((
self
.
num_heads
,
self
.
head_dim
))
self
.
k_norm
=
ChameleonLayerNorm
((
self
.
num_kv_heads
,
self
.
head_dim
))
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
def
_apply_qk_norm
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# reshape for layernorm
q
=
q
.
reshape
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
k
=
k
.
reshape
(
-
1
,
self
.
num_kv_heads
,
self
.
head_dim
)
q
=
self
.
q_norm
(
q
)
k
=
self
.
k_norm
(
k
)
q
=
q
.
view
(
*
q
.
shape
[:
-
2
],
-
1
)
k
=
k
.
view
(
*
k
.
shape
[:
-
2
],
-
1
)
return
q
,
k
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
ChameleonDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
ChameleonConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
if
rope_scaling
is
not
None
and
getattr
(
config
,
"original_max_position_embeddings"
,
None
):
rope_scaling
[
"original_max_position_embeddings"
]
=
(
config
.
original_max_position_embeddings
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
4096
)
self
.
self_attn
=
ChameleonAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
getattr
(
config
,
"num_key_value_heads"
,
config
.
num_attention_heads
),
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
quant_config
=
quant_config
,
bias
=
False
,
cache_config
=
cache_config
,
)
self
.
mlp
=
ChameleonMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
bias
=
getattr
(
config
,
"mlp_bias"
,
False
),
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
class
ChameleonSwinDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
ChameleonConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
if
rope_scaling
is
not
None
and
getattr
(
config
,
"original_max_position_embeddings"
,
None
):
rope_scaling
[
"original_max_position_embeddings"
]
=
(
config
.
original_max_position_embeddings
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
4096
)
self
.
self_attn
=
ChameleonAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
getattr
(
config
,
"num_key_value_heads"
,
config
.
num_attention_heads
),
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
quant_config
=
quant_config
,
bias
=
False
,
cache_config
=
cache_config
,
)
self
.
mlp
=
ChameleonMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
bias
=
getattr
(
config
,
"mlp_bias"
,
False
),
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
residual
=
hidden_states
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
hidden_states
+
residual
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
,
residual
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEVectorQuantizer #noqa
class
ChameleonVQVAEVectorQuantizer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
ChameleonVQVAEConfig
):
super
().
__init__
()
self
.
num_embeddings
=
config
.
num_embeddings
self
.
embedding_dim
=
config
.
embed_dim
self
.
beta
=
getattr
(
config
,
"beta"
,
0.25
)
self
.
embedding
=
nn
.
Embedding
(
self
.
num_embeddings
,
self
.
embedding_dim
)
self
.
re_embed
=
self
.
num_embeddings
def
forward
(
self
,
hidden_state
:
torch
.
Tensor
):
hidden_state
=
hidden_state
.
permute
(
0
,
2
,
3
,
1
).
contiguous
()
hidden_state_flattened
=
hidden_state
.
view
(
-
1
,
self
.
embedding_dim
)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
distances
=
(
torch
.
sum
(
hidden_state_flattened
**
2
,
dim
=
1
,
keepdim
=
True
)
+
torch
.
sum
(
self
.
embedding
.
weight
**
2
,
dim
=
1
)
-
2
*
torch
.
einsum
(
"bd,dn->bn"
,
hidden_state_flattened
,
self
.
embedding
.
weight
.
transpose
(
0
,
1
)))
min_encoding_indices
=
torch
.
argmin
(
distances
,
dim
=
1
)
hidden_state_quant
=
self
.
embedding
(
min_encoding_indices
).
view
(
hidden_state
.
shape
)
# compute loss for embedding
loss
=
torch
.
mean
((
hidden_state_quant
.
detach
()
-
hidden_state
)
**
2
)
+
self
.
beta
*
torch
.
mean
(
(
hidden_state_quant
-
hidden_state
.
detach
())
**
2
)
# preserve gradients
hidden_state_quant
=
hidden_state
+
(
hidden_state_quant
-
hidden_state
).
detach
()
# reshape back to match original input shape
hidden_state_quant
=
hidden_state_quant
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
return
hidden_state_quant
,
loss
,
min_encoding_indices
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderConvDownsample #noqa
class
ChameleonVQVAEEncoderConvDownsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
0
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
):
# no asymmetric padding in torch conv, must do it ourselves
hidden_states
=
F
.
pad
(
hidden_states
,
pad
=
(
0
,
1
,
0
,
1
),
mode
=
"constant"
,
value
=
0
)
hidden_states
=
self
.
conv
(
hidden_states
)
return
hidden_states
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderResnetBlock #noqa
class
ChameleonVQVAEEncoderResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
config
:
ChameleonVQVAEConfig
,
in_channels
:
int
,
out_channels
=
None
,
conv_shortcut
=
False
,
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
in_channels
if
out_channels
is
None
\
else
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
norm1
=
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
norm2
=
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
out_channels
,
eps
=
1e-6
,
affine
=
True
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
config
.
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
):
residual
=
hidden_states
hidden_states
=
self
.
norm1
(
hidden_states
)
hidden_states
*=
torch
.
sigmoid
(
hidden_states
)
hidden_states
=
self
.
conv1
(
hidden_states
)
hidden_states
=
self
.
norm2
(
hidden_states
)
hidden_states
*=
torch
.
sigmoid
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
hidden_states
=
self
.
conv2
(
hidden_states
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
residual
=
self
.
conv_shortcut
(
residual
)
else
:
residual
=
self
.
nin_shortcut
(
residual
)
return
residual
+
hidden_states
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderAttnBlock #noqa
class
ChameleonVQVAEEncoderAttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
norm
=
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
):
residual
=
hidden_states
hidden_states
=
self
.
norm
(
hidden_states
)
query_states
=
self
.
q
(
hidden_states
)
key_states
=
self
.
k
(
hidden_states
)
value_states
=
self
.
v
(
hidden_states
)
# compute attention
batch_size
,
channels
,
height
,
width
=
query_states
.
shape
query_states
=
query_states
.
reshape
(
batch_size
,
channels
,
height
*
width
).
permute
(
0
,
2
,
1
)
key_states
=
key_states
.
reshape
(
batch_size
,
channels
,
height
*
width
)
attn_weights
=
torch
.
bmm
(
query_states
,
key_states
)
attn_weights
=
attn_weights
*
(
int
(
channels
)
**
(
-
0.5
))
attn_weights
=
F
.
softmax
(
attn_weights
,
dim
=
2
)
# attend to values
value_states
=
value_states
.
reshape
(
batch_size
,
channels
,
height
*
width
)
attn_weights
=
attn_weights
.
permute
(
0
,
2
,
1
)
attn_output
=
torch
.
bmm
(
value_states
,
attn_weights
).
reshape
(
batch_size
,
channels
,
height
,
width
)
attn_output
=
self
.
proj_out
(
attn_output
)
return
residual
+
attn_output
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoder #noqa
class
ChameleonVQVAEEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
:
ChameleonVQVAEConfig
):
super
().
__init__
()
self
.
num_resolutions
=
len
(
config
.
channel_multiplier
)
self
.
num_res_blocks
=
config
.
num_res_blocks
base_channels
=
config
.
base_channels
resolution
=
config
.
resolution
in_channels
=
config
.
in_channels
double_latent
=
config
.
double_latent
latent_channels
=
config
.
latent_channels
channel_multiplier
=
config
.
channel_multiplier
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
base_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_channel_multiplier
=
(
1
,
)
+
tuple
(
channel_multiplier
)
self
.
in_channel_multiplier
=
in_channel_multiplier
self
.
down
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
base_channels
*
in_channel_multiplier
[
i_level
]
block_out
=
base_channels
*
channel_multiplier
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ChameleonVQVAEEncoderResnetBlock
(
config
=
config
,
in_channels
=
block_in
,
out_channels
=
block_out
,
))
block_in
=
block_out
if
(
config
.
attn_resolutions
is
not
None
and
curr_res
in
config
.
attn_resolutions
and
config
.
attn_type
==
"vanilla"
):
attn
.
append
(
ChameleonVQVAEEncoderAttnBlock
(
block_in
))
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
ChameleonVQVAEEncoderConvDownsample
(
block_in
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ChameleonVQVAEEncoderResnetBlock
(
config
=
config
,
in_channels
=
block_in
,
out_channels
=
block_in
,
)
self
.
mid
.
attn_1
=
ChameleonVQVAEEncoderAttnBlock
(
block_in
)
if
config
.
attn_type
==
"vanilla"
else
nn
.
Identity
()
self
.
mid
.
block_2
=
ChameleonVQVAEEncoderResnetBlock
(
config
=
config
,
in_channels
=
block_in
,
out_channels
=
block_in
,
)
self
.
norm_out
=
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
block_in
,
eps
=
1e-6
,
affine
=
True
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
2
*
latent_channels
if
double_latent
else
latent_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
)
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
):
# downsampling
hidden_states
=
[
self
.
conv_in
(
pixel_values
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
hidden_state
=
self
.
down
[
i_level
].
block
[
i_block
](
hidden_states
[
-
1
],
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
hidden_state
=
self
.
down
[
i_level
].
attn
[
i_block
](
hidden_state
)
hidden_states
.
append
(
hidden_state
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hidden_states
.
append
(
self
.
down
[
i_level
].
downsample
(
hidden_states
[
-
1
]))
# middle
last_hidden_state
=
hidden_states
[
-
1
]
last_hidden_state
=
self
.
mid
.
block_1
(
last_hidden_state
)
last_hidden_state
=
self
.
mid
.
attn_1
(
last_hidden_state
)
last_hidden_state
=
self
.
mid
.
block_2
(
last_hidden_state
)
# end
last_hidden_state
=
self
.
norm_out
(
last_hidden_state
)
last_hidden_state
*=
torch
.
sigmoid
(
last_hidden_state
)
last_hidden_state
=
self
.
conv_out
(
last_hidden_state
)
return
last_hidden_state
# Adapted from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAE #noqa
class
ChameleonVQVAE
(
nn
.
Module
):
def
__init__
(
self
,
config
:
ChameleonVQVAEConfig
):
super
().
__init__
()
self
.
encoder
=
ChameleonVQVAEEncoder
(
config
)
self
.
quantize
=
ChameleonVQVAEVectorQuantizer
(
config
)
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
config
.
latent_channels
,
config
.
embed_dim
,
1
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
config
.
embed_dim
,
config
.
latent_channels
,
1
)
self
.
eval
()
# Chameleon's VQ model is frozen
def
encode
(
self
,
pixel_values
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
hidden_states
=
self
.
encoder
(
pixel_values
)
hidden_states
=
self
.
quant_conv
(
hidden_states
)
quant
,
emb_loss
,
indices
=
self
.
quantize
(
hidden_states
)
return
quant
,
emb_loss
,
indices
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonImageVocabularyMapping #noqa
class
ChameleonImageVocabularyMapping
:
"""
A class for mapping discrete image tokens from VQGAN to BPE tokens.
"""
def
__init__
(
self
,
vocab_map
:
Dict
[
str
,
int
]):
self
.
vocab_map
=
vocab_map
self
.
image_token_id
=
vocab_map
.
get
(
"<image>"
)
@
cached_property
def
val2name
(
self
):
return
{
v
:
k
for
k
,
v
in
self
.
vocab_map
.
items
()}
@
cached_property
def
image_tokens
(
self
):
return
sorted
([
val
for
name
,
val
in
self
.
vocab_map
.
items
()
if
name
.
startswith
(
"IMGIMG"
)
])
@
cached_property
def
bpe2img
(
self
):
img_tkn_chr_mapping
=
{
chr
(
ord
(
"A"
)
+
i
):
str
(
i
)
for
i
in
range
(
10
)}
def
remap
(
old_name
:
str
)
->
str
:
return
""
.
join
(
img_tkn_chr_mapping
.
get
(
c
,
c
)
for
c
in
old_name
[
len
(
"IMGIMG"
):
-
1
])
return
{
tok
:
int
(
remap
(
self
.
val2name
[
tok
]))
for
tok
in
self
.
image_tokens
}
@
cached_property
def
img2bpe
(
self
):
return
{
v
:
k
for
k
,
v
in
self
.
bpe2img
.
items
()}
@
cached_property
def
bpe2img_search_tensors
(
self
):
return
torch
.
tensor
(
sorted
(
self
.
bpe2img
.
keys
())),
torch
.
tensor
(
sorted
(
self
.
bpe2img
.
values
()))
@
cached_property
def
img2bpe_mapping_tensor
(
self
):
mapping
=
torch
.
zeros
(
max
(
self
.
img2bpe
.
keys
())
+
1
,
dtype
=
torch
.
int
)
for
k
,
v
in
self
.
img2bpe
.
items
():
mapping
[
k
]
=
v
return
mapping
def
convert_img2bpe
(
self
,
img_batch
:
torch
.
Tensor
)
->
torch
.
Tensor
:
device
=
img_batch
.
device
img_tokens
=
self
.
img2bpe_mapping_tensor
[
img_batch
.
to
(
"cpu"
)]
return
img_tokens
.
to
(
device
)
class
ChameleonModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
ChameleonConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
vocab_size
,
config
.
hidden_size
,
)
self
.
vocabulary_mapping
=
ChameleonImageVocabularyMapping
(
config
.
vocabulary_map
)
decoder_layer
=
ChameleonDecoderLayer
if
not
self
.
config
.
swin_norm
\
else
ChameleonSwinDecoderLayer
self
.
layers
=
nn
.
ModuleList
([
decoder_layer
(
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
vqmodel
=
ChameleonVQVAE
(
config
.
vq_config
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
get_image_tokens
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Tokenizes images into discrete tokens with VQGAN module. Converts
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
special tokens.
"""
batch_size
=
pixel_values
.
shape
[
0
]
_
,
_
,
image_toks
=
self
.
vqmodel
.
encode
(
pixel_values
)
bpe_toks
=
self
.
vocabulary_mapping
.
convert_img2bpe
(
image_toks
)
bpe_toks
=
bpe_toks
.
view
(
batch_size
,
-
1
)
return
bpe_toks
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
attn_metadata
,
residual
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
()
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_chameleon_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_chameleon
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_chameleon
)
class
ChameleonForConditionalGeneration
(
nn
.
Module
,
SupportsVision
):
def
__init__
(
self
,
config
:
ChameleonConfig
,
multimodal_config
:
MultiModalConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
model
=
ChameleonModel
(
config
,
cache_config
,
quant_config
)
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
def
_validate_pixel_values
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
expected_dims
=
(
3
,
CHAMELEON_CROP_SIZE_HEIGHT
,
CHAMELEON_CROP_SIZE_WIDTH
)
actual_dims
=
tuple
(
data
.
shape
[
1
:])
if
actual_dims
!=
expected_dims
:
expected_expr
=
(
"batch_size"
,
*
map
(
str
,
expected_dims
))
raise
ValueError
(
f
"The expected shape of pixel values is
{
expected_expr
}
. "
f
"You supplied
{
tuple
(
data
.
shape
)
}
."
)
return
data
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
ChameleonImagePixelInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
if
pixel_values
is
None
:
return
None
if
not
isinstance
(
pixel_values
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
return
ChameleonImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
not
None
:
assert
self
.
model
.
vqmodel
is
not
None
image_tokens
=
self
.
model
.
get_image_tokens
(
image_input
[
"data"
].
to
(
self
.
config
.
torch_dtype
))
image_token_id
=
self
.
model
.
vocabulary_mapping
.
image_token_id
special_image_mask
=
input_ids
==
image_token_id
image_tokens
=
image_tokens
.
to
(
input_ids
.
device
,
input_ids
.
dtype
)
input_ids
=
input_ids
.
masked_scatter
(
special_image_mask
,
image_tokens
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
# Disallow image tokens which does not include special
# begin-image and end-image tokens
image_tokens
=
self
.
model
.
vocabulary_mapping
.
image_tokens
logits
[:,
image_tokens
]
=
torch
.
finfo
(
logits
.
dtype
).
min
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
use_default_weight_loading
=
False
if
"vqmodel"
in
name
:
if
self
.
model
.
vqmodel
is
not
None
:
# We only do sharding for language model and
# not vqvae for now.
use_default_weight_loading
=
True
else
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Remapping the name of FP8 kv-scale.
if
name
.
endswith
(
"kv_scale"
):
remapped_kv_scale_name
=
name
.
replace
(
".kv_scale"
,
".attn.kv_scale"
)
if
remapped_kv_scale_name
not
in
params_dict
:
print_warning_once
(
"Found kv scale in the checkpoint (e.g. "
f
"
{
name
}
), but not found the expected name in "
f
"the model (e.g.
{
remapped_kv_scale_name
}
). "
"kv-scale is not loaded."
)
continue
else
:
name
=
remapped_kv_scale_name
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
if
use_default_weight_loading
and
name
in
params_dict
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/gpt2.py
View file @
500b93c8
...
@@ -51,6 +51,7 @@ class GPT2Attention(nn.Module):
...
@@ -51,6 +51,7 @@ class GPT2Attention(nn.Module):
config
:
GPT2Config
,
config
:
GPT2Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -68,12 +69,14 @@ class GPT2Attention(nn.Module):
...
@@ -68,12 +69,14 @@ class GPT2Attention(nn.Module):
total_num_heads
,
total_num_heads
,
bias
=
True
,
bias
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.c_attn"
,
)
)
self
.
c_proj
=
RowParallelLinear
(
self
.
c_proj
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
bias
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.c_proj"
,
)
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
...
@@ -101,6 +104,7 @@ class GPT2MLP(nn.Module):
...
@@ -101,6 +104,7 @@ class GPT2MLP(nn.Module):
intermediate_size
:
int
,
intermediate_size
:
int
,
config
:
GPT2Config
,
config
:
GPT2Config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
hidden_size
=
config
.
hidden_size
hidden_size
=
config
.
hidden_size
...
@@ -109,12 +113,14 @@ class GPT2MLP(nn.Module):
...
@@ -109,12 +113,14 @@ class GPT2MLP(nn.Module):
intermediate_size
,
intermediate_size
,
bias
=
True
,
bias
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.c_fc"
,
)
)
self
.
c_proj
=
RowParallelLinear
(
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
intermediate_size
,
hidden_size
,
hidden_size
,
bias
=
True
,
bias
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.c_proj"
,
)
)
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
intermediate_size
)
intermediate_size
)
...
@@ -133,6 +139,7 @@ class GPT2Block(nn.Module):
...
@@ -133,6 +139,7 @@ class GPT2Block(nn.Module):
config
:
GPT2Config
,
config
:
GPT2Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
hidden_size
=
config
.
hidden_size
hidden_size
=
config
.
hidden_size
...
@@ -140,9 +147,15 @@ class GPT2Block(nn.Module):
...
@@ -140,9 +147,15 @@ class GPT2Block(nn.Module):
hidden_size
)
hidden_size
)
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
GPT2Attention
(
config
,
cache_config
,
quant_config
)
self
.
attn
=
GPT2Attention
(
config
,
cache_config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
self
.
ln_2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
GPT2MLP
(
inner_dim
,
config
,
quant_config
)
self
.
mlp
=
GPT2MLP
(
inner_dim
,
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -175,6 +188,7 @@ class GPT2Model(nn.Module):
...
@@ -175,6 +188,7 @@ class GPT2Model(nn.Module):
config
:
GPT2Config
,
config
:
GPT2Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -186,7 +200,9 @@ class GPT2Model(nn.Module):
...
@@ -186,7 +200,9 @@ class GPT2Model(nn.Module):
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
embed_dim
)
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
embed_dim
)
self
.
start_layer
,
self
.
end_layer
,
self
.
h
=
make_layers
(
self
.
start_layer
,
self
.
end_layer
,
self
.
h
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
lambda
:
GPT2Block
(
config
,
cache_config
,
quant_config
))
lambda
prefix
:
GPT2Block
(
config
,
cache_config
,
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.h"
)
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
def
forward
(
...
@@ -229,7 +245,10 @@ class GPT2LMHeadModel(nn.Module):
...
@@ -229,7 +245,10 @@ class GPT2LMHeadModel(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
transformer
=
GPT2Model
(
config
,
cache_config
,
quant_config
)
self
.
transformer
=
GPT2Model
(
config
,
cache_config
,
quant_config
,
prefix
=
"transformer"
)
self
.
lm_head
=
self
.
transformer
.
wte
self
.
lm_head
=
self
.
transformer
.
wte
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/interfaces.py
View file @
500b93c8
...
@@ -3,7 +3,7 @@ from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
...
@@ -3,7 +3,7 @@ from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
from
typing_extensions
import
TypeGuard
from
typing_extensions
import
TypeGuard
from
vllm.config
import
LoRAConfig
,
MultiModalConfig
from
vllm.config
import
LoRAConfig
,
MultiModalConfig
,
SchedulerConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -142,3 +142,49 @@ def _supports_lora(
...
@@ -142,3 +142,49 @@ def _supports_lora(
return
isinstance
(
model
,
_SupportsLoRAType
)
return
isinstance
(
model
,
_SupportsLoRAType
)
return
isinstance
(
model
,
SupportsLoRA
)
return
isinstance
(
model
,
SupportsLoRA
)
@
runtime_checkable
class
HasInnerState
(
Protocol
):
"""The interface required for all models that has inner state."""
has_inner_state
:
ClassVar
[
Literal
[
True
]]
=
True
"""
A flag that indicates this model has inner state.
Models that has inner state usually need access to the scheduler_config
for max_num_seqs ,etc... (Currently only used by Jamba)
"""
def
__init__
(
self
,
*
,
scheduler_config
:
Optional
[
SchedulerConfig
]
=
None
)
->
None
:
...
@
runtime_checkable
class
_HasInnerStateType
(
Protocol
):
has_inner_state
:
ClassVar
[
Literal
[
True
]]
def
__init__
(
self
,
*
,
scheduler_config
:
Optional
[
SchedulerConfig
]
=
None
)
->
None
:
...
@
overload
def
has_inner_state
(
model
:
object
)
->
TypeGuard
[
HasInnerState
]:
...
@
overload
def
has_inner_state
(
model
:
Type
[
object
])
->
TypeGuard
[
Type
[
HasInnerState
]]:
...
def
has_inner_state
(
model
:
Union
[
Type
[
object
],
object
]
)
->
Union
[
TypeGuard
[
Type
[
HasInnerState
]],
TypeGuard
[
HasInnerState
]]:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
_HasInnerStateType
)
return
isinstance
(
model
,
HasInnerState
)
vllm/model_executor/models/jamba.py
View file @
500b93c8
...
@@ -13,7 +13,7 @@ from transformers import JambaConfig
...
@@ -13,7 +13,7 @@ from transformers import JambaConfig
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
)
...
@@ -32,10 +32,12 @@ from vllm.model_executor.layers.sampler import Sampler
...
@@ -32,10 +32,12 @@ from vllm.model_executor.layers.sampler import Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
HasInnerState
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.worker.model_runner
import
_BATCH_SIZES_TO_CAPTURE
from
vllm.worker.model_runner
import
(
_BATCH_SIZES_TO_CAPTURE
,
_get_graph_batch_size
)
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
@@ -612,7 +614,7 @@ class JambaModel(nn.Module):
...
@@ -612,7 +614,7 @@ class JambaModel(nn.Module):
return
hidden_states
return
hidden_states
class
JambaForCausalLM
(
nn
.
Module
):
class
JambaForCausalLM
(
nn
.
Module
,
HasInnerState
):
packed_modules_mapping
=
{
packed_modules_mapping
=
{
"qkv_proj"
:
[
"qkv_proj"
:
[
"q_proj"
,
"q_proj"
,
...
@@ -640,9 +642,11 @@ class JambaForCausalLM(nn.Module):
...
@@ -640,9 +642,11 @@ class JambaForCausalLM(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
scheduler_config
:
Optional
[
SchedulerConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
scheduler_config
=
scheduler_config
self
.
model
=
JambaModel
(
config
,
self
.
model
=
JambaModel
(
config
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
...
@@ -689,6 +693,8 @@ class JambaForCausalLM(nn.Module):
...
@@ -689,6 +693,8 @@ class JambaForCausalLM(nn.Module):
for
key
in
[
"request_ids_to_seq_ids"
,
"finished_requests_ids"
])
for
key
in
[
"request_ids_to_seq_ids"
,
"finished_requests_ids"
])
request_ids_to_seq_ids
=
kwargs
[
"request_ids_to_seq_ids"
]
request_ids_to_seq_ids
=
kwargs
[
"request_ids_to_seq_ids"
]
finished_requests_ids
=
kwargs
[
"finished_requests_ids"
]
self
.
_release_mamba_cache
(
finished_requests_ids
)
batch_size
=
input_ids
.
shape
[
0
]
batch_size
=
input_ids
.
shape
[
0
]
if
attn_metadata
.
prefill_metadata
:
if
attn_metadata
.
prefill_metadata
:
batch_size
=
len
(
request_ids_to_seq_ids
)
batch_size
=
len
(
request_ids_to_seq_ids
)
...
@@ -696,9 +702,8 @@ class JambaForCausalLM(nn.Module):
...
@@ -696,9 +702,8 @@ class JambaForCausalLM(nn.Module):
current_seqlen_agnostic_cache
,
current_seqlen_agnostic_cache
,
indices
,
indices
,
)
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
)
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
batch_size
)
batch_size
,
finished_requests_ids
=
kwargs
[
"finished_requests_ids"
]
finished_requests_ids
)
self
.
_release_mamba_cache
(
finished_requests_ids
)
else
:
else
:
# CUDA graph capturing runs
# CUDA graph capturing runs
current_seqlen_agnostic_cache
,
indices
=
(
current_seqlen_agnostic_cache
,
indices
=
(
...
@@ -760,10 +765,15 @@ class JambaForCausalLM(nn.Module):
...
@@ -760,10 +765,15 @@ class JambaForCausalLM(nn.Module):
return
indices_for_current_run
return
indices_for_current_run
def
_prepare_current_run_mamba_cache
(
def
_prepare_current_run_mamba_cache
(
self
,
request_ids_to_seq_ids
:
Dict
[
str
,
list
[
int
]],
batch_size
:
int
self
,
request_ids_to_seq_ids
:
Dict
[
str
,
list
[
int
]],
batch_size
:
int
,
finished_requests_ids
:
List
[
str
]
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
List
[
int
]]:
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
List
[
int
]]:
indices_for_current_run
=
[]
indices_for_current_run
=
[]
for
request_id
,
seqs_id
in
request_ids_to_seq_ids
.
items
():
for
request_id
,
seqs_id
in
request_ids_to_seq_ids
.
items
():
if
request_id
in
finished_requests_ids
:
# Do not allocate cache for requests that run
# and finish right after
continue
indices_for_current_run
+=
self
.
_assign_seq_id_to_mamba_cache
(
indices_for_current_run
+=
self
.
_assign_seq_id_to_mamba_cache
(
request_id
,
seqs_id
)
request_id
,
seqs_id
)
## Pad the batch in case of running batch that was not captured via CG
## Pad the batch in case of running batch that was not captured via CG
...
@@ -787,16 +797,17 @@ class JambaForCausalLM(nn.Module):
...
@@ -787,16 +797,17 @@ class JambaForCausalLM(nn.Module):
assert
all
(
assert
all
(
key
in
kwargs
key
in
kwargs
for
key
in
[
"request_ids_to_seq_ids"
,
"finished_requests_ids"
])
for
key
in
[
"request_ids_to_seq_ids"
,
"finished_requests_ids"
])
finished_requests_ids
=
kwargs
[
"finished_requests_ids"
]
self
.
_release_mamba_cache
(
finished_requests_ids
)
request_ids_to_seq_ids
=
kwargs
[
"request_ids_to_seq_ids"
]
request_ids_to_seq_ids
=
kwargs
[
"request_ids_to_seq_ids"
]
cg_batch_size
=
input_buffers
[
'input_ids'
].
shape
[
0
]
cg_batch_size
=
input_buffers
[
'input_ids'
].
shape
[
0
]
(
(
current_mamba_cache
,
current_mamba_cache
,
indices
,
indices
,
)
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
)
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
cg_batch_size
)
cg_batch_size
,
finished_requests_ids
)
self
.
current_indices
=
indices
self
.
current_indices
=
indices
finished_requests_ids
=
kwargs
[
"finished_requests_ids"
]
self
.
_release_mamba_cache
(
finished_requests_ids
)
for
input_buffer
,
current_cache_buffer
in
zip
(
for
input_buffer
,
current_cache_buffer
in
zip
(
input_buffers
[
"seqlen_agnostic_capture_inputs"
],
input_buffers
[
"seqlen_agnostic_capture_inputs"
],
...
@@ -860,9 +871,12 @@ class JambaForCausalLM(nn.Module):
...
@@ -860,9 +871,12 @@ class JambaForCausalLM(nn.Module):
layers_type
=
self
.
config
.
layers_block_type
layers_type
=
self
.
config
.
layers_block_type
mamba_layers
=
sum
(
mamba_layers
=
sum
(
[
layer_type
==
"mamba"
for
layer_type
in
layers_type
])
[
layer_type
==
"mamba"
for
layer_type
in
layers_type
])
max_batch_size
=
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
+
10
max_batch_size
=
(
_get_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
if
self
.
scheduler_config
else
max
(
_BATCH_SIZES_TO_CAPTURE
))
+
10
conv_state_shape
,
temporal_state_shape
=
self
.
_get_mamba_cache_shape
()
conv_state_shape
,
temporal_state_shape
=
self
.
_get_mamba_cache_shape
()
assert
conv_state_shape
is
not
None
and
temporal_state_shape
is
not
None
assert
conv_state_shape
is
not
None
and
temporal_state_shape
is
not
None
for
buffername
in
[
"mamba_cache"
,
"mamba_gc_cache_buffer"
]:
for
buffername
in
[
"mamba_cache"
,
"mamba_gc_cache_buffer"
]:
buffer
=
(
torch
.
empty
(
size
=
(
mamba_layers
,
max_batch_size
)
+
buffer
=
(
torch
.
empty
(
size
=
(
mamba_layers
,
max_batch_size
)
+
conv_state_shape
,
conv_state_shape
,
...
...
vllm/model_executor/models/llama.py
View file @
500b93c8
...
@@ -41,18 +41,20 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...
@@ -41,18 +41,20 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
get_compressed_tensors_cache_scale
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
kv_cache_scales_loader
)
default_weight_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.utils
import
is_hip
,
print_warning_once
from
vllm.utils
import
is_hip
from
.interfaces
import
SupportsLoRA
from
.interfaces
import
SupportsLoRA
from
.utils
import
is_pp_missing_parameter
,
make_layers
from
.utils
import
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
...
@@ -67,17 +69,20 @@ class LlamaMLP(nn.Module):
...
@@ -67,17 +69,20 @@ class LlamaMLP(nn.Module):
hidden_act
:
str
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
bias
:
bool
=
False
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
input_size
=
hidden_size
,
input_size
=
hidden_size
,
output_sizes
=
[
intermediate_size
]
*
2
,
output_sizes
=
[
intermediate_size
]
*
2
,
bias
=
bias
,
bias
=
bias
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
)
self
.
down_proj
=
RowParallelLinear
(
input_size
=
intermediate_size
,
self
.
down_proj
=
RowParallelLinear
(
input_size
=
intermediate_size
,
output_size
=
hidden_size
,
output_size
=
hidden_size
,
bias
=
bias
,
bias
=
bias
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
)
if
hidden_act
!=
"silu"
:
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
"Only silu is supported for now."
)
...
@@ -94,6 +99,7 @@ class LlamaAttention(nn.Module):
...
@@ -94,6 +99,7 @@ class LlamaAttention(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
LlamaConfig
,
hidden_size
:
int
,
hidden_size
:
int
,
num_heads
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
num_kv_heads
:
int
,
...
@@ -103,6 +109,7 @@ class LlamaAttention(nn.Module):
...
@@ -103,6 +109,7 @@ class LlamaAttention(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
bias
:
bool
=
False
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -120,7 +127,9 @@ class LlamaAttention(nn.Module):
...
@@ -120,7 +127,9 @@ class LlamaAttention(nn.Module):
# the KV heads across multiple tensor parallel GPUs.
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self
.
head_dim
=
getattr
(
config
,
"head_dim"
,
self
.
hidden_size
//
self
.
total_num_heads
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
scaling
=
self
.
head_dim
**-
0.5
...
@@ -134,12 +143,14 @@ class LlamaAttention(nn.Module):
...
@@ -134,12 +143,14 @@ class LlamaAttention(nn.Module):
total_num_kv_heads
=
self
.
total_num_kv_heads
,
total_num_kv_heads
=
self
.
total_num_kv_heads
,
bias
=
bias
,
bias
=
bias
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
input_size
=
self
.
total_num_heads
*
self
.
head_dim
,
input_size
=
self
.
total_num_heads
*
self
.
head_dim
,
output_size
=
hidden_size
,
output_size
=
hidden_size
,
bias
=
bias
,
bias
=
bias
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
)
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
...
@@ -180,6 +191,7 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -180,6 +191,7 @@ class LlamaDecoderLayer(nn.Module):
config
:
LlamaConfig
,
config
:
LlamaConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -196,6 +208,7 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -196,6 +208,7 @@ class LlamaDecoderLayer(nn.Module):
attention_bias
=
getattr
(
config
,
"attention_bias"
,
False
)
or
getattr
(
attention_bias
=
getattr
(
config
,
"attention_bias"
,
False
)
or
getattr
(
config
,
"bias"
,
False
)
config
,
"bias"
,
False
)
self
.
self_attn
=
LlamaAttention
(
self
.
self_attn
=
LlamaAttention
(
config
=
config
,
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
getattr
(
config
,
"num_key_value_heads"
,
num_kv_heads
=
getattr
(
config
,
"num_key_value_heads"
,
...
@@ -206,6 +219,7 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -206,6 +219,7 @@ class LlamaDecoderLayer(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
bias
=
attention_bias
,
bias
=
attention_bias
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
)
self
.
mlp
=
LlamaMLP
(
self
.
mlp
=
LlamaMLP
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
...
@@ -213,6 +227,7 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -213,6 +227,7 @@ class LlamaDecoderLayer(nn.Module):
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
bias
=
getattr
(
config
,
"mlp_bias"
,
False
),
bias
=
getattr
(
config
,
"mlp_bias"
,
False
),
prefix
=
f
"
{
prefix
}
.mlp"
,
)
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
...
@@ -256,6 +271,7 @@ class LlamaModel(nn.Module):
...
@@ -256,6 +271,7 @@ class LlamaModel(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -264,17 +280,26 @@ class LlamaModel(nn.Module):
...
@@ -264,17 +280,26 @@ class LlamaModel(nn.Module):
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
org_vocab_size
=
config
.
vocab_size
self
.
org_vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
if
get_pp_group
().
is_first_rank
or
(
config
.
tie_word_embeddings
self
.
vocab_size
,
and
get_pp_group
().
is_last_rank
):
config
.
hidden_size
,
self
.
embed_tokens
=
VocabParallelEmbedding
(
org_num_embeddings
=
config
.
vocab_size
,
self
.
vocab_size
,
)
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
lambda
:
LlamaDecoderLayer
(
config
=
config
,
lambda
prefix
:
LlamaDecoderLayer
(
config
=
config
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
))
quant_config
=
quant_config
,
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
)
if
get_pp_group
().
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
()
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
return
self
.
embed_tokens
(
input_ids
)
...
@@ -366,27 +391,33 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -366,27 +391,33 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
self
.
model
=
LlamaModel
(
config
,
self
.
model
=
LlamaModel
(
config
,
cache_config
,
cache_config
,
quant_config
,
quant_config
,
lora_config
=
lora_config
)
lora_config
=
lora_config
,
self
.
unpadded_vocab_size
=
config
.
vocab_size
prefix
=
"model"
)
if
lora_config
:
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
lm_head
=
ParallelLMHead
(
if
lora_config
:
self
.
unpadded_vocab_size
,
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
config
.
hidden_size
,
self
.
lm_head
=
ParallelLMHead
(
org_num_embeddings
=
config
.
vocab_size
,
self
.
unpadded_vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
config
.
hidden_size
,
# We need bigger padding if using lora for kernel
org_num_embeddings
=
config
.
vocab_size
,
# compatibility
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
# We need bigger padding if using lora for kernel
quant_config
=
quant_config
,
# compatibility
)
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
if
config
.
tie_word_embeddings
:
quant_config
=
quant_config
,
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
)
if
config
.
tie_word_embeddings
:
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
sampler
=
Sampler
()
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
...
@@ -449,6 +480,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -449,6 +480,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
# Models trained using ColossalAI may include these tensors in
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
# the checkpoint. Skip them.
continue
continue
if
scale_name
:
=
get_compressed_tensors_cache_scale
(
name
):
# Loading kv cache scales for compressed-tensors quantization
param
=
params_dict
[
scale_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
loaded_weight
=
loaded_weight
[
0
]
weight_loader
(
param
,
loaded_weight
)
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
...
@@ -470,18 +509,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -470,18 +509,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
# Remapping the name of FP8 kv-scale.
# Remapping the name of FP8 kv-scale.
if
name
.
endswith
(
"kv_scale"
):
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
remapped_kv_scale_name
=
name
.
replace
(
if
name
is
None
:
".kv_scale"
,
".attn.kv_scale"
)
continue
if
remapped_kv_scale_name
not
in
params_dict
:
print_warning_once
(
f
"Found kv scale in the checkpoint (e.g.
{
name
}
), "
"but not found the expected name in the model "
f
"(e.g.
{
remapped_kv_scale_name
}
). kv-scale is "
"not loaded."
)
continue
else
:
name
=
remapped_kv_scale_name
if
is_pp_missing_parameter
(
name
,
self
):
if
is_pp_missing_parameter
(
name
,
self
):
continue
continue
...
...
vllm/model_executor/models/llava.py
View file @
500b93c8
...
@@ -155,7 +155,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
...
@@ -155,7 +155,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
quant_config
=
quant_config
)
quant_config
=
quant_config
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
config
.
text_config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
def
_validate_pixel_values
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_validate_pixel_values
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
vllm/model_executor/models/llava_next.py
View file @
500b93c8
...
@@ -249,7 +249,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
...
@@ -249,7 +249,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
quant_config
=
quant_config
)
quant_config
=
quant_config
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
config
.
text_config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
self
.
image_newline
=
nn
.
Parameter
(
self
.
image_newline
=
nn
.
Parameter
(
...
...
vllm/model_executor/models/mixtral.py
View file @
500b93c8
...
@@ -29,7 +29,7 @@ from transformers import MixtralConfig
...
@@ -29,7 +29,7 @@ from transformers import MixtralConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
...
@@ -42,12 +42,13 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -42,12 +42,13 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.utils
import
print_warning_once
from
.interfaces
import
SupportsLoRA
from
.interfaces
import
SupportsLoRA
from
.utils
import
is_pp_missing_parameter
,
make_layers
class
MixtralMoE
(
nn
.
Module
):
class
MixtralMoE
(
nn
.
Module
):
...
@@ -66,7 +67,8 @@ class MixtralMoE(nn.Module):
...
@@ -66,7 +67,8 @@ class MixtralMoE(nn.Module):
intermediate_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
):
tp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -75,7 +77,8 @@ class MixtralMoE(nn.Module):
...
@@ -75,7 +77,8 @@ class MixtralMoE(nn.Module):
num_experts
,
num_experts
,
bias
=
False
,
bias
=
False
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
quant_config
=
None
)
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.gate"
)
self
.
experts
=
FusedMoE
(
num_experts
=
num_experts
,
self
.
experts
=
FusedMoE
(
num_experts
=
num_experts
,
top_k
=
top_k
,
top_k
=
top_k
,
...
@@ -85,7 +88,8 @@ class MixtralMoE(nn.Module):
...
@@ -85,7 +88,8 @@ class MixtralMoE(nn.Module):
reduce_results
=
True
,
reduce_results
=
True
,
renormalize
=
True
,
renormalize
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
tp_size
=
tp_size
)
tp_size
=
tp_size
,
prefix
=
f
"
{
prefix
}
.experts"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# NOTE: hidden_states can have either 1D or 2D shape.
# NOTE: hidden_states can have either 1D or 2D shape.
...
@@ -108,6 +112,7 @@ class MixtralAttention(nn.Module):
...
@@ -108,6 +112,7 @@ class MixtralAttention(nn.Module):
rope_theta
:
float
=
10000
,
rope_theta
:
float
=
10000
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -138,12 +143,14 @@ class MixtralAttention(nn.Module):
...
@@ -138,12 +143,14 @@ class MixtralAttention(nn.Module):
self
.
total_num_kv_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
)
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
head_dim
,
...
@@ -181,6 +188,7 @@ class MixtralDecoderLayer(nn.Module):
...
@@ -181,6 +188,7 @@ class MixtralDecoderLayer(nn.Module):
config
:
MixtralConfig
,
config
:
MixtralConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -193,13 +201,15 @@ class MixtralDecoderLayer(nn.Module):
...
@@ -193,13 +201,15 @@ class MixtralDecoderLayer(nn.Module):
num_kv_heads
=
config
.
num_key_value_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
rope_theta
=
rope_theta
,
rope_theta
=
rope_theta
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
)
self
.
block_sparse_moe
=
MixtralMoE
(
self
.
block_sparse_moe
=
MixtralMoE
(
num_experts
=
config
.
num_local_experts
,
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.block_sparse_moe"
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
...
@@ -242,6 +252,7 @@ class MixtralModel(nn.Module):
...
@@ -242,6 +252,7 @@ class MixtralModel(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
...
@@ -255,12 +266,14 @@ class MixtralModel(nn.Module):
...
@@ -255,12 +266,14 @@ class MixtralModel(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
org_num_embeddings
=
config
.
vocab_size
,
)
)
self
.
layers
=
nn
.
ModuleList
([
MixtralDecoderLayer
(
config
,
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
cache_config
,
config
.
num_hidden_layers
,
quant_config
=
quant_config
)
lambda
prefix
:
MixtralDecoderLayer
(
for
_
in
range
(
config
.
num_hidden_layers
)
config
,
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
])
),
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
def
forward
(
...
@@ -269,14 +282,25 @@ class MixtralModel(nn.Module):
...
@@ -269,14 +282,25 @@ class MixtralModel(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
get_pp_group
().
is_first_rank
:
residual
=
None
hidden_states
=
self
.
embed_tokens
(
input_ids
)
for
i
in
range
(
len
(
self
.
layers
)):
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
attn_metadata
,
kv_caches
[
i
-
self
.
start_layer
],
residual
)
attn_metadata
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
...
@@ -320,7 +344,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
...
@@ -320,7 +344,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
self
.
model
=
MixtralModel
(
config
,
self
.
model
=
MixtralModel
(
config
,
cache_config
,
cache_config
,
quant_config
,
quant_config
,
lora_config
=
lora_config
)
lora_config
=
lora_config
,
prefix
=
"model"
)
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
...
@@ -347,7 +372,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
...
@@ -347,7 +372,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
attn_metadata
,
intermediate_tensors
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
@@ -356,6 +381,20 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
...
@@ -356,6 +381,20 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
sampling_metadata
)
sampling_metadata
)
return
logits
return
logits
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
"hidden_states"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
"residual"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
})
def
sample
(
def
sample
(
self
,
self
,
logits
:
Optional
[
torch
.
Tensor
],
logits
:
Optional
[
torch
.
Tensor
],
...
@@ -392,6 +431,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
...
@@ -392,6 +431,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
...
@@ -402,6 +445,9 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
...
@@ -402,6 +445,9 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
weight_loader
(
param
,
...
@@ -414,20 +460,14 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
...
@@ -414,20 +460,14 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
# Remapping the name of FP8 kv-scale.
# Remapping the name of FP8 kv-scale.
if
name
.
endswith
(
"kv_scale"
):
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
remapped_kv_scale_name
=
name
.
replace
(
if
name
is
None
:
".kv_scale"
,
".attn.kv_scale"
)
continue
if
remapped_kv_scale_name
not
in
params_dict
:
print_warning_once
(
"Found kv scale in the checkpoint "
f
"(e.g.
{
name
}
), but not found the expected "
f
"name in the model "
f
"(e.g.
{
remapped_kv_scale_name
}
). "
"kv-scale is not loaded."
)
continue
else
:
name
=
remapped_kv_scale_name
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
...
...
Prev
1
…
8
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment