Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c2395367
Unverified
Commit
c2395367
authored
Sep 25, 2024
by
Isotr0py
Committed by
GitHub
Sep 24, 2024
Browse files
[Hardware][CPU] Enable mrope and support Qwen2-VL on CPU backend (#8770)
parent
e3dd0692
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
99 additions
and
9 deletions
+99
-9
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+16
-0
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+83
-9
No files found.
vllm/model_executor/models/qwen2_vl.py
View file @
c2395367
...
...
@@ -67,6 +67,7 @@ from vllm.multimodal.image import cached_get_image_processor
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
from
vllm.transformers_utils.processor
import
get_processor
from
vllm.utils
import
is_cpu
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
)
...
...
@@ -281,6 +282,21 @@ class Qwen2VisionAttention(nn.Module):
context_layer
=
rearrange
(
output
,
"(b s) ... -> b s ..."
,
b
=
batch_size
)
elif
is_cpu
():
seq_length
=
q
.
size
(
1
)
q
,
k
,
v
=
[
rearrange
(
x
,
"b s h d -> b h s d"
)
for
x
in
[
q
,
k
,
v
]]
attention_mask
=
torch
.
zeros
([
1
,
seq_length
,
seq_length
],
device
=
q
.
device
,
dtype
=
torch
.
bool
)
for
i
in
range
(
1
,
len
(
cu_seqlens
)):
attention_mask
[...,
cu_seqlens
[
i
-
1
]:
cu_seqlens
[
i
],
cu_seqlens
[
i
-
1
]:
cu_seqlens
[
i
]]
=
True
output
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attention_mask
,
dropout_p
=
0.0
)
context_layer
=
rearrange
(
output
,
"b h s d -> b s h d "
)
else
:
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalMask
...
...
vllm/worker/cpu_model_runner.py
View file @
c2395367
...
...
@@ -12,11 +12,13 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
SchedulerConfig
)
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
MultiModalInputs
)
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
from
vllm.sequence
import
(
IntermediateTensors
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.utils
import
STR_NOT_IMPL_ENC_DEC_ERR_STRS
,
make_tensor_with_pad
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
...
...
@@ -145,6 +147,38 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
query_lens
=
seq_lens
,
)
def
_compute_multi_modal_input
(
self
,
seq_data
:
SequenceData
,
mm_data
,
computed_len
:
int
):
mm_kwargs
=
self
.
multi_modal_input_mapper
(
mm_data
)
# special processing for mrope position deltas.
mrope_positions
=
None
if
self
.
runner
.
model_is_mrope
:
image_grid_thw
=
mm_kwargs
.
get
(
"image_grid_thw"
,
None
)
video_grid_thw
=
mm_kwargs
.
get
(
"video_grid_thw"
,
None
)
assert
image_grid_thw
is
not
None
or
video_grid_thw
is
not
None
,
(
"mrope embedding type requires multi-modal input mapper "
"returns 'image_grid_thw' or 'video_grid_thw'."
)
hf_config
=
self
.
runner
.
model_config
.
hf_config
token_ids
=
seq_data
.
get_token_ids
()
mrope_positions
,
mrope_position_delta
=
\
MRotaryEmbedding
.
get_input_positions
(
token_ids
,
image_grid_thw
=
image_grid_thw
,
video_grid_thw
=
video_grid_thw
,
image_token_id
=
hf_config
.
image_token_id
,
video_token_id
=
hf_config
.
video_token_id
,
vision_start_token_id
=
hf_config
.
vision_start_token_id
,
vision_end_token_id
=
hf_config
.
vision_end_token_id
,
spatial_merge_size
=
hf_config
.
vision_config
.
spatial_merge_size
,
context_len
=
computed_len
,
)
seq_data
.
mrope_position_delta
=
mrope_position_delta
return
mm_kwargs
,
mrope_positions
def
_prepare_prompt
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
...
...
@@ -153,6 +187,8 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
input_mrope_positions
:
List
[
List
[
int
]]
=
[[]
for
_
in
range
(
3
)]
slot_mapping
:
List
[
int
]
=
[]
seq_lens
:
List
[
int
]
=
[]
multi_modal_inputs_list
:
List
[
MultiModalInputs
]
=
[]
...
...
@@ -171,15 +207,21 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
seq_lens
.
append
(
seq_len
)
# Prompt token num
input_tokens
.
extend
(
prompt_tokens
)
# Token ids
mrope_positions
=
None
if
(
mm_data
:
=
seq_group_metadata
.
multi_modal_data
):
mm_kwargs
,
mrope_positions
=
self
.
_compute_multi_modal_input
(
seq_data
,
mm_data
,
computed_len
)
multi_modal_inputs_list
.
append
(
mm_kwargs
)
# Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
if
mrope_positions
:
for
idx
in
range
(
3
):
input_mrope_positions
[
idx
].
extend
(
mrope_positions
[
idx
])
else
:
input_positions
.
extend
(
list
(
range
(
computed_len
,
seq_len
)))
if
(
mm_data
:
=
seq_group_metadata
.
multi_modal_data
):
mm_kwargs
=
self
.
multi_modal_input_mapper
(
mm_data
)
multi_modal_inputs_list
.
append
(
mm_kwargs
)
# Compute the slot mapping.
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
...
...
@@ -202,12 +244,18 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
if
any
(
input_mrope_positions
):
input_positions
=
None
# type: ignore
else
:
input_mrope_positions
=
None
# type: ignore
num_prompt_tokens
=
len
(
input_tokens
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# type: ignore
input_positions
=
torch
.
tensor
(
input_positions
,
input_positions
=
torch
.
tensor
(
input_positions
or
input_mrope_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# type: ignore
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
...
...
@@ -238,6 +286,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
input_mrope_positions
:
List
[
List
[
int
]]
=
[[]
for
_
in
range
(
3
)]
slot_mapping
:
List
[
int
]
=
[]
seq_lens
:
List
[
int
]
=
[]
block_tables
:
List
[
List
[
int
]]
=
[]
...
...
@@ -255,6 +304,16 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
seq_len
=
seq_data
.
get_len
()
position
=
seq_len
-
1
if
seq_data
.
mrope_position_delta
is
not
None
:
context_len
=
seq_data
.
get_num_computed_tokens
()
next_pos
=
MRotaryEmbedding
.
get_next_input_positions
(
seq_data
.
mrope_position_delta
,
context_len
,
seq_len
,
)
for
idx
in
range
(
3
):
input_mrope_positions
[
idx
].
extend
(
next_pos
[
idx
])
else
:
input_positions
.
append
(
position
)
seq_len
=
seq_len
if
self
.
sliding_window
is
None
else
min
(
...
...
@@ -273,12 +332,18 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
block_table
=
block_table
[
-
sliding_window_blocks
:]
block_tables
.
append
(
block_table
)
if
any
(
input_mrope_positions
):
input_positions
=
None
# type: ignore
else
:
input_mrope_positions
=
None
# type: ignore
max_decode_seq_len
=
max
(
seq_lens
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions
=
torch
.
tensor
(
input_positions
,
input_positions
=
torch
.
tensor
(
input_positions
or
input_mrope_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
...
...
@@ -373,6 +438,15 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
raise
NotImplementedError
(
STR_NOT_IMPL_ENC_DEC_ERR_STRS
[
'STR_NOT_IMPL_ENC_DEC_CPU'
])
@
property
def
model_is_mrope
(
self
)
->
bool
:
"""Detect if the model has "mrope" rope_scaling type.
mrope requires keep "rope_deltas" between prompt and decoding phases."""
rope_scaling
=
getattr
(
self
.
model_config
.
hf_config
,
"rope_scaling"
,
{})
if
rope_scaling
is
None
:
return
False
return
rope_scaling
.
get
(
"type"
,
None
)
==
"mrope"
def
load_model
(
self
)
->
None
:
self
.
model
=
get_model
(
model_config
=
self
.
model_config
,
load_config
=
self
.
load_config
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment