Commit 72501097 authored by xuxzh1's avatar xuxzh1 🎱
Browse files

update rocm.py

parent d4bccff3
...@@ -80,7 +80,6 @@ def paged_attention( ...@@ -80,7 +80,6 @@ def paged_attention(
_PARTITION_SIZE = _PARTITION_SIZE_V1V2 _PARTITION_SIZE = _PARTITION_SIZE_V1V2
else: else:
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM _PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
_PARTITION_SIZE = 512
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = seqlen.input_lengths + seqlen.cache_lengths input_lengths = seqlen.input_lengths + seqlen.cache_lengths
...@@ -234,7 +233,7 @@ def attention( ...@@ -234,7 +233,7 @@ def attention(
softcap = 0.0 softcap = 0.0
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
query, query,
key, key,
value, value,
...@@ -257,7 +256,7 @@ def attention( ...@@ -257,7 +256,7 @@ def attention(
False, False,
None, None,
)[0] )[0]
elif ENGINE == "triton": elif ENGINE == "triton":
from .flash_attn_triton import triton_attention from .flash_attn_triton import triton_attention
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment