Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e09ce759
Unverified
Commit
e09ce759
authored
Jul 17, 2024
by
Woosuk Kwon
Committed by
GitHub
Jul 17, 2024
Browse files
[TPU] Remove multi-modal args in TPU backend (#6504)
parent
5fa6e987
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
40 deletions
+6
-40
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+6
-40
No files found.
vllm/worker/tpu_model_runner.py
View file @
e09ce759
import
time
import
time
from
typing
import
List
,
Mapping
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -12,8 +12,6 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
...
@@ -12,8 +12,6 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensors
,
MultiModalInputs
)
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
SamplerOutput
,
SequenceGroupMetadata
,
SequenceOutput
)
SequenceOutput
)
...
@@ -68,10 +66,6 @@ class TPUModelRunner:
...
@@ -68,10 +66,6 @@ class TPUModelRunner:
False
,
False
,
)
)
# Multi-modal data support
self
.
multi_modal_input_mapper
=
MULTIMODAL_REGISTRY
\
.
create_input_mapper
(
self
.
model_config
)
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
self
.
device
=
self
.
device_config
.
device
self
.
device
=
self
.
device_config
.
device
...
@@ -154,7 +148,7 @@ class TPUModelRunner:
...
@@ -154,7 +148,7 @@ class TPUModelRunner:
# Dummy run.
# Dummy run.
num_samples
=
_MAX_NUM_SAMPLES
if
is_prompt
else
1
num_samples
=
_MAX_NUM_SAMPLES
if
is_prompt
else
1
self
.
model
(
token_ids
,
position_ids
,
kv_caches
,
attn_metadata
,
self
.
model
(
token_ids
,
position_ids
,
kv_caches
,
attn_metadata
,
input_lens
,
None
,
t
,
p
,
num_samples
)
input_lens
,
t
,
p
,
num_samples
)
def
warmup_model
(
def
warmup_model
(
self
,
self
,
...
@@ -199,14 +193,12 @@ class TPUModelRunner:
...
@@ -199,14 +193,12 @@ class TPUModelRunner:
def
_prepare_prompt
(
def
_prepare_prompt
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
torch
.
Tensor
]:
Mapping
[
str
,
BatchedTensors
]]:
assert
len
(
seq_group_metadata_list
)
>
0
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
List
[
int
]]
=
[]
input_tokens
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
prompt_lens
:
List
[
int
]
=
[]
prompt_lens
:
List
[
int
]
=
[]
slot_mapping
:
List
[
List
[
int
]]
=
[]
slot_mapping
:
List
[
List
[
int
]]
=
[]
multi_modal_inputs_list
:
List
[
MultiModalInputs
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
is_prompt
assert
seq_group_metadata
.
is_prompt
...
@@ -232,11 +224,6 @@ class TPUModelRunner:
...
@@ -232,11 +224,6 @@ class TPUModelRunner:
slot
=
block_number
*
self
.
block_size
+
block_offset
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
[
-
1
].
append
(
slot
)
slot_mapping
[
-
1
].
append
(
slot
)
mm_data
=
seq_group_metadata
.
multi_modal_data
if
mm_data
:
mm_kwargs
=
self
.
multi_modal_input_mapper
(
mm_data
)
multi_modal_inputs_list
.
append
(
mm_kwargs
)
assert
len
(
prompt_lens
)
>
0
assert
len
(
prompt_lens
)
>
0
num_prefills
=
len
(
prompt_lens
)
num_prefills
=
len
(
prompt_lens
)
num_prefill_tokens
=
sum
(
prompt_lens
)
num_prefill_tokens
=
sum
(
prompt_lens
)
...
@@ -274,24 +261,17 @@ class TPUModelRunner:
...
@@ -274,24 +261,17 @@ class TPUModelRunner:
block_tables
=
None
,
block_tables
=
None
,
context_lens
=
None
,
context_lens
=
None
,
)
)
return
input_tokens
,
input_positions
,
attn_metadata
,
prompt_lens
multi_modal_kwargs
=
MultiModalInputs
.
batch
(
multi_modal_inputs_list
,
device
=
self
.
device
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
prompt_lens
,
multi_modal_kwargs
)
def
_prepare_decode
(
def
_prepare_decode
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
torch
.
Tensor
]:
Mapping
[
str
,
BatchedTensors
]]:
assert
len
(
seq_group_metadata_list
)
>
0
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
List
[
int
]]
=
[]
input_tokens
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
slot_mapping
:
List
[
List
[
int
]]
=
[]
slot_mapping
:
List
[
List
[
int
]]
=
[]
context_lens
:
List
[
int
]
=
[]
context_lens
:
List
[
int
]
=
[]
multi_modal_inputs_list
:
List
[
MultiModalInputs
]
=
[]
batch_idx
=
0
batch_idx
=
0
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
seq_group_metadata_list
:
...
@@ -317,11 +297,6 @@ class TPUModelRunner:
...
@@ -317,11 +297,6 @@ class TPUModelRunner:
slot
=
block_number
*
self
.
block_size
+
block_offset
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
([
slot
])
slot_mapping
.
append
([
slot
])
mm_data
=
seq_group_metadata
.
multi_modal_data
if
mm_data
:
mm_kwargs
=
self
.
multi_modal_input_mapper
(
mm_data
)
multi_modal_inputs_list
.
append
(
mm_kwargs
)
batch_size
=
_get_padded_batch_size
(
batch_idx
)
batch_size
=
_get_padded_batch_size
(
batch_idx
)
num_paddings
=
batch_size
-
batch_idx
num_paddings
=
batch_size
-
batch_idx
input_tokens
=
input_tokens
+
[[
0
]]
*
num_paddings
input_tokens
=
input_tokens
+
[[
0
]]
*
num_paddings
...
@@ -355,12 +330,7 @@ class TPUModelRunner:
...
@@ -355,12 +330,7 @@ class TPUModelRunner:
block_tables
=
block_tables
,
block_tables
=
block_tables
,
context_lens
=
context_lens
,
context_lens
=
context_lens
,
)
)
return
input_tokens
,
input_positions
,
attn_metadata
,
input_lens
multi_modal_kwargs
=
MultiModalInputs
.
batch
(
multi_modal_inputs_list
,
device
=
self
.
device
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
input_lens
,
multi_modal_kwargs
)
def
_prepare_sample
(
def
_prepare_sample
(
self
,
self
,
...
@@ -513,7 +483,6 @@ class ModelWrapper(nn.Module):
...
@@ -513,7 +483,6 @@ class ModelWrapper(nn.Module):
kv_caches
:
List
[
Tuple
[
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]],
kv_caches
:
List
[
Tuple
[
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
input_lens
:
torch
.
Tensor
,
input_lens
:
torch
.
Tensor
,
multi_modal_kwargs
:
Optional
[
Mapping
[
str
,
BatchedTensors
]],
t
:
torch
.
Tensor
,
t
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
num_samples
:
int
,
num_samples
:
int
,
...
@@ -527,8 +496,6 @@ class ModelWrapper(nn.Module):
...
@@ -527,8 +496,6 @@ class ModelWrapper(nn.Module):
memory profiling at initialization.
memory profiling at initialization.
attn_metadata: The Pallas attention metadata.
attn_metadata: The Pallas attention metadata.
input_lens: The actual input lengths of shape [batch_size].
input_lens: The actual input lengths of shape [batch_size].
multi_modal_kwargs: Keyword arguments from multi-modal data to
pass to the model.
t: The sampling temperature of shape [batch_size].
t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size].
p: The top-p probability of shape [batch_size].
"""
"""
...
@@ -573,7 +540,6 @@ class ModelWrapper(nn.Module):
...
@@ -573,7 +540,6 @@ class ModelWrapper(nn.Module):
position_ids
,
position_ids
,
kv_caches
,
kv_caches
,
attn_metadata
,
attn_metadata
,
**
(
multi_modal_kwargs
or
{}),
)
)
hidden_states
=
hidden_states
.
flatten
(
0
,
1
)
hidden_states
=
hidden_states
.
flatten
(
0
,
1
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment