Commit 2fc9d1a2 authored by zhangshao's avatar zhangshao
Browse files

Merge branch 'fp8e5m2_bw' into 'v0.11.0-dev'

解决TORCHDYNAMO跟fp8 scale冲突的问题

See merge request dcutoolkit/deeplearing/vllm!458
parents 2830c329 4a0dfa15
......@@ -290,7 +290,8 @@ class Attention(nn.Module, AttentionLayerBase):
# if attn_metadata.enable_kv_scales_calculation:
# self.calc_kv_scales(query, key, value)
torch.ops.vllm.maybe_calc_kv_scales(query, key, value,
self.layer_name)
self.layer_name)
self.check_fp8_overflow = False
output_dtype = query.dtype
if self.query_quant is not None:
......@@ -358,12 +359,6 @@ class Attention(nn.Module, AttentionLayerBase):
query, key, value, self.layer_name)
def calc_kv_scales(self, query, key, value):
self.check_fp8_overflow = False
if self.calculate_kv_scales == False :
if self.kv_cache_dtype in {"fp8", "fp8_e4m3"} and torch.abs(query).max().item()<=200 : #check fp8 overflow
return
if self.kv_cache_dtype in {"fp8_e5m2"} and torch.abs(query).max().item()>=0.01 : #check fp8 too small
return
bias=0.0 # add bias to avoid q values are too small(or zeros) and scales are not correct
if torch.abs(query).max().item() < 0.01:
if self.kv_cache_dtype in {"fp8_e5m2"}:
......@@ -601,7 +596,13 @@ def maybe_calc_kv_scales(
# Only calculate if the layer's calculate_kv_scales flag is True
# This flag gets set to False after the first forward pass
if self.check_fp8_overflow :
if self.kv_cache_dtype in {"fp8", "fp8_e4m3"} and torch.abs(query).max().item()>200 : #check fp8 overflow
self.calculate_kv_scales = True
if self.kv_cache_dtype in {"fp8_e5m2"} and torch.abs(query).max().item()<0.01 : #check fp8 too small
self.calculate_kv_scales = True
if not self.calculate_kv_scales:
return
self.calc_kv_scales(query, key, value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment