"vscode:/vscode.git/clone" did not exist on "1d50dfa018f15678b2a46afb663c379079b75f26"
Unverified Commit dc188132 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix perf regression on small batch sizes (#3008)

parent 10bfce71
...@@ -47,8 +47,8 @@ class RadixAttention(nn.Module): ...@@ -47,8 +47,8 @@ class RadixAttention(nn.Module):
self.logit_cap = logit_cap self.logit_cap = logit_cap
self.sliding_window_size = sliding_window_size or -1 self.sliding_window_size = sliding_window_size or -1
self.is_cross_attention = is_cross_attention self.is_cross_attention = is_cross_attention
self.k_scale = 1.0 self.k_scale = None
self.v_scale = 1.0 self.v_scale = None
def forward( def forward(
self, self,
......
...@@ -27,7 +27,7 @@ import logging ...@@ -27,7 +27,7 @@ import logging
import threading import threading
from enum import IntEnum from enum import IntEnum
from functools import wraps from functools import wraps
from typing import List, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import psutil import psutil
...@@ -270,13 +270,17 @@ class MHATokenToKVPool(BaseTokenToKVPool): ...@@ -270,13 +270,17 @@ class MHATokenToKVPool(BaseTokenToKVPool):
loc: torch.Tensor, loc: torch.Tensor,
cache_k: torch.Tensor, cache_k: torch.Tensor,
cache_v: torch.Tensor, cache_v: torch.Tensor,
k_scale: float = 1.0, k_scale: Optional[float] = None,
v_scale: float = 1.0, v_scale: Optional[float] = None,
): ):
layer_id = layer.layer_id layer_id = layer.layer_id
if cache_k.dtype != self.dtype: if cache_k.dtype != self.dtype:
cache_k = (cache_k / k_scale).to(self.dtype) if k_scale is not None:
cache_v = (cache_v / v_scale).to(self.dtype) cache_k.div_(k_scale)
if v_scale is not None:
cache_v.div_(v_scale)
cache_k = cache_k.to(self.dtype)
cache_v = cache_v.to(self.dtype)
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype) self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype) self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment