Commit 535ca991 authored by Myle Ott's avatar Myle Ott
Browse files

Merge internal changes

parent 28069cf4
......@@ -9,8 +9,7 @@ from collections import namedtuple
import pickle
import torch
from torch import distributed, nn
from torch.distributed import group
from torch import nn
from fairseq import utils
......@@ -33,6 +32,16 @@ else:
c10d_status = C10dStatus(has_c10d=False, is_default=False)
if c10d_status.is_default:
import torch.distributed as dist_c10d
import torch.distributed.deprecated as dist_no_c10d
elif c10d_status.has_c10d:
import torch.distributed.c10d as dist_c10d
import torch.distributed as dist_no_c10d
else:
import torch.distributed as dist_no_c10d
def distributed_init(args):
if args.distributed_world_size == 1:
raise ValueError('Cannot initialize distributed with distributed_world_size=1')
......@@ -44,15 +53,9 @@ def distributed_init(args):
args.distributed_rank, args.distributed_init_method), flush=True)
if _use_c10d[0]:
if c10d_status.is_default:
init_fn = distributed.init_process_group
else:
init_fn = distributed.c10d.init_process_group
init_fn = dist_c10d.init_process_group
else:
if c10d_status.is_default:
init_fn = distributed.deprecated.init_process_group
else:
init_fn = distributed.init_process_group
init_fn = dist_no_c10d.init_process_group
init_fn(
backend=args.distributed_backend,
......@@ -83,32 +86,32 @@ def suppress_output():
def get_rank():
if _use_c10d[0]:
return distributed.c10d.get_rank()
return dist_c10d.get_rank()
else:
return distributed.get_rank()
return dist_no_c10d.get_rank()
def get_world_size():
if _use_c10d[0]:
return distributed.c10d.get_world_size()
return dist_c10d.get_world_size()
else:
return distributed.get_world_size()
return dist_no_c10d.get_world_size()
def get_default_group():
if _use_c10d[0]:
return distributed.c10d.group.WORLD
return dist_c10d.group.WORLD
else:
return distributed.group.WORLD
return dist_no_c10d.group.WORLD
def all_reduce(tensor, group=None):
if group is None:
group = get_default_group()
if _use_c10d[0]:
return distributed.c10d.all_reduce(tensor, group=group)
return dist_c10d.all_reduce(tensor, group=group)
else:
return distributed.all_reduce(tensor, group=group)
return dist_no_c10d.all_reduce(tensor, group=group)
def all_gather_list(data, group=None, max_size=16384):
......
......@@ -627,7 +627,13 @@ class TransformerDecoderLayer(nn.Module):
self.final_layer_norm = LayerNorm(self.embed_dim)
self.need_attn = True
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state, self_attn_mask=None,
self.onnx_trace = False
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state,
prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None,
self_attn_padding_mask=None):
"""
Args:
......@@ -640,6 +646,12 @@ class TransformerDecoderLayer(nn.Module):
"""
residual = x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
if prev_self_attn_state is not None:
if incremental_state is None:
incremental_state = {}
prev_key, prev_value = prev_self_attn_state
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
self.self_attn._set_input_buffer(incremental_state, saved_state)
x, _ = self.self_attn(
query=x,
key=x,
......@@ -657,6 +669,12 @@ class TransformerDecoderLayer(nn.Module):
if self.encoder_attn is not None:
residual = x
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
if prev_attn_state is not None:
if incremental_state is None:
incremental_state = {}
prev_key, prev_value = prev_attn_state
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
......@@ -678,6 +696,10 @@ class TransformerDecoderLayer(nn.Module):
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
if self.onnx_trace:
saved_state = self.self_attn._get_input_buffer(incremental_state)
self_attn_state = saved_state["prev_key"], saved_state["prev_value"]
return x, attn, self_attn_state
return x, attn
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
......
......@@ -45,6 +45,11 @@ class MultiheadAttention(nn.Module):
self.reset_parameters()
self.onnx_trace = False
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def reset_parameters(self):
nn.init.xavier_uniform_(self.in_proj_weight)
nn.init.xavier_uniform_(self.out_proj.weight)
......@@ -94,9 +99,7 @@ class MultiheadAttention(nn.Module):
q = self.in_proj_q(query)
if key is None:
assert value is None
# this will allow us to concat it with previous value and get
# just get the previous value
k = v = q.new(0)
k = v = None
else:
k, v = self.in_proj_kv(key)
else:
......@@ -106,12 +109,20 @@ class MultiheadAttention(nn.Module):
q *= self.scaling
if saved_state is not None:
if 'prev_key' in saved_state:
k = torch.cat((saved_state['prev_key'], k), dim=0)
if static_kv:
k = saved_state['prev_key']
else:
k = torch.cat((saved_state['prev_key'], k), dim=0)
if 'prev_value' in saved_state:
v = torch.cat((saved_state['prev_value'], v), dim=0)
if static_kv:
v = saved_state['prev_value']
else:
v = torch.cat((saved_state['prev_value'], v), dim=0)
saved_state['prev_key'] = k
saved_state['prev_value'] = v
self._set_input_buffer(incremental_state, saved_state)
if self.bias_k is not None:
......
......@@ -9,6 +9,7 @@ import math
import torch
import torch.nn as nn
import torch.onnx.operators
from fairseq import utils
......@@ -55,12 +56,12 @@ class SinusoidalPositionalEmbedding(nn.Module):
emb[padding_idx, :] = 0
return emb
def forward(self, input, incremental_state=None):
def forward(self, input, incremental_state=None, timestep=None):
"""Input is expected to be of size [bsz x seqlen]."""
# recompute/expand embeddings if needed
bsz, seq_len = input.size()
bsz, seq_len = torch.onnx.operators.shape_as_tensor(input)
max_pos = self.padding_idx + 1 + seq_len
if self.weights is None or max_pos > self.weights.size(0):
# recompute/expand embeddings if needed
self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos,
self.embedding_dim,
......@@ -70,12 +71,13 @@ class SinusoidalPositionalEmbedding(nn.Module):
if incremental_state is not None:
# positions is the same for every token when decoding a single step
return self.weights[self.padding_idx + seq_len, :].expand(bsz, 1, -1)
pos = (timestep.int() + 1).long() if timestep is not None else seq_len
if self.onnx_trace:
return self.weights[self.padding_idx + pos, :].unsqueeze(1).repeat(bsz, 1, 1)
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
positions = utils.make_positions(input, self.padding_idx, self.left_pad, self.onnx_trace)
if self.onnx_trace:
bsz = torch.onnx.operators.shape_as_tensor(input)[0]
seq_len = torch.onnx.operators.shape_as_tensor(input)[1]
flat_embeddings = self.weights.detach().index_select(0, positions.view(-1))
embedding_shape = torch.cat((bsz.view(1), seq_len.view(1), torch.LongTensor([-1])))
embeddings = torch.onnx.operators.reshape_from_tensor_shape(flat_embeddings, embedding_shape)
......
......@@ -11,7 +11,7 @@ import os, re
import torch
from multiprocessing import Pool
SPACE_NORMALIZER = re.compile("\s+")
SPACE_NORMALIZER = re.compile(r"\s+")
def tokenize_line(line):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment