Commit 6b8cb7db authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Merge internal changes

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/682

Differential Revision: D15147735

Pulled By: myleott

fbshipit-source-id: 4a5f12c0b24591f964fe1f465be3775a67578e79
parent f5e52c19
......@@ -85,7 +85,7 @@ class IndexedDataset(torch.utils.data.Dataset):
if not self.data_file:
self.read_data(self.path)
self.check_index(i)
tensor_size = int(self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]])
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
......
......@@ -19,20 +19,31 @@ class MultiheadAttention(nn.Module):
See "Attention Is All You Need" for more details.
"""
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False):
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
if self.qkv_same_dim:
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
else:
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
if bias:
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
else:
self.register_parameter('in_proj_bias', None)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
if add_bias_kv:
......@@ -51,7 +62,13 @@ class MultiheadAttention(nn.Module):
self.onnx_trace = True
def reset_parameters(self):
nn.init.xavier_uniform_(self.in_proj_weight)
if self.qkv_same_dim:
nn.init.xavier_uniform_(self.in_proj_weight)
else:
nn.init.xavier_uniform_(self.k_proj_weight)
nn.init.xavier_uniform_(self.v_proj_weight)
nn.init.xavier_uniform_(self.q_proj_weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.in_proj_bias is not None:
nn.init.constant_(self.in_proj_bias, 0.)
......@@ -78,7 +95,6 @@ class MultiheadAttention(nn.Module):
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
assert key.size() == value.size()
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
......@@ -101,7 +117,9 @@ class MultiheadAttention(nn.Module):
assert value is None
k = v = None
else:
k, v = self.in_proj_kv(key)
k = self.in_proj_k(key)
v = self.in_proj_v(key)
else:
q = self.in_proj_q(query)
k = self.in_proj_k(key)
......@@ -216,17 +234,34 @@ class MultiheadAttention(nn.Module):
def in_proj_qkv(self, query):
return self._in_proj(query).chunk(3, dim=-1)
def in_proj_kv(self, key):
return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)
def in_proj_q(self, query):
return self._in_proj(query, end=self.embed_dim)
if self.qkv_same_dim:
return self._in_proj(query, end=self.embed_dim)
else:
bias = self.in_proj_bias
if bias is not None:
bias = bias[:self.embed_dim]
return F.linear(query, self.q_proj_weight, bias)
def in_proj_k(self, key):
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
if self.qkv_same_dim:
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
else:
weight = self.k_proj_weight
bias = self.in_proj_bias
if bias is not None:
bias = bias[self.embed_dim:2 * self.embed_dim]
return F.linear(key, weight, bias)
def in_proj_v(self, value):
return self._in_proj(value, start=2 * self.embed_dim)
if self.qkv_same_dim:
return self._in_proj(value, start=2 * self.embed_dim)
else:
weight = self.v_proj_weight
bias = self.in_proj_bias
if bias is not None:
bias = bias[2 * self.embed_dim:]
return F.linear(value, weight, bias)
def _in_proj(self, input, start=0, end=None):
weight = self.in_proj_weight
......
......@@ -5,7 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import defaultdict, OrderedDict
from collections import defaultdict
from typing import Callable
import copy
import importlib.util
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment