Commit 6b8cb7db authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Merge internal changes

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/682

Differential Revision: D15147735

Pulled By: myleott

fbshipit-source-id: 4a5f12c0b24591f964fe1f465be3775a67578e79
parent f5e52c19
...@@ -85,7 +85,7 @@ class IndexedDataset(torch.utils.data.Dataset): ...@@ -85,7 +85,7 @@ class IndexedDataset(torch.utils.data.Dataset):
if not self.data_file: if not self.data_file:
self.read_data(self.path) self.read_data(self.path)
self.check_index(i) self.check_index(i)
tensor_size = int(self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]) tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype) a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a) self.data_file.readinto(a)
......
...@@ -19,20 +19,31 @@ class MultiheadAttention(nn.Module): ...@@ -19,20 +19,31 @@ class MultiheadAttention(nn.Module):
See "Attention Is All You Need" for more details. See "Attention Is All You Need" for more details.
""" """
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False): def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5 self.scaling = self.head_dim ** -0.5
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim)) if self.qkv_same_dim:
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
else:
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
if bias: if bias:
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim)) self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
else: else:
self.register_parameter('in_proj_bias', None) self.register_parameter('in_proj_bias', None)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
if add_bias_kv: if add_bias_kv:
...@@ -51,7 +62,13 @@ class MultiheadAttention(nn.Module): ...@@ -51,7 +62,13 @@ class MultiheadAttention(nn.Module):
self.onnx_trace = True self.onnx_trace = True
def reset_parameters(self): def reset_parameters(self):
nn.init.xavier_uniform_(self.in_proj_weight) if self.qkv_same_dim:
nn.init.xavier_uniform_(self.in_proj_weight)
else:
nn.init.xavier_uniform_(self.k_proj_weight)
nn.init.xavier_uniform_(self.v_proj_weight)
nn.init.xavier_uniform_(self.q_proj_weight)
nn.init.xavier_uniform_(self.out_proj.weight) nn.init.xavier_uniform_(self.out_proj.weight)
if self.in_proj_bias is not None: if self.in_proj_bias is not None:
nn.init.constant_(self.in_proj_bias, 0.) nn.init.constant_(self.in_proj_bias, 0.)
...@@ -78,7 +95,6 @@ class MultiheadAttention(nn.Module): ...@@ -78,7 +95,6 @@ class MultiheadAttention(nn.Module):
tgt_len, bsz, embed_dim = query.size() tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim] assert list(query.size()) == [tgt_len, bsz, embed_dim]
assert key.size() == value.size()
if incremental_state is not None: if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state) saved_state = self._get_input_buffer(incremental_state)
...@@ -101,7 +117,9 @@ class MultiheadAttention(nn.Module): ...@@ -101,7 +117,9 @@ class MultiheadAttention(nn.Module):
assert value is None assert value is None
k = v = None k = v = None
else: else:
k, v = self.in_proj_kv(key) k = self.in_proj_k(key)
v = self.in_proj_v(key)
else: else:
q = self.in_proj_q(query) q = self.in_proj_q(query)
k = self.in_proj_k(key) k = self.in_proj_k(key)
...@@ -216,17 +234,34 @@ class MultiheadAttention(nn.Module): ...@@ -216,17 +234,34 @@ class MultiheadAttention(nn.Module):
def in_proj_qkv(self, query): def in_proj_qkv(self, query):
return self._in_proj(query).chunk(3, dim=-1) return self._in_proj(query).chunk(3, dim=-1)
def in_proj_kv(self, key):
return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)
def in_proj_q(self, query): def in_proj_q(self, query):
return self._in_proj(query, end=self.embed_dim) if self.qkv_same_dim:
return self._in_proj(query, end=self.embed_dim)
else:
bias = self.in_proj_bias
if bias is not None:
bias = bias[:self.embed_dim]
return F.linear(query, self.q_proj_weight, bias)
def in_proj_k(self, key): def in_proj_k(self, key):
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim) if self.qkv_same_dim:
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
else:
weight = self.k_proj_weight
bias = self.in_proj_bias
if bias is not None:
bias = bias[self.embed_dim:2 * self.embed_dim]
return F.linear(key, weight, bias)
def in_proj_v(self, value): def in_proj_v(self, value):
return self._in_proj(value, start=2 * self.embed_dim) if self.qkv_same_dim:
return self._in_proj(value, start=2 * self.embed_dim)
else:
weight = self.v_proj_weight
bias = self.in_proj_bias
if bias is not None:
bias = bias[2 * self.embed_dim:]
return F.linear(value, weight, bias)
def _in_proj(self, input, start=0, end=None): def _in_proj(self, input, start=0, end=None):
weight = self.in_proj_weight weight = self.in_proj_weight
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from collections import defaultdict, OrderedDict from collections import defaultdict
from typing import Callable from typing import Callable
import copy import copy
import importlib.util import importlib.util
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment