"Deepspeed/vscode:/vscode.git/clone" did not exist on "316d3f9088d2479a79ff4251090254be628ebdc8"
Commit 771ed2c4 authored by maxiang's avatar maxiang
Browse files

[fix] connector new_block_ids = None 导致 OOM

parent aef3c487
...@@ -48,6 +48,7 @@ class ReqMeta: ...@@ -48,6 +48,7 @@ class ReqMeta:
block_ids_tensor = torch.tensor(block_ids) block_ids_tensor = torch.tensor(block_ids)
num_blocks = block_ids_tensor.shape[0] num_blocks = block_ids_tensor.shape[0]
block_offsets = torch.arange(0, block_size) block_offsets = torch.arange(0, block_size)
slot_mapping = block_offsets.reshape((1, block_size)) + \ slot_mapping = block_offsets.reshape((1, block_size)) + \
block_ids_tensor.reshape((num_blocks, 1)) * block_size block_ids_tensor.reshape((num_blocks, 1)) * block_size
slot_mapping = slot_mapping.flatten()[:valid_num_tokens] slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
...@@ -70,7 +71,7 @@ class DuSwiftConnectorMetadata(KVConnectorMetadata): ...@@ -70,7 +71,7 @@ class DuSwiftConnectorMetadata(KVConnectorMetadata):
self, self,
request_id: str, request_id: str,
token_ids: list[int], token_ids: list[int],
block_ids: list[int], block_ids: list[int], #这里为None ??
block_size: int, block_size: int,
) -> None: ) -> None:
self.requests.append( self.requests.append(
...@@ -619,12 +620,15 @@ class DuSwiftConnector(KVConnectorBase_V1): ...@@ -619,12 +620,15 @@ class DuSwiftConnector(KVConnectorBase_V1):
num_scheduled_tokens = ( num_scheduled_tokens = (
scheduler_output.num_scheduled_tokens)[req_id] scheduler_output.num_scheduled_tokens)[req_id]
num_tokens = (num_scheduled_tokens + num_computed_tokens) num_tokens = (num_scheduled_tokens + num_computed_tokens)
# assert req_id in self.chunked_prefill
if req_id not in self.chunked_prefill: if req_id not in self.chunked_prefill:
continue continue
block_ids = new_block_ids[0] delta_block_ids = (
[] if new_block_ids is None else new_block_ids[0])
if not resumed_from_preemption: if not resumed_from_preemption:
block_ids = (self.chunked_prefill[req_id][0] + block_ids) block_ids = (
self.chunked_prefill[req_id][0] + delta_block_ids)
else:
block_ids = delta_block_ids
prompt_token_ids = self.chunked_prefill[req_id][1] prompt_token_ids = self.chunked_prefill[req_id][1]
# the request's prompt is chunked prefill again # the request's prompt is chunked prefill again
if num_tokens < len(prompt_token_ids): if num_tokens < len(prompt_token_ids):
...@@ -644,13 +648,23 @@ class DuSwiftConnector(KVConnectorBase_V1): ...@@ -644,13 +648,23 @@ class DuSwiftConnector(KVConnectorBase_V1):
if not resumed_from_preemption: if not resumed_from_preemption:
break break
if req_id in self._requests_need_load: if req_id in self._requests_need_load:
request, _ = self._requests_need_load.pop(req_id) request, fallback_block_ids = (
self._requests_need_load.pop(req_id))
total_tokens = num_computed_tokens + 1 total_tokens = num_computed_tokens + 1
token_ids = request.all_token_ids[:total_tokens] token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all if new_block_ids is not None:
# of the block_ids for the request.
block_ids = new_block_ids[0] block_ids = new_block_ids[0]
elif fallback_block_ids:
block_ids = fallback_block_ids
logger.warning(
"Using fallback block_ids for resumed request "
"%s: new_block_ids is None.", req_id)
else:
logger.warning(
"Skip KV load meta for resumed request %s: "
"no block_ids available.", req_id)
continue
meta.add_request(request_id=req_id, meta.add_request(request_id=req_id,
token_ids=token_ids, token_ids=token_ids,
......
...@@ -477,10 +477,11 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -477,10 +477,11 @@ class P2pNcclConnector(KVConnectorBase_V1):
""" """
Update KVConnector state after block allocation. Update KVConnector state after block allocation.
""" """
#将全量blocks存入字典
if not self.is_producer and num_external_tokens > 0: if not self.is_producer and num_external_tokens > 0:
self._requests_need_load[request.request_id] = ( self._requests_need_load[request.request_id] = (
request, request,
blocks.get_block_ids()[0], blocks.get_block_ids()[0], #转换为block ID 列表 req的全量blocks
) )
def build_connector_meta( def build_connector_meta(
...@@ -520,6 +521,7 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -520,6 +521,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
block_size=self._block_size, block_size=self._block_size,
) )
continue continue
#新请求
if new_req.req_id in self._requests_need_load: if new_req.req_id in self._requests_need_load:
meta.add_request( meta.add_request(
request_id=new_req.req_id, request_id=new_req.req_id,
...@@ -538,16 +540,19 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -538,16 +540,19 @@ class P2pNcclConnector(KVConnectorBase_V1):
if self.is_producer: if self.is_producer:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_tokens = num_scheduled_tokens + num_computed_tokens num_tokens = num_scheduled_tokens + num_computed_tokens
# assert req_id in self.chunked_prefill
if req_id not in self.chunked_prefill: if req_id not in self.chunked_prefill:
continue continue
assert new_block_ids is not None delta_block_ids = (
block_ids = new_block_ids[0] [] if new_block_ids is None else new_block_ids[0])
if not resumed_from_preemption: if not resumed_from_preemption:
block_ids = self.chunked_prefill[req_id][0] + block_ids block_ids = (
self.chunked_prefill[req_id][0] + delta_block_ids)
else:
block_ids = delta_block_ids
prompt_token_ids = self.chunked_prefill[req_id][1] prompt_token_ids = self.chunked_prefill[req_id][1]
assert prompt_token_ids is not None assert prompt_token_ids is not None
# the request's prompt is chunked prefill again # the request's prompt is chunked prefill again
# ???? 一直累积
if num_tokens < len(prompt_token_ids): if num_tokens < len(prompt_token_ids):
self.chunked_prefill[req_id] = (block_ids, prompt_token_ids) self.chunked_prefill[req_id] = (block_ids, prompt_token_ids)
continue continue
...@@ -563,17 +568,27 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -563,17 +568,27 @@ class P2pNcclConnector(KVConnectorBase_V1):
# NOTE(rob): here we rely on the resumed requests being # NOTE(rob): here we rely on the resumed requests being
# the first N requests in the list scheduled_cache_reqs. # the first N requests in the list scheduled_cache_reqs.
if not resumed_from_preemption: if not resumed_from_preemption:
break break
if req_id in self._requests_need_load: if req_id in self._requests_need_load:
request, _ = self._requests_need_load.pop(req_id) request, fallback_block_ids = (
self._requests_need_load.pop(req_id))
total_tokens = num_computed_tokens + 1 total_tokens = num_computed_tokens + 1
token_ids = request.all_token_ids[:total_tokens] token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all if new_block_ids is not None:
# of the block_ids for the request.
assert new_block_ids is not None
block_ids = new_block_ids[0] block_ids = new_block_ids[0]
elif fallback_block_ids:
block_ids = fallback_block_ids
logger.warning(
"Using fallback block_ids for resumed request "
"%s: new_block_ids is None.", req_id)
else:
logger.warning(
"Skip KV load meta for resumed request %s: "
"no block_ids available.", req_id)
continue
meta.add_request( meta.add_request(
request_id=req_id, request_id=req_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment