Unverified Commit c4bb5264 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

fix(server): decrease memory fragmentation (#557)

parent 6f429427
import torch
from typing import Dict, Optional, TypeVar from typing import Dict, Optional, TypeVar
from text_generation_server.models.types import Batch from text_generation_server.models.types import Batch
...@@ -20,6 +22,8 @@ class Cache: ...@@ -20,6 +22,8 @@ class Cache:
batch = self.pop(batch_id) batch = self.pop(batch_id)
if batch is not None: if batch is not None:
del batch del batch
if torch.cuda.is_available():
torch.cuda.empty_cache()
def clear(self): def clear(self):
keys = list(self.cache.keys()) keys = list(self.cache.keys())
......
...@@ -638,6 +638,8 @@ class FlashCausalLMBatch(Batch): ...@@ -638,6 +638,8 @@ class FlashCausalLMBatch(Batch):
# Needed to avoid dropping blocks when the batches will go out of scope # Needed to avoid dropping blocks when the batches will go out of scope
for b in batches: for b in batches:
b.block_tables = None b.block_tables = None
del b
torch.cuda.empty_cache()
return FlashCausalLMBatch( return FlashCausalLMBatch(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
...@@ -732,6 +734,7 @@ class FlashCausalLM(Model): ...@@ -732,6 +734,7 @@ class FlashCausalLM(Model):
) )
raise e raise e
del batch del batch
torch.cuda.empty_cache()
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
return self.tokenizer.decode( return self.tokenizer.decode(
...@@ -775,16 +778,21 @@ class FlashCausalLM(Model): ...@@ -775,16 +778,21 @@ class FlashCausalLM(Model):
# Allocate blocks to this batch # Allocate blocks to this batch
CACHE_MANAGER.allocate(batch) CACHE_MANAGER.allocate(batch)
out = self.forward( try:
batch.input_ids, out = self.forward(
batch.position_ids, batch.input_ids,
batch.cu_seqlen_prefill, batch.position_ids,
batch.block_tables_tensor, batch.cu_seqlen_prefill,
batch.slots[batch.slot_indices], batch.block_tables_tensor,
batch.input_lengths_tensor, batch.slots[batch.slot_indices],
batch.max_seqlen, batch.input_lengths_tensor,
batch.prefill_head_indices, batch.max_seqlen,
) batch.prefill_head_indices,
)
except Exception as e:
del batch
torch.cuda.empty_cache()
raise e
if prefill: if prefill:
next_token_logits = ( next_token_logits = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment