"vscode:/vscode.git/clone" did not exist on "185347e411247ae9c6d8ace910dc3f876958bee1"
Unverified Commit 215ed3ad authored by drbh's avatar drbh Committed by GitHub
Browse files

fix: attempt forward on flash attn2 to check hardware support (#2335)

* fix: attempt forward on flash attn2 to check hardware support

* fix: warn window_size_left when using flash attn 1

* fix: prefer version check over test op and avoid window_size_left if not flash attn2

* fix: improve condtional and error message

* fix: update sliding window conditional

* fix: simplify changes and revert model changes

* fix: avoid changing conditional

* fix: typo tweak
parent 47447ef0
......@@ -172,6 +172,10 @@ def paged_attention(
try:
is_ampere_or_newer = major >= 8 and minor >= 0
if not is_ampere_or_newer:
raise ImportError("FlashAttention only supports Ampere GPUs or newer.")
import flash_attn_2_cuda
V2 = True
......
......@@ -484,6 +484,9 @@ def get_model(
)
sliding_window = config_dict.get("sliding_window", -1)
if max_input_tokens is not None and max_input_tokens <= sliding_window:
sliding_window = -1
if (
(sliding_window is not None and sliding_window != -1)
and not SUPPORTS_WINDOWING
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment