Commit c3b8a0ae authored by zhuwenwen's avatar zhuwenwen
Browse files

remove redundant maybe_calc_kv_scales

parent ad141d07
...@@ -432,43 +432,6 @@ def maybe_calc_kv_scales_fake( query: torch.Tensor, ...@@ -432,43 +432,6 @@ def maybe_calc_kv_scales_fake( query: torch.Tensor,
return return
direct_register_custom_op(
op_name="maybe_calc_kv_scales",
op_func=maybe_calc_kv_scales,
mutates_args=[],
fake_impl=maybe_calc_kv_scales_fake,
dispatch_key=current_platform.dispatch_key,
tags=tag_cudagraph_unsafe,)
def maybe_calc_kv_scales(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
# if attn_metadata is None or not getattr(
# attn_metadata, 'enable_kv_scales_calculation', False):
# return
self = forward_context.no_compile_layers[layer_name]
self.calc_kv_scales(query, key, value)
def maybe_calc_kv_scales_fake( query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op( direct_register_custom_op(
op_name="maybe_calc_kv_scales", op_name="maybe_calc_kv_scales",
op_func=maybe_calc_kv_scales, op_func=maybe_calc_kv_scales,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment