Unverified Commit bc12d403 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Add grouped free operations (#1706)

parent 392f2863
...@@ -834,6 +834,8 @@ class Scheduler: ...@@ -834,6 +834,8 @@ class Scheduler:
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids) next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
self.token_to_kv_pool.free_group_begin()
# Check finish condition # Check finish condition
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
if self.server_args.enable_overlap_schedule and req.finished(): if self.server_args.enable_overlap_schedule and req.finished():
...@@ -860,6 +862,8 @@ class Scheduler: ...@@ -860,6 +862,8 @@ class Scheduler:
self.stream_output(batch.reqs) self.stream_output(batch.reqs)
self.token_to_kv_pool.free_group_end()
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30) self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0: if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
self.print_decode_stats() self.print_decode_stats()
......
...@@ -18,7 +18,6 @@ limitations under the License. ...@@ -18,7 +18,6 @@ limitations under the License.
import logging import logging
from typing import List, Tuple, Union from typing import List, Tuple, Union
import numpy as np
import torch import torch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -77,6 +76,8 @@ class BaseTokenToKVPool: ...@@ -77,6 +76,8 @@ class BaseTokenToKVPool:
self.store_dtype = dtype self.store_dtype = dtype
self.free_slots = None self.free_slots = None
self.is_not_in_free_group = True
self.free_group = []
self.clear() self.clear()
def available_size(self): def available_size(self):
...@@ -89,14 +90,28 @@ class BaseTokenToKVPool: ...@@ -89,14 +90,28 @@ class BaseTokenToKVPool:
select_index = self.free_slots[:need_size] select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:] self.free_slots = self.free_slots[need_size:]
return torch.tensor(select_index, dtype=torch.int32, device=self.device) return select_index.to(self.device)
def free(self, free_index: torch.Tensor): def free(self, free_index: torch.Tensor):
self.free_slots = np.concatenate((self.free_slots, free_index.cpu().numpy())) if self.is_not_in_free_group:
self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
else:
self.free_group.append(free_index)
def free_group_begin(self):
self.is_not_in_free_group = False
self.free_group = []
def free_group_end(self):
self.is_not_in_free_group = True
if self.free_group:
self.free(torch.concat(self.free_group))
def clear(self): def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens. # The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_slots = np.arange(1, self.size + 1) self.free_slots = torch.arange(1, self.size + 1, dtype=torch.int32)
self.is_in_free_group = False
self.free_group = []
def get_key_buffer(self, layer_id: int) -> torch.Tensor: def get_key_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError() raise NotImplementedError()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment