Unverified Commit dc188132 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix perf regression on small batch sizes (#3008)

parent 10bfce71
......@@ -47,8 +47,8 @@ class RadixAttention(nn.Module):
self.logit_cap = logit_cap
self.sliding_window_size = sliding_window_size or -1
self.is_cross_attention = is_cross_attention
self.k_scale = 1.0
self.v_scale = 1.0
self.k_scale = None
self.v_scale = None
def forward(
self,
......
......@@ -27,7 +27,7 @@ import logging
import threading
from enum import IntEnum
from functools import wraps
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union
import numpy as np
import psutil
......@@ -270,13 +270,17 @@ class MHATokenToKVPool(BaseTokenToKVPool):
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
k_scale: float = 1.0,
v_scale: float = 1.0,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
):
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
cache_k = (cache_k / k_scale).to(self.dtype)
cache_v = (cache_v / v_scale).to(self.dtype)
if k_scale is not None:
cache_k.div_(k_scale)
if v_scale is not None:
cache_v.div_(v_scale)
cache_k = cache_k.to(self.dtype)
cache_v = cache_v.to(self.dtype)
if self.store_dtype != self.dtype:
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment