Commit 18459e7a authored by zhangshao's avatar zhangshao
Browse files

Merge branch '015-fp8-kvscale' into 'v0.15.1-dev'

优化015 fp8 kvscale

See merge request dcutoolkit/deeplearing/vllm!473
parents cca00f5c 2bbe4385
...@@ -210,6 +210,10 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -210,6 +210,10 @@ class Attention(nn.Module, AttentionLayerBase):
) )
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales self.calculate_kv_scales = calculate_kv_scales
if self.kv_cache_dtype in {"fp8", "fp8_e4m3","fp8_e5m2"} :
self.check_fp8_overflow = True
else:
self.check_fp8_overflow = False
if num_kv_heads is None: if num_kv_heads is None:
num_kv_heads = num_heads num_kv_heads = num_heads
assert num_heads % num_kv_heads == 0, ( assert num_heads % num_kv_heads == 0, (
...@@ -359,8 +363,9 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -359,8 +363,9 @@ class Attention(nn.Module, AttentionLayerBase):
context using context using
`vllm.forward_context.get_forward_context().attn_metadata`. `vllm.forward_context.get_forward_context().attn_metadata`.
""" """
if self.calculate_kv_scales: if self.calculate_kv_scales or self.check_fp8_overflow:
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name) torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
self.check_fp8_overflow = False
output_dtype = query.dtype output_dtype = query.dtype
if self.query_quant is not None: if self.query_quant is not None:
# quantizing with a simple torch operation enables # quantizing with a simple torch operation enables
...@@ -437,9 +442,16 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -437,9 +442,16 @@ class Attention(nn.Module, AttentionLayerBase):
) )
def calc_kv_scales(self, query, key, value): def calc_kv_scales(self, query, key, value):
self._q_scale.copy_(torch.abs(query).max() / self.q_range) bias=0.0 # add bias to avoid q values are too small(or zeros) and scales are not correct
self._k_scale.copy_(torch.abs(key).max() / self.k_range) if torch.abs(query).max().item() < 0.01:
self._v_scale.copy_(torch.abs(value).max() / self.v_range) if self.kv_cache_dtype in {"fp8_e5m2"}:
bias = 0.1
else :
bias = 1.0
self._q_scale.copy_(torch.abs(query).max() / self.q_range+bias)
self._k_scale.copy_(torch.abs(key).max() / self.k_range+bias)
self._v_scale.copy_(torch.abs(value).max() / self.v_range+bias)
self._q_scale_float = self._q_scale.item() self._q_scale_float = self._q_scale.item()
self._k_scale_float = self._k_scale.item() self._k_scale_float = self._k_scale.item()
self._v_scale_float = self._v_scale.item() self._v_scale_float = self._v_scale.item()
...@@ -833,6 +845,11 @@ def maybe_calc_kv_scales( ...@@ -833,6 +845,11 @@ def maybe_calc_kv_scales(
# Only calculate if the layer's calculate_kv_scales flag is True # Only calculate if the layer's calculate_kv_scales flag is True
# This flag gets set to False after the first forward pass # This flag gets set to False after the first forward pass
if self.check_fp8_overflow :
if self.kv_cache_dtype in {"fp8", "fp8_e4m3"} and torch.abs(query).max().item()>200 : #check fp8 overflow
self.calculate_kv_scales = True
if self.kv_cache_dtype in {"fp8_e5m2"} and torch.abs(query).max().item()<0.01 : #check fp8 too small
self.calculate_kv_scales = True
if not self.calculate_kv_scales: if not self.calculate_kv_scales:
return return
......
...@@ -132,9 +132,9 @@ if TYPE_CHECKING: ...@@ -132,9 +132,9 @@ if TYPE_CHECKING:
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False VLLM_DISABLE_COMPILE_CACHE: bool = False
Q_SCALE_CONSTANT: int = 200 Q_SCALE_CONSTANT: int = 10
K_SCALE_CONSTANT: int = 200 K_SCALE_CONSTANT: int = 10
V_SCALE_CONSTANT: int = 100 V_SCALE_CONSTANT: int = 10
VLLM_SERVER_DEV_MODE: bool = False VLLM_SERVER_DEV_MODE: bool = False
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128 VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
VLLM_MLA_DISABLE: bool = False VLLM_MLA_DISABLE: bool = False
...@@ -1115,11 +1115,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1115,11 +1115,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
os.environ.get("VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", None) os.environ.get("VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", None)
), ),
# Divisor for dynamic query scale factor calculation for FP8 KV Cache # Divisor for dynamic query scale factor calculation for FP8 KV Cache
"Q_SCALE_CONSTANT": lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")), "Q_SCALE_CONSTANT": lambda: int(os.getenv("Q_SCALE_CONSTANT", "10")),
# Divisor for dynamic key scale factor calculation for FP8 KV Cache # Divisor for dynamic key scale factor calculation for FP8 KV Cache
"K_SCALE_CONSTANT": lambda: int(os.getenv("K_SCALE_CONSTANT", "200")), "K_SCALE_CONSTANT": lambda: int(os.getenv("K_SCALE_CONSTANT", "10")),
# Divisor for dynamic value scale factor calculation for FP8 KV Cache # Divisor for dynamic value scale factor calculation for FP8 KV Cache
"V_SCALE_CONSTANT": lambda: int(os.getenv("V_SCALE_CONSTANT", "100")), "V_SCALE_CONSTANT": lambda: int(os.getenv("V_SCALE_CONSTANT", "10")),
# If set, enable multiprocessing in LLM for the V1 code path. # If set, enable multiprocessing in LLM for the V1 code path.
"VLLM_ENABLE_V1_MULTIPROCESSING": lambda: bool( "VLLM_ENABLE_V1_MULTIPROCESSING": lambda: bool(
int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1")) int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment