Unverified Commit 4a6b72c2 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[BugFix] Fix triton compile error in `kernel_unified_attention_2/3d` caused by...


[BugFix] Fix triton compile error in `kernel_unified_attention_2/3d` caused by attention sinks (#22368)
Signed-off-by: default avatarLucasWilkinson <lwilkinson@neuralmagic.com>
parent b4b9813b
...@@ -75,6 +75,7 @@ def kernel_unified_attention_2d( ...@@ -75,6 +75,7 @@ def kernel_unified_attention_2d(
USE_ALIBI_SLOPES: tl.constexpr, # bool USE_ALIBI_SLOPES: tl.constexpr, # bool
USE_QQ_BIAS: tl.constexpr, # bool USE_QQ_BIAS: tl.constexpr, # bool
USE_SOFTCAP: tl.constexpr, # bool USE_SOFTCAP: tl.constexpr, # bool
USE_SINKS: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int SLIDING_WINDOW: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int stride_k_cache_1: tl.int64, # int
...@@ -132,7 +133,7 @@ def kernel_unified_attention_2d( ...@@ -132,7 +133,7 @@ def kernel_unified_attention_2d(
block_table_offset = seq_idx * block_table_stride block_table_offset = seq_idx * block_table_stride
if sink_ptr is None: if not USE_SINKS:
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
else: else:
M = tl.load( M = tl.load(
...@@ -322,6 +323,7 @@ def kernel_unified_attention_3d( ...@@ -322,6 +323,7 @@ def kernel_unified_attention_3d(
USE_ALIBI_SLOPES: tl.constexpr, # bool USE_ALIBI_SLOPES: tl.constexpr, # bool
USE_QQ_BIAS: tl.constexpr, # bool USE_QQ_BIAS: tl.constexpr, # bool
USE_SOFTCAP: tl.constexpr, # bool USE_SOFTCAP: tl.constexpr, # bool
USE_SINKS: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int SLIDING_WINDOW: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int stride_k_cache_1: tl.int64, # int
...@@ -393,14 +395,17 @@ def kernel_unified_attention_3d( ...@@ -393,14 +395,17 @@ def kernel_unified_attention_3d(
block_table_offset = seq_idx * block_table_stride block_table_offset = seq_idx * block_table_stride
if sink_ptr is None or segm_idx != 0: if USE_SINKS:
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) if segm_idx == 0:
else:
M = tl.load( M = tl.load(
sink_ptr + query_offset_1, sink_ptr + query_offset_1,
mask=query_mask_1, mask=query_mask_1,
other=float("-inf"), other=float("-inf"),
).to(dtype=tl.float32) ).to(dtype=tl.float32)
else:
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
else:
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
...@@ -716,6 +721,7 @@ def unified_attention( ...@@ -716,6 +721,7 @@ def unified_attention(
USE_ALIBI_SLOPES=use_alibi_slopes, USE_ALIBI_SLOPES=use_alibi_slopes,
USE_QQ_BIAS=use_qq_bias, USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0), USE_SOFTCAP=(softcap > 0),
USE_SINKS=(sinks is not None),
SLIDING_WINDOW=(1 + window_size[0]), SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0), stride_k_cache_0=k.stride(0),
stride_k_cache_1=k.stride(1), stride_k_cache_1=k.stride(1),
...@@ -787,6 +793,7 @@ def unified_attention( ...@@ -787,6 +793,7 @@ def unified_attention(
USE_ALIBI_SLOPES=use_alibi_slopes, USE_ALIBI_SLOPES=use_alibi_slopes,
USE_QQ_BIAS=use_qq_bias, USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0), USE_SOFTCAP=(softcap > 0),
USE_SINKS=(sinks is not None),
SLIDING_WINDOW=(1 + window_size[0]), SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0), stride_k_cache_0=k.stride(0),
stride_k_cache_1=k.stride(1), stride_k_cache_1=k.stride(1),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment