Commit 771ed2c4 authored by maxiang's avatar maxiang
Browse files

[fix] connector new_block_ids = None 导致 OOM

parent aef3c487
......@@ -48,6 +48,7 @@ class ReqMeta:
block_ids_tensor = torch.tensor(block_ids)
num_blocks = block_ids_tensor.shape[0]
block_offsets = torch.arange(0, block_size)
slot_mapping = block_offsets.reshape((1, block_size)) + \
block_ids_tensor.reshape((num_blocks, 1)) * block_size
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
......@@ -70,7 +71,7 @@ class DuSwiftConnectorMetadata(KVConnectorMetadata):
self,
request_id: str,
token_ids: list[int],
block_ids: list[int],
block_ids: list[int], #这里为None ??
block_size: int,
) -> None:
self.requests.append(
......@@ -619,12 +620,15 @@ class DuSwiftConnector(KVConnectorBase_V1):
num_scheduled_tokens = (
scheduler_output.num_scheduled_tokens)[req_id]
num_tokens = (num_scheduled_tokens + num_computed_tokens)
# assert req_id in self.chunked_prefill
if req_id not in self.chunked_prefill:
if req_id not in self.chunked_prefill:
continue
block_ids = new_block_ids[0]
delta_block_ids = (
[] if new_block_ids is None else new_block_ids[0])
if not resumed_from_preemption:
block_ids = (self.chunked_prefill[req_id][0] + block_ids)
block_ids = (
self.chunked_prefill[req_id][0] + delta_block_ids)
else:
block_ids = delta_block_ids
prompt_token_ids = self.chunked_prefill[req_id][1]
# the request's prompt is chunked prefill again
if num_tokens < len(prompt_token_ids):
......@@ -644,13 +648,23 @@ class DuSwiftConnector(KVConnectorBase_V1):
if not resumed_from_preemption:
break
if req_id in self._requests_need_load:
request, _ = self._requests_need_load.pop(req_id)
request, fallback_block_ids = (
self._requests_need_load.pop(req_id))
total_tokens = num_computed_tokens + 1
token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
block_ids = new_block_ids[0]
if new_block_ids is not None:
block_ids = new_block_ids[0]
elif fallback_block_ids:
block_ids = fallback_block_ids
logger.warning(
"Using fallback block_ids for resumed request "
"%s: new_block_ids is None.", req_id)
else:
logger.warning(
"Skip KV load meta for resumed request %s: "
"no block_ids available.", req_id)
continue
meta.add_request(request_id=req_id,
token_ids=token_ids,
......
......@@ -477,10 +477,11 @@ class P2pNcclConnector(KVConnectorBase_V1):
"""
Update KVConnector state after block allocation.
"""
#将全量blocks存入字典
if not self.is_producer and num_external_tokens > 0:
self._requests_need_load[request.request_id] = (
request,
blocks.get_block_ids()[0],
blocks.get_block_ids()[0], #转换为block ID 列表 req的全量blocks
)
def build_connector_meta(
......@@ -520,6 +521,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
block_size=self._block_size,
)
continue
#新请求
if new_req.req_id in self._requests_need_load:
meta.add_request(
request_id=new_req.req_id,
......@@ -538,16 +540,19 @@ class P2pNcclConnector(KVConnectorBase_V1):
if self.is_producer:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_tokens = num_scheduled_tokens + num_computed_tokens
# assert req_id in self.chunked_prefill
if req_id not in self.chunked_prefill:
if req_id not in self.chunked_prefill:
continue
assert new_block_ids is not None
block_ids = new_block_ids[0]
delta_block_ids = (
[] if new_block_ids is None else new_block_ids[0])
if not resumed_from_preemption:
block_ids = self.chunked_prefill[req_id][0] + block_ids
block_ids = (
self.chunked_prefill[req_id][0] + delta_block_ids)
else:
block_ids = delta_block_ids
prompt_token_ids = self.chunked_prefill[req_id][1]
assert prompt_token_ids is not None
# the request's prompt is chunked prefill again
# ???? 一直累积
if num_tokens < len(prompt_token_ids):
self.chunked_prefill[req_id] = (block_ids, prompt_token_ids)
continue
......@@ -563,17 +568,27 @@ class P2pNcclConnector(KVConnectorBase_V1):
# NOTE(rob): here we rely on the resumed requests being
# the first N requests in the list scheduled_cache_reqs.
if not resumed_from_preemption:
break
if req_id in self._requests_need_load:
request, _ = self._requests_need_load.pop(req_id)
request, fallback_block_ids = (
self._requests_need_load.pop(req_id))
total_tokens = num_computed_tokens + 1
token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
assert new_block_ids is not None
block_ids = new_block_ids[0]
if new_block_ids is not None:
block_ids = new_block_ids[0]
elif fallback_block_ids:
block_ids = fallback_block_ids
logger.warning(
"Using fallback block_ids for resumed request "
"%s: new_block_ids is None.", req_id)
else:
logger.warning(
"Skip KV load meta for resumed request %s: "
"no block_ids available.", req_id)
continue
meta.add_request(
request_id=req_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment