Commit 94cc4bd9 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Use scaled_dot_product_attention in Wav2vec2/HuBERT's SelfAttention (#3253)

Summary:
Replace the attention computation with `torch.nn.functional.scaled_dot_product_attention` to improve running efficiency.

Pull Request resolved: https://github.com/pytorch/audio/pull/3253

Reviewed By: mthrok

Differential Revision: D44800353

Pulled By: nateanl

fbshipit-source-id: 41550d868c809099aadbe812b0ebe2c38121efb8
parent 5a5b0fc3
...@@ -262,7 +262,7 @@ class SelfAttention(Module): ...@@ -262,7 +262,7 @@ class SelfAttention(Module):
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = torch.nn.Dropout(dropout) self.dropout = dropout
self.head_dim = head_dim self.head_dim = head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
...@@ -304,25 +304,14 @@ class SelfAttention(Module): ...@@ -304,25 +304,14 @@ class SelfAttention(Module):
shape = (batch_size, length, self.num_heads, self.head_dim) shape = (batch_size, length, self.num_heads, self.head_dim)
q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
k = self.k_proj(x).view(*shape).permute(0, 2, 3, 1) # B, nH, Hd, L k = self.k_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
v = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd v = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
dropout = self.dropout if self.training else 0.0
# scale down q to avoid value overflow. attn_output = torch.nn.functional.scaled_dot_product_attention(
weights = (self.scaling * q) @ k # B, nH, L, L q, k, v, attn_mask=attention_mask, dropout_p=dropout, is_causal=False
if attention_mask is not None: )
weights += attention_mask attn_output = attn_output.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
# subtracting a constant value from the tensor won't change the output of softmax. output = self.out_proj(attn_output)
# apply the subtraction to avoid value overflow in torch.nn.functional.softmax.
# for more details, please see Equation 7 in https://arxiv.org/abs/2112.08778
weights = weights - weights.max(dim=-1, keepdim=True)[0]
weights = torch.nn.functional.softmax(weights, dim=-1)
weights = self.dropout(weights)
output = weights @ v # B, nH, L, Hd
output = output.transpose(2, 1).reshape(batch_size, length, embed_dim)
output = self.out_proj(output)
return output, None # Necessary for compatibility with WavLMSelAttention return output, None # Necessary for compatibility with WavLMSelAttention
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment