Unverified Commit 3eda6562 authored by Casper's avatar Casper Committed by GitHub
Browse files

Fix performance regression (#148)

parent bf64abd8
...@@ -107,7 +107,7 @@ class QuantAttentionFused(nn.Module): ...@@ -107,7 +107,7 @@ class QuantAttentionFused(nn.Module):
) )
# cache store that rolls cache # cache store that rolls cache
self.cache = WindowedCache( self.cache = WindowedCache(
self.attention_shapes["cache_v"], self.attention_shapes["cache_k"], dev self.attention_shapes["cache_v"], self.attention_shapes["cache_k"], self.max_seq_len, dev
) )
if use_alibi: if use_alibi:
...@@ -128,9 +128,14 @@ class QuantAttentionFused(nn.Module): ...@@ -128,9 +128,14 @@ class QuantAttentionFused(nn.Module):
f"Use: AutoAWQForCausalLM.from_quantized(batch_size={bsz})" f"Use: AutoAWQForCausalLM.from_quantized(batch_size={bsz})"
) )
if self.start_pos > self.max_seq_len or self.start_pos + seqlen > self.max_seq_len: will_cache_be_exceeded = self.start_pos + seqlen > self.max_seq_len
excess_length = self.start_pos + seqlen - self.max_seq_len
self.start_pos = self.cache.roll_kv(excess_length, self.start_pos) # Reset and avoid retaining state when processing context
if will_cache_be_exceeded:
self.start_pos = self.cache.roll_kv_n_steps(self.start_pos, n=self.start_pos)
# Slowly roll out old tokens without performance hit if exceeded during decoding
elif will_cache_be_exceeded and seqlen == 1:
self.start_pos = self.cache.roll_kv_n_steps(self.start_pos, n=100)
xqkv = self.qkv_proj(hidden_states) xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"]) xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
...@@ -158,6 +163,7 @@ class QuantAttentionFused(nn.Module): ...@@ -158,6 +163,7 @@ class QuantAttentionFused(nn.Module):
self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen) self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen)
# Only necessary to retrieve from cache when we are not processing context
if seqlen == 1: if seqlen == 1:
xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim) xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim)
......
import torch import torch
class WindowedCache: class WindowedCache:
def __init__(self, cache_v_shape, cache_k_shape, device): def __init__(self, cache_v_shape, cache_k_shape, max_seq_len, device):
""" """
The window size is the same as the max_new_tokens. The window will The window size is the same as the max_new_tokens. The window will
automatically roll once max_new_tokens is exceeded. automatically roll once max_new_tokens is exceeded.
...@@ -10,8 +10,12 @@ class WindowedCache: ...@@ -10,8 +10,12 @@ class WindowedCache:
self.v = torch.zeros(cache_v_shape).to(device).half() self.v = torch.zeros(cache_v_shape).to(device).half()
# [batch_size, n_kv_heads, head_dim // pack_factor, max_seq_len, pack_factor] # [batch_size, n_kv_heads, head_dim // pack_factor, max_seq_len, pack_factor]
self.k = torch.zeros(cache_k_shape).to(device).half() self.k = torch.zeros(cache_k_shape).to(device).half()
self.max_seq_len = max_seq_len
def get_kv(self, batch_size, start_pos, seqlen, head_dim): def get_kv(self, batch_size, start_pos, seqlen, head_dim):
"""
Gets the key-value store in correct shapes.
"""
xv = self.v[:batch_size, :, : start_pos + seqlen, :].transpose(1, 2).contiguous() xv = self.v[:batch_size, :, : start_pos + seqlen, :].transpose(1, 2).contiguous()
xk = self.k[:batch_size, :, :, : start_pos + seqlen, :].transpose(2, 3).contiguous() xk = self.k[:batch_size, :, :, : start_pos + seqlen, :].transpose(2, 3).contiguous()
xk = xk.reshape(xk.shape[:-2] + (head_dim,)).transpose(1, 2).contiguous() xk = xk.reshape(xk.shape[:-2] + (head_dim,)).transpose(1, 2).contiguous()
...@@ -19,19 +23,26 @@ class WindowedCache: ...@@ -19,19 +23,26 @@ class WindowedCache:
return xv, xk return xv, xk
def update_kv(self, values_store, keys_store, batch_size, start_pos, seqlen): def update_kv(self, values_store, keys_store, batch_size, start_pos, seqlen):
"""
Updates the values in the key-value store.
"""
self.v[:batch_size, :, start_pos : start_pos + seqlen, :] = values_store self.v[:batch_size, :, start_pos : start_pos + seqlen, :] = values_store
self.k[:batch_size, :, :, start_pos : start_pos + seqlen, :] = keys_store self.k[:batch_size, :, :, start_pos : start_pos + seqlen, :] = keys_store
def roll_kv(self, roll_len, start_pos): def roll_kv_n_steps(self, start_pos, n=100):
# Roll only the necessary part of the cache to the left """
self.v[:, :, :-roll_len, :] = self.v[:, :, roll_len:, :] Roll cache n to the left.
self.k[:, :, :, :-roll_len, :] = self.k[:, :, :, roll_len:, :] """
n = min(n, self.max_seq_len)
# Roll cache to the left
self.v = torch.roll(self.v, shifts=-n, dims=2)
self.k = torch.roll(self.k, shifts=-n, dims=3)
# Zero out the new part # Zero out the new part
self.v[:, :, -roll_len:, :] = 0 self.v[:, :, -n:, :] = 0
self.k[:, :, :, -roll_len:, :] = 0 self.k[:, :, :, -n:, :] = 0
return start_pos - roll_len return start_pos - n
def to(self, device): def to(self, device):
self.k = self.k.to(device) self.k = self.k.to(device)
......
...@@ -85,7 +85,7 @@ def run_round(model_path, quant_file, n_generate, input_ids, batch_size, safeten ...@@ -85,7 +85,7 @@ def run_round(model_path, quant_file, n_generate, input_ids, batch_size, safeten
"Prefill tokens/s": prefill_tokens_per_second, "Prefill tokens/s": prefill_tokens_per_second,
"Decode tokens/s": decode_tokens_per_second, "Decode tokens/s": decode_tokens_per_second,
"Memory (VRAM)": f"{memory_used:.2f} GB ({memory_pct:.2f}%)" "Memory (VRAM)": f"{memory_used:.2f} GB ({memory_pct:.2f}%)"
}, model.quant_config["version"] }, model.quant_config.version
def main(args): def main(args):
rounds = [ rounds = [
......
...@@ -13,7 +13,7 @@ PYPI_BUILD = os.getenv("PYPI_BUILD", "0") == "1" ...@@ -13,7 +13,7 @@ PYPI_BUILD = os.getenv("PYPI_BUILD", "0") == "1"
if not PYPI_BUILD: if not PYPI_BUILD:
try: try:
CUDA_VERSION = "".join(os.environ.get("CUDA_VERSION", torch.version.cuda).split("."))[:3] CUDA_VERSION = "".join(os.environ.get("CUDA_VERSION", torch.version.cuda).split("."))[:3]
AUTOAWQ_VERSION += f"cu+{CUDA_VERSION}" AUTOAWQ_VERSION += f"+cu{CUDA_VERSION}"
except Exception as ex: except Exception as ex:
raise RuntimeError("Your system must have an Nvidia GPU for installing AutoAWQ") raise RuntimeError("Your system must have an Nvidia GPU for installing AutoAWQ")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment