Unverified Commit c5581b27 authored by Casper's avatar Casper Committed by GitHub
Browse files

Adaptive batch sizing (#181)

parent df909e83
...@@ -120,7 +120,7 @@ Fused modules are a large part of the speedup you get from AutoAWQ. The idea is ...@@ -120,7 +120,7 @@ Fused modules are a large part of the speedup you get from AutoAWQ. The idea is
- Fused modules are activated when you use `fuse_layers=True`. - Fused modules are activated when you use `fuse_layers=True`.
- A custom cache is implemented. It preallocates based on batch size and sequence length. - A custom cache is implemented. It preallocates based on batch size and sequence length.
- You cannot change the sequence length or batch size after you have created your model. - You cannot change the sequence length after you have created your model.
- Reference: `AutoAWQForCausalLM.from_quantized(max_new_tokens=seq_len, batch_size=batch_size)` - Reference: `AutoAWQForCausalLM.from_quantized(max_new_tokens=seq_len, batch_size=batch_size)`
- The main accelerator in the fused modules comes from FasterTransformer, which is only compatible with Linux. - The main accelerator in the fused modules comes from FasterTransformer, which is only compatible with Linux.
- The `past_key_values` from `model.generate()` are only dummy values, so they cannot be used after generation. - The `past_key_values` from `model.generate()` are only dummy values, so they cannot be used after generation.
......
...@@ -123,11 +123,14 @@ class QuantAttentionFused(nn.Module): ...@@ -123,11 +123,14 @@ class QuantAttentionFused(nn.Module):
def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs): def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs):
bsz, seqlen, _ = hidden_states.shape bsz, seqlen, _ = hidden_states.shape
# Reallocate cache if batch size changes
if bsz != self.cache_batch_size: if bsz != self.cache_batch_size:
raise RuntimeError( if bsz > self.cache_batch_size:
f"Batch size is incorrectly set - input batch size {bsz}, kv-cache batch size {self.cache_batch_size}. " self.cache.increase_batch_size(bsz)
f"Use: AutoAWQForCausalLM.from_quantized(batch_size={bsz})" self.cache_batch_size = bsz
) elif bsz < self.cache_batch_size:
self.cache.decrease_batch_size(bsz)
self.cache_batch_size = bsz
xqkv = self.qkv_proj(hidden_states) xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"]) xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
......
...@@ -47,4 +47,13 @@ class WindowedCache: ...@@ -47,4 +47,13 @@ class WindowedCache:
def to(self, device): def to(self, device):
self.k = self.k.to(device) self.k = self.k.to(device)
self.v = self.v.to(device) self.v = self.v.to(device)
\ No newline at end of file def increase_batch_size(self, to_bsz):
"""Dynamically allocate new kv when batch size changes."""
self.v = torch.zeros(to_bsz, *self.v.shape[1:], dtype=self.v.dtype, device=self.v.device)
self.k = torch.zeros(to_bsz, *self.k.shape[1:], dtype=self.k.dtype, device=self.k.device)
def decrease_batch_size(self, to_bsz):
"""Dynamically remove part of cache if batch size changes."""
self.v = self.v[:to_bsz, :, :, :]
self.k = self.k[:to_bsz, :, :, :, :]
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment