Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
63575bc2
"tests/vscode:/vscode.git/clone" did not exist on "46cdd59577978f893dbf9c733cacd920011fc7fd"
Unverified
Commit
63575bc2
authored
May 06, 2024
by
youkaichao
Committed by
GitHub
May 06, 2024
Browse files
[Core][Optimization] change python dict to pytorch tensor (#4607)
parent
a98187cf
Changes
19
Show whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
77 additions
and
81 deletions
+77
-81
csrc/cache.h
csrc/cache.h
+1
-1
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+5
-15
csrc/cpu/cache.cpp
csrc/cpu/cache.cpp
+6
-14
tests/core/test_scheduler.py
tests/core/test_scheduler.py
+4
-4
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+12
-9
tests/worker/test_swap.py
tests/worker/test_swap.py
+1
-1
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+1
-1
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+1
-1
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+1
-1
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+1
-1
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+1
-1
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+1
-1
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+1
-1
vllm/core/scheduler.py
vllm/core/scheduler.py
+20
-21
vllm/distributed/communication_op.py
vllm/distributed/communication_op.py
+7
-0
vllm/sequence.py
vllm/sequence.py
+3
-3
vllm/worker/cache_engine.py
vllm/worker/cache_engine.py
+1
-1
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+5
-2
vllm/worker/worker.py
vllm/worker/worker.py
+5
-3
No files found.
csrc/cache.h
View file @
63575bc2
...
@@ -13,7 +13,7 @@ void swap_blocks(
...
@@ -13,7 +13,7 @@ void swap_blocks(
void
copy_blocks
(
void
copy_blocks
(
std
::
vector
<
torch
::
Tensor
>&
key_caches
,
std
::
vector
<
torch
::
Tensor
>&
key_caches
,
std
::
vector
<
torch
::
Tensor
>&
value_caches
,
std
::
vector
<
torch
::
Tensor
>&
value_caches
,
const
std
::
map
<
int64_t
,
std
::
vector
<
int64_t
>>
&
block_mapping
);
torch
::
Tensor
&
block_mapping
);
void
reshape_and_cache
(
void
reshape_and_cache
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
key
,
...
...
csrc/cache_kernels.cu
View file @
63575bc2
...
@@ -97,7 +97,7 @@ __global__ void copy_blocks_kernel(
...
@@ -97,7 +97,7 @@ __global__ void copy_blocks_kernel(
void
copy_blocks
(
void
copy_blocks
(
std
::
vector
<
torch
::
Tensor
>&
key_caches
,
std
::
vector
<
torch
::
Tensor
>&
key_caches
,
std
::
vector
<
torch
::
Tensor
>&
value_caches
,
std
::
vector
<
torch
::
Tensor
>&
value_caches
,
const
std
::
map
<
int64_t
,
std
::
vector
<
int64_t
>>
&
block_mapping
)
{
torch
::
Tensor
&
block_mapping
)
{
int
num_layers
=
key_caches
.
size
();
int
num_layers
=
key_caches
.
size
();
TORCH_CHECK
(
num_layers
==
value_caches
.
size
());
TORCH_CHECK
(
num_layers
==
value_caches
.
size
());
if
(
num_layers
==
0
)
{
if
(
num_layers
==
0
)
{
...
@@ -114,17 +114,9 @@ void copy_blocks(
...
@@ -114,17 +114,9 @@ void copy_blocks(
key_cache_ptrs
[
layer_idx
]
=
reinterpret_cast
<
int64_t
>
(
key_caches
[
layer_idx
].
data_ptr
());
key_cache_ptrs
[
layer_idx
]
=
reinterpret_cast
<
int64_t
>
(
key_caches
[
layer_idx
].
data_ptr
());
value_cache_ptrs
[
layer_idx
]
=
reinterpret_cast
<
int64_t
>
(
value_caches
[
layer_idx
].
data_ptr
());
value_cache_ptrs
[
layer_idx
]
=
reinterpret_cast
<
int64_t
>
(
value_caches
[
layer_idx
].
data_ptr
());
}
}
// Create block mapping array.
std
::
vector
<
int64_t
>
block_mapping_vec
;
// block_mapping is a 2D tensor with shape (num_pairs, 2).
for
(
const
auto
&
pair
:
block_mapping
)
{
int
num_pairs
=
block_mapping
.
size
(
0
);
int64_t
src_block_number
=
pair
.
first
;
for
(
int64_t
dst_block_number
:
pair
.
second
)
{
block_mapping_vec
.
push_back
(
src_block_number
);
block_mapping_vec
.
push_back
(
dst_block_number
);
}
}
int64_t
*
block_mapping_array
=
block_mapping_vec
.
data
();
int
num_pairs
=
block_mapping_vec
.
size
()
/
2
;
// Move the data structures to the GPU.
// Move the data structures to the GPU.
// NOTE: This synchronizes the CPU and GPU.
// NOTE: This synchronizes the CPU and GPU.
...
@@ -132,8 +124,6 @@ void copy_blocks(
...
@@ -132,8 +124,6 @@ void copy_blocks(
key_cache_ptrs
,
{
num_layers
},
torch
::
kInt64
).
to
(
cache_device
);
key_cache_ptrs
,
{
num_layers
},
torch
::
kInt64
).
to
(
cache_device
);
torch
::
Tensor
value_cache_ptrs_tensor
=
torch
::
from_blob
(
torch
::
Tensor
value_cache_ptrs_tensor
=
torch
::
from_blob
(
value_cache_ptrs
,
{
num_layers
},
torch
::
kInt64
).
to
(
cache_device
);
value_cache_ptrs
,
{
num_layers
},
torch
::
kInt64
).
to
(
cache_device
);
torch
::
Tensor
block_mapping_tensor
=
torch
::
from_blob
(
block_mapping_array
,
{
2
*
num_pairs
},
torch
::
kInt64
).
to
(
cache_device
);
// Launch the kernel.
// Launch the kernel.
const
int
numel_per_block
=
key_caches
[
0
][
0
].
numel
();
const
int
numel_per_block
=
key_caches
[
0
][
0
].
numel
();
...
@@ -146,7 +136,7 @@ void copy_blocks(
...
@@ -146,7 +136,7 @@ void copy_blocks(
vllm
::
copy_blocks_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
vllm
::
copy_blocks_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
key_cache_ptrs_tensor
.
data_ptr
<
int64_t
>
(),
key_cache_ptrs_tensor
.
data_ptr
<
int64_t
>
(),
value_cache_ptrs_tensor
.
data_ptr
<
int64_t
>
(),
value_cache_ptrs_tensor
.
data_ptr
<
int64_t
>
(),
block_mapping
_tensor
.
data_ptr
<
int64_t
>
(),
block_mapping
.
data_ptr
<
int64_t
>
(),
numel_per_block
);
numel_per_block
);
}));
}));
}
}
...
...
csrc/cpu/cache.cpp
View file @
63575bc2
...
@@ -8,16 +8,16 @@ template <typename scalar_t>
...
@@ -8,16 +8,16 @@ template <typename scalar_t>
void
copy_blocks_cpu_impl
(
void
copy_blocks_cpu_impl
(
std
::
vector
<
torch
::
Tensor
>
&
key_caches
,
std
::
vector
<
torch
::
Tensor
>
&
key_caches
,
std
::
vector
<
torch
::
Tensor
>
&
value_caches
,
std
::
vector
<
torch
::
Tensor
>
&
value_caches
,
const
std
::
vector
<
std
::
pair
<
int64_t
,
int64_t
>>
mapping_pairs
,
const
torch
::
Tensor
&
mapping_pairs
,
const
int
element_num_per_block
,
const
int
layer_num
)
{
const
int
element_num_per_block
,
const
int
layer_num
)
{
const
size_t
pair_num
=
mapping_pairs
.
size
();
const
size_t
pair_num
=
mapping_pairs
.
size
(
0
);
const
size_t
block_bytes
=
sizeof
(
scalar_t
)
*
element_num_per_block
;
const
size_t
block_bytes
=
sizeof
(
scalar_t
)
*
element_num_per_block
;
#pragma omp parallel for collapse(2)
#pragma omp parallel for collapse(2)
for
(
int
layer
=
0
;
layer
<
layer_num
;
++
layer
)
{
for
(
int
layer
=
0
;
layer
<
layer_num
;
++
layer
)
{
for
(
size_t
pair
=
0
;
pair
<
pair_num
;
++
pair
)
{
for
(
size_t
pair
=
0
;
pair
<
pair_num
;
++
pair
)
{
int64_t
source_offset
=
element_num_per_block
*
mapping_pairs
[
pair
]
.
first
;
int64_t
source_offset
=
element_num_per_block
*
mapping_pairs
[
pair
]
[
0
].
item
<
int64_t
>
()
;
int64_t
target_offset
=
int64_t
target_offset
=
element_num_per_block
*
mapping_pairs
[
pair
]
.
second
;
element_num_per_block
*
mapping_pairs
[
pair
]
[
1
].
item
<
int64_t
>
()
;
scalar_t
*
key_cache_ptr
=
key_caches
[
layer
].
data_ptr
<
scalar_t
>
();
scalar_t
*
key_cache_ptr
=
key_caches
[
layer
].
data_ptr
<
scalar_t
>
();
scalar_t
*
source_ptr
=
key_cache_ptr
+
source_offset
;
scalar_t
*
source_ptr
=
key_cache_ptr
+
source_offset
;
scalar_t
*
target_ptr
=
key_cache_ptr
+
target_offset
;
scalar_t
*
target_ptr
=
key_cache_ptr
+
target_offset
;
...
@@ -83,26 +83,18 @@ void reshape_and_cache_cpu_impl(
...
@@ -83,26 +83,18 @@ void reshape_and_cache_cpu_impl(
void
copy_blocks
(
std
::
vector
<
torch
::
Tensor
>
&
key_caches
,
void
copy_blocks
(
std
::
vector
<
torch
::
Tensor
>
&
key_caches
,
std
::
vector
<
torch
::
Tensor
>
&
value_caches
,
std
::
vector
<
torch
::
Tensor
>
&
value_caches
,
const
std
::
map
<
int64_t
,
std
::
vector
<
int64_t
>>
&
block_mapping
)
{
torch
::
Tensor
&
block_mapping
)
{
int
num_layers
=
key_caches
.
size
();
int
num_layers
=
key_caches
.
size
();
TORCH_CHECK
(
num_layers
==
value_caches
.
size
());
TORCH_CHECK
(
num_layers
==
value_caches
.
size
());
if
(
num_layers
==
0
)
{
if
(
num_layers
==
0
)
{
return
;
return
;
}
}
std
::
vector
<
std
::
pair
<
int64_t
,
int64_t
>>
mapping_pairs
;
mapping_pairs
.
reserve
(
block_mapping
.
size
());
for
(
const
auto
&
pair
:
block_mapping
)
{
for
(
const
auto
&
dst
:
pair
.
second
)
{
mapping_pairs
.
emplace_back
(
pair
.
first
,
dst
);
}
}
const
int
element_num_per_block
=
key_caches
[
0
][
0
].
numel
();
const
int
element_num_per_block
=
key_caches
[
0
][
0
].
numel
();
VLLM_DISPATCH_FLOATING_TYPES
(
VLLM_DISPATCH_FLOATING_TYPES
(
key_caches
[
0
].
scalar_type
(),
"copy_blocks_cpu_impl"
,
[
&
]
{
key_caches
[
0
].
scalar_type
(),
"copy_blocks_cpu_impl"
,
[
&
]
{
CPU_KERNEL_GUARD_IN
(
copy_blocks_cpu_impl
)
CPU_KERNEL_GUARD_IN
(
copy_blocks_cpu_impl
)
copy_blocks_cpu_impl
<
scalar_t
>
(
key_caches
,
value_caches
,
mapping
_pairs
,
copy_blocks_cpu_impl
<
scalar_t
>
(
key_caches
,
value_caches
,
block_
mapping
,
element_num_per_block
,
num_layers
);
element_num_per_block
,
num_layers
);
CPU_KERNEL_GUARD_OUT
(
copy_blocks_cpu_impl
)
CPU_KERNEL_GUARD_OUT
(
copy_blocks_cpu_impl
)
});
});
...
...
tests/core/test_scheduler.py
View file @
63575bc2
...
@@ -568,7 +568,7 @@ def test_decode_schedule_preempted():
...
@@ -568,7 +568,7 @@ def test_decode_schedule_preempted():
# Both should be preempted, not swapped.
# Both should be preempted, not swapped.
assert
output
.
blocks_to_swap_out
==
{}
assert
output
.
blocks_to_swap_out
==
{}
# Nothing is copied.
# Nothing is copied.
assert
output
.
blocks_to_copy
==
{}
assert
output
.
blocks_to_copy
==
[]
def
test_decode_swap_beam_search
():
def
test_decode_swap_beam_search
():
...
@@ -618,7 +618,7 @@ def test_decode_swap_beam_search():
...
@@ -618,7 +618,7 @@ def test_decode_swap_beam_search():
# Both should be preempted, not swapped.
# Both should be preempted, not swapped.
assert
output
.
blocks_to_swap_out
==
expected_swap_mapping
assert
output
.
blocks_to_swap_out
==
expected_swap_mapping
# Nothing is copied.
# Nothing is copied.
assert
output
.
blocks_to_copy
==
{}
assert
output
.
blocks_to_copy
==
[]
def
test_schedule_decode_blocks_to_copy_update
():
def
test_schedule_decode_blocks_to_copy_update
():
...
@@ -650,7 +650,7 @@ def test_schedule_decode_blocks_to_copy_update():
...
@@ -650,7 +650,7 @@ def test_schedule_decode_blocks_to_copy_update():
assert
output
.
blocks_to_swap_out
==
{}
assert
output
.
blocks_to_swap_out
==
{}
# Since append_slot returns the source -> dist mapping, it should
# Since append_slot returns the source -> dist mapping, it should
# applied.
# applied.
assert
output
.
blocks_to_copy
==
{
2
:
[
3
]}
assert
output
.
blocks_to_copy
==
[(
2
,
3
)]
def
test_schedule_swapped_simple
():
def
test_schedule_swapped_simple
():
...
@@ -853,7 +853,7 @@ def test_schedule_swapped_blocks_to_copy():
...
@@ -853,7 +853,7 @@ def test_schedule_swapped_blocks_to_copy():
assert
len
(
remaining_swapped
)
==
0
assert
len
(
remaining_swapped
)
==
0
assert
len
(
output
.
decode_seq_groups
)
==
1
assert
len
(
output
.
decode_seq_groups
)
==
1
assert
len
(
output
.
prefill_seq_groups
)
==
0
assert
len
(
output
.
prefill_seq_groups
)
==
0
assert
output
.
blocks_to_copy
==
{
2
:
[
3
]}
assert
output
.
blocks_to_copy
==
[(
2
,
3
)]
def
test_scheduling_budget
():
def
test_scheduling_budget
():
...
...
tests/kernels/test_cache.py
View file @
63575bc2
...
@@ -63,12 +63,13 @@ def test_copy_blocks(
...
@@ -63,12 +63,13 @@ def test_copy_blocks(
src_blocks
=
random
.
sample
(
range
(
num_blocks
),
num_mappings
)
src_blocks
=
random
.
sample
(
range
(
num_blocks
),
num_mappings
)
remainig_blocks
=
list
(
set
(
range
(
num_blocks
))
-
set
(
src_blocks
))
remainig_blocks
=
list
(
set
(
range
(
num_blocks
))
-
set
(
src_blocks
))
dst_blocks
=
random
.
sample
(
remainig_blocks
,
2
*
num_mappings
)
dst_blocks
=
random
.
sample
(
remainig_blocks
,
2
*
num_mappings
)
block_mapping
=
{}
block_mapping
=
[]
for
i
in
range
(
num_mappings
):
for
i
in
range
(
num_mappings
):
src
=
src_blocks
[
i
]
src
=
src_blocks
[
i
]
dst1
=
dst_blocks
[
2
*
i
]
dst1
=
dst_blocks
[
2
*
i
]
dst2
=
dst_blocks
[
2
*
i
+
1
]
dst2
=
dst_blocks
[
2
*
i
+
1
]
block_mapping
[
src
]
=
[
dst1
,
dst2
]
block_mapping
.
append
((
src
,
dst1
))
block_mapping
.
append
((
src
,
dst2
))
# Create the KV caches.
# Create the KV caches.
key_caches
,
value_caches
=
kv_cache_factory
(
num_blocks
,
block_size
,
key_caches
,
value_caches
=
kv_cache_factory
(
num_blocks
,
block_size
,
...
@@ -81,11 +82,13 @@ def test_copy_blocks(
...
@@ -81,11 +82,13 @@ def test_copy_blocks(
cloned_value_caches
=
[
value_cache
.
clone
()
for
value_cache
in
value_caches
]
cloned_value_caches
=
[
value_cache
.
clone
()
for
value_cache
in
value_caches
]
# Call the copy blocks kernel.
# Call the copy blocks kernel.
ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping
)
block_mapping_tensor
=
torch
.
tensor
(
block_mapping
,
dtype
=
torch
.
int64
,
device
=
device
).
view
(
-
1
,
2
)
ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping_tensor
)
# Run the reference implementation.
# Run the reference implementation.
for
src
,
dsts
in
block_mapping
.
items
():
for
src
,
dst
in
block_mapping
:
for
dst
in
dsts
:
for
cloned_key_cache
in
cloned_key_caches
:
for
cloned_key_cache
in
cloned_key_caches
:
cloned_key_cache
[
dst
].
copy_
(
cloned_key_cache
[
src
])
cloned_key_cache
[
dst
].
copy_
(
cloned_key_cache
[
src
])
for
cloned_value_cache
in
cloned_value_caches
:
for
cloned_value_cache
in
cloned_value_caches
:
...
...
tests/worker/test_swap.py
View file @
63575bc2
...
@@ -59,7 +59,7 @@ def test_swap() -> None:
...
@@ -59,7 +59,7 @@ def test_swap() -> None:
seq_group_metadata_list
=
[],
seq_group_metadata_list
=
[],
blocks_to_swap_in
=
{},
blocks_to_swap_in
=
{},
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
{}
,
blocks_to_copy
=
[]
,
)
)
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
...
...
vllm/attention/backends/abstract.py
View file @
63575bc2
...
@@ -42,7 +42,7 @@ class AttentionBackend(ABC):
...
@@ -42,7 +42,7 @@ class AttentionBackend(ABC):
@
abstractmethod
@
abstractmethod
def
copy_blocks
(
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]]
,
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
...
...
vllm/attention/backends/flash_attn.py
View file @
63575bc2
...
@@ -48,7 +48,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -48,7 +48,7 @@ class FlashAttentionBackend(AttentionBackend):
@
staticmethod
@
staticmethod
def
copy_blocks
(
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]]
,
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
...
...
vllm/attention/backends/flashinfer.py
View file @
63575bc2
...
@@ -48,7 +48,7 @@ class FlashInferBackend(AttentionBackend):
...
@@ -48,7 +48,7 @@ class FlashInferBackend(AttentionBackend):
@
staticmethod
@
staticmethod
def
copy_blocks
(
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]]
,
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
63575bc2
...
@@ -46,7 +46,7 @@ class ROCmFlashAttentionBackend(AttentionBackend):
...
@@ -46,7 +46,7 @@ class ROCmFlashAttentionBackend(AttentionBackend):
@
staticmethod
@
staticmethod
def
copy_blocks
(
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]]
,
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
...
...
vllm/attention/backends/torch_sdpa.py
View file @
63575bc2
...
@@ -44,7 +44,7 @@ class TorchSDPABackend(AttentionBackend):
...
@@ -44,7 +44,7 @@ class TorchSDPABackend(AttentionBackend):
@
staticmethod
@
staticmethod
def
copy_blocks
(
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]]
,
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
...
...
vllm/attention/backends/xformers.py
View file @
63575bc2
...
@@ -49,7 +49,7 @@ class XFormersBackend(AttentionBackend):
...
@@ -49,7 +49,7 @@ class XFormersBackend(AttentionBackend):
@
staticmethod
@
staticmethod
def
copy_blocks
(
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]]
,
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
...
...
vllm/attention/ops/paged_attn.py
View file @
63575bc2
...
@@ -209,7 +209,7 @@ class PagedAttention:
...
@@ -209,7 +209,7 @@ class PagedAttention:
@
staticmethod
@
staticmethod
def
copy_blocks
(
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]]
,
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
...
...
vllm/core/scheduler.py
View file @
63575bc2
...
@@ -13,7 +13,6 @@ from vllm.logger import init_logger
...
@@ -13,7 +13,6 @@ from vllm.logger import init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
(
Sequence
,
SequenceData
,
SequenceGroup
,
from
vllm.sequence
import
(
Sequence
,
SequenceData
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
)
SequenceGroupMetadata
,
SequenceStatus
)
from
vllm.utils
import
merge_dicts
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -122,8 +121,8 @@ class SchedulerOutputs:
...
@@ -122,8 +121,8 @@ class SchedulerOutputs:
blocks_to_swap_in
:
Dict
[
int
,
int
]
blocks_to_swap_in
:
Dict
[
int
,
int
]
# Blocks to swap out. Dict of GPU -> CPU block number.
# Blocks to swap out. Dict of GPU -> CPU block number.
blocks_to_swap_out
:
Dict
[
int
,
int
]
blocks_to_swap_out
:
Dict
[
int
,
int
]
# Blocks to copy. Source to
a list of
dest block
s
.
# Blocks to copy. Source to dest block.
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
# Sequence groups that are going to be ignored.
# Sequence groups that are going to be ignored.
ignored_seq_groups
:
List
[
SequenceGroup
]
ignored_seq_groups
:
List
[
SequenceGroup
]
# The number of slots for lookahead decoding.
# The number of slots for lookahead decoding.
...
@@ -177,7 +176,7 @@ class SchedulerRunningOutputs:
...
@@ -177,7 +176,7 @@ class SchedulerRunningOutputs:
# The blocks to swap out.
# The blocks to swap out.
blocks_to_swap_out
:
Dict
[
int
,
int
]
blocks_to_swap_out
:
Dict
[
int
,
int
]
# The blocks to copy.
# The blocks to copy.
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
# The number of slots for lookahead decoding.
# The number of slots for lookahead decoding.
num_lookahead_slots
:
int
num_lookahead_slots
:
int
...
@@ -189,7 +188,7 @@ class SchedulerRunningOutputs:
...
@@ -189,7 +188,7 @@ class SchedulerRunningOutputs:
preempted
=
[],
preempted
=
[],
swapped_out
=
[],
swapped_out
=
[],
blocks_to_swap_out
=
{},
blocks_to_swap_out
=
{},
blocks_to_copy
=
{}
,
blocks_to_copy
=
[]
,
num_lookahead_slots
=
0
,
num_lookahead_slots
=
0
,
)
)
...
@@ -209,7 +208,7 @@ class SchedulerSwappedInOutputs:
...
@@ -209,7 +208,7 @@ class SchedulerSwappedInOutputs:
# The blocks to swap in.
# The blocks to swap in.
blocks_to_swap_in
:
Dict
[
int
,
int
]
blocks_to_swap_in
:
Dict
[
int
,
int
]
# The blocks to copy.
# The blocks to copy.
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
# The number of slots for lookahead decoding.
# The number of slots for lookahead decoding.
num_lookahead_slots
:
int
num_lookahead_slots
:
int
# Infeasible sequence groups.
# Infeasible sequence groups.
...
@@ -221,7 +220,7 @@ class SchedulerSwappedInOutputs:
...
@@ -221,7 +220,7 @@ class SchedulerSwappedInOutputs:
decode_seq_groups
=
[],
decode_seq_groups
=
[],
prefill_seq_groups
=
[],
prefill_seq_groups
=
[],
blocks_to_swap_in
=
{},
blocks_to_swap_in
=
{},
blocks_to_copy
=
{}
,
blocks_to_copy
=
[]
,
num_lookahead_slots
=
0
,
num_lookahead_slots
=
0
,
infeasible_seq_groups
=
[],
infeasible_seq_groups
=
[],
)
)
...
@@ -394,7 +393,7 @@ class Scheduler:
...
@@ -394,7 +393,7 @@ class Scheduler:
"""
"""
# Blocks that need to be swapped or copied before model execution.
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
{}
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
{}
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
=
{}
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
=
[]
decode_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
decode_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
prefill_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
prefill_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
...
@@ -511,7 +510,7 @@ class Scheduler:
...
@@ -511,7 +510,7 @@ class Scheduler:
"""
"""
# Blocks that need to be swapped or copied before model execution.
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
{}
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
{}
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
=
{}
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
=
[]
decode_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
decode_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
prefill_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
prefill_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
now
=
time
.
time
()
now
=
time
.
time
()
...
@@ -794,8 +793,8 @@ class Scheduler:
...
@@ -794,8 +793,8 @@ class Scheduler:
num_batched_tokens
=
budget
.
num_batched_tokens
,
num_batched_tokens
=
budget
.
num_batched_tokens
,
blocks_to_swap_in
=
swapped_in
.
blocks_to_swap_in
,
blocks_to_swap_in
=
swapped_in
.
blocks_to_swap_in
,
blocks_to_swap_out
=
running_scheduled
.
blocks_to_swap_out
,
blocks_to_swap_out
=
running_scheduled
.
blocks_to_swap_out
,
blocks_to_copy
=
merge_dicts
(
running_scheduled
.
blocks_to_copy
,
blocks_to_copy
=
running_scheduled
.
blocks_to_copy
+
swapped_in
.
blocks_to_copy
)
,
swapped_in
.
blocks_to_copy
,
ignored_seq_groups
=
prefills
.
ignored_seq_groups
+
ignored_seq_groups
=
prefills
.
ignored_seq_groups
+
swapped_in
.
infeasible_seq_groups
,
swapped_in
.
infeasible_seq_groups
,
num_lookahead_slots
=
running_scheduled
.
num_lookahead_slots
,
num_lookahead_slots
=
running_scheduled
.
num_lookahead_slots
,
...
@@ -882,8 +881,8 @@ class Scheduler:
...
@@ -882,8 +881,8 @@ class Scheduler:
num_batched_tokens
=
budget
.
num_batched_tokens
,
num_batched_tokens
=
budget
.
num_batched_tokens
,
blocks_to_swap_in
=
swapped_in
.
blocks_to_swap_in
,
blocks_to_swap_in
=
swapped_in
.
blocks_to_swap_in
,
blocks_to_swap_out
=
running_scheduled
.
blocks_to_swap_out
,
blocks_to_swap_out
=
running_scheduled
.
blocks_to_swap_out
,
blocks_to_copy
=
merge_dicts
(
running_scheduled
.
blocks_to_copy
,
blocks_to_copy
=
running_scheduled
.
blocks_to_copy
+
swapped_in
.
blocks_to_copy
)
,
swapped_in
.
blocks_to_copy
,
ignored_seq_groups
=
prefills
.
ignored_seq_groups
,
ignored_seq_groups
=
prefills
.
ignored_seq_groups
,
num_lookahead_slots
=
running_scheduled
.
num_lookahead_slots
,
num_lookahead_slots
=
running_scheduled
.
num_lookahead_slots
,
running_queue_size
=
len
(
self
.
running
),
running_queue_size
=
len
(
self
.
running
),
...
@@ -1011,17 +1010,18 @@ class Scheduler:
...
@@ -1011,17 +1010,18 @@ class Scheduler:
def
_append_slots
(
def
_append_slots
(
self
,
self
,
seq_group
:
SequenceGroup
,
seq_group
:
SequenceGroup
,
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]],
)
->
None
:
)
->
None
:
"""Appends new slots to the sequences in the given sequence group.
"""Appends new slots to the sequences in the given sequence group.
Args:
Args:
seq_group (SequenceGroup): The sequence group containing the
seq_group (SequenceGroup): The sequence group containing the
sequences to append slots to.
sequences to append slots to.
blocks_to_copy (Dict[int, List[int]]): A dictionary mapping source
blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two
block indices to lists of destination block indices. This
ints, the first int is the source block index, and the second
dictionary is updated with the new source and destination block
int is the destination block index. This list is updated with
indices for the appended slots.
the new source and destination block indices for the appended
slots.
"""
"""
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
False
)
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
False
)
...
@@ -1029,9 +1029,8 @@ class Scheduler:
...
@@ -1029,9 +1029,8 @@ class Scheduler:
cows
=
self
.
block_manager
.
append_slots
(
seq
,
num_lookahead_slots
)
cows
=
self
.
block_manager
.
append_slots
(
seq
,
num_lookahead_slots
)
for
src
,
dests
in
cows
.
items
():
for
src
,
dests
in
cows
.
items
():
if
src
not
in
blocks_to_copy
:
for
dest
in
dests
:
blocks_to_copy
[
src
]
=
[]
blocks_to_copy
.
append
((
src
,
dest
))
blocks_to_copy
[
src
].
extend
(
dests
)
def
_preempt
(
def
_preempt
(
self
,
self
,
...
...
vllm/distributed/communication_op.py
View file @
63575bc2
...
@@ -203,6 +203,9 @@ def broadcast_tensor_dict(
...
@@ -203,6 +203,9 @@ def broadcast_tensor_dict(
group
=
metadata_group
)
group
=
metadata_group
)
async_handles
=
[]
async_handles
=
[]
for
tensor
in
tensor_list
:
for
tensor
in
tensor_list
:
if
tensor
.
numel
()
==
0
:
# Skip broadcasting empty tensors.
continue
async_handles
.
append
(
async_handles
.
append
(
torch
.
distributed
.
broadcast
(
tensor
,
torch
.
distributed
.
broadcast
(
tensor
,
src
=
src
,
src
=
src
,
...
@@ -224,6 +227,10 @@ def broadcast_tensor_dict(
...
@@ -224,6 +227,10 @@ def broadcast_tensor_dict(
tensor
=
torch
.
empty
(
value
.
size
,
tensor
=
torch
.
empty
(
value
.
size
,
dtype
=
value
.
dtype
,
dtype
=
value
.
dtype
,
device
=
"cuda"
)
device
=
"cuda"
)
if
tensor
.
numel
()
==
0
:
# Skip broadcasting empty tensors.
tensor_dict
[
key
]
=
tensor
continue
async_handle
=
torch
.
distributed
.
broadcast
(
tensor
,
async_handle
=
torch
.
distributed
.
broadcast
(
tensor
,
src
=
src
,
src
=
src
,
async_op
=
True
,
async_op
=
True
,
...
...
vllm/sequence.py
View file @
63575bc2
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
import
copy
import
copy
import
enum
import
enum
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
vllm.block
import
LogicalTokenBlock
from
vllm.block
import
LogicalTokenBlock
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
...
@@ -745,8 +745,8 @@ class ExecuteModelRequest:
...
@@ -745,8 +745,8 @@ class ExecuteModelRequest:
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
field
(
default_factory
=
dict
)
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
field
(
default_factory
=
dict
)
# Blocks to swap out. Dict of GPU -> CPU block number.
# Blocks to swap out. Dict of GPU -> CPU block number.
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
field
(
default_factory
=
dict
)
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
field
(
default_factory
=
dict
)
# Blocks to copy. Source to
a list of
dest block
s
.
# Blocks to copy. Source to dest block.
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
=
field
(
default_factory
=
dic
t
)
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
=
field
(
default_factory
=
lis
t
)
# The number of slots for lookahead decoding.
# The number of slots for lookahead decoding.
num_lookahead_slots
:
int
=
0
num_lookahead_slots
:
int
=
0
# The number of requests in the running queue.
# The number of requests in the running queue.
...
...
vllm/worker/cache_engine.py
View file @
63575bc2
...
@@ -77,7 +77,7 @@ class CacheEngine:
...
@@ -77,7 +77,7 @@ class CacheEngine:
self
.
attn_backend
.
swap_blocks
(
self
.
gpu_cache
[
i
],
self
.
cpu_cache
[
i
],
self
.
attn_backend
.
swap_blocks
(
self
.
gpu_cache
[
i
],
self
.
cpu_cache
[
i
],
src_to_dst
)
src_to_dst
)
def
copy
(
self
,
src_to_dsts
:
Dict
[
int
,
List
[
int
]]
)
->
None
:
def
copy
(
self
,
src_to_dsts
:
torch
.
Tensor
)
->
None
:
self
.
attn_backend
.
copy_blocks
(
self
.
gpu_cache
,
src_to_dsts
)
self
.
attn_backend
.
copy_blocks
(
self
.
gpu_cache
,
src_to_dsts
)
@
staticmethod
@
staticmethod
...
...
vllm/worker/cpu_worker.py
View file @
63575bc2
...
@@ -248,9 +248,9 @@ class CPUWorker(LoraNotSupportedWorkerBase):
...
@@ -248,9 +248,9 @@ class CPUWorker(LoraNotSupportedWorkerBase):
def
cache_copy
(
def
cache_copy
(
self
,
self
,
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
,
blocks_to_copy
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
if
blocks_to_copy
:
if
blocks_to_copy
.
numel
()
>
0
:
self
.
cache_engine
.
copy
(
blocks_to_copy
)
self
.
cache_engine
.
copy
(
blocks_to_copy
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
@@ -269,6 +269,9 @@ class CPUWorker(LoraNotSupportedWorkerBase):
...
@@ -269,6 +269,9 @@ class CPUWorker(LoraNotSupportedWorkerBase):
num_seq_groups
:
int
=
len
(
seq_group_metadata_list
)
num_seq_groups
:
int
=
len
(
seq_group_metadata_list
)
assert
execute_model_req
is
not
None
assert
execute_model_req
is
not
None
blocks_to_copy
=
execute_model_req
.
blocks_to_copy
blocks_to_copy
=
execute_model_req
.
blocks_to_copy
blocks_to_copy
=
torch
.
tensor
(
execute_model_req
.
blocks_to_copy
,
device
=
"cpu"
,
dtype
=
torch
.
int64
).
view
(
-
1
,
2
)
assert
len
(
execute_model_req
.
blocks_to_swap_in
)
==
0
assert
len
(
execute_model_req
.
blocks_to_swap_in
)
==
0
assert
len
(
execute_model_req
.
blocks_to_swap_out
)
==
0
assert
len
(
execute_model_req
.
blocks_to_swap_out
)
==
0
data
:
Dict
[
str
,
Any
]
=
{
data
:
Dict
[
str
,
Any
]
=
{
...
...
vllm/worker/worker.py
View file @
63575bc2
...
@@ -197,7 +197,7 @@ class Worker(WorkerBase):
...
@@ -197,7 +197,7 @@ class Worker(WorkerBase):
self
,
self
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
,
blocks_to_copy
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
# Issue cache operations.
# Issue cache operations.
# TODO(woosuk): Profile swapping overhead and optimize if needed.
# TODO(woosuk): Profile swapping overhead and optimize if needed.
...
@@ -205,7 +205,7 @@ class Worker(WorkerBase):
...
@@ -205,7 +205,7 @@ class Worker(WorkerBase):
self
.
cache_engine
.
swap_in
(
blocks_to_swap_in
)
self
.
cache_engine
.
swap_in
(
blocks_to_swap_in
)
if
blocks_to_swap_out
:
if
blocks_to_swap_out
:
self
.
cache_engine
.
swap_out
(
blocks_to_swap_out
)
self
.
cache_engine
.
swap_out
(
blocks_to_swap_out
)
if
blocks_to_copy
:
if
blocks_to_copy
.
numel
()
>
0
:
self
.
cache_engine
.
copy
(
blocks_to_copy
)
self
.
cache_engine
.
copy
(
blocks_to_copy
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
@@ -225,7 +225,9 @@ class Worker(WorkerBase):
...
@@ -225,7 +225,9 @@ class Worker(WorkerBase):
num_seq_groups
=
len
(
seq_group_metadata_list
)
num_seq_groups
=
len
(
seq_group_metadata_list
)
blocks_to_swap_in
=
execute_model_req
.
blocks_to_swap_in
blocks_to_swap_in
=
execute_model_req
.
blocks_to_swap_in
blocks_to_swap_out
=
execute_model_req
.
blocks_to_swap_out
blocks_to_swap_out
=
execute_model_req
.
blocks_to_swap_out
blocks_to_copy
=
execute_model_req
.
blocks_to_copy
blocks_to_copy
=
torch
.
tensor
(
execute_model_req
.
blocks_to_copy
,
device
=
self
.
device
,
dtype
=
torch
.
int64
).
view
(
-
1
,
2
)
data
:
Dict
[
str
,
Any
]
=
{
data
:
Dict
[
str
,
Any
]
=
{
"num_seq_groups"
:
num_seq_groups
,
"num_seq_groups"
:
num_seq_groups
,
"blocks_to_swap_in"
:
blocks_to_swap_in
,
"blocks_to_swap_in"
:
blocks_to_swap_in
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment