"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "fc37c192ae9d7244090107b0d5a6442440d194db"
Unverified Commit ad935933 authored by Birch-san's avatar Birch-san Committed by GitHub
Browse files

perf: prefer batched matmuls for attention (#1203)

perf: prefer batched matmuls for attention. added fast-path to Decoder when num_heads=1
parent 78a6eed2
...@@ -284,22 +284,52 @@ class AttentionBlock(nn.Module): ...@@ -284,22 +284,52 @@ class AttentionBlock(nn.Module):
key_proj = self.key(hidden_states) key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states) value_proj = self.value(hidden_states)
# transpose scale = 1 / math.sqrt(self.channels / self.num_heads)
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)
# get scores # get scores
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) if self.num_heads > 1:
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)
# TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors?
# or reformulate this into a 3D problem?
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * scale
else:
query_states, key_states, value_states = query_proj, key_proj, value_proj
attention_scores = torch.baddbmm(
torch.empty(
query_states.shape[0],
query_states.shape[1],
key_states.shape[1],
dtype=query_states.dtype,
device=query_states.device,
),
query_states,
key_states.transpose(-1, -2),
beta=0,
alpha=scale,
)
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
# compute attention output # compute attention output
hidden_states = torch.matmul(attention_probs, value_states) if self.num_heads > 1:
# TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors?
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() # or reformulate this into a 3D problem?
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) # TODO: measure whether on MPS device it would be faster to do this matmul via einsum
hidden_states = hidden_states.view(new_hidden_states_shape) # as some matmuls can be 1.94x slower than an equivalent einsum on MPS
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
hidden_states = torch.matmul(attention_probs, value_states)
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
hidden_states = hidden_states.view(new_hidden_states_shape)
else:
hidden_states = torch.bmm(attention_probs, value_states)
# compute next hidden_states # compute next hidden_states
hidden_states = self.proj_attn(hidden_states) hidden_states = self.proj_attn(hidden_states)
...@@ -507,19 +537,17 @@ class CrossAttention(nn.Module): ...@@ -507,19 +537,17 @@ class CrossAttention(nn.Module):
return hidden_states return hidden_states
def _attention(self, query, key, value): def _attention(self, query, key, value):
# TODO: use baddbmm for better performance attention_scores = torch.baddbmm(
if query.device.type == "mps": torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
# Better performance on mps (~20-25%) query,
attention_scores = torch.einsum("b i d, b j d -> b i j", query, key) * self.scale key.transpose(-1, -2),
else: beta=0,
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale alpha=self.scale,
)
attention_probs = attention_scores.softmax(dim=-1) attention_probs = attention_scores.softmax(dim=-1)
# compute attention output # compute attention output
if query.device.type == "mps": hidden_states = torch.bmm(attention_probs, value)
hidden_states = torch.einsum("b i j, b j d -> b i d", attention_probs, value)
else:
hidden_states = torch.matmul(attention_probs, value)
# reshape hidden_states # reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states) hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
...@@ -534,21 +562,15 @@ class CrossAttention(nn.Module): ...@@ -534,21 +562,15 @@ class CrossAttention(nn.Module):
for i in range(hidden_states.shape[0] // slice_size): for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size start_idx = i * slice_size
end_idx = (i + 1) * slice_size end_idx = (i + 1) * slice_size
if query.device.type == "mps": attn_slice = torch.baddbmm(
# Better performance on mps (~20-25%) torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
attn_slice = ( query[start_idx:end_idx],
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) key[start_idx:end_idx].transpose(-1, -2),
* self.scale beta=0,
) alpha=self.scale,
else: )
attn_slice = (
torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
) # TODO: use baddbmm for better performance
attn_slice = attn_slice.softmax(dim=-1) attn_slice = attn_slice.softmax(dim=-1)
if query.device.type == "mps": attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
else:
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice hidden_states[start_idx:end_idx] = attn_slice
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment