Commit 535ca991 authored by Myle Ott's avatar Myle Ott
Browse files

Merge internal changes

parent 28069cf4
...@@ -9,8 +9,7 @@ from collections import namedtuple ...@@ -9,8 +9,7 @@ from collections import namedtuple
import pickle import pickle
import torch import torch
from torch import distributed, nn from torch import nn
from torch.distributed import group
from fairseq import utils from fairseq import utils
...@@ -33,6 +32,16 @@ else: ...@@ -33,6 +32,16 @@ else:
c10d_status = C10dStatus(has_c10d=False, is_default=False) c10d_status = C10dStatus(has_c10d=False, is_default=False)
if c10d_status.is_default:
import torch.distributed as dist_c10d
import torch.distributed.deprecated as dist_no_c10d
elif c10d_status.has_c10d:
import torch.distributed.c10d as dist_c10d
import torch.distributed as dist_no_c10d
else:
import torch.distributed as dist_no_c10d
def distributed_init(args): def distributed_init(args):
if args.distributed_world_size == 1: if args.distributed_world_size == 1:
raise ValueError('Cannot initialize distributed with distributed_world_size=1') raise ValueError('Cannot initialize distributed with distributed_world_size=1')
...@@ -44,15 +53,9 @@ def distributed_init(args): ...@@ -44,15 +53,9 @@ def distributed_init(args):
args.distributed_rank, args.distributed_init_method), flush=True) args.distributed_rank, args.distributed_init_method), flush=True)
if _use_c10d[0]: if _use_c10d[0]:
if c10d_status.is_default: init_fn = dist_c10d.init_process_group
init_fn = distributed.init_process_group
else:
init_fn = distributed.c10d.init_process_group
else:
if c10d_status.is_default:
init_fn = distributed.deprecated.init_process_group
else: else:
init_fn = distributed.init_process_group init_fn = dist_no_c10d.init_process_group
init_fn( init_fn(
backend=args.distributed_backend, backend=args.distributed_backend,
...@@ -83,32 +86,32 @@ def suppress_output(): ...@@ -83,32 +86,32 @@ def suppress_output():
def get_rank(): def get_rank():
if _use_c10d[0]: if _use_c10d[0]:
return distributed.c10d.get_rank() return dist_c10d.get_rank()
else: else:
return distributed.get_rank() return dist_no_c10d.get_rank()
def get_world_size(): def get_world_size():
if _use_c10d[0]: if _use_c10d[0]:
return distributed.c10d.get_world_size() return dist_c10d.get_world_size()
else: else:
return distributed.get_world_size() return dist_no_c10d.get_world_size()
def get_default_group(): def get_default_group():
if _use_c10d[0]: if _use_c10d[0]:
return distributed.c10d.group.WORLD return dist_c10d.group.WORLD
else: else:
return distributed.group.WORLD return dist_no_c10d.group.WORLD
def all_reduce(tensor, group=None): def all_reduce(tensor, group=None):
if group is None: if group is None:
group = get_default_group() group = get_default_group()
if _use_c10d[0]: if _use_c10d[0]:
return distributed.c10d.all_reduce(tensor, group=group) return dist_c10d.all_reduce(tensor, group=group)
else: else:
return distributed.all_reduce(tensor, group=group) return dist_no_c10d.all_reduce(tensor, group=group)
def all_gather_list(data, group=None, max_size=16384): def all_gather_list(data, group=None, max_size=16384):
......
...@@ -627,7 +627,13 @@ class TransformerDecoderLayer(nn.Module): ...@@ -627,7 +627,13 @@ class TransformerDecoderLayer(nn.Module):
self.final_layer_norm = LayerNorm(self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim)
self.need_attn = True self.need_attn = True
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state, self_attn_mask=None, self.onnx_trace = False
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state,
prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None,
self_attn_padding_mask=None): self_attn_padding_mask=None):
""" """
Args: Args:
...@@ -640,6 +646,12 @@ class TransformerDecoderLayer(nn.Module): ...@@ -640,6 +646,12 @@ class TransformerDecoderLayer(nn.Module):
""" """
residual = x residual = x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
if prev_self_attn_state is not None:
if incremental_state is None:
incremental_state = {}
prev_key, prev_value = prev_self_attn_state
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
self.self_attn._set_input_buffer(incremental_state, saved_state)
x, _ = self.self_attn( x, _ = self.self_attn(
query=x, query=x,
key=x, key=x,
...@@ -657,6 +669,12 @@ class TransformerDecoderLayer(nn.Module): ...@@ -657,6 +669,12 @@ class TransformerDecoderLayer(nn.Module):
if self.encoder_attn is not None: if self.encoder_attn is not None:
residual = x residual = x
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
if prev_attn_state is not None:
if incremental_state is None:
incremental_state = {}
prev_key, prev_value = prev_attn_state
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn( x, attn = self.encoder_attn(
query=x, query=x,
key=encoder_out, key=encoder_out,
...@@ -678,6 +696,10 @@ class TransformerDecoderLayer(nn.Module): ...@@ -678,6 +696,10 @@ class TransformerDecoderLayer(nn.Module):
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
if self.onnx_trace:
saved_state = self.self_attn._get_input_buffer(incremental_state)
self_attn_state = saved_state["prev_key"], saved_state["prev_value"]
return x, attn, self_attn_state
return x, attn return x, attn
def maybe_layer_norm(self, layer_norm, x, before=False, after=False): def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
......
...@@ -45,6 +45,11 @@ class MultiheadAttention(nn.Module): ...@@ -45,6 +45,11 @@ class MultiheadAttention(nn.Module):
self.reset_parameters() self.reset_parameters()
self.onnx_trace = False
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def reset_parameters(self): def reset_parameters(self):
nn.init.xavier_uniform_(self.in_proj_weight) nn.init.xavier_uniform_(self.in_proj_weight)
nn.init.xavier_uniform_(self.out_proj.weight) nn.init.xavier_uniform_(self.out_proj.weight)
...@@ -94,9 +99,7 @@ class MultiheadAttention(nn.Module): ...@@ -94,9 +99,7 @@ class MultiheadAttention(nn.Module):
q = self.in_proj_q(query) q = self.in_proj_q(query)
if key is None: if key is None:
assert value is None assert value is None
# this will allow us to concat it with previous value and get k = v = None
# just get the previous value
k = v = q.new(0)
else: else:
k, v = self.in_proj_kv(key) k, v = self.in_proj_kv(key)
else: else:
...@@ -106,12 +109,20 @@ class MultiheadAttention(nn.Module): ...@@ -106,12 +109,20 @@ class MultiheadAttention(nn.Module):
q *= self.scaling q *= self.scaling
if saved_state is not None: if saved_state is not None:
if 'prev_key' in saved_state: if 'prev_key' in saved_state:
if static_kv:
k = saved_state['prev_key']
else:
k = torch.cat((saved_state['prev_key'], k), dim=0) k = torch.cat((saved_state['prev_key'], k), dim=0)
if 'prev_value' in saved_state: if 'prev_value' in saved_state:
if static_kv:
v = saved_state['prev_value']
else:
v = torch.cat((saved_state['prev_value'], v), dim=0) v = torch.cat((saved_state['prev_value'], v), dim=0)
saved_state['prev_key'] = k saved_state['prev_key'] = k
saved_state['prev_value'] = v saved_state['prev_value'] = v
self._set_input_buffer(incremental_state, saved_state) self._set_input_buffer(incremental_state, saved_state)
if self.bias_k is not None: if self.bias_k is not None:
......
...@@ -9,6 +9,7 @@ import math ...@@ -9,6 +9,7 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.onnx.operators
from fairseq import utils from fairseq import utils
...@@ -55,12 +56,12 @@ class SinusoidalPositionalEmbedding(nn.Module): ...@@ -55,12 +56,12 @@ class SinusoidalPositionalEmbedding(nn.Module):
emb[padding_idx, :] = 0 emb[padding_idx, :] = 0
return emb return emb
def forward(self, input, incremental_state=None): def forward(self, input, incremental_state=None, timestep=None):
"""Input is expected to be of size [bsz x seqlen].""" """Input is expected to be of size [bsz x seqlen]."""
# recompute/expand embeddings if needed bsz, seq_len = torch.onnx.operators.shape_as_tensor(input)
bsz, seq_len = input.size()
max_pos = self.padding_idx + 1 + seq_len max_pos = self.padding_idx + 1 + seq_len
if self.weights is None or max_pos > self.weights.size(0): if self.weights is None or max_pos > self.weights.size(0):
# recompute/expand embeddings if needed
self.weights = SinusoidalPositionalEmbedding.get_embedding( self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos, max_pos,
self.embedding_dim, self.embedding_dim,
...@@ -70,12 +71,13 @@ class SinusoidalPositionalEmbedding(nn.Module): ...@@ -70,12 +71,13 @@ class SinusoidalPositionalEmbedding(nn.Module):
if incremental_state is not None: if incremental_state is not None:
# positions is the same for every token when decoding a single step # positions is the same for every token when decoding a single step
return self.weights[self.padding_idx + seq_len, :].expand(bsz, 1, -1) pos = (timestep.int() + 1).long() if timestep is not None else seq_len
if self.onnx_trace:
return self.weights[self.padding_idx + pos, :].unsqueeze(1).repeat(bsz, 1, 1)
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
positions = utils.make_positions(input, self.padding_idx, self.left_pad, self.onnx_trace) positions = utils.make_positions(input, self.padding_idx, self.left_pad, self.onnx_trace)
if self.onnx_trace: if self.onnx_trace:
bsz = torch.onnx.operators.shape_as_tensor(input)[0]
seq_len = torch.onnx.operators.shape_as_tensor(input)[1]
flat_embeddings = self.weights.detach().index_select(0, positions.view(-1)) flat_embeddings = self.weights.detach().index_select(0, positions.view(-1))
embedding_shape = torch.cat((bsz.view(1), seq_len.view(1), torch.LongTensor([-1]))) embedding_shape = torch.cat((bsz.view(1), seq_len.view(1), torch.LongTensor([-1])))
embeddings = torch.onnx.operators.reshape_from_tensor_shape(flat_embeddings, embedding_shape) embeddings = torch.onnx.operators.reshape_from_tensor_shape(flat_embeddings, embedding_shape)
......
...@@ -11,7 +11,7 @@ import os, re ...@@ -11,7 +11,7 @@ import os, re
import torch import torch
from multiprocessing import Pool from multiprocessing import Pool
SPACE_NORMALIZER = re.compile("\s+") SPACE_NORMALIZER = re.compile(r"\s+")
def tokenize_line(line): def tokenize_line(line):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment