"vscode:/vscode.git/clone" did not exist on "b2c8ce57c68db0764a49d66f048b8a7a5cef9d13"
Commit 4a0dfa15 authored by zhangshao's avatar zhangshao
Browse files

解决TORCHDYNAMO跟fp8 scale冲突的问题

parent bf93e83b
......@@ -290,7 +290,8 @@ class Attention(nn.Module, AttentionLayerBase):
# if attn_metadata.enable_kv_scales_calculation:
# self.calc_kv_scales(query, key, value)
torch.ops.vllm.maybe_calc_kv_scales(query, key, value,
self.layer_name)
self.layer_name)
self.check_fp8_overflow = False
output_dtype = query.dtype
if self.query_quant is not None:
......@@ -358,12 +359,6 @@ class Attention(nn.Module, AttentionLayerBase):
query, key, value, self.layer_name)
def calc_kv_scales(self, query, key, value):
self.check_fp8_overflow = False
if self.calculate_kv_scales == False :
if self.kv_cache_dtype in {"fp8", "fp8_e4m3"} and torch.abs(query).max().item()<=200 : #check fp8 overflow
return
if self.kv_cache_dtype in {"fp8_e5m2"} and torch.abs(query).max().item()>=0.01 : #check fp8 too small
return
bias=0.0 # add bias to avoid q values are too small(or zeros) and scales are not correct
if torch.abs(query).max().item() < 0.01:
if self.kv_cache_dtype in {"fp8_e5m2"}:
......@@ -601,7 +596,13 @@ def maybe_calc_kv_scales(
# Only calculate if the layer's calculate_kv_scales flag is True
# This flag gets set to False after the first forward pass
if self.check_fp8_overflow :
if self.kv_cache_dtype in {"fp8", "fp8_e4m3"} and torch.abs(query).max().item()>200 : #check fp8 overflow
self.calculate_kv_scales = True
if self.kv_cache_dtype in {"fp8_e5m2"} and torch.abs(query).max().item()<0.01 : #check fp8 too small
self.calculate_kv_scales = True
if not self.calculate_kv_scales:
return
self.calc_kv_scales(query, key, value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment