"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "fbe29c62984c33c6cf9cf7ad120a992fe6d20854"
Unverified Commit 215ed3ad authored by drbh's avatar drbh Committed by GitHub
Browse files

fix: attempt forward on flash attn2 to check hardware support (#2335)

* fix: attempt forward on flash attn2 to check hardware support

* fix: warn window_size_left when using flash attn 1

* fix: prefer version check over test op and avoid window_size_left if not flash attn2

* fix: improve condtional and error message

* fix: update sliding window conditional

* fix: simplify changes and revert model changes

* fix: avoid changing conditional

* fix: typo tweak
parent 47447ef0
...@@ -172,6 +172,10 @@ def paged_attention( ...@@ -172,6 +172,10 @@ def paged_attention(
try: try:
is_ampere_or_newer = major >= 8 and minor >= 0
if not is_ampere_or_newer:
raise ImportError("FlashAttention only supports Ampere GPUs or newer.")
import flash_attn_2_cuda import flash_attn_2_cuda
V2 = True V2 = True
......
...@@ -484,6 +484,9 @@ def get_model( ...@@ -484,6 +484,9 @@ def get_model(
) )
sliding_window = config_dict.get("sliding_window", -1) sliding_window = config_dict.get("sliding_window", -1)
if max_input_tokens is not None and max_input_tokens <= sliding_window:
sliding_window = -1
if ( if (
(sliding_window is not None and sliding_window != -1) (sliding_window is not None and sliding_window != -1)
and not SUPPORTS_WINDOWING and not SUPPORTS_WINDOWING
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment