Commit 72501097 authored by xuxzh1's avatar xuxzh1 🎱
Browse files

update rocm.py

parent d4bccff3
......@@ -80,7 +80,6 @@ def paged_attention(
_PARTITION_SIZE = _PARTITION_SIZE_V1V2
else:
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
_PARTITION_SIZE = 512
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
......@@ -234,7 +233,7 @@ def attention(
softcap = 0.0
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return flash_attn_2_cuda.varlen_fwd(
return flash_attn_2_cuda.varlen_fwd(
query,
key,
value,
......@@ -257,7 +256,7 @@ def attention(
False,
None,
)[0]
elif ENGINE == "triton":
from .flash_attn_triton import triton_attention
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment