flashpy_xformers-0.0.23.rocm.patch 5.56 KB
Newer Older
1
2
3
--- flash_ori.py	2023-12-13 05:43:31.530752623 +0000
+++ flash_patch.py	2023-12-13 06:00:45.962403104 +0000
@@ -36,44 +36,44 @@
4
5
6
7
8
9
10
11
12
13
14
15
16
17
 
 FLASH_VERSION = "0.0.0"
 try:
-    try:
-        from ... import _C_flashattention  # type: ignore[attr-defined]
-        from ..._cpp_lib import _build_metadata
-
-        if _build_metadata is not None:
-            FLASH_VERSION = _build_metadata.flash_version
-    except ImportError:
-        import flash_attn
-        from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
-
-        FLASH_VERSION = flash_attn.__version__
18
19
20
21
22
23
-        flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
-        if (
-            flash_ver_parsed != (2, 3, 6)
-            and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
-        ):
-            raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
24
25
26
27
28
29
30
31
32
33
34
+    #try:
+    #    from ... import _C_flashattention  # type: ignore[attr-defined]
+    #    from ..._cpp_lib import _build_metadata
+
+    #    if _build_metadata is not None:
+    #        FLASH_VERSION = _build_metadata.flash_version
+    #except ImportError:
+    import flash_attn
+    from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
+
+    FLASH_VERSION = flash_attn.__version__
35
36
37
38
39
40
+    #    flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
+    #    if (
+    #        flash_ver_parsed != (2, 3, 6)
+    #        and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
+    #    ):
+    #        raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
41
42
43
 
     # create library so that flash-attn goes through the PyTorch Dispatcher
-    _flash_lib = torch.library.Library("xformers_flash", "DEF")
44
-
45
46
-    _flash_lib.define(
-        "flash_fwd(Tensor query, Tensor key, Tensor value, "
47
-        "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
48
49
-        "int max_seqlen_q, int max_seqlen_k, "
-        "float p, float softmax_scale, "
50
51
-        "bool is_causal, int window_left, "
-        "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
52
-    )
53
54
+    #_flash_lib = torch.library.Library("xformers_flash", "DEF")
 
55
56
57
58
59
-    _flash_lib.define(
-        "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
-        "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
-        "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
-        "int max_seqlen_q, int max_seqlen_k, "
60
61
-        "float p, float softmax_scale, bool is_causal, "
-        "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
62
63
64
-    )
+    #_flash_lib.define(
+    #    "flash_fwd(Tensor query, Tensor key, Tensor value, "
65
+    #    "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
66
67
+    #    "int max_seqlen_q, int max_seqlen_k, "
+    #    "float p, float softmax_scale, "
68
69
+    #    "bool is_causal, int window_left, "
+    #    "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
70
71
72
73
74
75
76
+    #)
+
+    #_flash_lib.define(
+    #    "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
+    #    "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
+    #    "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
+    #    "int max_seqlen_q, int max_seqlen_k, "
77
78
+    #    "float p, float softmax_scale, bool is_causal, "
+    #    "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
79
80
81
82
+    #)
 
     def _flash_fwd(
         query,
83
@@ -111,8 +111,8 @@
84
85
86
                 p,
                 softmax_scale,
                 is_causal,
87
88
89
90
-                window_left,  # window_size_left
-                window_right,  # window_size_right
+        #        window_left,  # window_size_left
+        #        window_right,  # window_size_right
91
92
93
                 return_softmax,
                 None,  # rng
             )
94
95
96
97
98
99
100
101
102
@@ -134,15 +134,15 @@
                 out,
                 cu_seq_lens_q,
                 cu_seq_lens_k,
-                seqused_k,
+         #       seqused_k,
                 max_seq_len_q,
                 max_seq_len_k,
                 p,
103
104
105
                 softmax_scale,
                 False,
                 is_causal,
106
107
108
109
-                window_left,
-                window_right,
+         #       window_left,
+         #       window_right,
110
111
112
                 return_softmax,
                 None,
             )
113
@@ -184,8 +184,8 @@
114
115
116
                 p,
                 softmax_scale,
                 is_causal,
117
118
119
120
-                window_left,
-                window_right,
+        #        window_left,
+        #        window_right,
121
122
123
                 None,
                 rng_state,
             )
124
@@ -208,15 +208,15 @@
125
126
127
                 softmax_scale,
                 False,  # zero_tensors
                 is_causal,
128
129
130
131
-                window_left,
-                window_right,
+        #        window_left,
+        #        window_right,
132
133
134
135
136
137
138
139
140
141
142
143
                 None,
                 rng_state,
             )
         return dq, dk, dv
 
-    _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
-    _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
+    #_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
+    #_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
 except ImportError:
     pass
 
144
@@ -400,7 +400,7 @@
145
146
147
148
149
150
151
152
         implementation.
     """
 
-    OPERATOR = get_operator("xformers_flash", "flash_fwd")
+    OPERATOR = _flash_fwd # get_operator("xformers_flash", "flash_fwd")
     SUPPORTED_DEVICES: Set[str] = {"cuda"}
     CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
     SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}