"tests/git@developer.sourcefind.cn:OpenDAS/lmdeploy.git" did not exist on "d4d609bd8f513377bd9745974d350960ec78b087"
Commit 125b03ee authored by comfyanonymous's avatar comfyanonymous
Browse files

Fix some OOM issues with split attention.

parent 41b07ff8
...@@ -229,7 +229,7 @@ def attention_split(q, k, v, heads, mask=None): ...@@ -229,7 +229,7 @@ def attention_split(q, k, v, heads, mask=None):
gb = 1024 ** 3 gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
modifier = 3 if element_size == 2 else 2.5 modifier = 3
mem_required = tensor_size * modifier mem_required = tensor_size * modifier
steps = 1 steps = 1
...@@ -257,10 +257,10 @@ def attention_split(q, k, v, heads, mask=None): ...@@ -257,10 +257,10 @@ def attention_split(q, k, v, heads, mask=None):
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
else: else:
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
first_op_done = True
s2 = s1.softmax(dim=-1).to(v.dtype) s2 = s1.softmax(dim=-1).to(v.dtype)
del s1 del s1
first_op_done = True
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2 del s2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment