flashpy_xformers-0.0.22.post7.rocm.patch 5.19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/flash.py	2023-11-29 03:17:03.930103539 +0000
+++ flash.py	2023-11-28 16:14:25.206128903 +0000
@@ -31,39 +31,39 @@
 
 FLASH_VERSION = "0.0.0"
 try:
-    try:
-        from ... import _C_flashattention  # type: ignore[attr-defined]
-        from ..._cpp_lib import _build_metadata
-
-        if _build_metadata is not None:
-            FLASH_VERSION = _build_metadata.flash_version
-    except ImportError:
-        import flash_attn
-        from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
-
-        FLASH_VERSION = flash_attn.__version__
-        flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
-        if flash_ver_parsed < (2, 3):
-            raise ImportError("Requires 2.3 for sliding window support")
+    #try:
+    #    from ... import _C_flashattention  # type: ignore[attr-defined]
+    #    from ..._cpp_lib import _build_metadata
+
+    #    if _build_metadata is not None:
+    #        FLASH_VERSION = _build_metadata.flash_version
+    #except ImportError:
+    import flash_attn
+    from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
+
+    FLASH_VERSION = flash_attn.__version__
+    #    flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
+    #    if flash_ver_parsed < (2, 3):
+    #        raise ImportError("Requires 2.3 for sliding window support")
 
     # create library so that flash-attn goes through the PyTorch Dispatcher
-    _flash_lib = torch.library.Library("xformers_flash", "DEF")
+    #_flash_lib = torch.library.Library("xformers_flash", "DEF")
 
-    _flash_lib.define(
-        "flash_fwd(Tensor query, Tensor key, Tensor value, "
-        "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
-        "int max_seqlen_q, int max_seqlen_k, "
-        "float p, float softmax_scale, "
-        "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
-    )
-
-    _flash_lib.define(
-        "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
-        "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
-        "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
-        "int max_seqlen_q, int max_seqlen_k, "
-        "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
-    )
+    #_flash_lib.define(
+    #    "flash_fwd(Tensor query, Tensor key, Tensor value, "
+    #    "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
+    #    "int max_seqlen_q, int max_seqlen_k, "
+    #    "float p, float softmax_scale, "
+    #    "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
+    #)
+
+    #_flash_lib.define(
+    #    "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
+    #    "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
+    #    "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
+    #    "int max_seqlen_q, int max_seqlen_k, "
+    #    "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
+    #)
 
     def _flash_fwd(
         query,
@@ -98,8 +98,8 @@
                 p,
                 softmax_scale,
                 is_causal,
-                window_size - 1,  # window_size_left
-                -1,  # window_size_right
+        #        window_size - 1,  # window_size_left
+        #        -1,  # window_size_right
                 return_softmax,
                 None,  # rng
             )
@@ -127,8 +127,8 @@
                 softmax_scale,
                 False,
                 is_causal,
-                window_size - 1,  # window_size_left
-                -1,  # window_size_right
+         #       window_size - 1,  # window_size_left
+         #       -1,  # window_size_right
                 return_softmax,
                 None,
             )
@@ -169,8 +169,8 @@
                 p,
                 softmax_scale,
                 is_causal,
-                window_size - 1,  # window_size_left
-                -1,  # window_size_right
+        #        window_size - 1,  # window_size_left
+        #        -1,  # window_size_right
                 None,
                 rng_state,
             )
@@ -193,15 +193,15 @@
                 softmax_scale,
                 False,  # zero_tensors
                 is_causal,
-                window_size - 1,  # window_size_left
-                -1,  # window_size_right
+        #        window_size - 1,  # window_size_left
+        #        -1,  # window_size_right
                 None,
                 rng_state,
             )
         return dq, dk, dv
 
-    _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
-    _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
+    #_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
+    #_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
 except ImportError:
     pass
 
@@ -348,7 +348,7 @@
         implementation.
     """
 
-    OPERATOR = get_operator("xformers_flash", "flash_fwd")
+    OPERATOR = _flash_fwd # get_operator("xformers_flash", "flash_fwd")
     SUPPORTED_DEVICES: Set[str] = {"cuda"}
     CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
     SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}