"tests/vscode:/vscode.git/clone" did not exist on "46cdd59577978f893dbf9c733cacd920011fc7fd"
Unverified Commit 63575bc2 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[Core][Optimization] change python dict to pytorch tensor (#4607)

parent a98187cf
...@@ -13,7 +13,7 @@ void swap_blocks( ...@@ -13,7 +13,7 @@ void swap_blocks(
void copy_blocks( void copy_blocks(
std::vector<torch::Tensor>& key_caches, std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches, std::vector<torch::Tensor>& value_caches,
const std::map<int64_t, std::vector<int64_t>>& block_mapping); torch::Tensor& block_mapping);
void reshape_and_cache( void reshape_and_cache(
torch::Tensor& key, torch::Tensor& key,
......
...@@ -97,7 +97,7 @@ __global__ void copy_blocks_kernel( ...@@ -97,7 +97,7 @@ __global__ void copy_blocks_kernel(
void copy_blocks( void copy_blocks(
std::vector<torch::Tensor>& key_caches, std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches, std::vector<torch::Tensor>& value_caches,
const std::map<int64_t, std::vector<int64_t>>& block_mapping) { torch::Tensor& block_mapping) {
int num_layers = key_caches.size(); int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size()); TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) { if (num_layers == 0) {
...@@ -114,17 +114,9 @@ void copy_blocks( ...@@ -114,17 +114,9 @@ void copy_blocks(
key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr()); key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr()); value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
} }
// Create block mapping array.
std::vector<int64_t> block_mapping_vec; // block_mapping is a 2D tensor with shape (num_pairs, 2).
for (const auto& pair : block_mapping) { int num_pairs = block_mapping.size(0);
int64_t src_block_number = pair.first;
for (int64_t dst_block_number : pair.second) {
block_mapping_vec.push_back(src_block_number);
block_mapping_vec.push_back(dst_block_number);
}
}
int64_t* block_mapping_array = block_mapping_vec.data();
int num_pairs = block_mapping_vec.size() / 2;
// Move the data structures to the GPU. // Move the data structures to the GPU.
// NOTE: This synchronizes the CPU and GPU. // NOTE: This synchronizes the CPU and GPU.
...@@ -132,8 +124,6 @@ void copy_blocks( ...@@ -132,8 +124,6 @@ void copy_blocks(
key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
torch::Tensor value_cache_ptrs_tensor = torch::from_blob( torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
torch::Tensor block_mapping_tensor = torch::from_blob(
block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
// Launch the kernel. // Launch the kernel.
const int numel_per_block = key_caches[0][0].numel(); const int numel_per_block = key_caches[0][0].numel();
...@@ -146,7 +136,7 @@ void copy_blocks( ...@@ -146,7 +136,7 @@ void copy_blocks(
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>( vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(), key_cache_ptrs_tensor.data_ptr<int64_t>(),
value_cache_ptrs_tensor.data_ptr<int64_t>(), value_cache_ptrs_tensor.data_ptr<int64_t>(),
block_mapping_tensor.data_ptr<int64_t>(), block_mapping.data_ptr<int64_t>(),
numel_per_block); numel_per_block);
})); }));
} }
......
...@@ -8,16 +8,16 @@ template <typename scalar_t> ...@@ -8,16 +8,16 @@ template <typename scalar_t>
void copy_blocks_cpu_impl( void copy_blocks_cpu_impl(
std::vector<torch::Tensor> &key_caches, std::vector<torch::Tensor> &key_caches,
std::vector<torch::Tensor> &value_caches, std::vector<torch::Tensor> &value_caches,
const std::vector<std::pair<int64_t, int64_t>> mapping_pairs, const torch::Tensor& mapping_pairs,
const int element_num_per_block, const int layer_num) { const int element_num_per_block, const int layer_num) {
const size_t pair_num = mapping_pairs.size(); const size_t pair_num = mapping_pairs.size(0);
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int layer = 0; layer < layer_num; ++layer) { for (int layer = 0; layer < layer_num; ++layer) {
for (size_t pair = 0; pair < pair_num; ++pair) { for (size_t pair = 0; pair < pair_num; ++pair) {
int64_t source_offset = element_num_per_block * mapping_pairs[pair].first; int64_t source_offset = element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
int64_t target_offset = int64_t target_offset =
element_num_per_block * mapping_pairs[pair].second; element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
scalar_t *key_cache_ptr = key_caches[layer].data_ptr<scalar_t>(); scalar_t *key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
scalar_t *source_ptr = key_cache_ptr + source_offset; scalar_t *source_ptr = key_cache_ptr + source_offset;
scalar_t *target_ptr = key_cache_ptr + target_offset; scalar_t *target_ptr = key_cache_ptr + target_offset;
...@@ -83,26 +83,18 @@ void reshape_and_cache_cpu_impl( ...@@ -83,26 +83,18 @@ void reshape_and_cache_cpu_impl(
void copy_blocks(std::vector<torch::Tensor> &key_caches, void copy_blocks(std::vector<torch::Tensor> &key_caches,
std::vector<torch::Tensor> &value_caches, std::vector<torch::Tensor> &value_caches,
const std::map<int64_t, std::vector<int64_t>> &block_mapping) { torch::Tensor& block_mapping) {
int num_layers = key_caches.size(); int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size()); TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) { if (num_layers == 0) {
return; return;
} }
std::vector<std::pair<int64_t, int64_t>> mapping_pairs;
mapping_pairs.reserve(block_mapping.size());
for (const auto &pair : block_mapping) {
for (const auto &dst : pair.second) {
mapping_pairs.emplace_back(pair.first, dst);
}
}
const int element_num_per_block = key_caches[0][0].numel(); const int element_num_per_block = key_caches[0][0].numel();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] { key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl) CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, mapping_pairs, copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
element_num_per_block, num_layers); element_num_per_block, num_layers);
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl) CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
}); });
......
...@@ -568,7 +568,7 @@ def test_decode_schedule_preempted(): ...@@ -568,7 +568,7 @@ def test_decode_schedule_preempted():
# Both should be preempted, not swapped. # Both should be preempted, not swapped.
assert output.blocks_to_swap_out == {} assert output.blocks_to_swap_out == {}
# Nothing is copied. # Nothing is copied.
assert output.blocks_to_copy == {} assert output.blocks_to_copy == []
def test_decode_swap_beam_search(): def test_decode_swap_beam_search():
...@@ -618,7 +618,7 @@ def test_decode_swap_beam_search(): ...@@ -618,7 +618,7 @@ def test_decode_swap_beam_search():
# Both should be preempted, not swapped. # Both should be preempted, not swapped.
assert output.blocks_to_swap_out == expected_swap_mapping assert output.blocks_to_swap_out == expected_swap_mapping
# Nothing is copied. # Nothing is copied.
assert output.blocks_to_copy == {} assert output.blocks_to_copy == []
def test_schedule_decode_blocks_to_copy_update(): def test_schedule_decode_blocks_to_copy_update():
...@@ -650,7 +650,7 @@ def test_schedule_decode_blocks_to_copy_update(): ...@@ -650,7 +650,7 @@ def test_schedule_decode_blocks_to_copy_update():
assert output.blocks_to_swap_out == {} assert output.blocks_to_swap_out == {}
# Since append_slot returns the source -> dist mapping, it should # Since append_slot returns the source -> dist mapping, it should
# applied. # applied.
assert output.blocks_to_copy == {2: [3]} assert output.blocks_to_copy == [(2, 3)]
def test_schedule_swapped_simple(): def test_schedule_swapped_simple():
...@@ -853,7 +853,7 @@ def test_schedule_swapped_blocks_to_copy(): ...@@ -853,7 +853,7 @@ def test_schedule_swapped_blocks_to_copy():
assert len(remaining_swapped) == 0 assert len(remaining_swapped) == 0
assert len(output.decode_seq_groups) == 1 assert len(output.decode_seq_groups) == 1
assert len(output.prefill_seq_groups) == 0 assert len(output.prefill_seq_groups) == 0
assert output.blocks_to_copy == {2: [3]} assert output.blocks_to_copy == [(2, 3)]
def test_scheduling_budget(): def test_scheduling_budget():
......
...@@ -63,12 +63,13 @@ def test_copy_blocks( ...@@ -63,12 +63,13 @@ def test_copy_blocks(
src_blocks = random.sample(range(num_blocks), num_mappings) src_blocks = random.sample(range(num_blocks), num_mappings)
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings) dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
block_mapping = {} block_mapping = []
for i in range(num_mappings): for i in range(num_mappings):
src = src_blocks[i] src = src_blocks[i]
dst1 = dst_blocks[2 * i] dst1 = dst_blocks[2 * i]
dst2 = dst_blocks[2 * i + 1] dst2 = dst_blocks[2 * i + 1]
block_mapping[src] = [dst1, dst2] block_mapping.append((src, dst1))
block_mapping.append((src, dst2))
# Create the KV caches. # Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
...@@ -81,11 +82,13 @@ def test_copy_blocks( ...@@ -81,11 +82,13 @@ def test_copy_blocks(
cloned_value_caches = [value_cache.clone() for value_cache in value_caches] cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
# Call the copy blocks kernel. # Call the copy blocks kernel.
ops.copy_blocks(key_caches, value_caches, block_mapping) block_mapping_tensor = torch.tensor(block_mapping,
dtype=torch.int64,
device=device).view(-1, 2)
ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
# Run the reference implementation. # Run the reference implementation.
for src, dsts in block_mapping.items(): for src, dst in block_mapping:
for dst in dsts:
for cloned_key_cache in cloned_key_caches: for cloned_key_cache in cloned_key_caches:
cloned_key_cache[dst].copy_(cloned_key_cache[src]) cloned_key_cache[dst].copy_(cloned_key_cache[src])
for cloned_value_cache in cloned_value_caches: for cloned_value_cache in cloned_value_caches:
......
...@@ -59,7 +59,7 @@ def test_swap() -> None: ...@@ -59,7 +59,7 @@ def test_swap() -> None:
seq_group_metadata_list=[], seq_group_metadata_list=[],
blocks_to_swap_in={}, blocks_to_swap_in={},
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy={}, blocks_to_copy=[],
) )
worker.execute_model(execute_model_req=execute_model_req) worker.execute_model(execute_model_req=execute_model_req)
......
...@@ -42,7 +42,7 @@ class AttentionBackend(ABC): ...@@ -42,7 +42,7 @@ class AttentionBackend(ABC):
@abstractmethod @abstractmethod
def copy_blocks( def copy_blocks(
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]], src_to_dists: torch.Tensor,
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError
......
...@@ -48,7 +48,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -48,7 +48,7 @@ class FlashAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def copy_blocks( def copy_blocks(
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]], src_to_dists: torch.Tensor,
) -> None: ) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists) PagedAttention.copy_blocks(kv_caches, src_to_dists)
......
...@@ -48,7 +48,7 @@ class FlashInferBackend(AttentionBackend): ...@@ -48,7 +48,7 @@ class FlashInferBackend(AttentionBackend):
@staticmethod @staticmethod
def copy_blocks( def copy_blocks(
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]], src_to_dists: torch.Tensor,
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError
......
...@@ -46,7 +46,7 @@ class ROCmFlashAttentionBackend(AttentionBackend): ...@@ -46,7 +46,7 @@ class ROCmFlashAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def copy_blocks( def copy_blocks(
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]], src_to_dists: torch.Tensor,
) -> None: ) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists) PagedAttention.copy_blocks(kv_caches, src_to_dists)
......
...@@ -44,7 +44,7 @@ class TorchSDPABackend(AttentionBackend): ...@@ -44,7 +44,7 @@ class TorchSDPABackend(AttentionBackend):
@staticmethod @staticmethod
def copy_blocks( def copy_blocks(
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]], src_to_dists: torch.Tensor,
) -> None: ) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists) PagedAttention.copy_blocks(kv_caches, src_to_dists)
......
...@@ -49,7 +49,7 @@ class XFormersBackend(AttentionBackend): ...@@ -49,7 +49,7 @@ class XFormersBackend(AttentionBackend):
@staticmethod @staticmethod
def copy_blocks( def copy_blocks(
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]], src_to_dists: torch.Tensor,
) -> None: ) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists) PagedAttention.copy_blocks(kv_caches, src_to_dists)
......
...@@ -209,7 +209,7 @@ class PagedAttention: ...@@ -209,7 +209,7 @@ class PagedAttention:
@staticmethod @staticmethod
def copy_blocks( def copy_blocks(
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]], src_to_dists: torch.Tensor,
) -> None: ) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches] key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches]
......
...@@ -13,7 +13,6 @@ from vllm.logger import init_logger ...@@ -13,7 +13,6 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup, from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus) SequenceGroupMetadata, SequenceStatus)
from vllm.utils import merge_dicts
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -122,8 +121,8 @@ class SchedulerOutputs: ...@@ -122,8 +121,8 @@ class SchedulerOutputs:
blocks_to_swap_in: Dict[int, int] blocks_to_swap_in: Dict[int, int]
# Blocks to swap out. Dict of GPU -> CPU block number. # Blocks to swap out. Dict of GPU -> CPU block number.
blocks_to_swap_out: Dict[int, int] blocks_to_swap_out: Dict[int, int]
# Blocks to copy. Source to a list of dest blocks. # Blocks to copy. Source to dest block.
blocks_to_copy: Dict[int, List[int]] blocks_to_copy: List[Tuple[int, int]]
# Sequence groups that are going to be ignored. # Sequence groups that are going to be ignored.
ignored_seq_groups: List[SequenceGroup] ignored_seq_groups: List[SequenceGroup]
# The number of slots for lookahead decoding. # The number of slots for lookahead decoding.
...@@ -177,7 +176,7 @@ class SchedulerRunningOutputs: ...@@ -177,7 +176,7 @@ class SchedulerRunningOutputs:
# The blocks to swap out. # The blocks to swap out.
blocks_to_swap_out: Dict[int, int] blocks_to_swap_out: Dict[int, int]
# The blocks to copy. # The blocks to copy.
blocks_to_copy: Dict[int, List[int]] blocks_to_copy: List[Tuple[int, int]]
# The number of slots for lookahead decoding. # The number of slots for lookahead decoding.
num_lookahead_slots: int num_lookahead_slots: int
...@@ -189,7 +188,7 @@ class SchedulerRunningOutputs: ...@@ -189,7 +188,7 @@ class SchedulerRunningOutputs:
preempted=[], preempted=[],
swapped_out=[], swapped_out=[],
blocks_to_swap_out={}, blocks_to_swap_out={},
blocks_to_copy={}, blocks_to_copy=[],
num_lookahead_slots=0, num_lookahead_slots=0,
) )
...@@ -209,7 +208,7 @@ class SchedulerSwappedInOutputs: ...@@ -209,7 +208,7 @@ class SchedulerSwappedInOutputs:
# The blocks to swap in. # The blocks to swap in.
blocks_to_swap_in: Dict[int, int] blocks_to_swap_in: Dict[int, int]
# The blocks to copy. # The blocks to copy.
blocks_to_copy: Dict[int, List[int]] blocks_to_copy: List[Tuple[int, int]]
# The number of slots for lookahead decoding. # The number of slots for lookahead decoding.
num_lookahead_slots: int num_lookahead_slots: int
# Infeasible sequence groups. # Infeasible sequence groups.
...@@ -221,7 +220,7 @@ class SchedulerSwappedInOutputs: ...@@ -221,7 +220,7 @@ class SchedulerSwappedInOutputs:
decode_seq_groups=[], decode_seq_groups=[],
prefill_seq_groups=[], prefill_seq_groups=[],
blocks_to_swap_in={}, blocks_to_swap_in={},
blocks_to_copy={}, blocks_to_copy=[],
num_lookahead_slots=0, num_lookahead_slots=0,
infeasible_seq_groups=[], infeasible_seq_groups=[],
) )
...@@ -394,7 +393,7 @@ class Scheduler: ...@@ -394,7 +393,7 @@ class Scheduler:
""" """
# Blocks that need to be swapped or copied before model execution. # Blocks that need to be swapped or copied before model execution.
blocks_to_swap_out: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {} blocks_to_copy: List[Tuple[int, int]] = []
decode_seq_groups: List[ScheduledSequenceGroup] = [] decode_seq_groups: List[ScheduledSequenceGroup] = []
prefill_seq_groups: List[ScheduledSequenceGroup] = [] prefill_seq_groups: List[ScheduledSequenceGroup] = []
...@@ -511,7 +510,7 @@ class Scheduler: ...@@ -511,7 +510,7 @@ class Scheduler:
""" """
# Blocks that need to be swapped or copied before model execution. # Blocks that need to be swapped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_in: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {} blocks_to_copy: List[Tuple[int, int]] = []
decode_seq_groups: List[ScheduledSequenceGroup] = [] decode_seq_groups: List[ScheduledSequenceGroup] = []
prefill_seq_groups: List[ScheduledSequenceGroup] = [] prefill_seq_groups: List[ScheduledSequenceGroup] = []
now = time.time() now = time.time()
...@@ -794,8 +793,8 @@ class Scheduler: ...@@ -794,8 +793,8 @@ class Scheduler:
num_batched_tokens=budget.num_batched_tokens, num_batched_tokens=budget.num_batched_tokens,
blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_in=swapped_in.blocks_to_swap_in,
blocks_to_swap_out=running_scheduled.blocks_to_swap_out, blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, blocks_to_copy=running_scheduled.blocks_to_copy +
swapped_in.blocks_to_copy), swapped_in.blocks_to_copy,
ignored_seq_groups=prefills.ignored_seq_groups + ignored_seq_groups=prefills.ignored_seq_groups +
swapped_in.infeasible_seq_groups, swapped_in.infeasible_seq_groups,
num_lookahead_slots=running_scheduled.num_lookahead_slots, num_lookahead_slots=running_scheduled.num_lookahead_slots,
...@@ -882,8 +881,8 @@ class Scheduler: ...@@ -882,8 +881,8 @@ class Scheduler:
num_batched_tokens=budget.num_batched_tokens, num_batched_tokens=budget.num_batched_tokens,
blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_in=swapped_in.blocks_to_swap_in,
blocks_to_swap_out=running_scheduled.blocks_to_swap_out, blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, blocks_to_copy=running_scheduled.blocks_to_copy +
swapped_in.blocks_to_copy), swapped_in.blocks_to_copy,
ignored_seq_groups=prefills.ignored_seq_groups, ignored_seq_groups=prefills.ignored_seq_groups,
num_lookahead_slots=running_scheduled.num_lookahead_slots, num_lookahead_slots=running_scheduled.num_lookahead_slots,
running_queue_size=len(self.running), running_queue_size=len(self.running),
...@@ -1011,17 +1010,18 @@ class Scheduler: ...@@ -1011,17 +1010,18 @@ class Scheduler:
def _append_slots( def _append_slots(
self, self,
seq_group: SequenceGroup, seq_group: SequenceGroup,
blocks_to_copy: Dict[int, List[int]], blocks_to_copy: List[Tuple[int, int]],
) -> None: ) -> None:
"""Appends new slots to the sequences in the given sequence group. """Appends new slots to the sequences in the given sequence group.
Args: Args:
seq_group (SequenceGroup): The sequence group containing the seq_group (SequenceGroup): The sequence group containing the
sequences to append slots to. sequences to append slots to.
blocks_to_copy (Dict[int, List[int]]): A dictionary mapping source blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two
block indices to lists of destination block indices. This ints, the first int is the source block index, and the second
dictionary is updated with the new source and destination block int is the destination block index. This list is updated with
indices for the appended slots. the new source and destination block indices for the appended
slots.
""" """
num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False) num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
...@@ -1029,9 +1029,8 @@ class Scheduler: ...@@ -1029,9 +1029,8 @@ class Scheduler:
cows = self.block_manager.append_slots(seq, num_lookahead_slots) cows = self.block_manager.append_slots(seq, num_lookahead_slots)
for src, dests in cows.items(): for src, dests in cows.items():
if src not in blocks_to_copy: for dest in dests:
blocks_to_copy[src] = [] blocks_to_copy.append((src, dest))
blocks_to_copy[src].extend(dests)
def _preempt( def _preempt(
self, self,
......
...@@ -203,6 +203,9 @@ def broadcast_tensor_dict( ...@@ -203,6 +203,9 @@ def broadcast_tensor_dict(
group=metadata_group) group=metadata_group)
async_handles = [] async_handles = []
for tensor in tensor_list: for tensor in tensor_list:
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
continue
async_handles.append( async_handles.append(
torch.distributed.broadcast(tensor, torch.distributed.broadcast(tensor,
src=src, src=src,
...@@ -224,6 +227,10 @@ def broadcast_tensor_dict( ...@@ -224,6 +227,10 @@ def broadcast_tensor_dict(
tensor = torch.empty(value.size, tensor = torch.empty(value.size,
dtype=value.dtype, dtype=value.dtype,
device="cuda") device="cuda")
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
async_handle = torch.distributed.broadcast(tensor, async_handle = torch.distributed.broadcast(tensor,
src=src, src=src,
async_op=True, async_op=True,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import copy import copy
import enum import enum
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from vllm.block import LogicalTokenBlock from vllm.block import LogicalTokenBlock
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -745,8 +745,8 @@ class ExecuteModelRequest: ...@@ -745,8 +745,8 @@ class ExecuteModelRequest:
blocks_to_swap_in: Dict[int, int] = field(default_factory=dict) blocks_to_swap_in: Dict[int, int] = field(default_factory=dict)
# Blocks to swap out. Dict of GPU -> CPU block number. # Blocks to swap out. Dict of GPU -> CPU block number.
blocks_to_swap_out: Dict[int, int] = field(default_factory=dict) blocks_to_swap_out: Dict[int, int] = field(default_factory=dict)
# Blocks to copy. Source to a list of dest blocks. # Blocks to copy. Source to dest block.
blocks_to_copy: Dict[int, List[int]] = field(default_factory=dict) blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
# The number of slots for lookahead decoding. # The number of slots for lookahead decoding.
num_lookahead_slots: int = 0 num_lookahead_slots: int = 0
# The number of requests in the running queue. # The number of requests in the running queue.
......
...@@ -77,7 +77,7 @@ class CacheEngine: ...@@ -77,7 +77,7 @@ class CacheEngine:
self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i], self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
src_to_dst) src_to_dst)
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: def copy(self, src_to_dsts: torch.Tensor) -> None:
self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts) self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
@staticmethod @staticmethod
......
...@@ -248,9 +248,9 @@ class CPUWorker(LoraNotSupportedWorkerBase): ...@@ -248,9 +248,9 @@ class CPUWorker(LoraNotSupportedWorkerBase):
def cache_copy( def cache_copy(
self, self,
blocks_to_copy: Dict[int, List[int]], blocks_to_copy: torch.Tensor,
) -> None: ) -> None:
if blocks_to_copy: if blocks_to_copy.numel() > 0:
self.cache_engine.copy(blocks_to_copy) self.cache_engine.copy(blocks_to_copy)
@torch.inference_mode() @torch.inference_mode()
...@@ -269,6 +269,9 @@ class CPUWorker(LoraNotSupportedWorkerBase): ...@@ -269,6 +269,9 @@ class CPUWorker(LoraNotSupportedWorkerBase):
num_seq_groups: int = len(seq_group_metadata_list) num_seq_groups: int = len(seq_group_metadata_list)
assert execute_model_req is not None assert execute_model_req is not None
blocks_to_copy = execute_model_req.blocks_to_copy blocks_to_copy = execute_model_req.blocks_to_copy
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device="cpu",
dtype=torch.int64).view(-1, 2)
assert len(execute_model_req.blocks_to_swap_in) == 0 assert len(execute_model_req.blocks_to_swap_in) == 0
assert len(execute_model_req.blocks_to_swap_out) == 0 assert len(execute_model_req.blocks_to_swap_out) == 0
data: Dict[str, Any] = { data: Dict[str, Any] = {
......
...@@ -197,7 +197,7 @@ class Worker(WorkerBase): ...@@ -197,7 +197,7 @@ class Worker(WorkerBase):
self, self,
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]], blocks_to_copy: torch.Tensor,
) -> None: ) -> None:
# Issue cache operations. # Issue cache operations.
# TODO(woosuk): Profile swapping overhead and optimize if needed. # TODO(woosuk): Profile swapping overhead and optimize if needed.
...@@ -205,7 +205,7 @@ class Worker(WorkerBase): ...@@ -205,7 +205,7 @@ class Worker(WorkerBase):
self.cache_engine.swap_in(blocks_to_swap_in) self.cache_engine.swap_in(blocks_to_swap_in)
if blocks_to_swap_out: if blocks_to_swap_out:
self.cache_engine.swap_out(blocks_to_swap_out) self.cache_engine.swap_out(blocks_to_swap_out)
if blocks_to_copy: if blocks_to_copy.numel() > 0:
self.cache_engine.copy(blocks_to_copy) self.cache_engine.copy(blocks_to_copy)
@torch.inference_mode() @torch.inference_mode()
...@@ -225,7 +225,9 @@ class Worker(WorkerBase): ...@@ -225,7 +225,9 @@ class Worker(WorkerBase):
num_seq_groups = len(seq_group_metadata_list) num_seq_groups = len(seq_group_metadata_list)
blocks_to_swap_in = execute_model_req.blocks_to_swap_in blocks_to_swap_in = execute_model_req.blocks_to_swap_in
blocks_to_swap_out = execute_model_req.blocks_to_swap_out blocks_to_swap_out = execute_model_req.blocks_to_swap_out
blocks_to_copy = execute_model_req.blocks_to_copy blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device=self.device,
dtype=torch.int64).view(-1, 2)
data: Dict[str, Any] = { data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups, "num_seq_groups": num_seq_groups,
"blocks_to_swap_in": blocks_to_swap_in, "blocks_to_swap_in": blocks_to_swap_in,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment