Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0640f227
Commit
0640f227
authored
Sep 09, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.0' into v0.6.0-dev
parents
82f1ffdf
32e7db25
Changes
335
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
751 additions
and
213 deletions
+751
-213
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+15
-11
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+247
-43
vllm/model_executor/layers/spec_decode_base_sampler.py
vllm/model_executor/layers/spec_decode_base_sampler.py
+25
-18
vllm/model_executor/layers/typical_acceptance_sampler.py
vllm/model_executor/layers/typical_acceptance_sampler.py
+4
-3
vllm/model_executor/layers/vocab_parallel_embedding.py
vllm/model_executor/layers/vocab_parallel_embedding.py
+10
-4
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+126
-79
vllm/model_executor/model_loader/neuron.py
vllm/model_executor/model_loader/neuron.py
+82
-19
vllm/model_executor/model_loader/openvino.py
vllm/model_executor/model_loader/openvino.py
+1
-2
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+2
-2
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+3
-0
vllm/model_executor/models/arctic.py
vllm/model_executor/models/arctic.py
+2
-2
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+2
-2
vllm/model_executor/models/bart.py
vllm/model_executor/models/bart.py
+2
-2
vllm/model_executor/models/blip.py
vllm/model_executor/models/blip.py
+89
-1
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+12
-6
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+2
-2
vllm/model_executor/models/chameleon.py
vllm/model_executor/models/chameleon.py
+6
-3
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+2
-2
vllm/model_executor/models/clip.py
vllm/model_executor/models/clip.py
+117
-10
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+2
-2
No files found.
vllm/model_executor/layers/rotary_embedding.py
View file @
0640f227
...
@@ -503,8 +503,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
...
@@ -503,8 +503,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
short_factor
:
List
[
float
],
short_factor
:
List
[
float
],
long_factor
:
List
[
float
],
long_factor
:
List
[
float
],
short_mscale
:
float
=
1.0
,
short_mscale
:
Optional
[
float
]
=
None
,
long_mscale
:
float
=
1.0
,
long_mscale
:
Optional
[
float
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -523,18 +523,22 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
...
@@ -523,18 +523,22 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
self
.
base
=
base
self
.
base
=
base
self
.
short_factor
=
short_factor
self
.
short_factor
=
short_factor
self
.
long_factor
=
long_factor
self
.
long_factor
=
long_factor
self
.
short_mscale
=
short_mscale
self
.
long_mscale
=
long_mscale
scale
=
(
self
.
max_position_embeddings
/
self
.
original_max_position_embeddings
)
scale
=
self
.
max_position_embeddings
/
\
self
.
original_max_position_embeddings
if
scale
<=
1.0
:
if
scale
<=
1.0
:
self
.
scaling_factor
=
1.0
scaling_factor
=
1.0
else
:
else
:
self
.
scaling_factor
=
math
.
sqrt
(
scaling_factor
=
math
.
sqrt
(
1
+
math
.
log
(
scale
)
/
1
+
math
.
log
(
scale
)
/
math
.
log
(
self
.
original_max_position_embeddings
))
math
.
log
(
self
.
original_max_position_embeddings
))
if
short_mscale
is
None
:
short_mscale
=
scaling_factor
if
long_mscale
is
None
:
long_mscale
=
scaling_factor
self
.
short_mscale
=
short_mscale
self
.
long_mscale
=
long_mscale
short_cache
=
self
.
_compute_cos_sin_cache
(
short_cache
=
self
.
_compute_cos_sin_cache
(
original_max_position_embeddings
,
short_factor
,
short_mscale
)
original_max_position_embeddings
,
short_factor
,
short_mscale
)
...
@@ -571,8 +575,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
...
@@ -571,8 +575,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
inv_freq
=
self
.
_compute_inv_freq
(
rescale_factors
)
inv_freq
=
self
.
_compute_inv_freq
(
rescale_factors
)
t
=
torch
.
arange
(
max_position_embeddings
,
dtype
=
torch
.
float
)
t
=
torch
.
arange
(
max_position_embeddings
,
dtype
=
torch
.
float
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
*
mscale
*
self
.
scaling_factor
cos
=
freqs
.
cos
()
*
mscale
sin
=
freqs
.
sin
()
*
mscale
*
self
.
scaling_factor
sin
=
freqs
.
sin
()
*
mscale
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
return
cache
...
...
vllm/model_executor/layers/sampler.py
View file @
0640f227
"""A layer that samples the next tokens from the model's outputs."""
"""A layer that samples the next tokens from the model's outputs."""
import
itertools
import
itertools
import
warnings
import
warnings
from
dataclasses
import
dataclass
from
importlib.util
import
find_spec
from
importlib.util
import
find_spec
from
math
import
inf
from
math
import
inf
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
msgspec
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.spec_decode.metrics
import
SpecDecodeWorkerMetrics
from
vllm.triton_utils
import
HAS_TRITON
from
vllm.triton_utils
import
HAS_TRITON
if
HAS_TRITON
:
if
HAS_TRITON
:
...
@@ -19,8 +22,7 @@ from vllm.model_executor.sampling_metadata import (SamplingMetadata,
...
@@ -19,8 +22,7 @@ from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SequenceGroupToSample
)
SequenceGroupToSample
)
from
vllm.sampling_params
import
SamplingType
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
PromptLogprobs
,
SampleLogprobs
,
SequenceOutput
)
SequenceOutput
)
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
find_spec
(
"flashinfer"
):
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
find_spec
(
"flashinfer"
):
import
flashinfer.sampling
import
flashinfer.sampling
...
@@ -35,6 +37,116 @@ else:
...
@@ -35,6 +37,116 @@ else:
# (num_token_ids, num_parent_ids) per sequence group.
# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType
=
List
[
Tuple
[
List
[
int
],
List
[
int
]]]
SampleResultType
=
List
[
Tuple
[
List
[
int
],
List
[
int
]]]
# Types of temporary data structures used for
# computing sample_result
SampleMetadataType
=
Dict
[
SamplingType
,
Tuple
[
List
[
int
],
List
[
SequenceGroupToSample
]]]
MultinomialSamplesType
=
Dict
[
SamplingType
,
torch
.
Tensor
]
SampleResultsDictType
=
Dict
[
int
,
Tuple
[
List
[
int
],
List
[
int
]]]
# Encapsulates temporary data structures for computing
# sample_result.
#
# * For multi-step scheduling: must be returned
# by `Sampler.forward()` and used later to compute the pythonized
# sample_result
#
# * For single-step scheduling: consumed immediately
# inside `Sampler.forward()` to compute pythonized sample_result.
@
dataclass
class
SampleResultArgsType
:
sample_metadata
:
SampleMetadataType
multinomial_samples
:
MultinomialSamplesType
sample_results_dict
:
SampleResultsDictType
sampling_metadata
:
SamplingMetadata
greedy_samples
:
Optional
[
torch
.
Tensor
]
beam_search_logprobs
:
Optional
[
torch
.
Tensor
]
# Union of non-deferred (single-step scheduling)
# vs deferred (multi-step scheduling)
# sample result types
MaybeDeferredSampleResultType
=
Union
[
SampleResultType
,
SampleResultArgsType
]
# Abbreviation of the _sample() return type
SampleReturnType
=
Tuple
[
MaybeDeferredSampleResultType
,
Optional
[
torch
.
Tensor
]]
class
SamplerOutput
(
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
"""For each sequence group, we generate a list of SequenceOutput object,
each of which contains one possible candidate for the next token.
This data structure implements methods, so it can be used like a list, but
also has optional fields for device tensors.
"""
outputs
:
List
[
CompletionSequenceGroupOutput
]
# On-device tensor containing probabilities of each token.
sampled_token_probs
:
Optional
[
torch
.
Tensor
]
=
None
# On-device tensor containing the logprobs of each token.
logprobs
:
Optional
[
"torch.Tensor"
]
=
None
# Holds either (1) the pythonized sampler result (single-step scheduling)
# or (2) what will be arguments for later deferred pythonization of the
# sampler result (muliti-step scheduling)
deferred_sample_results_args
:
Optional
[
SampleResultArgsType
]
=
None
# On-device tensor containing the sampled token ids.
sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
# CPU tensor containing the sampled token ids. Used during multi-step to
# return the sampled token ids from last rank to AsyncLLMEngine to be
# 'broadcasted' to all other PP ranks for next step.
sampled_token_ids_cpu
:
Optional
[
torch
.
Tensor
]
=
None
# Spec decode metrics populated by workers.
spec_decode_worker_metrics
:
Optional
[
SpecDecodeWorkerMetrics
]
=
None
# Optional last hidden states from the model.
hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
# Optional prefill hidden states from the model
# (used for models like EAGLE).
prefill_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
# Time taken in the forward pass for this across all workers
model_forward_time
:
Optional
[
float
]
=
None
# Time taken in the model execute function. This will include model forward,
# block/sync across workers, cpu-gpu sync time and sampling time.
model_execute_time
:
Optional
[
float
]
=
None
def
__getitem__
(
self
,
idx
:
int
):
return
self
.
outputs
[
idx
]
def
__setitem__
(
self
,
idx
:
int
,
value
):
self
.
outputs
[
idx
]
=
value
def
__len__
(
self
):
return
len
(
self
.
outputs
)
def
__eq__
(
self
,
other
:
object
):
return
isinstance
(
other
,
self
.
__class__
)
and
self
.
outputs
==
other
.
outputs
def
__repr__
(
self
)
->
str
:
"""Show the shape of a tensor instead of its values to reduce noise.
"""
sampled_token_probs_repr
=
(
"None"
if
self
.
sampled_token_probs
is
None
else
self
.
sampled_token_probs
.
shape
)
sampled_token_ids_repr
=
(
"None"
if
self
.
sampled_token_ids
is
None
else
self
.
sampled_token_ids
.
shape
)
return
(
f
"SamplerOutput(outputs=
{
self
.
outputs
}
, "
f
"sampled_token_probs=
{
sampled_token_probs_repr
}
, "
f
"sampled_token_ids=
{
sampled_token_ids_repr
}
, "
f
"spec_decode_worker_metrics=
{
self
.
spec_decode_worker_metrics
}
)"
)
class
Sampler
(
nn
.
Module
):
class
Sampler
(
nn
.
Module
):
"""Samples the next tokens from the model's outputs.
"""Samples the next tokens from the model's outputs.
...
@@ -98,6 +210,19 @@ class Sampler(nn.Module):
...
@@ -98,6 +210,19 @@ class Sampler(nn.Module):
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
"""
"""
Single-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Pythonize sampling result & logprobs tensor
Multi-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Defer Pythonization of sampling result & logprobs
tensor
* Encapsulate arguments required for deferred Pythonization
in the :class:`SamplerOutput` structure
Args:
Args:
logits: (num_tokens, vocab_size).
logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling.
sampling_metadata: Metadata for sampling.
...
@@ -150,7 +275,7 @@ class Sampler(nn.Module):
...
@@ -150,7 +275,7 @@ class Sampler(nn.Module):
logprobs
=
torch
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
logprobs
=
torch
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Sample the next tokens.
# Sample the next tokens.
sample_results
,
maybe_sampled_tokens_tensor
=
_sample
(
maybe_deferred_
sample_results
,
maybe_sampled_tokens_tensor
=
_sample
(
probs
,
probs
,
logprobs
,
logprobs
,
sampling_metadata
,
sampling_metadata
,
...
@@ -160,20 +285,28 @@ class Sampler(nn.Module):
...
@@ -160,20 +285,28 @@ class Sampler(nn.Module):
)
)
if
self
.
include_gpu_probs_tensor
:
if
self
.
include_gpu_probs_tensor
:
# Since we will defer sampler result Pythonization,
# preserve GPU-side tensors in support of later
# deferred pythonization of logprobs
assert
maybe_sampled_tokens_tensor
is
not
None
assert
maybe_sampled_tokens_tensor
is
not
None
on_device_tensors
=
(
probs
,
logprobs
,
maybe_sampled_tokens_tensor
)
on_device_tensors
=
(
probs
,
logprobs
,
maybe_sampled_tokens_tensor
)
else
:
else
:
# Since Pythonization has already happened, don't preserve
# GPU-side tensors.
on_device_tensors
=
None
on_device_tensors
=
None
# Get the logprobs query results.
# Get the logprobs query results.
prompt_logprobs
=
None
prompt_logprobs
=
None
sample_logprobs
=
None
sample_logprobs
=
None
if
not
sampling_metadata
.
skip_sampler_cpu_output
:
if
not
sampling_metadata
.
skip_sampler_cpu_output
:
prompt_logprobs
,
sample_logprobs
=
_get_logprobs
(
# Pythonize logprobs now (GPU -> CPU); do not defer.
logprobs
,
sampling_metadata
,
sample_results
)
assert
not
isinstance
(
maybe_deferred_sample_results
,
SampleResultArgsType
)
prompt_logprobs
,
sample_logprobs
=
get_logprobs
(
logprobs
,
sampling_metadata
,
maybe_deferred_sample_results
)
return
_build_sampler_output
(
return
_build_sampler_output
(
sample_results
,
maybe_deferred_
sample_results
,
sampling_metadata
,
sampling_metadata
,
prompt_logprobs
,
prompt_logprobs
,
sample_logprobs
,
sample_logprobs
,
...
@@ -543,6 +676,60 @@ def _top_k_top_p_multinomial_with_flashinfer(
...
@@ -543,6 +676,60 @@ def _top_k_top_p_multinomial_with_flashinfer(
return
batch_next_token_ids
.
view
(
-
1
,
num_samples
)
return
batch_next_token_ids
.
view
(
-
1
,
num_samples
)
def
get_pythonized_sample_results
(
sample_result_args
:
SampleResultArgsType
)
->
SampleResultType
:
'''This function consumes GPU-side sampler results and computes
Pythonized CPU-side sampler results (GPU -> CPU sync.)
Single-step scheduling: this function is invoked at sampling-time
for immediate Pythonization.
Multi-step scheduling: Pythonization is deferred until after multiple
GPU-side steps have been completed.
Args:
sample_result_args: GPU-side inputs to the Pythonization process
Returns:
Pythonized sampler results
'''
(
sample_metadata
,
sampling_metadata
,
greedy_samples
,
multinomial_samples
,
beam_search_logprobs
,
sample_results_dict
,
)
=
(
sample_result_args
.
sample_metadata
,
sample_result_args
.
sampling_metadata
,
sample_result_args
.
greedy_samples
,
sample_result_args
.
multinomial_samples
,
sample_result_args
.
beam_search_logprobs
,
sample_result_args
.
sample_results_dict
,
)
for
sampling_type
in
SamplingType
:
if
sampling_type
not
in
sample_metadata
:
continue
(
seq_group_id
,
seq_groups
)
=
sample_metadata
[
sampling_type
]
if
sampling_type
==
SamplingType
.
GREEDY
:
sample_results
=
_greedy_sample
(
seq_groups
,
greedy_samples
)
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
sample_results
=
_random_sample
(
seq_groups
,
multinomial_samples
[
sampling_type
])
elif
sampling_type
==
SamplingType
.
BEAM
:
sample_results
=
_beam_search_sample
(
seq_groups
,
beam_search_logprobs
)
sample_results_dict
.
update
(
zip
(
seq_group_id
,
sample_results
))
return
[
sample_results_dict
.
get
(
i
,
([],
[]))
for
i
in
range
(
len
(
sampling_metadata
.
seq_groups
))
]
def
_sample_with_torch
(
def
_sample_with_torch
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
...
@@ -550,7 +737,19 @@ def _sample_with_torch(
...
@@ -550,7 +737,19 @@ def _sample_with_torch(
sampling_tensors
:
SamplingTensors
,
sampling_tensors
:
SamplingTensors
,
include_gpu_probs_tensor
:
bool
,
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
,
modify_greedy_probs
:
bool
,
)
->
Tuple
[
SampleResultType
,
Optional
[
torch
.
Tensor
]]:
)
->
SampleReturnType
:
'''Torch-oriented _sample() implementation.
Single-step scheduling:
* Perform GPU-side sampling computation
* Immediately Pythonize sampling result
Multi-step scheduling:
* Perform GPU-side sampling computation
* Defer Pythonization & preserve GPU-side
tensors required for Pythonization
'''
categorized_seq_group_ids
:
Dict
[
SamplingType
,
categorized_seq_group_ids
:
Dict
[
SamplingType
,
List
[
int
]]
=
{
t
:
[]
List
[
int
]]
=
{
t
:
[]
for
t
in
SamplingType
}
for
t
in
SamplingType
}
...
@@ -560,10 +759,11 @@ def _sample_with_torch(
...
@@ -560,10 +759,11 @@ def _sample_with_torch(
sampling_type
=
sampling_params
.
sampling_type
sampling_type
=
sampling_params
.
sampling_type
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
sample_results_dict
:
Dict
[
int
,
Tuple
[
List
[
int
],
List
[
int
]]]
=
{}
sample_results_dict
:
SampleResultsDictType
=
{}
sample_metadata
:
Dict
[
SamplingType
,
sample_metadata
:
SampleMetadataType
=
{}
Tuple
[
List
[
int
],
List
[
SequenceGroupToSample
]]]
=
{}
multinomial_samples
:
MultinomialSamplesType
=
{}
multinomial_samples
:
Dict
[
SamplingType
,
torch
.
Tensor
]
=
{}
greedy_samples
:
Optional
[
torch
.
Tensor
]
=
None
beam_search_logprobs
:
Optional
[
torch
.
Tensor
]
=
None
# Create output tensor for sampled token ids.
# Create output tensor for sampled token ids.
if
include_gpu_probs_tensor
:
if
include_gpu_probs_tensor
:
...
@@ -638,32 +838,29 @@ def _sample_with_torch(
...
@@ -638,32 +838,29 @@ def _sample_with_torch(
else
:
else
:
raise
ValueError
(
f
"Unsupported sampling type:
{
sampling_type
}
"
)
raise
ValueError
(
f
"Unsupported sampling type:
{
sampling_type
}
"
)
# GPU<->CPU sync happens in the loop below.
# Encapsulate arguments for computing Pythonized sampler
# This also converts the sample output to Python objects.
# results, whether deferred or otherwise.
maybe_deferred_args
=
SampleResultArgsType
(
sampling_metadata
=
sampling_metadata
,
sample_metadata
=
sample_metadata
,
multinomial_samples
=
multinomial_samples
,
greedy_samples
=
greedy_samples
,
beam_search_logprobs
=
beam_search_logprobs
,
sample_results_dict
=
sample_results_dict
)
if
not
sampling_metadata
.
skip_sampler_cpu_output
:
if
not
sampling_metadata
.
skip_sampler_cpu_output
:
for
sampling_type
in
SamplingType
:
# GPU<->CPU sync happens here.
if
sampling_type
not
in
sample_metadata
:
# This also converts the sampler output to a Python object.
continue
# Return Pythonized sampler result & sampled token ids
(
seq_group_id
,
seq_groups
)
=
sample_metadata
[
sampling_type
]
return
get_pythonized_sample_results
(
if
sampling_type
==
SamplingType
.
GREEDY
:
maybe_deferred_args
),
sampled_token_ids_tensor
sample_results
=
_greedy_sample
(
seq_groups
,
greedy_samples
)
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
sample_results
=
_random_sample
(
seq_groups
,
multinomial_samples
[
sampling_type
])
elif
sampling_type
==
SamplingType
.
BEAM
:
sample_results
=
_beam_search_sample
(
seq_groups
,
beam_search_logprobs
)
sample_results_dict
.
update
(
zip
(
seq_group_id
,
sample_results
))
sample_results
=
[
sample_results_dict
.
get
(
i
,
([],
[]))
for
i
in
range
(
len
(
sampling_metadata
.
seq_groups
))
]
else
:
else
:
sample_results
=
[]
# Defer sampler result Pythonization; return deferred
# Pythonization args & sampled token ids
return
sample_results
,
sampled_token_ids_tensor
return
(
maybe_deferred_args
,
sampled_token_ids_tensor
,
)
def
_sample_with_triton_kernel
(
def
_sample_with_triton_kernel
(
...
@@ -755,7 +952,7 @@ def _sample(
...
@@ -755,7 +952,7 @@ def _sample(
sampling_tensors
:
SamplingTensors
,
sampling_tensors
:
SamplingTensors
,
include_gpu_probs_tensor
:
bool
,
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
,
modify_greedy_probs
:
bool
,
)
->
Tuple
[
SampleRe
sultType
,
Optional
[
torch
.
Tensor
]]
:
)
->
SampleRe
turnType
:
"""
"""
Args:
Args:
probs: (num_query_tokens_in_batch, num_vocab)
probs: (num_query_tokens_in_batch, num_vocab)
...
@@ -803,7 +1000,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
...
@@ -803,7 +1000,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
return
result
.
sum
(
1
).
add_
(
1
)
return
result
.
sum
(
1
).
add_
(
1
)
def
_
get_logprobs
(
def
get_logprobs
(
logprobs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
sample_results
:
SampleResultType
,
sample_results
:
SampleResultType
,
...
@@ -1126,7 +1323,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
...
@@ -1126,7 +1323,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
def
_build_sampler_output
(
def
_build_sampler_output
(
sample_results
:
SampleResultType
,
maybe_deferred_
sample_results
:
MaybeDeferred
SampleResultType
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
prompt_logprobs
:
Optional
[
List
[
Optional
[
PromptLogprobs
]]],
prompt_logprobs
:
Optional
[
List
[
Optional
[
PromptLogprobs
]]],
sample_logprobs
:
Optional
[
List
[
SampleLogprobs
]],
sample_logprobs
:
Optional
[
List
[
SampleLogprobs
]],
...
@@ -1143,14 +1340,21 @@ def _build_sampler_output(
...
@@ -1143,14 +1340,21 @@ def _build_sampler_output(
speculative decoding rejection sampling.
speculative decoding rejection sampling.
"""
"""
sampler_output
:
List
[
CompletionSequenceGroupOutput
]
=
[]
sampler_output
:
List
[
CompletionSequenceGroupOutput
]
=
[]
if
not
skip_sampler_cpu_output
:
if
skip_sampler_cpu_output
:
assert
isinstance
(
maybe_deferred_sample_results
,
SampleResultArgsType
)
deferred_sample_results_args
=
maybe_deferred_sample_results
else
:
assert
prompt_logprobs
is
not
None
assert
prompt_logprobs
is
not
None
assert
sample_logprobs
is
not
None
assert
sample_logprobs
is
not
None
assert
not
isinstance
(
maybe_deferred_sample_results
,
SampleResultArgsType
)
deferred_sample_results_args
=
None
for
(
seq_group
,
sample_result
,
group_prompt_logprobs
,
for
(
seq_group
,
sample_result
,
group_prompt_logprobs
,
group_sample_logprobs
)
in
zip
(
sampling_metadata
.
seq_groups
,
group_sample_logprobs
)
in
zip
(
sampling_metadata
.
seq_groups
,
sample_results
,
prompt_logprob
s
,
maybe_deferred_sample_result
s
,
sample_logprobs
):
prompt_logprobs
,
sample_logprobs
):
seq_ids
=
seq_group
.
seq_ids
seq_ids
=
seq_group
.
seq_ids
next_token_ids
,
parent_ids
=
sample_result
next_token_ids
,
parent_ids
=
sample_result
seq_outputs
:
List
[
SequenceOutput
]
=
[]
seq_outputs
:
List
[
SequenceOutput
]
=
[]
...
@@ -1176,7 +1380,7 @@ def _build_sampler_output(
...
@@ -1176,7 +1380,7 @@ def _build_sampler_output(
sampled_token_probs
=
sampled_token_probs
,
sampled_token_probs
=
sampled_token_probs
,
sampled_token_ids
=
sampled_token_ids
,
sampled_token_ids
=
sampled_token_ids
,
logprobs
=
logprobs_tensor
,
logprobs
=
logprobs_tensor
,
)
deferred_sample_results_args
=
deferred_sample_results_args
)
def
_get_next_prompt_tokens
(
seq_group
:
SequenceGroupToSample
)
->
List
[
int
]:
def
_get_next_prompt_tokens
(
seq_group
:
SequenceGroupToSample
)
->
List
[
int
]:
...
...
vllm/model_executor/layers/spec_decode_base_sampler.py
View file @
0640f227
...
@@ -130,29 +130,35 @@ class SpecDecodeBaseSampler(nn.Module):
...
@@ -130,29 +130,35 @@ class SpecDecodeBaseSampler(nn.Module):
def
_raise_if_incorrect_input
(
def
_raise_if_incorrect_input
(
self
,
self
,
target_probs
:
torch
.
Tensor
,
target_
with_bonus_
probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
Optional
[
torch
.
Tensor
]
=
None
,
draft_probs
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
)
->
None
:
self
.
_raise_if_incorrect_shape
(
target_probs
,
draft_token_ids
,
self
.
_raise_if_incorrect_shape
(
target_with_bonus_probs
,
bonus_token_ids
,
draft_probs
)
draft_token_ids
,
bonus_token_ids
,
self
.
_raise_if_incorrect_dtype
(
target_probs
,
draft_token_ids
,
draft_probs
)
bonus_token_ids
,
draft_probs
)
self
.
_raise_if_incorrect_dtype
(
target_with_bonus_probs
,
self
.
_raise_if_inconsistent_device
(
target_probs
,
draft_token_ids
,
draft_token_ids
,
bonus_token_ids
,
bonus_token_ids
,
draft_probs
)
draft_probs
)
self
.
_raise_if_out_of_bounds_vocab
(
target_probs
.
shape
[
-
1
],
self
.
_raise_if_inconsistent_device
(
target_with_bonus_probs
,
draft_token_ids
,
bonus_token_ids
,
draft_probs
)
self
.
_raise_if_out_of_bounds_vocab
(
target_with_bonus_probs
.
shape
[
-
1
],
draft_token_ids
,
bonus_token_ids
)
draft_token_ids
,
bonus_token_ids
)
def
_raise_if_incorrect_shape
(
def
_raise_if_incorrect_shape
(
self
,
self
,
target_probs
:
torch
.
Tensor
,
target_
with_bonus_
probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
Optional
[
torch
.
Tensor
]
=
None
,
draft_probs
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
)
->
None
:
(
target_batch_size
,
num_target_probs
,
(
target_batch_size
,
num_target_probs
,
target_vocab_size
)
=
target_probs
.
shape
target_vocab_size
)
=
target_with_bonus_probs
.
shape
# Does not count the extra token
num_target_probs
-=
1
# validate the shape of draft token ids.
# validate the shape of draft token ids.
draft_token_ids_batch_size
,
num_draft_token_ids
=
draft_token_ids
.
shape
draft_token_ids_batch_size
,
num_draft_token_ids
=
draft_token_ids
.
shape
...
@@ -175,12 +181,12 @@ class SpecDecodeBaseSampler(nn.Module):
...
@@ -175,12 +181,12 @@ class SpecDecodeBaseSampler(nn.Module):
def
_raise_if_incorrect_dtype
(
def
_raise_if_incorrect_dtype
(
self
,
self
,
target_probs
:
torch
.
Tensor
,
target_
with_bonus_
probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
Optional
[
torch
.
Tensor
]
=
None
,
draft_probs
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
)
->
None
:
assert
target_probs
.
dtype
==
self
.
probs_dtype
assert
target_
with_bonus_
probs
.
dtype
==
self
.
probs_dtype
assert
draft_token_ids
.
dtype
==
self
.
token_id_dtype
assert
draft_token_ids
.
dtype
==
self
.
token_id_dtype
assert
bonus_token_ids
.
dtype
==
self
.
token_id_dtype
assert
bonus_token_ids
.
dtype
==
self
.
token_id_dtype
if
draft_probs
is
not
None
:
if
draft_probs
is
not
None
:
...
@@ -188,15 +194,16 @@ class SpecDecodeBaseSampler(nn.Module):
...
@@ -188,15 +194,16 @@ class SpecDecodeBaseSampler(nn.Module):
def
_raise_if_inconsistent_device
(
def
_raise_if_inconsistent_device
(
self
,
self
,
target_probs
:
torch
.
Tensor
,
target_
with_bonus_
probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
Optional
[
torch
.
Tensor
]
=
None
,
draft_probs
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
)
->
None
:
devices
=
[
devices
=
[
t
.
device
for
t
in
t
.
device
for
t
in
[
[
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
]
target_with_bonus_probs
,
bonus_token_ids
,
draft_probs
,
if
t
is
not
None
draft_token_ids
]
if
t
is
not
None
]
]
assert
all
([
devices
[
0
]
==
device
for
device
in
devices
])
assert
all
([
devices
[
0
]
==
device
for
device
in
devices
])
...
@@ -220,7 +227,7 @@ class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler):
...
@@ -220,7 +227,7 @@ class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler):
@
abstractmethod
@
abstractmethod
def
forward
(
def
forward
(
self
,
self
,
target_probs
:
torch
.
Tensor
,
target_
with_bonus_
probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
...
@@ -236,7 +243,7 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
...
@@ -236,7 +243,7 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
@
abstractmethod
@
abstractmethod
def
forward
(
def
forward
(
self
,
self
,
target_probs
:
torch
.
Tensor
,
target_
with_bonus_
probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/typical_acceptance_sampler.py
View file @
0640f227
...
@@ -41,7 +41,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...
@@ -41,7 +41,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
def
forward
(
def
forward
(
self
,
self
,
target_probs
:
torch
.
Tensor
,
target_
with_bonus_
probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
...
@@ -80,8 +80,9 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...
@@ -80,8 +80,9 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
# Only perform shape/dtype/device checking in strict mode, as it adds
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
# overhead.
if
self
.
_strict_mode
:
if
self
.
_strict_mode
:
self
.
_raise_if_incorrect_input
(
target_probs
,
draft_token_ids
,
self
.
_raise_if_incorrect_input
(
target_with_bonus_probs
,
bonus_token_ids
)
draft_token_ids
,
bonus_token_ids
)
target_probs
=
target_with_bonus_probs
[:,
:
-
1
]
accepted
=
self
.
_evaluate_accepted_tokens
(
target_probs
,
accepted
=
self
.
_evaluate_accepted_tokens
(
target_probs
,
draft_token_ids
)
draft_token_ids
)
recovered_token_ids
=
self
.
_replacement_token_ids
(
target_probs
)
recovered_token_ids
=
self
.
_replacement_token_ids
(
target_probs
)
...
...
vllm/model_executor/layers/vocab_parallel_embedding.py
View file @
0640f227
...
@@ -10,6 +10,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
...
@@ -10,6 +10,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
method_has_implemented_embedding
)
QuantizationConfig
,
QuantizeMethodBase
,
method_has_implemented_embedding
)
from
vllm.model_executor.parameter
import
BasevLLMParameter
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE
=
64
DEFAULT_VOCAB_PADDING_SIZE
=
64
...
@@ -351,7 +352,10 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -351,7 +352,10 @@ class VocabParallelEmbedding(torch.nn.Module):
param
.
weight_type
=
loaded_weight
.
item
()
param
.
weight_type
=
loaded_weight
.
item
()
return
return
elif
isinstance
(
param
,
UninitializedParameter
):
elif
isinstance
(
param
,
UninitializedParameter
):
param
.
materialize
(
loaded_weight
.
shape
,
dtype
=
loaded_weight
.
dtype
)
shape
=
list
(
loaded_weight
.
shape
)
if
output_dim
is
not
None
:
shape
[
output_dim
]
=
shape
[
output_dim
]
//
self
.
tp_size
param
.
materialize
(
tuple
(
shape
),
dtype
=
loaded_weight
.
dtype
)
# If parameter does not have output dim, then it should
# If parameter does not have output dim, then it should
# be copied onto all gpus (e.g. g_idx for act_order gptq).
# be copied onto all gpus (e.g. g_idx for act_order gptq).
...
@@ -367,10 +371,12 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -367,10 +371,12 @@ class VocabParallelEmbedding(torch.nn.Module):
# If param packed on the same dim we are sharding on, then
# If param packed on the same dim we are sharding on, then
# need to adjust offsets of loaded weight by pack_factor.
# need to adjust offsets of loaded weight by pack_factor.
if
packed_dim
is
not
None
and
packed_dim
==
output_dim
:
if
packed_dim
is
not
None
and
packed_dim
==
output_dim
:
packed_factor
=
param
.
packed_factor
if
isinstance
(
param
,
BasevLLMParameter
)
else
param
.
pack_factor
assert
loaded_weight
.
shape
[
output_dim
]
==
(
self
.
org_vocab_size
//
assert
loaded_weight
.
shape
[
output_dim
]
==
(
self
.
org_vocab_size
//
param
.
pack_factor
)
param
.
pack
ed
_factor
)
start_idx
=
start_idx
//
pa
ram
.
pa
ck_factor
start_idx
=
start_idx
//
pack
ed
_factor
shard_size
=
shard_size
//
pa
ram
.
pa
ck_factor
shard_size
=
shard_size
//
pack
ed
_factor
else
:
else
:
assert
loaded_weight
.
shape
[
output_dim
]
==
self
.
org_vocab_size
assert
loaded_weight
.
shape
[
output_dim
]
==
self
.
org_vocab_size
...
...
vllm/model_executor/model_loader/loader.py
View file @
0640f227
...
@@ -774,7 +774,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -774,7 +774,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
return
pt_weights_iterator
(
hf_weights_files
)
return
pt_weights_iterator
(
hf_weights_files
)
def
_get_quantized_weights_iterator
(
def
_get_quantized_weights_iterator
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
pre_quant
:
bool
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
pre_quant
:
bool
,
load_8bit
:
bool
,
)
->
Tuple
[
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
],
Dict
[
str
,
)
->
Tuple
[
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
],
Dict
[
str
,
Any
]]:
Any
]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
"""Get an iterator to the model weights with bitsandbytes quantization,
...
@@ -783,11 +787,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -783,11 +787,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# only load the bitsandbytes module when needed
# only load the bitsandbytes module when needed
try
:
try
:
import
bitsandbytes
import
bitsandbytes
from
bitsandbytes.functional
import
QuantState
if
bitsandbytes
.
__version__
<
"0.42.0"
:
if
bitsandbytes
.
__version__
<
"0.42.0"
:
raise
ImportError
(
"bitsandbytes version is wrong. Please "
raise
ImportError
(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.42.0."
)
"install bitsandbytes>=0.42.0."
)
from
bitsandbytes.functional
import
quantize_4bit
except
ImportError
as
err
:
except
ImportError
as
err
:
raise
ImportError
(
"Please install bitsandbytes>=0.42.0 via "
raise
ImportError
(
"Please install bitsandbytes>=0.42.0 via "
"`pip install bitsandbytes>=0.42.0` to use "
"`pip install bitsandbytes>=0.42.0` to use "
...
@@ -796,80 +798,111 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -796,80 +798,111 @@ class BitsAndBytesModelLoader(BaseModelLoader):
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
model_name_or_path
,
revision
)
model_name_or_path
,
revision
)
quant_state_dict
=
{}
quant_state_dict
:
Dict
[
str
,
Any
]
=
{}
def
quantized_checkpoint
()
->
Generator
:
# First iterate over all quant state weights
weight_iterator
=
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
)
temp_state_dict
=
{}
for
weight_name
,
weight_tensor
in
weight_iterator
:
if
weight_name
.
endswith
(
".weight"
):
continue
# TODO: only nf4 quantization is supported for now
if
weight_name
.
endswith
(
".quant_state.bitsandbytes__fp4"
):
raise
NotImplementedError
(
"Only bitsandbytes_nf4 quantization"
f
"is supported for now.
{
weight_name
}
is fp4 quantized"
)
temp_state_dict
[
weight_name
]
=
weight_tensor
# Closure to parse quant_state for each prequant weight
if
pre_quant
:
def
_parse_quant_state
(
param_name
:
str
,
if
load_8bit
:
temp_state_dict
:
Dict
)
->
QuantState
:
return
self
.
_quantized_8bit_generator
(
quant_state
=
{}
hf_weights_files
,
use_safetensors
,
for
k
in
temp_state_dict
:
quant_state_dict
),
quant_state_dict
if
param_name
+
"."
in
k
:
else
:
quant_state
[
k
]
=
temp_state_dict
[
k
]
return
self
.
_quantized_4bit_generator
(
# bitsandbytes library requires
hf_weights_files
,
use_safetensors
,
# weight.quant_state.bitsandbytes__nf4 in CPU
quant_state_dict
),
quant_state_dict
quant_state
[
param_name
+
".quant_state.bitsandbytes__nf4"
]
=
quant_state
[
param_name
+
".quant_state.bitsandbytes__nf4"
].
cpu
().
data
return
QuantState
.
from_dict
(
quant_state
,
device
=
"cuda"
)
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
# Filter out all weights whose suffix is not ".weight"
if
not
weight_name
.
endswith
(
".weight"
):
continue
if
weight_name
+
".quant_state.bitsandbytes__nf4"
\
in
temp_state_dict
:
quant_state
=
_parse_quant_state
(
weight_name
,
temp_state_dict
)
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
quant_state_dict
[
weight_name
]
=
quant_state
yield
weight_name
.
replace
(
".weight"
,
".qweight"
),
weight_tensor
else
:
yield
weight_name
,
weight_tensor
def
generator
()
->
Generator
:
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
any
(
target_module
in
weight_name
for
target_module
in
self
.
target_modules
):
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
# bitsandbytes requires data in GPU
loaded_weight
=
weight_tensor
.
cuda
().
data
with
set_default_torch_dtype
(
torch
.
float32
):
processed_weight
,
quant_state
=
quantize_4bit
(
loaded_weight
,
compress_statistics
=
True
,
quant_type
=
"nf4"
)
quant_state_dict
[
weight_name
]
=
quant_state
else
:
processed_weight
=
weight_tensor
yield
weight_name
,
processed_weight
return
self
.
_unquantized_generator
(
hf_weights_files
,
use_safetensors
,
quant_state_dict
),
quant_state_dict
if
pre_quant
:
def
_quantized_8bit_generator
(
self
,
hf_weights_files
,
use_safetensors
,
return
quantized_checkpoint
(),
quant_state_dict
quant_state_dict
)
->
Generator
:
return
generator
(),
quant_state_dict
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
not
weight_name
.
lower
().
endswith
(
".scb"
):
continue
weight_key
=
weight_name
.
lower
().
replace
(
".scb"
,
".qweight"
)
quant_state_dict
[
weight_key
]
=
weight_tensor
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
not
weight_name
.
endswith
(
".weight"
):
continue
qweight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
if
qweight_name
in
quant_state_dict
:
set_weight_attrs
(
weight_tensor
,
{
"load_in_8bit"
:
True
})
yield
qweight_name
,
weight_tensor
else
:
yield
weight_name
,
weight_tensor
def
_quantized_4bit_generator
(
self
,
hf_weights_files
,
use_safetensors
,
quant_state_dict
)
->
Generator
:
from
bitsandbytes.functional
import
QuantState
# First iterate over all quant state weights
weight_iterator
=
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
)
temp_state_dict
=
{}
for
weight_name
,
weight_tensor
in
weight_iterator
:
if
weight_name
.
endswith
(
".weight"
):
continue
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
if
"quant_state.bitsandbytes"
in
weight_name
:
temp_state_dict
[
weight_name
]
=
weight_tensor
.
cpu
().
data
else
:
temp_state_dict
[
weight_name
]
=
weight_tensor
# Closure to parse quant_state for each prequant weight
def
_parse_quant_state
(
param_name
:
str
,
temp_state_dict
:
Dict
)
->
QuantState
:
quant_state
=
{}
for
k
in
temp_state_dict
:
if
param_name
+
"."
in
k
:
quant_state
[
k
]
=
temp_state_dict
[
k
]
return
QuantState
.
from_dict
(
quant_state
,
device
=
"cuda"
)
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
# Filter out all weights whose suffix is not ".weight"
if
not
weight_name
.
endswith
(
".weight"
):
continue
if
(
f
"
{
weight_name
}
.quant_state.bitsandbytes__nf4"
\
in
temp_state_dict
)
or
\
(
f
"
{
weight_name
}
.quant_state.bitsandbytes__fp4"
\
in
temp_state_dict
):
quant_state
=
_parse_quant_state
(
weight_name
,
temp_state_dict
)
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
quant_state_dict
[
weight_name
]
=
quant_state
yield
weight_name
.
replace
(
".weight"
,
".qweight"
),
weight_tensor
else
:
yield
weight_name
,
weight_tensor
def
_unquantized_generator
(
self
,
hf_weights_files
,
use_safetensors
,
quant_state_dict
)
->
Generator
:
from
bitsandbytes.functional
import
quantize_4bit
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
any
(
target_module
in
weight_name
for
target_module
in
self
.
target_modules
):
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
# bitsandbytes requires data in GPU
loaded_weight
=
weight_tensor
.
cuda
().
data
with
set_default_torch_dtype
(
torch
.
float32
):
processed_weight
,
quant_state
=
quantize_4bit
(
loaded_weight
,
compress_statistics
=
True
,
quant_type
=
"nf4"
)
quant_state_dict
[
weight_name
]
=
quant_state
else
:
processed_weight
=
weight_tensor
yield
weight_name
,
processed_weight
def
_load_weights
(
self
,
model_config
:
ModelConfig
,
def
_load_weights
(
self
,
model_config
:
ModelConfig
,
model
:
nn
.
Module
)
->
None
:
model
:
nn
.
Module
)
->
None
:
...
@@ -886,16 +919,26 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -886,16 +919,26 @@ class BitsAndBytesModelLoader(BaseModelLoader):
logger
.
info
(
"Loading weights with BitsAndBytes quantization. "
logger
.
info
(
"Loading weights with BitsAndBytes quantization. "
" May take a while ..."
)
" May take a while ..."
)
is_quantized_checkpoint
=
False
quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
None
)
None
)
if
quant_config
is
not
None
and
quant_config
.
get
(
'quant_method'
)
==
"bitsandbytes"
:
pre_quant
=
False
is_quantized_checkpoint
=
True
if
quant_config
is
not
None
:
quant_method
=
quant_config
.
get
(
'quant_method'
)
if
quant_method
==
"bitsandbytes"
:
pre_quant
=
True
else
:
raise
ValueError
(
f
"BitsAndBytes loader does not support
{
quant_method
}
"
"quantization"
)
load_8bit
=
False
if
pre_quant
:
load_8bit
=
quant_config
.
get
(
'load_in_8bit'
,
False
)
qweight_iterator
,
quant_state_dict
=
\
qweight_iterator
,
quant_state_dict
=
\
self
.
_get_quantized_weights_iterator
(
self
.
_get_quantized_weights_iterator
(
model_config
.
model
,
model_config
.
revision
,
is
_quant
ized_checkpoin
t
)
model_config
.
model
,
model_config
.
revision
,
pre
_quant
,
load_8bi
t
)
model
.
load_weights
(
qweight_iterator
)
model
.
load_weights
(
qweight_iterator
)
...
@@ -945,6 +988,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -945,6 +988,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
offsets
=
np
.
concatenate
(([
0
],
np
.
cumsum
(
num_elements
)))
offsets
=
np
.
concatenate
(([
0
],
np
.
cumsum
(
num_elements
)))
set_weight_attrs
(
param
,
{
"bnb_shard_offsets"
:
offsets
})
set_weight_attrs
(
param
,
{
"bnb_shard_offsets"
:
offsets
})
if
load_8bit
:
set_weight_attrs
(
param
,
{
"matmul_state"
:
[
None
]
*
len
(
quant_states
)})
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
...
...
vllm/model_executor/model_loader/neuron.py
View file @
0640f227
"""Utilities for selecting and loading neuron models."""
"""Utilities for selecting and loading neuron models."""
import
importlib
import
importlib
import
os
import
os
from
typing
import
Dict
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -10,9 +10,9 @@ from transformers import PretrainedConfig
...
@@ -10,9 +10,9 @@ from transformers import PretrainedConfig
from
vllm.config
import
ModelConfig
,
ParallelConfig
,
SchedulerConfig
from
vllm.config
import
ModelConfig
,
ParallelConfig
,
SchedulerConfig
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.quantization
import
get_quantization_config
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
TORCH_DTYPE_TO_NEURON_AMP
=
{
TORCH_DTYPE_TO_NEURON_AMP
=
{
"auto"
:
"f32"
,
"auto"
:
"f32"
,
...
@@ -82,8 +82,7 @@ class NeuronCasualLM(nn.Module):
...
@@ -82,8 +82,7 @@ class NeuronCasualLM(nn.Module):
neuronx_model_cls
=
getattr
(
neuronx_module
,
neuronx_model_cls_name
)
neuronx_model_cls
=
getattr
(
neuronx_module
,
neuronx_model_cls_name
)
split_model_dir
=
f
"
{
model_name_or_path
}
-split"
split_model_dir
=
f
"
{
model_name_or_path
}
-split"
if
os
.
path
.
isdir
(
os
.
path
.
join
(
model_name_or_path
,
if
_is_pretrained_neuron_checkpoint
(
model_name_or_path
):
"pytorch_model.bin"
)):
split_model_dir
=
model_name_or_path
split_model_dir
=
model_name_or_path
elif
not
os
.
path
.
exists
(
f
"
{
model_name_or_path
}
-split"
):
elif
not
os
.
path
.
exists
(
f
"
{
model_name_or_path
}
-split"
):
hf_model_cls
=
getattr
(
transformers
,
hf_model_cls_name
)
hf_model_cls
=
getattr
(
transformers
,
hf_model_cls_name
)
...
@@ -98,6 +97,23 @@ class NeuronCasualLM(nn.Module):
...
@@ -98,6 +97,23 @@ class NeuronCasualLM(nn.Module):
self
.
model
.
to_neuron
()
self
.
model
.
to_neuron
()
def
_is_pretrained_neuron_checkpoint
(
model_name_or_path
:
str
)
->
bool
:
# Checking if the neuron checkpoint is saved in the old format.
if
os
.
path
.
isdir
(
os
.
path
.
join
(
model_name_or_path
,
"pytorch_model.bin"
)):
return
True
# Checking if the neuron checkpoint is saved in the new format.
pretrained_split_files
=
[
"config.json"
,
"generation_config.json"
]
pretrained_split_format
=
".safetensors"
for
file
in
pretrained_split_files
:
file_path
=
os
.
path
.
join
(
model_name_or_path
,
file
)
if
not
os
.
path
.
isfile
(
file_path
):
return
False
for
file
in
os
.
listdir
(
model_name_or_path
):
if
file
.
endswith
(
pretrained_split_format
):
return
True
return
False
def
_get_model_architecture
(
config
:
PretrainedConfig
)
->
str
:
def
_get_model_architecture
(
config
:
PretrainedConfig
)
->
str
:
architectures
=
getattr
(
config
,
"architectures"
,
[])
architectures
=
getattr
(
config
,
"architectures"
,
[])
for
arch
in
architectures
:
for
arch
in
architectures
:
...
@@ -109,28 +125,75 @@ def _get_model_architecture(config: PretrainedConfig) -> str:
...
@@ -109,28 +125,75 @@ def _get_model_architecture(config: PretrainedConfig) -> str:
f
"
{
list
(
_NEURON_SUPPORTED_MODELS
.
keys
())
}
"
)
f
"
{
list
(
_NEURON_SUPPORTED_MODELS
.
keys
())
}
"
)
def
_get_buckets
(
env
:
str
,
default_value
:
List
[
int
])
->
List
[
int
]:
env_value
=
os
.
getenv
(
env
)
if
env_value
is
None
:
return
default_value
buckets_remove_empty
=
filter
(
lambda
x
:
x
is
not
None
and
len
(
x
.
strip
())
>
0
,
env_value
.
split
(
","
))
buckets_int
=
map
(
int
,
buckets_remove_empty
)
buckets_list
=
list
(
buckets_int
)
return
buckets_list
def
_get_default_neuron_config
(
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
):
from
transformers_neuronx.config
import
ContinuousBatchingConfig
from
transformers_neuronx.constants
import
LAYOUT_BSH
continuous_batching_config
=
ContinuousBatchingConfig
(
batch_size_for_shared_caches
=
scheduler_config
.
max_num_seqs
)
quant_config
=
dict
(
dequant_dtype
=
TORCH_DTYPE_TO_NEURON_AMP
[
model_config
.
dtype
],
quantize_method
=
"vector_dynamic"
)
neuron_quantization_config_builder
=
lambda
quant
:
get_quantization_config
(
quant
).
from_config
(
quant_config
).
get_quant_method
(
None
,
""
)
# TODO: Add Paged attention config to the default neuron arguments.
default_neuron_args
=
dict
(
collectives_layout
=
LAYOUT_BSH
,
attention_layout
=
LAYOUT_BSH
,
fuse_qkv
=
True
,
quant
=
neuron_quantization_config_builder
(
model_config
.
quantization
)
if
model_config
.
quantization
else
None
,
continuous_batching
=
continuous_batching_config
,
weight_tiling
=
bool
(
model_config
.
quantization
))
return
default_neuron_args
def
_get_neuron_config_after_override
(
default_neuron_config
,
overridden_neuron_config
):
from
transformers_neuronx.config
import
NeuronConfig
overridden_neuron_config
=
overridden_neuron_config
or
{}
default_neuron_config
.
update
(
overridden_neuron_config
)
return
NeuronConfig
(
**
default_neuron_config
)
def
get_neuron_model
(
model_config
:
ModelConfig
,
def
get_neuron_model
(
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
)
->
nn
.
Module
:
scheduler_config
:
SchedulerConfig
)
->
nn
.
Module
:
from
transformers_neuronx.config
import
(
ContinuousBatchingConfig
,
NeuronConfig
)
# Create a model instance.
# Create a model instance.
model
=
NeuronCasualLM
(
model_config
.
hf_config
)
model
=
NeuronCasualLM
(
model_config
.
hf_config
)
continuous_batching_config
=
ContinuousBatchingConfig
(
default_neuron_config_args
=
_get_default_neuron_config
(
batch_size_for_shared_caches
=
scheduler_config
.
max_num_seqs
)
model_config
,
parallel_config
,
scheduler_config
)
neuron_config
=
NeuronConfig
(
continuous_batching
=
continuous_batching_config
)
neuron_config
=
_get_neuron_config_after_override
(
default_neuron_config_args
,
model_config
.
override_neuron_config
)
context_length_estimates
=
_get_buckets
(
"NEURON_CONTEXT_LENGTH_BUCKETS"
,
[
scheduler_config
.
max_model_len
])
n_positions
=
_get_buckets
(
"NEURON_TOKEN_GEN_BUCKETS"
,
[
scheduler_config
.
max_model_len
])
# Load the weights from the cached or downloaded files.
# Load the weights from the cached or downloaded files.
model
.
load_weights
(
model
.
load_weights
(
model_config
.
model
,
model_config
.
model
,
tp_degree
=
parallel_config
.
tensor_parallel_size
,
tp_degree
=
parallel_config
.
tensor_parallel_size
,
amp
=
TORCH_DTYPE_TO_NEURON_AMP
[
model_config
.
dtype
],
amp
=
TORCH_DTYPE_TO_NEURON_AMP
[
model_config
.
dtype
],
neuron_config
=
neuron_config
,
neuron_config
=
neuron_config
,
context_length_estimate
=
context_length_estimates
,
context_length_estimate
=
[
scheduler_config
.
max_model_len
],
n_positions
=
n_positions
,
n_positions
=
[
scheduler_config
.
max_model_len
],
batch_size
=
scheduler_config
.
max_num_seqs
)
batch_size
=
scheduler_config
.
max_num_seqs
)
return
model
.
eval
()
return
model
.
eval
()
vllm/model_executor/model_loader/openvino.py
View file @
0640f227
...
@@ -15,9 +15,8 @@ from vllm.config import DeviceConfig, ModelConfig
...
@@ -15,9 +15,8 @@ from vllm.config import DeviceConfig, ModelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.logits_processor
import
(
LogitsProcessor
,
from
vllm.model_executor.layers.logits_processor
import
(
LogitsProcessor
,
_prune_hidden_states
)
_prune_hidden_states
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
...
vllm/model_executor/model_loader/utils.py
View file @
0640f227
...
@@ -42,11 +42,11 @@ def get_model_architecture(
...
@@ -42,11 +42,11 @@ def get_model_architecture(
# Special handling for quantized Mixtral.
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported
=
[
"fp8"
,
"compressed-tensors"
]
if
(
model_config
.
quantization
is
not
None
if
(
model_config
.
quantization
is
not
None
and
model_config
.
quantization
!=
"fp8"
and
model_config
.
quantization
not
in
mixtral_supported
and
"MixtralForCausalLM"
in
architectures
):
and
"MixtralForCausalLM"
in
architectures
):
architectures
=
[
"QuantMixtralForCausalLM"
]
architectures
=
[
"QuantMixtralForCausalLM"
]
return
ModelRegistry
.
resolve_model_cls
(
architectures
)
return
ModelRegistry
.
resolve_model_cls
(
architectures
)
...
...
vllm/model_executor/models/__init__.py
View file @
0640f227
...
@@ -22,6 +22,7 @@ _GENERATION_MODELS = {
...
@@ -22,6 +22,7 @@ _GENERATION_MODELS = {
"DeciLMForCausalLM"
:
(
"decilm"
,
"DeciLMForCausalLM"
),
"DeciLMForCausalLM"
:
(
"decilm"
,
"DeciLMForCausalLM"
),
"DeepseekForCausalLM"
:
(
"deepseek"
,
"DeepseekForCausalLM"
),
"DeepseekForCausalLM"
:
(
"deepseek"
,
"DeepseekForCausalLM"
),
"DeepseekV2ForCausalLM"
:
(
"deepseek_v2"
,
"DeepseekV2ForCausalLM"
),
"DeepseekV2ForCausalLM"
:
(
"deepseek_v2"
,
"DeepseekV2ForCausalLM"
),
"ExaoneForCausalLM"
:
(
"exaone"
,
"ExaoneForCausalLM"
),
"FalconForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"FalconForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"GemmaForCausalLM"
:
(
"gemma"
,
"GemmaForCausalLM"
),
"GemmaForCausalLM"
:
(
"gemma"
,
"GemmaForCausalLM"
),
"Gemma2ForCausalLM"
:
(
"gemma2"
,
"Gemma2ForCausalLM"
),
"Gemma2ForCausalLM"
:
(
"gemma2"
,
"Gemma2ForCausalLM"
),
...
@@ -49,6 +50,7 @@ _GENERATION_MODELS = {
...
@@ -49,6 +50,7 @@ _GENERATION_MODELS = {
"PersimmonForCausalLM"
:
(
"persimmon"
,
"PersimmonForCausalLM"
),
"PersimmonForCausalLM"
:
(
"persimmon"
,
"PersimmonForCausalLM"
),
"PhiForCausalLM"
:
(
"phi"
,
"PhiForCausalLM"
),
"PhiForCausalLM"
:
(
"phi"
,
"PhiForCausalLM"
),
"Phi3ForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"Phi3ForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"PhiMoEForCausalLM"
:
(
"phimoe"
,
"PhiMoEForCausalLM"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"Qwen2ForCausalLM"
:
(
"qwen2"
,
"Qwen2ForCausalLM"
),
"Qwen2ForCausalLM"
:
(
"qwen2"
,
"Qwen2ForCausalLM"
),
"Qwen2MoeForCausalLM"
:
(
"qwen2_moe"
,
"Qwen2MoeForCausalLM"
),
"Qwen2MoeForCausalLM"
:
(
"qwen2_moe"
,
"Qwen2MoeForCausalLM"
),
...
@@ -63,6 +65,7 @@ _GENERATION_MODELS = {
...
@@ -63,6 +65,7 @@ _GENERATION_MODELS = {
"EAGLEModel"
:
(
"eagle"
,
"EAGLE"
),
"EAGLEModel"
:
(
"eagle"
,
"EAGLE"
),
"MLPSpeculatorPreTrainedModel"
:
(
"mlp_speculator"
,
"MLPSpeculator"
),
"MLPSpeculatorPreTrainedModel"
:
(
"mlp_speculator"
,
"MLPSpeculator"
),
"JambaForCausalLM"
:
(
"jamba"
,
"JambaForCausalLM"
),
"JambaForCausalLM"
:
(
"jamba"
,
"JambaForCausalLM"
),
"GraniteForCausalLM"
:
(
"granite"
,
"GraniteForCausalLM"
)
}
}
_EMBEDDING_MODELS
=
{
_EMBEDDING_MODELS
=
{
...
...
vllm/model_executor/models/arctic.py
View file @
0640f227
...
@@ -23,13 +23,13 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -23,13 +23,13 @@ from vllm.model_executor.layers.quantization.base_config import (
from
vllm.model_executor.layers.quantization.deepspeedfp
import
(
from
vllm.model_executor.layers.quantization.deepspeedfp
import
(
DeepSpeedFPConfig
,
DeepSpeedFPParameter
)
DeepSpeedFPConfig
,
DeepSpeedFPParameter
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs.arctic
import
ArcticConfig
from
vllm.transformers_utils.configs.arctic
import
ArcticConfig
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
...
vllm/model_executor/models/baichuan.py
View file @
0640f227
...
@@ -40,12 +40,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
...
@@ -40,12 +40,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
from
.interfaces
import
SupportsLoRA
...
...
vllm/model_executor/models/bart.py
View file @
0640f227
...
@@ -34,12 +34,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -34,12 +34,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
...
...
vllm/model_executor/models/blip.py
View file @
0640f227
...
@@ -10,15 +10,23 @@ from transformers import Blip2VisionConfig, BlipVisionConfig
...
@@ -10,15 +10,23 @@ from transformers import Blip2VisionConfig, BlipVisionConfig
from
transformers.models.blip.modeling_blip
import
BlipAttention
from
transformers.models.blip.modeling_blip
import
BlipAttention
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.inputs
import
LLMInputs
from
vllm.inputs
import
LLMInputs
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.multimodal.utils
import
(
cached_get_tokenizer
,
from
vllm.multimodal.utils
import
(
cached_get_tokenizer
,
repeat_and_pad_placeholder_tokens
)
repeat_and_pad_placeholder_tokens
)
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
try
:
from
xformers
import
ops
as
xops
USE_XFORMERS_OPS
=
True
except
ImportError
:
USE_XFORMERS_OPS
=
False
def
get_blip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
def
get_blip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
assert
image_size
%
patch_size
==
0
assert
image_size
%
patch_size
==
0
...
@@ -154,6 +162,77 @@ class BlipVisionEmbeddings(nn.Module):
...
@@ -154,6 +162,77 @@ class BlipVisionEmbeddings(nn.Module):
return
embeddings
return
embeddings
class
BlipParallelAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
self
,
config
:
BlipVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_dim
=
self
.
embed_dim
//
self
.
num_heads
if
self
.
head_dim
*
self
.
num_heads
!=
self
.
embed_dim
:
raise
ValueError
(
"embed_dim must be divisible by num_heads "
f
"(got `embed_dim`:
{
self
.
embed_dim
}
and `num_heads`:"
f
"
{
self
.
num_heads
}
)."
)
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
dropout
=
config
.
attention_dropout
self
.
qkv
=
QKVParallelLinear
(
self
.
embed_dim
,
self
.
head_dim
,
self
.
num_heads
,
bias
=
config
.
qkv_bias
,
quant_config
=
quant_config
,
)
self
.
projection
=
RowParallelLinear
(
self
.
embed_dim
,
self
.
embed_dim
,
quant_config
=
quant_config
,
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
).
contiguous
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
):
"""Input shape: Batch x Time x Channel"""
bsz
,
tgt_len
,
_
=
hidden_states
.
size
()
qkv_states
,
_
=
self
.
qkv
(
hidden_states
)
query_states
,
key_states
,
value_states
=
qkv_states
.
chunk
(
3
,
dim
=-
1
)
query_states
=
query_states
.
view
(
bsz
,
tgt_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
key_states
=
key_states
.
view
(
bsz
,
tgt_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
value_states
=
value_states
.
view
(
bsz
,
tgt_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
out
=
xops
.
memory_efficient_attention_forward
(
query_states
,
key_states
,
value_states
,
p
=
self
.
dropout
,
scale
=
self
.
scale
)
out
=
out
.
view
(
bsz
,
tgt_len
,
-
1
)
attn_output
,
_
=
self
.
projection
(
out
)
return
attn_output
,
None
class
BlipMLP
(
nn
.
Module
):
class
BlipMLP
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -188,7 +267,16 @@ class BlipEncoderLayer(nn.Module):
...
@@ -188,7 +267,16 @@ class BlipEncoderLayer(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
self_attn
=
BlipAttention
(
config
)
# fallback to sdpa attention if tp unavailable
num_heads
=
config
.
num_attention_heads
tp_size
=
get_tensor_model_parallel_world_size
()
if
USE_XFORMERS_OPS
and
num_heads
%
tp_size
==
0
:
self
.
self_attn
=
BlipParallelAttention
(
config
,
quant_config
=
quant_config
)
else
:
# Blip doesn't have SDPA attention implemented in transformers
# use eager attention instead for cpu backend
self
.
self_attn
=
BlipAttention
(
config
)
self
.
layer_norm1
=
nn
.
LayerNorm
(
config
.
hidden_size
,
self
.
layer_norm1
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
eps
=
config
.
layer_norm_eps
)
self
.
mlp
=
BlipMLP
(
config
,
quant_config
=
quant_config
)
self
.
mlp
=
BlipMLP
(
config
,
quant_config
=
quant_config
)
...
...
vllm/model_executor/models/blip2.py
View file @
0640f227
...
@@ -13,13 +13,13 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
...
@@ -13,13 +13,13 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.opt
import
OPTModel
from
vllm.model_executor.models.opt
import
OPTModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SamplerOutput
,
SequenceData
)
SequenceData
)
from
.blip
import
(
BlipVisionModel
,
dummy_image_for_blip
,
from
.blip
import
(
BlipVisionModel
,
dummy_image_for_blip
,
get_max_blip_image_tokens
)
get_max_blip_image_tokens
)
...
@@ -40,13 +40,13 @@ BLIP2_IMAGE_TOKEN_ID = 50265
...
@@ -40,13 +40,13 @@ BLIP2_IMAGE_TOKEN_ID = 50265
class
Blip2ImagePixelInputs
(
TypedDict
):
class
Blip2ImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
type
:
Literal
[
"pixel_values"
]
data
:
torch
.
Tensor
data
:
torch
.
Tensor
"""Shape: (batch_size, num_channels, height, width)"""
"""Shape:
`
(batch_size
* num_images
, num_channels, height, width)
`
"""
class
Blip2ImageEmbeddingInputs
(
TypedDict
):
class
Blip2ImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
type
:
Literal
[
"image_embeds"
]
data
:
torch
.
Tensor
data
:
torch
.
Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
"""Shape: `(batch_size
* num_images
, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
`hidden_size` must match the hidden size of language model backbone.
"""
"""
...
@@ -555,6 +555,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -555,6 +555,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
raise
ValueError
(
"Incorrect type of pixel values. "
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
f
"Got type:
{
type
(
pixel_values
)
}
"
)
# Remove the N dimension until multiple images are supported.
pixel_values
=
pixel_values
.
squeeze
(
1
)
return
Blip2ImagePixelInputs
(
return
Blip2ImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
data
=
self
.
_validate_pixel_values
(
pixel_values
),
...
@@ -564,6 +567,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -564,6 +567,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
if
not
isinstance
(
image_embeds
,
torch
.
Tensor
):
if
not
isinstance
(
image_embeds
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of image embeddings. "
raise
ValueError
(
"Incorrect type of image embeddings. "
f
"Got type:
{
type
(
image_embeds
)
}
"
)
f
"Got type:
{
type
(
image_embeds
)
}
"
)
# Remove the N dimension until multiple images are supported.
image_embeds
=
image_embeds
.
squeeze
(
1
)
return
Blip2ImageEmbeddingInputs
(
return
Blip2ImageEmbeddingInputs
(
type
=
"image_embeds"
,
type
=
"image_embeds"
,
data
=
image_embeds
,
data
=
image_embeds
,
...
@@ -707,8 +714,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -707,8 +714,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
use_default_weight_loading
=
False
use_default_weight_loading
=
False
if
"vision"
in
name
:
if
"vision"
in
name
:
if
self
.
vision_model
is
not
None
:
if
self
.
vision_model
is
not
None
:
# We only do sharding for language model and
# BlipVisionModel does not need sharding
# not vision model for now.
use_default_weight_loading
=
True
use_default_weight_loading
=
True
else
:
else
:
for
(
param_name
,
weight_name
,
for
(
param_name
,
weight_name
,
...
...
vllm/model_executor/models/bloom.py
View file @
0640f227
...
@@ -36,12 +36,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -36,12 +36,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
...
...
vllm/model_executor/models/chameleon.py
View file @
0640f227
...
@@ -22,7 +22,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...
@@ -22,7 +22,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
...
@@ -33,7 +33,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
...
@@ -33,7 +33,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from
vllm.multimodal.utils
import
(
cached_get_tokenizer
,
from
vllm.multimodal.utils
import
(
cached_get_tokenizer
,
repeat_and_pad_placeholder_tokens
)
repeat_and_pad_placeholder_tokens
)
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SamplerOutput
,
SequenceData
)
SequenceData
)
from
vllm.utils
import
print_warning_once
from
vllm.utils
import
print_warning_once
from
.interfaces
import
SupportsMultiModal
from
.interfaces
import
SupportsMultiModal
...
@@ -53,7 +53,7 @@ CHAMELEON_SEP_TOKEN_ID = 8710
...
@@ -53,7 +53,7 @@ CHAMELEON_SEP_TOKEN_ID = 8710
class
ChameleonImagePixelInputs
(
TypedDict
):
class
ChameleonImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
type
:
Literal
[
"pixel_values"
]
data
:
torch
.
Tensor
data
:
torch
.
Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
"""Shape: `(batch_size
* num_images
, num_channels, height, width)`"""
def
get_max_chameleon_image_tokens
(
ctx
:
InputContext
):
def
get_max_chameleon_image_tokens
(
ctx
:
InputContext
):
...
@@ -946,6 +946,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -946,6 +946,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
raise
ValueError
(
"Incorrect type of pixel values. "
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
f
"Got type:
{
type
(
pixel_values
)
}
"
)
# Remove the N dimension until multiple images are supported.
pixel_values
=
pixel_values
.
squeeze
(
1
)
return
ChameleonImagePixelInputs
(
return
ChameleonImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
data
=
self
.
_validate_pixel_values
(
pixel_values
),
...
...
vllm/model_executor/models/chatglm.py
View file @
0640f227
...
@@ -22,12 +22,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
...
@@ -22,12 +22,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs
import
ChatGLMConfig
from
vllm.transformers_utils.configs
import
ChatGLMConfig
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
...
...
vllm/model_executor/models/clip.py
View file @
0640f227
"""Minimal implementation of CLIPVisionModel intended to be only used
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
within a vision language model."""
from
array
import
array
from
array
import
array
from
typing
import
Iterable
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
PIL
import
Image
from
PIL
import
Image
from
transformers
import
CLIPVisionConfig
from
transformers
import
CLIPVisionConfig
from
transformers.models.clip.modeling_clip
import
CLIPAttention
from
transformers.models.clip.modeling_clip
import
CLIP
Sdpa
Attention
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.inputs
import
LLMInputs
from
vllm.inputs
import
LLMInputs
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
@@ -20,6 +22,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
...
@@ -20,6 +22,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens
)
repeat_and_pad_placeholder_tokens
)
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
try
:
from
xformers
import
ops
as
xops
USE_XFORMERS_OPS
=
True
except
ImportError
:
USE_XFORMERS_OPS
=
False
def
get_clip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
def
get_clip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
assert
image_size
%
patch_size
==
0
assert
image_size
%
patch_size
==
0
...
@@ -84,7 +92,7 @@ def input_processor_for_clip(
...
@@ -84,7 +92,7 @@ def input_processor_for_clip(
llm_inputs
:
LLMInputs
,
llm_inputs
:
LLMInputs
,
*
,
*
,
image_token_id
:
int
,
image_token_id
:
int
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
image_feature_size_override
:
Optional
[
Union
[
int
,
List
[
int
]]
]
=
None
,
):
):
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
...
@@ -160,6 +168,78 @@ class CLIPVisionEmbeddings(nn.Module):
...
@@ -160,6 +168,78 @@ class CLIPVisionEmbeddings(nn.Module):
return
embeddings
return
embeddings
class
CLIPParallelAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
self
,
config
:
CLIPVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_dim
=
self
.
embed_dim
//
self
.
num_heads
if
self
.
head_dim
*
self
.
num_heads
!=
self
.
embed_dim
:
raise
ValueError
(
"embed_dim must be divisible by num_heads "
f
"(got `embed_dim`:
{
self
.
embed_dim
}
and `num_heads`:"
f
"
{
self
.
num_heads
}
)."
)
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
dropout
=
config
.
attention_dropout
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
=
self
.
embed_dim
,
head_size
=
self
.
head_dim
,
total_num_heads
=
self
.
num_heads
,
quant_config
=
quant_config
,
)
self
.
out_proj
=
RowParallelLinear
(
input_size
=
self
.
embed_dim
,
output_size
=
self
.
embed_dim
,
quant_config
=
quant_config
,
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
).
contiguous
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
):
"""Input shape: Batch x Time x Channel"""
bsz
,
tgt_len
,
_
=
hidden_states
.
size
()
qkv_states
,
_
=
self
.
qkv_proj
(
hidden_states
)
query_states
,
key_states
,
value_states
=
qkv_states
.
chunk
(
3
,
dim
=-
1
)
query_states
=
query_states
.
view
(
bsz
,
tgt_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
key_states
=
key_states
.
view
(
bsz
,
tgt_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
value_states
=
value_states
.
view
(
bsz
,
tgt_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
out
=
xops
.
memory_efficient_attention_forward
(
query_states
,
key_states
,
value_states
,
p
=
self
.
dropout
,
scale
=
self
.
scale
)
out
=
out
.
view
(
bsz
,
tgt_len
,
-
1
)
attn_output
,
_
=
self
.
out_proj
(
out
)
return
attn_output
,
None
class
CLIPMLP
(
nn
.
Module
):
class
CLIPMLP
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -192,7 +272,13 @@ class CLIPEncoderLayer(nn.Module):
...
@@ -192,7 +272,13 @@ class CLIPEncoderLayer(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
self_attn
=
CLIPAttention
(
config
)
num_heads
=
config
.
num_attention_heads
tp_size
=
get_tensor_model_parallel_world_size
()
if
USE_XFORMERS_OPS
and
num_heads
%
tp_size
==
0
:
self
.
self_attn
=
CLIPParallelAttention
(
config
,
quant_config
=
quant_config
)
else
:
self
.
self_attn
=
CLIPSdpaAttention
(
config
)
self
.
layer_norm1
=
nn
.
LayerNorm
(
config
.
hidden_size
,
self
.
layer_norm1
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
eps
=
config
.
layer_norm_eps
)
self
.
mlp
=
CLIPMLP
(
config
,
quant_config
=
quant_config
)
self
.
mlp
=
CLIPMLP
(
config
,
quant_config
=
quant_config
)
...
@@ -217,7 +303,7 @@ class CLIPEncoderLayer(nn.Module):
...
@@ -217,7 +303,7 @@ class CLIPEncoderLayer(nn.Module):
class
CLIPEncoder
(
nn
.
Module
):
class
CLIPEncoder
(
nn
.
Module
):
"""
"""
Transformer encoder consisting of `config.num_hidden_layers` self
Transformer encoder consisting of `config.num_hidden_layers` self
attention layers. Each layer is a [`CLIPEncoderLayer`].
attention layers. Each layer is a [`CLIPEncoderLayer`].
Args:
Args:
...
@@ -291,6 +377,10 @@ class CLIPVisionModel(nn.Module):
...
@@ -291,6 +377,10 @@ class CLIPVisionModel(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
num_hidden_layers_override
:
Optional
[
int
]
=
None
):
num_hidden_layers_override
:
Optional
[
int
]
=
None
):
super
().
__init__
()
super
().
__init__
()
tp_size
=
get_tensor_model_parallel_world_size
()
num_heads
=
config
.
num_attention_heads
self
.
shard_weight
=
USE_XFORMERS_OPS
and
num_heads
%
tp_size
==
0
self
.
vision_model
=
CLIPVisionTransformer
(
self
.
vision_model
=
CLIPVisionTransformer
(
config
=
config
,
config
=
config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
...
@@ -304,7 +394,15 @@ class CLIPVisionModel(nn.Module):
...
@@ -304,7 +394,15 @@ class CLIPVisionModel(nn.Module):
def
device
(
self
):
def
device
(
self
):
return
next
(
self
.
parameters
()).
device
return
next
(
self
.
parameters
()).
device
# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
if
self
.
shard_weight
else
[]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
layer_count
=
len
(
self
.
vision_model
.
encoder
.
layers
)
layer_count
=
len
(
self
.
vision_model
.
encoder
.
layers
)
...
@@ -318,7 +416,16 @@ class CLIPVisionModel(nn.Module):
...
@@ -318,7 +416,16 @@ class CLIPVisionModel(nn.Module):
if
layer_idx
>=
layer_count
:
if
layer_idx
>=
layer_count
:
continue
continue
param
=
params_dict
[
name
]
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
weight_loader
=
getattr
(
param
,
"weight_loader"
,
if
weight_name
not
in
name
:
default_weight_loader
)
continue
weight_loader
(
param
,
loaded_weight
)
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/commandr.py
View file @
0640f227
...
@@ -38,14 +38,14 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
...
@@ -38,14 +38,14 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
row_parallel_weight_loader
)
default_weight_loader
,
row_parallel_weight_loader
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
@
torch
.
compile
@
torch
.
compile
...
...
Prev
1
…
8
9
10
11
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment