Commit c462f3a0 authored by wanghl6's avatar wanghl6
Browse files

[FIX]减少mqa_logits显存占用

parent 3c0e74be
...@@ -51,13 +51,7 @@ def sparse_attn_indexer( ...@@ -51,13 +51,7 @@ def sparse_attn_indexer(
# careful! this will be None in dummy run # careful! this will be None in dummy run
attn_metadata = get_forward_context().attn_metadata attn_metadata = get_forward_context().attn_metadata
fp8_dtype = current_platform.fp8_dtype() fp8_dtype = current_platform.fp8_dtype()
if q_fp8.dtype == fp8_dtype: MAX_ELEMENTS = 16384 * 16384
MAX_ELEMENTS = 65536 * 65536
elif q_fp8.dtype in (torch.bfloat16, torch.float16):
MAX_ELEMENTS = 16384 * 32768
else:
MAX_ELEMENTS = 16384 * 32768
device = q_fp8.device device = q_fp8.device
if device not in _GLOBAL_LOGITS_BUFFERS or _GLOBAL_LOGITS_BUFFERS[device].numel() < MAX_ELEMENTS: if device not in _GLOBAL_LOGITS_BUFFERS or _GLOBAL_LOGITS_BUFFERS[device].numel() < MAX_ELEMENTS:
_GLOBAL_LOGITS_BUFFERS[device] = torch.empty( _GLOBAL_LOGITS_BUFFERS[device] = torch.empty(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment