Unverified Commit 5990743f authored by Ali Hassani's avatar Ali Hassani Committed by GitHub
Browse files

Correct NATTEN function signatures and force new version (#22298)

parent d35f7296
......@@ -129,7 +129,7 @@ _deps = [
"keras-nlp>=0.3.1",
"librosa",
"nltk",
"natten>=0.14.5",
"natten>=0.14.6",
"numpy>=1.17",
"onnxconverter-common",
"onnxruntime-tools>=1.4.2",
......
......@@ -35,7 +35,7 @@ deps = {
"keras-nlp": "keras-nlp>=0.3.1",
"librosa": "librosa",
"nltk": "nltk",
"natten": "natten>=0.14.5",
"natten": "natten>=0.14.6",
"numpy": "numpy>=1.17",
"onnxconverter-common": "onnxconverter-common",
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
......
......@@ -356,7 +356,7 @@ class NeighborhoodAttention(nn.Module):
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
context_layer = natten2dav(attention_probs, value_layer, self.dilation)
context_layer = natten2dav(attention_probs, value_layer, self.kernel_size, self.dilation)
context_layer = context_layer.permute(0, 2, 3, 1, 4).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
......
......@@ -348,7 +348,7 @@ class NeighborhoodAttention(nn.Module):
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
context_layer = natten2dav(attention_probs, value_layer, 1)
context_layer = natten2dav(attention_probs, value_layer, self.kernel_size, 1)
context_layer = context_layer.permute(0, 2, 3, 1, 4).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment