Unverified Commit cc92a4b4 authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Merge pull request #55 from ROCmSoftwarePlatform/IFU-master-2021-10-15

IFU-2021-10-15 (+ remove redundant defines + C10_CUDA_CHECK)
parents 1e0f9bc6 fec3141c
###############################################################################
# Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
###############################################################################
import sys
import torch
import numpy as np
import unittest
import math
import fmhalib as mha
def py_mha(qkv, amask, b, s, h, d):
qkv = qkv.view(b, s, h, 3, d)
q = qkv[:, :, :, 0, :].permute(0,2,1,3)
k = qkv[:, :, :, 1, :].permute(0,2,1,3)
v = qkv[:, :, :, 2, :].permute(0,2,1,3)
p = torch.matmul(q.float(), k.permute(0,1,3,2).float())
p_masked = p / math.sqrt(d) + (1.0 - amask) * -10000.0
s = torch.softmax(p_masked, -1).to(qkv.dtype)
ctx = torch.matmul(s, v)
ctx = ctx.permute(0,2,1,3).contiguous()
ctx.retain_grad()
return ctx
class TestFMHA(unittest.TestCase):
def run_test(self, s, b):
print(f'Test s={s} b={b}')
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
dtype = torch.float16
device = torch.device('cuda')
h = 16
d = 64
slens = [s] * b
a = torch.tensor(np.array([0] + slens), dtype=torch.int32)
amask = torch.ones(b,h,s,s, dtype=dtype, device=device)
seqlens = torch.tensor(slens, dtype=torch.int32, device=device)
cu_seqlens = torch.cumsum(a, 0).to(dtype=torch.int32, device=device)
total = cu_seqlens[-1].item()
qkv = torch.randn((b,s,h,3,d), device=device, dtype=dtype)
qkv_vs = qkv.permute(0,1,3,2,4).contiguous().view(b*s, 3, h,d)
qkv.requires_grad = True
if b < 4:
ctx, S_ = mha.fwd_nl(qkv_vs, cu_seqlens, 0.0, s, True, None)
else:
ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, None)
ctx = ctx.view(b,s,h,d)
ctx_ref = py_mha(qkv, amask, b,s,h,d)
self.assertTrue(torch.allclose(ctx_ref.float(), ctx.float(), atol=1e-3))
labels = torch.randn_like(ctx_ref)
diff = ctx_ref - labels
l = (diff * diff).sum() / b
l.backward()
dw = ctx_ref.grad.permute(0,2,1,3)
dw2 = dw.permute(0,2,1,3).clone().detach().contiguous()
if b < 4:
dqkv2, _, _ = mha.bwd_nl(dw2, qkv_vs, S_, cu_seqlens, 0.0, s)
else:
dqkv2, _ = mha.bwd(dw2, qkv_vs, S_, cu_seqlens, 0.0, s)
dqkv2 = dqkv2.permute(0,2,1,3).view(b,s, h,3,d)
self.assertTrue(torch.allclose(qkv.grad.float(), dqkv2.float(), atol=1e-3))
def test_128(self):
self.run_test(128, 32)
def test_256(self):
self.run_test(256, 32)
def test_384(self):
self.run_test(384, 32)
def test_512(self):
self.run_test(512, 32)
self.run_test(512, 2)
self.run_test(512, 3)
if __name__ == '__main__':
unittest.main()
import torch
import unittest
import torch.nn.functional as F
from apex import fused_dense
from torch import nn
from apex import amp
class FusedDenseTest(unittest.TestCase):
def setUp(self, seed=0):
torch.manual_seed(seed)
#torch.cuda.manual_seed_all(seed)
self.seq_length = 512
self.sequences = 3
self.hidden_dim = 1024
self.ref_inputs = torch.randn(self.sequences*self.seq_length, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).int().half().requires_grad_(True)
self.tst_inputs = self.ref_inputs.clone().detach().requires_grad_(True)
self.dense = fused_dense.FusedDense(1024, 3072)
self.dense.half()
self.dense.cuda()
def test_fused_dense(self) :
y_tst = self.dense(self.tst_inputs)
y_ref = torch.matmul(self.ref_inputs,self.dense.weight.t())+self.dense.bias
dy = torch.randn_like(y_tst).half()
y_tst.backward(dy)
dw_ref = torch.matmul(dy.t(), self.ref_inputs)
dx_ref = torch.matmul(dy, self.dense.weight.clone())
db_ref = dy.sum(0, False)
self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(y_ref, y_tst, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(dw_ref, self.dense.weight.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(dx_ref, self.tst_inputs.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(db_ref, self.dense.bias.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
if __name__ == '__main__':
unittest.main()
import torch
import unittest
from apex.contrib.transducer import TransducerJoint
import transducer_ref
class TransducerJointTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def gen_input(self, for_vector_kernel):
self.B = 4
T_min = 51
T_max = 101
U_min = 12
U_max = 25
if for_vector_kernel:
H = 512
else:
H = 509
dtype = torch.float16
device = "cuda"
self.f_tst = torch.randn((self.B, T_max, H), dtype=dtype, requires_grad=True, device=device)
self.g_tst = torch.randn((self.B, U_max, H), dtype=dtype, requires_grad=True, device=device)
self.h_grad = torch.randn(self.B, T_max, U_max, H, dtype=dtype, device=device)
self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device)
self.g_len = torch.randint(U_min, U_max+1, (self.B,), dtype=torch.int, device=device)
self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max
self.g_len[torch.randint(0, self.B, (1,)).item()] = U_max
self.dropout_prob = 0.5
# Make sure gradients from out-of-bound locations are zero. This should be guaranteed by
# the loss function
for b in range(self.B):
self.h_grad[b, self.f_len[b]:, :, :] = 0
self.h_grad[b, :, self.g_len[b]:, :] = 0
self.h_grad_packed = self._pack(self.h_grad, self.f_len, self.g_len)
def _pack(self, x, f_len, g_len):
B = x.size(0)
list_x = []
for b in range(B):
list_x_row = [x[b, t, :g_len[b]] for t in range(f_len[b])]
x_row = torch.cat(list_x_row)
list_x.append(x_row)
x_packed = torch.cat(list_x).data.clone()
x_packed.requires_grad = True
batch_offset = torch.cumsum(f_len * g_len, dim=0)
return x_packed
def _unpack(self, x, f_len, g_len):
batch_offset = torch.cumsum(f_len * g_len, dim=0)
x_unpacked = torch.zeros_like(self.h_grad, dtype=torch.uint8)
B = self.h_grad.size(0)
H = self.h_grad.size(-1)
for b in range(B):
my_batch_offset = 0 if b == 0 else batch_offset[b-1]
my_f_len = f_len[b]
my_g_len = g_len[b]
for t in range(my_f_len):
x_unpacked[b, t, :my_g_len] = x[my_batch_offset + t*my_g_len :
my_batch_offset + t*my_g_len + my_g_len]
return x_unpacked
def run_transducer_joint(self, for_vector_kernel, pack_output, relu, dropout):
self.gen_input(for_vector_kernel=for_vector_kernel)
# Generate reference
f_ref = self.f_tst.data.clone()
g_ref = self.g_tst.data.clone()
f_ref.requires_grad = True
g_ref.requires_grad = True
my_joint = TransducerJoint(pack_output=pack_output, relu=relu, dropout=dropout,
dropout_prob=self.dropout_prob, probe_mask=True)
if not pack_output:
h_tst = my_joint( f=self.f_tst,
g=self.g_tst,
f_len=self.f_len,
g_len=self.g_len)
h_tst.backward(self.h_grad)
if dropout:
mask = my_joint.mask_probe[0]
else:
batch_offset = torch.cumsum(self.f_len * self.g_len, dim=0)
h_tst = my_joint( f=self.f_tst,
g=self.g_tst,
f_len=self.f_len,
g_len=self.g_len,
batch_offset=batch_offset,
packed_batch=batch_offset[-1])
h_tst.backward(self.h_grad_packed)
if dropout:
mask_packed = my_joint.mask_probe[0]
mask = self._unpack(mask_packed, self.f_len, self.g_len)
# reference
h_ref, f_grad_ref, g_grad_ref \
= transducer_ref.transducer_joint_reference(f=f_ref,
g=g_ref,
h_grad=self.h_grad,
f_len=self.f_len,
g_len=self.g_len,
pack_output=pack_output,
relu=relu,
dropout=dropout,
dropout_prob=self.dropout_prob,
mask=mask if dropout else None)
f_grad_tst = self.f_tst.grad
g_grad_tst = self.g_tst.grad
self.assertTrue(torch.allclose(h_ref, h_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(f_grad_ref, f_grad_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(g_grad_ref, g_grad_tst, atol=1e-4, rtol=1e-4))
def test_transducer_joint(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False)
def test_transducer_joint_vec(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=False, dropout=False)
def test_transducer_joint_pack(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=False, dropout=False)
def test_transducer_joint_vec_pack(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False)
def test_transducer_joint_relu(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False)
def test_transducer_joint_vec_relu(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=False)
def test_transducer_joint_pack_relu(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=False)
def test_transducer_joint_vec_pack_relu(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False)
def test_transducer_joint_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)
def test_transducer_joint_vec_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=True)
def test_transducer_joint_pack_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=True)
def test_transducer_joint_vec_pack_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
import torch
import unittest
from apex.contrib.transducer import TransducerLoss
import transducer_ref
class TransducerLossTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def gen_input(self, scalar_t, for_vector_kernel):
self.B = 5
T_min = 23
T_max = 51
U_min = 12
U_max = 25
V = 16 if for_vector_kernel else 14
self.blank_idx = V - 1
device = "cuda"
self.x_tst = torch.randn((self.B, T_max, U_max, V), dtype=scalar_t, requires_grad=True,
device=device)
self.y = torch.randint(0, self.blank_idx, (self.B, U_max-1), dtype=torch.int, device=device)
self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device)
self.y_len = torch.randint(U_min-1, U_max, (self.B,), dtype=torch.int, device=device)
self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max
self.y_len[torch.randint(0, self.B, (1,)).item()] = U_max-1
self.x_tst_packed, self.batch_offset = self._pack(self.x_tst)
# Generate reference
x_ref = self.x_tst.data.clone()
x_ref.requires_grad = True
loss_grad = torch.ones(x_ref.size(0), dtype=x_ref.dtype, device=x_ref.device)/x_ref.size(0)
_, _, self.grad_ref, self.loss_ref \
= transducer_ref.transducer_loss_reference( x=x_ref,
label=self.y,
f_len=self.f_len,
y_len=self.y_len,
blank_idx=self.blank_idx,
loss_grad=loss_grad)
def _pack(self, x):
list_x = []
for b in range(self.B):
list_x_row = [x[b, t, : self.y_len[b]+1] for t in range(self.f_len[b])]
x_row = torch.cat(list_x_row)
list_x.append(x_row)
x_packed = torch.cat(list_x).data.clone()
x_packed.requires_grad = True
batch_offset = torch.cumsum(self.f_len * (self.y_len+1), dim=0)
return x_packed, batch_offset
def _unpack(self, x):
x_unpacked = torch.zeros(self.B, self.f_len.max(), self.y_len.max()+1, x.size(-1),
dtype=x.dtype, device=x.device)
for b in range(self.B):
my_batch_offset = 0 if b == 0 else self.batch_offset[b-1]
my_f_len = self.f_len[b]
my_g_len = self.y_len[b] + 1
for t in range(my_f_len):
for u in range(my_g_len):
x_unpacked[b, t, u] = x[my_batch_offset + t*my_g_len + u]
return x_unpacked
def run_transducer_loss(self, scalar_t, fuse_softmax_backward, packed_input, for_vector_kernel):
self.gen_input(scalar_t, for_vector_kernel)
my_loss = TransducerLoss( fuse_softmax_backward=fuse_softmax_backward,
packed_input=packed_input)
if not packed_input:
loss_tst = my_loss( x=self.x_tst,
label=self.y,
f_len=self.f_len,
y_len=self.y_len,
blank_idx=self.blank_idx)
loss_tst.mean().backward()
grad_tst = self.x_tst.grad
else:
loss_tst = my_loss( x=self.x_tst_packed,
label=self.y,
f_len=self.f_len,
y_len=self.y_len,
blank_idx=self.blank_idx,
batch_offset=self.batch_offset,
max_f_len=max(self.f_len))
loss_tst.mean().backward()
grad_tst_packed = self.x_tst_packed.grad
grad_tst = self._unpack(grad_tst_packed)
return loss_tst, grad_tst
def test_transducer_loss_fp32(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float32,
fuse_softmax_backward=False,
packed_input=False,
for_vector_kernel=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-5, rtol=1e-5))
def test_transducer_loss_fp16(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=False,
packed_input=False,
for_vector_kernel=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
def test_transducer_loss_fp16_backward_fusion(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=True,
packed_input=False,
for_vector_kernel=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
def test_transducer_loss_fp16_backward_fusion_packed(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=True,
packed_input=True,
for_vector_kernel=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
def test_transducer_loss_fp16_backward_fusion_packed_vec(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=True,
packed_input=True,
for_vector_kernel=True)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
import torch
import numpy as np
import pdb
def transducer_loss_reference(x, label, f_len, y_len, blank_idx, loss_grad):
def log_sum_exp(a, b):
if (a >= b):
return a + torch.log(1 + torch.exp(b-a))
else:
return b + torch.log(1 + torch.exp(a-b))
def forward_alpha(x, label, f_len, y_len, blank_idx):
B, T, U, V = x.size()
acc_t = torch.float32 if x.dtype in [torch.float16, torch.float32] else x.dtype
alpha = torch.zeros((B, T, U), dtype=acc_t, device=x.device)
for b in range(B):
alpha[b, 0, 0] = 0
for t in range(1, f_len[b]):
alpha[b, t, 0] = alpha[b, t-1, 0] + x[b, t-1, 0, blank_idx]
for u in range(1, y_len[b]+1):
alpha[b, 0, u] = alpha[b, 0, u-1] + x[b, 0, u-1, label[b, u-1]]
for t in range(1, f_len[b]):
for u in range(1, y_len[b]+1):
curr_ = alpha[b, t-1, u] + x[b, t-1, u, blank_idx]
next_ = alpha[b, t, u-1] + x[b, t, u-1, label[b, u-1]]
alpha[b, t, u] = log_sum_exp(curr_, next_)
return alpha
def forward_beta(x, label, f_len, y_len, blank_idx):
B, T, U, V = x.shape
acc_t = torch.float32 if x.dtype in [torch.float16, torch.float32] else x.dtype
beta = torch.zeros((B, T, U), dtype=acc_t, device=x.device)
for b in range(B):
beta[b, f_len[b]-1, y_len[b]] = x[b, f_len[b]-1, y_len[b], blank_idx]
for t in range(f_len[b]-2, -1, -1):
beta[b, t, y_len[b]] = beta[b, t+1, y_len[b]] + x[b, t, y_len[b], blank_idx]
for u in range(y_len[b]-1, -1, -1):
beta[b, f_len[b]-1, u] = beta[b, f_len[b]-1, u+1] + x[b, f_len[b]-1, u, label[b, u]]
for t in range(f_len[b]-2, -1, -1):
for u in range(y_len[b]-1, -1, -1):
curr_ = beta[b, t+1, u] + x[b, t, u, blank_idx]
next_ = beta[b, t, u+1] + x[b, t, u, label[b, u]]
beta[b, t, u] = log_sum_exp(curr_, next_)
return beta
def backward(x, label, f_len, y_len, alpha, beta, loss_grad, blank_idx):
grad = torch.zeros_like(x)
B, T, U, V = x.size()
for b in range(B):
common_factor = torch.log(loss_grad[b]) + alpha - beta[b, 0, 0]
# next
for u in range(y_len[b]):
grad[b, :f_len[b], u, label[b, u]] = -torch.exp(common_factor[b, :f_len[b], u]
+ beta[b, :f_len[b], u+1]
+ x[b, :f_len[b], u, label[b, u]])
# current
grad[b, :f_len[b]-1, :y_len[b]+1, blank_idx] \
= -torch.exp(common_factor[b, :f_len[b]-1, :y_len[b]+1]
+ beta[b, 1:f_len[b], :y_len[b]+1]
+ x[b, :f_len[b]-1, :y_len[b]+1, blank_idx])
grad[b, f_len[b]-1, y_len[b], blank_idx] = -torch.exp(common_factor[b, f_len[b]-1, y_len[b]]
+ x[b, f_len[b]-1, y_len[b], blank_idx])
return grad
x_log = torch.nn.functional.log_softmax(x, dim=-1)
alpha = forward_alpha(x_log, label, f_len, y_len, blank_idx)
beta = forward_beta(x_log, label, f_len, y_len, blank_idx)
grad = backward(x_log, label, f_len, y_len, alpha, beta,
loss_grad, blank_idx)
x_log.backward(grad)
loss = -beta[:, 0, 0]
loss = loss.to(x.dtype)
return alpha, beta, x.grad, loss
def transducer_joint_reference(f, g, h_grad, f_len, g_len, pack_output, relu, dropout,
dropout_prob=0, mask=None):
if dropout and mask == None:
raise NotImplementedError("mask needs to supplied to test dropout.")
B, T, H = f.size()
U = g.size(1)
f_expand = f.unsqueeze(dim=2)
g_expand = g.unsqueeze(dim=1)
h = f_expand + g_expand
if relu:
h = torch.nn.functional.relu(h)
if dropout:
h *= mask
scale = 1/(1-dropout_prob)
h *= scale
h.backward(h_grad)
if pack_output == False:
# intentionally set don't-care region to -1 to test if transducer joint
# write these regions to avoid NaN and inf
for b in range(B):
h[b, f_len[b]:] = -1
h[b, :, g_len[b]:] = -1
return h, f.grad, g.grad
# packing
list_to_pack = []
for b in range(B):
list_to_pack.append(h[b, :f_len[b], :g_len[b], :].reshape(-1, H))
h_packed = torch.cat(list_to_pack)
return h_packed, f.grad, g.grad
from .transducer import TransducerJoint
from .transducer import TransducerLoss
\ No newline at end of file
import torch
import transducer_loss_cuda
import transducer_joint_cuda
class TransducerJoint(torch.nn.Module):
"""Transducer joint
Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural
Networks
Arguments:
pack_output (bool, optional): whether to pack the output in a compact form with don't-care
data being removed. (default: False)
relu (bool, optional): apply ReLU to the output of the joint operation. Requires opt=1
(default: False)
dropout (bool, optional): apply dropout to the output of the joint operation. Requires opt=1
(default: False)
opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a tiled algorithm.
(default: 1)
fwd_tile_size (int, optional): tile size used in forward operation. This argument will be
ignored if opt != 1. (default: 4)
dropout_prob (float, optional): dropout probability. (default: 0.0)
probe_mask (bool, optional): a flag used to probe the mask generated by ReLU and/or dropout
operation. When this argument is set to True, the mask can be accessed through
self.mask_probe. (default: false)
"""
def __init__(self, pack_output=False, relu=False, dropout=False, opt=1, fwd_tile_size=4,
dropout_prob=0, probe_mask=False):
super(TransducerJoint, self).__init__()
self.pack_output = pack_output
self.relu = relu
self.dropout = dropout
self.dropout_prob = dropout_prob
self.opt = opt
self.fwd_tile_size = fwd_tile_size
self.dummy_batch_offset = torch.empty(0)
masked = self.relu or self.dropout
self.mask_probe = [] if masked and probe_mask else None
if masked and opt != 1:
raise NotImplementedError("ReLU and dropout fusion is only supported with opt=1")
def forward(self, f, g, f_len, g_len, batch_offset=None, packed_batch=0):
"""Forward operation of transducer joint
Arguments:
f (tensor): transcription vector from encode block of shape (B, T, H).
g (tensor): prediction vector form predict block of shape (B, U, H).
f_len (tensor): length of transcription vector for each batch.
g_len (tensor): length of prediction vector minus 1 for each batch.
batch_offset (tensor, optional): tensor containing the offset of each batch
in the results. For example, batch offset can be obtained from:
batch_offset = torch.cumsum(f_len*g_len, dim=0)
This argument is required if pack_output == True, and is ignored if
pack_output == False. (default: None)
packed_batch (int, optional): the batch size after packing. This argument is
ignored if pack_output == False. (default: 0)
"""
my_batch_offset = batch_offset if self.pack_output else self.dummy_batch_offset
if self.pack_output and (batch_offset is None or packed_batch == 0):
raise Exception("Please specify batch_offset and packed_batch when packing is enabled")
dropout = self.dropout and self.training # only dropout for training
return TransducerJointFunc.apply(f, g, f_len, g_len, self.pack_output, self.relu, dropout,
my_batch_offset, packed_batch, self.opt,
self.fwd_tile_size, self.dropout_prob, self.mask_probe)
class TransducerLoss(torch.nn.Module):
"""Transducer loss
Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural
Networks
Arguments:
fuse_softmax_backward (bool, optional) whether to fuse the backward of transducer loss with
softmax. (default: True)
opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a more optimized
algorithm. In some cases, opt=1 might fall back to opt=0. (default: 1)
packed_input (bool, optional): whether to pack the output in a compact form with don't-care
data being removed. (default: False)
"""
def __init__(self, fuse_softmax_backward=True, opt=1, packed_input=False):
super(TransducerLoss, self).__init__()
self.fuse_softmax_backward = fuse_softmax_backward
self.opt = opt
self.packed_input = packed_input
self.dummy_batch_offset = torch.empty(0)
def forward(self, x, label, f_len, y_len, blank_idx, batch_offset=None, max_f_len=None,
debug_list=None):
"""Forward operation of transducer joint
Arguments:
x (tensor): input tensor to the loss function with a shape of (B, T, U, H).
label (tensor): labels for the input data.
f_len (tensor): lengths of the inputs in the time dimension for each batch.
y_len (tensor): lengths of the labels for each batch.
blank_idx (int): index for the null symbol.
batch_offset (tensor, optional): tensor containing the offset of each batch
in the input. For example, batch offset can be obtained from:
batch_offset = torch.cumsum(f_len*(y_len+1), dim=0)
This argument is required if packed_input == True, and is ignored if
packed_input == False. (default: None)
max_f_len (int, optional): maximum length of the input in the time dimension.
For example, it can be obtained as
max_f_len = max(f_len)
This argument is required if packed_input == True, and is ignored if
packed_input == False. (default: None)
(default: None)
debug_list (list, optional): when an empty list is supplied, Alpha and Beta generated
in the forward operation will be attached to this list for debug purpose.
(default: None)
"""
if self.packed_input:
if batch_offset is None or max_f_len is None:
raise Exception("Please specify batch_offset and max_f_len when packing is \
enabled")
my_batch_offset = batch_offset
my_max_f_len = max_f_len
else:
my_batch_offset = self.dummy_batch_offset
my_max_f_len = x.size(1)
return TransducerLossFunc.apply(x, label, f_len, y_len, my_batch_offset, my_max_f_len,
blank_idx, self.fuse_softmax_backward, debug_list,
self.opt, self.packed_input)
class TransducerLossFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, x, label, f_len, y_len, batch_offset, max_f_len, blank_idx,
fuse_softmax_backward, debug_list, opt, packed_input):
if fuse_softmax_backward == False:
with torch.enable_grad():
x = torch.nn.functional.log_softmax(x, dim=-1)
else:
x = torch.nn.functional.log_softmax(x, dim=-1)
alpha, beta, loss = transducer_loss_cuda.forward( x, label, f_len, y_len, batch_offset,
max_f_len, blank_idx, opt, packed_input)
if debug_list == []:
debug_list += [alpha, beta]
ctx.save_for_backward(x, alpha, beta, f_len, y_len, label, batch_offset)
ctx.blank_idx = blank_idx
ctx.fuse_softmax_backward = fuse_softmax_backward
ctx.opt = opt
ctx.packed_input = packed_input
ctx.max_f_len = max_f_len
return loss
@staticmethod
def backward(ctx, loss_grad):
x, alpha, beta, f_len, y_len, label, batch_offset = ctx.saved_tensors
x_grad = transducer_loss_cuda.backward( x, loss_grad, alpha, beta, f_len, y_len, label,
batch_offset, ctx.max_f_len, ctx.blank_idx, ctx.opt,
ctx.fuse_softmax_backward, ctx.packed_input)
if ctx.fuse_softmax_backward == False:
x_grad = x.backward(x_grad)
return x_grad, None, None, None, None, None, None, None, None, None, None
class TransducerJointFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, f, g, f_len, g_len, pack_output, relu, dropout, batch_offset, packed_batch,
opt, fwd_tile_size, dropout_prob, mask_probe):
h = transducer_joint_cuda.forward(f, g, f_len, g_len, batch_offset, packed_batch, opt,
pack_output, relu, dropout, dropout_prob, fwd_tile_size)
masked = relu or dropout
if masked:
ctx.save_for_backward(h[1], f_len, g_len, batch_offset)
if mask_probe is not None:
mask_probe.append(h[1])
else:
ctx.save_for_backward(f_len, g_len, batch_offset)
ctx.pack_output = pack_output
ctx.masked = relu or dropout
ctx.max_f_len = f.size(1)
ctx.max_g_len = g.size(1)
ctx.scale = 1 / (1-dropout_prob) if dropout and dropout_prob != 1 else 1
return h[0]
@staticmethod
def backward(ctx, loss_grad):
if ctx.masked:
mask, f_len, g_len, batch_offset = ctx.saved_tensors
inp = [loss_grad, mask]
else:
f_len, g_len, batch_offset = ctx.saved_tensors
inp = [loss_grad]
f_grad, g_grad = transducer_joint_cuda.backward( inp, f_len, g_len, batch_offset,
ctx.max_f_len, ctx.max_g_len,
ctx.pack_output, ctx.scale)
return f_grad, g_grad, None, None, None, None, None, None, None, None, None, None, None, \
None, None, None
from .fused_dense import *
import torch
from torch import nn
import fused_dense_cuda
from .. import amp
#implements fused GEMM+bias in forward pass using mlp_cuda from apex
class FusedDenseFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias):
ctx.save_for_backward(input, weight)
output = fused_dense_cuda.linear_bias_forward(input, weight, bias)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_input, grad_weight, grad_bias = fused_dense_cuda.linear_bias_backward(input, weight, grad_output)
return grad_input, grad_weight, grad_bias
class DenseNoBiasFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight):
ctx.save_for_backward(input, weight)
output = torch.matmul(input, weight.t())
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_input = grad_output.mm(weight)
grad_weight = grad_output.t().mm(input)
return grad_input, grad_weight
class FusedDenseGeluDenseFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight1, bias1, weight2, bias2):
ctx.save_for_backward(input, weight1, weight2)
output1, output2, gelu_in = fused_dense_cuda.linear_gelu_linear_forward(input, weight1, bias1, weight2, bias2)
ctx.save_for_backward(input, weight1, weight2, gelu_in, output1)
return output2
@staticmethod
def backward(ctx, grad_output):
input, weight1, weight2, gelu_in, output1 = ctx.saved_tensors
grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_gelu_linear_backward(input, gelu_in, output1, weight1, weight2, grad_output)
return grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2
fused_dense_function = amp.half_function(FusedDenseFunc.apply)
dense_no_bias_function = amp.half_function(DenseNoBiasFunc.apply)
fused_dense_gelu_dense_function = amp.half_function(FusedDenseGeluDenseFunc.apply)
class FusedDense(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super(FusedDense, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_features))
else:
#assert False, "no-bias option not added yet"
self.register_parameter('bias', None)
def forward(self, input):
if self.bias is not None:
return fused_dense_function(input, self.weight, self.bias)
else:
return dense_no_bias_function(input, self.weight)
class FusedDenseGeluDense(nn.Module):
def __init__(self, in_features, intermediate_features, out_features, bias=True):
super(FusedDenseGeluDense, self).__init__()
assert bias == True, "DenseGeluDense module without bias is currently not supported"
self.in_features = in_features
self.intermediate_features = intermediate_features
self.out_features = out_features
self.weight1 = nn.Parameter(torch.Tensor(intermediate_features, in_features))
self.bias1 = nn.Parameter(torch.Tensor(intermediate_features))
self.weight2 = nn.Parameter(torch.Tensor(out_features, intermediate_features))
self.bias2 = nn.Parameter(torch.Tensor(out_features))
def forward(self, input):
return fused_dense_gelu_dense_function(input, self.weight1, self.bias1, self.weight2, self.bias2)
from .fused_layer_norm import FusedLayerNorm from .fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm
import math import importlib
import torch
import numbers import numbers
import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
from torch.nn import functional as F from torch.nn import functional as F
import importlib
from apex._autocast_utils import _cast_if_autocast_enabled
global fused_layer_norm_cuda global fused_layer_norm_cuda
fused_layer_norm_cuda = None fused_layer_norm_cuda = None
class FusedLayerNormAffineFunction(torch.autograd.Function): class FusedLayerNormAffineFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward_affine(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps
)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps
)
return grad_input, grad_weight, grad_bias, None, None
class FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFunction):
@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward_affine_mixed_dtypes(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps
)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward_affine(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None
class FusedLayerNormFunction(torch.autograd.Function): class FusedLayerNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward(input_, ctx.normalized_shape, ctx.eps)
ctx.save_for_backward(input_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, mean, invvar = ctx.saved_tensors
grad_input = None
grad_input = fused_layer_norm_cuda.backward(
grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, ctx.eps
)
return grad_input, None, None
def fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6):
args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps)
with torch.cuda.amp.autocast(enabled=False):
return FusedLayerNormAffineFunction.apply(*args)
@staticmethod
def forward(ctx, input, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward(
input_, ctx.normalized_shape, ctx.eps)
ctx.save_for_backward(input_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, mean, invvar = ctx.saved_tensors
grad_input = None
grad_input = fused_layer_norm_cuda.backward(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
ctx.eps)
return grad_input, None, None
def fused_layer_norm_affine(input, normalized_shape, weight, bias, eps=1e-6):
return FusedLayerNormAffineFunction.apply(input, weight, bias, normalized_shape, eps)
def fused_layer_norm(input, normalized_shape, eps=1e-6): def fused_layer_norm(input, normalized_shape, eps=1e-6):
return FusedLayerNormFunction.apply(input, normalized_shape, eps) args = _cast_if_autocast_enabled(input, normalized_shape, eps)
with torch.cuda.amp.autocast(enabled=False):
return FusedLayerNormFunction.apply(*args)
def mixed_dtype_fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6):
args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps)
with torch.cuda.amp.autocast(enabled=False):
return FusedLayerNormAffineMixedDtypesFunction.apply(*args)
class FusedLayerNorm(torch.nn.Module): class FusedLayerNorm(torch.nn.Module):
r"""Applies Layer Normalization over a mini-batch of inputs as described in r"""Applies Layer Normalization over a mini-batch of inputs as described in
...@@ -126,8 +158,9 @@ class FusedLayerNorm(torch.nn.Module): ...@@ -126,8 +158,9 @@ class FusedLayerNorm(torch.nn.Module):
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450 .. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
""" """
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super(FusedLayerNorm, self).__init__() super().__init__()
global fused_layer_norm_cuda global fused_layer_norm_cuda
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
...@@ -141,8 +174,8 @@ class FusedLayerNorm(torch.nn.Module): ...@@ -141,8 +174,8 @@ class FusedLayerNorm(torch.nn.Module):
self.weight = Parameter(torch.Tensor(*normalized_shape)) self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape)) self.bias = Parameter(torch.Tensor(*normalized_shape))
else: else:
self.register_parameter('weight', None) self.register_parameter("weight", None)
self.register_parameter('bias', None) self.register_parameter("bias", None)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
...@@ -152,14 +185,34 @@ class FusedLayerNorm(torch.nn.Module): ...@@ -152,14 +185,34 @@ class FusedLayerNorm(torch.nn.Module):
def forward(self, input): def forward(self, input):
if not input.is_cuda: if not input.is_cuda:
return F.layer_norm( return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
input, self.normalized_shape, self.weight, self.bias, self.eps)
if self.elementwise_affine: if self.elementwise_affine:
return FusedLayerNormAffineFunction.apply( return fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps)
input, self.weight, self.bias, self.normalized_shape,self.eps)
else: else:
return FusedLayerNormFunction.apply(input, self.normalized_shape, self.eps) return fused_layer_norm(input, self.normalized_shape, self.eps)
def extra_repr(self): def extra_repr(self):
return '{normalized_shape}, eps={eps}, ' \ return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__)
'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
# NOTE (mkozuki): Why "mixed"?
# MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype
# as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype.
# See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in "csrc/layer_norm_cuda.cpp"
class MixedFusedLayerNorm(FusedLayerNorm):
def __init__(self, normalized_shape, eps=1e-5, **kwargs):
if "elementwise_affine" in kwargs:
import warnings
warnings.warn("MixedFusedLayerNorm does not support `elementwise_affine` argument")
elementwise_affine = kwargs.pop("elementwise_affine")
if not elementwise_affine:
raise RuntimeError("MixedFusedLayerNorm does not support `elementwise_affine = False`")
super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=True)
def forward(self, input: torch.Tensor):
# NOTE (mkozuki): CPU path is here mainly for unittest sake.
if not input.is_cuda:
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
return mixed_dtype_fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps)
...@@ -13,7 +13,8 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -13,7 +13,8 @@ class FusedAdam(torch.optim.Optimizer):
* Fusion of the Adam update's elementwise operations * Fusion of the Adam update's elementwise operations
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches. * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
:class:`apex.optimizers.FusedAdam` may be used as a drop-in replacement for ``torch.optim.Adam``:: :class:`apex.optimizers.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,
or ``torch.optim.Adam`` with ``adam_w_mode=False``::
opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....) opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....)
... ...
......
...@@ -79,7 +79,9 @@ class FusedNovoGrad(torch.optim.Optimizer): ...@@ -79,7 +79,9 @@ class FusedNovoGrad(torch.optim.Optimizer):
if multi_tensor_applier.available: if multi_tensor_applier.available:
import amp_C import amp_C
# Skip buffer # Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
# Creating the overflow buffer on the same device as the params tensors.
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
self.multi_tensor_novograd = amp_C.multi_tensor_novograd self.multi_tensor_novograd = amp_C.multi_tensor_novograd
else: else:
raise RuntimeError('apex.optimizers.FusedNovoGrad requires cuda extensions') raise RuntimeError('apex.optimizers.FusedNovoGrad requires cuda extensions')
...@@ -158,8 +160,9 @@ class FusedNovoGrad(torch.optim.Optimizer): ...@@ -158,8 +160,9 @@ class FusedNovoGrad(torch.optim.Optimizer):
if 'exp_avg_sq' not in group: if 'exp_avg_sq' not in group:
group['exp_avg_sq'] = [None, None] group['exp_avg_sq'] = [None, None]
if group['init_zero']: if group['init_zero']:
group['exp_avg_sq'][0] = torch.cuda.FloatTensor(len(g_16)).contiguous().fill_(0) # Creating the following parameters on the same device as the params tensors.
group['exp_avg_sq'][1] = torch.cuda.FloatTensor(len(g_32)).contiguous().fill_(0) group['exp_avg_sq'][0] = torch.cuda.FloatTensor(len(g_16), device=self.param_groups[0]["params"][0].device).contiguous().fill_(0)
group['exp_avg_sq'][1] = torch.cuda.FloatTensor(len(g_32), device=self.param_groups[0]["params"][0].device).contiguous().fill_(0)
else: # init with first step norm, so first blend have no effect else: # init with first step norm, so first blend have no effect
if group['norm_type'] == 0: if group['norm_type'] == 0:
v_16 = [torch.max(torch.abs(g.to(torch.float32))).item() for g in g_16] v_16 = [torch.max(torch.abs(g.to(torch.float32))).item() for g in g_16]
...@@ -169,8 +172,9 @@ class FusedNovoGrad(torch.optim.Optimizer): ...@@ -169,8 +172,9 @@ class FusedNovoGrad(torch.optim.Optimizer):
v_32 = [torch.sum(torch.pow(g, 2)).sqrt().item() for g in g_32] v_32 = [torch.sum(torch.pow(g, 2)).sqrt().item() for g in g_32]
else: else:
raise RuntimeError('FusedNovoGrad only support l2/inf norm now.') raise RuntimeError('FusedNovoGrad only support l2/inf norm now.')
group['exp_avg_sq'][0] = torch.cuda.FloatTensor(v_16) # Creating the following parameters on the same device as the params tensors.
group['exp_avg_sq'][1] = torch.cuda.FloatTensor(v_32) group['exp_avg_sq'][0] = torch.cuda.FloatTensor(v_16, device=self.param_groups[0]["params"][0].device)
group['exp_avg_sq'][1] = torch.cuda.FloatTensor(v_32, device=self.param_groups[0]["params"][0].device)
else: else:
assert(len(g_16) == group['exp_avg_sq'][0].numel()) assert(len(g_16) == group['exp_avg_sq'][0].numel())
assert(len(g_32) == group['exp_avg_sq'][1].numel()) assert(len(g_32) == group['exp_avg_sq'][1].numel())
......
# apex.transformer
`apex.transformer` is a module which enables efficient large Transformer models at scale.
`apex.transformer.tensor_parallel` is based on [NVIDIA/Megatron-LM](https://github.com/NVIDIA/Megatron-LM)'s `megatron.mpu` module.
from . import tensor_parallel
from . import functional
from .enums import LayerType
from .enums import AttnType
from .enums import AttnMaskType
from .parallel_state import (
is_unitialized,
destroy_model_parallel,
get_data_parallel_group,
get_data_parallel_rank,
get_data_parallel_world_size,
get_embedding_group,
get_model_parallel_group,
get_tensor_model_parallel_group,
get_pipeline_model_parallel_group,
get_tensor_model_parallel_rank,
set_tensor_model_parallel_rank,
get_pipeline_model_parallel_rank,
set_pipeline_model_parallel_rank,
is_pipeline_first_stage,
is_pipeline_last_stage,
get_tensor_model_parallel_src_rank,
get_pipeline_model_parallel_first_rank,
get_pipeline_model_parallel_last_rank,
get_pipeline_model_parallel_next_rank,
get_pipeline_model_parallel_prev_rank,
get_tensor_model_parallel_world_size,
set_tensor_model_parallel_world_size,
get_pipeline_model_parallel_world_size,
set_pipeline_model_parallel_world_size,
get_virtual_pipeline_model_parallel_rank,
set_virtual_pipeline_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
class LayerType(enum.Enum):
encoder = 1
decoder = 2
class AttnType(enum.Enum):
self_attn = 1
cross_attn = 2
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
from .fused_softmax import FusedScaleMaskSoftmax
__all__ = [
"FusedScaleMaskSoftmax",
]
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from apex._autocast_utils import _cast_if_autocast_enabled
from ..enums import AttnMaskType
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, scale):
import scaled_upper_triang_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
import scaled_upper_triang_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_upper_triang_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None
def scaled_upper_triang_masked_softmax(inputs, _, scale):
b, np, sq, sk = inputs.size()
assert sq == sk, "causal mask is only for self attention"
# Reshaping input to 3D tensor (attn_batches, sq, sk)
inputs = inputs.view(-1, sq, sk)
args = _cast_if_autocast_enabled(inputs, scale)
with torch.cuda.amp.autocast(enabled=False):
probs = ScaledUpperTriangMaskedSoftmax.apply(*args)
return probs.view(b, np, sq, sk)
# NOTE (mkozuki): `ScaledMaskedSoftmax` somehow doesn't work well with `torch.cuda.amp.custom_fwd`.
# Without `cast_inputs` kwarg, somehow inputs are not cast to dtype used in the autocast context.
# So I needed to manually write two `torch.autograd.Function` inheritances.
# Fused operation which performs following three operations in sequence
# 1. Scale the tensor.
# 2. Apply the mask.
# 3. Perform softmax.
class ScaledMaskedSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, mask, scale):
import scaled_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
import scaled_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None
def scaled_masked_softmax(inputs, mask, scale):
# input is 4D tensor (b, np, sq, sk)
args = _cast_if_autocast_enabled(inputs, mask, scale)
with torch.cuda.amp.autocast(enabled=False):
return ScaledMaskedSoftmax.apply(*args)
class FusedScaleMaskSoftmax(torch.nn.Module):
"""
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal)
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def __init__(
self,
input_in_fp16,
input_in_bf16,
attn_mask_type,
scaled_masked_softmax_fusion,
mask_func,
softmax_in_fp32,
scale,
):
super().__init__()
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
if self.input_in_fp16 and self.input_in_bf16:
raise RuntimeError("both fp16 and bf16 flags cannot be active at the same time.")
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
if not (self.scale is None or softmax_in_fp32):
raise RuntimeError("softmax should be in fp32 when scaled")
if self.scaled_masked_softmax_fusion:
if self.attn_mask_type == AttnMaskType.causal:
self.fused_softmax_func = scaled_upper_triang_masked_softmax
elif self.attn_mask_type == AttnMaskType.padding:
self.fused_softmax_func = scaled_masked_softmax
else:
raise ValueError("Invalid attn_mask_type.")
def forward(self, input, mask):
# [b, np, sq, sk]
assert input.dim() == 4
if self.is_kernel_available(mask, *input.size()):
return self.forward_fused_softmax(input, mask)
else:
return self.forward_torch_softmax(input, mask)
def is_kernel_available(self, mask, b, np, sq, sk):
attn_batches = b * np
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None
and 16 < sk <= 2048 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
if self.attn_mask_type == AttnMaskType.causal:
if attn_batches % batch_per_block == 0:
return True
else:
if sq % batch_per_block == 0:
return True
return False
def forward_fused_softmax(self, input, mask):
# input.shape = [b, np, sq, sk]
scale = self.scale if self.scale is not None else 1.0
return self.fused_softmax_func(input, mask, scale)
def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs
@staticmethod
def get_batch_per_block(sq, sk, b, np):
import scaled_masked_softmax_cuda
return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model and data parallel groups."""
import torch
# TODO (mkozuki): Consider dissecting utils as this utils import is here
# only for ensure_divisibility
from .tensor_parallel import utils
# Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# Inter-layer model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
_MODEL_PARALLEL_GROUP = None
# Embedding group.
_EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
# These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage
_PIPELINE_GLOBAL_RANKS = None
def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization"""
return _DATA_PARALLEL_GROUP is None
def initialize_model_parallel(
tensor_model_parallel_size_=1, pipeline_model_parallel_size_=1, virtual_pipeline_model_parallel_size_=None
):
"""
Initialize model data parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used to parallelize model tensor.
pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
and 8 data-parallel groups as:
8 data_parallel groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
8 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
4 pipeline model-parallel groups:
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
if torch.distributed.get_rank() == 0:
print("> initializing tensor model parallel with size {}".format(tensor_model_parallel_size_))
print("> initializing pipeline model parallel with size {}".format(pipeline_model_parallel_size_))
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)
# TODO (mkozuki): Consider moving `ensure_divisibility` to this file.
utils.ensure_divisibility(world_size, tensor_model_parallel_size * pipeline_model_parallel_size)
data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
num_data_parallel_groups = world_size // data_parallel_size
if virtual_pipeline_model_parallel_size_ is not None:
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_
rank = torch.distributed.get_rank()
# Build the data-parallel groups.
global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"
all_data_parallel_group_ranks = []
for i in range(pipeline_model_parallel_size):
start_rank = i * num_pipeline_model_parallel_groups
end_rank = (i + 1) * num_pipeline_model_parallel_groups
for j in range(tensor_model_parallel_size):
ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
all_data_parallel_group_ranks.append(list(ranks))
group = torch.distributed.new_group(ranks)
if rank in ranks:
_DATA_PARALLEL_GROUP = group
# Build the model-parallel groups.
global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized"
for i in range(data_parallel_size):
ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks]
group = torch.distributed.new_group(ranks)
if rank in ranks:
_MODEL_PARALLEL_GROUP = group
# Build the tensor model-parallel groups.
global _TENSOR_MODEL_PARALLEL_GROUP
assert _TENSOR_MODEL_PARALLEL_GROUP is None, "tensor model parallel group is already initialized"
for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group
# Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group).
global _PIPELINE_MODEL_PARALLEL_GROUP
global _PIPELINE_GLOBAL_RANKS
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, "pipeline model parallel group is already initialized"
global _EMBEDDING_GROUP
assert _EMBEDDING_GROUP is None, "embedding group is already initialized"
for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group
_PIPELINE_GLOBAL_RANKS = ranks
# Setup embedding group (to exchange gradients between
# first and last stages).
if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]]
else:
embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks)
if rank in embedding_ranks:
_EMBEDDING_GROUP = group
def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized."""
if _TENSOR_MODEL_PARALLEL_GROUP is None or _PIPELINE_MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None:
return False
return True
def get_model_parallel_group():
"""Get the model parallel group the caller rank belongs to."""
assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized"
return _MODEL_PARALLEL_GROUP
def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, "intra_layer_model parallel group is not initialized"
return _TENSOR_MODEL_PARALLEL_GROUP
def get_pipeline_model_parallel_group():
"""Get the pipeline model parallel group the caller rank belongs to."""
assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, "pipeline_model parallel group is not initialized"
return _PIPELINE_MODEL_PARALLEL_GROUP
def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized"
return _DATA_PARALLEL_GROUP
def get_embedding_group():
"""Get the embedding group the caller rank belongs to."""
assert _EMBEDDING_GROUP is not None, "embedding group is not initialized"
return _EMBEDDING_GROUP
def set_tensor_model_parallel_world_size(world_size):
"""Set the tensor model parallel size"""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
def set_pipeline_model_parallel_world_size(world_size):
"""Set the pipeline model parallel size"""
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
def get_pipeline_model_parallel_world_size():
"""Return world size for the pipeline model parallel group."""
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
def set_tensor_model_parallel_rank(rank):
"""Set tensor model parallel rank."""
global _MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK = rank
def set_pipeline_model_parallel_rank(rank):
"""Set pipeline model parallel rank."""
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_RANK
if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
return _MPU_TENSOR_MODEL_PARALLEL_RANK
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
def get_pipeline_model_parallel_rank():
"""Return my rank for the pipeline model parallel group."""
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
return _MPU_PIPELINE_MODEL_PARALLEL_RANK
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
if (
get_virtual_pipeline_model_parallel_world_size() is not None
and get_virtual_pipeline_model_parallel_rank() != 0
):
return False
return get_pipeline_model_parallel_rank() == 0
def is_pipeline_last_stage(ignore_virtual=False):
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
virtual_pipeline_model_parallel_world_size = get_virtual_pipeline_model_parallel_world_size()
if virtual_pipeline_model_parallel_world_size is not None and get_virtual_pipeline_model_parallel_rank() != (
virtual_pipeline_model_parallel_world_size - 1
):
return False
return get_pipeline_model_parallel_rank() == (get_pipeline_model_parallel_world_size() - 1)
def get_virtual_pipeline_model_parallel_rank():
"""Return the virtual pipeline-parallel rank."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
def set_virtual_pipeline_model_parallel_rank(rank):
"""Set the virtual pipeline-parallel rank."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank
def get_virtual_pipeline_model_parallel_world_size():
"""Return the virtual pipeline-parallel world size."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank = torch.distributed.get_rank()
local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size
def get_pipeline_model_parallel_first_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
return _PIPELINE_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_last_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_next_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
def get_pipeline_model_parallel_prev_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
def get_data_parallel_world_size():
"""Return world size for the data parallel group."""
return torch.distributed.get_world_size(group=get_data_parallel_group())
def get_data_parallel_rank():
"""Return my rank for the data parallel group."""
return torch.distributed.get_rank(group=get_data_parallel_group())
def destroy_model_parallel():
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = None
global _TENSOR_MODEL_PARALLEL_GROUP
_TENSOR_MODEL_PARALLEL_GROUP = None
global _PIPELINE_MODEL_PARALLEL_GROUP
_PIPELINE_MODEL_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None
global _EMBEDDING_GROUP
_EMBEDDING_GROUP = None
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model parallel utility interface."""
from .cross_entropy import vocab_parallel_cross_entropy
from .data import broadcast_data
from .layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes,
)
from .mappings import (
copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region,
scatter_to_tensor_model_parallel_region,
)
from .random import (
checkpoint,
get_cuda_rng_tracker,
init_checkpointed_activations_memory_buffer,
model_parallel_cuda_manual_seed,
reset_checkpointed_activations_memory_buffer,
gather_split_1d_tensor,
split_tensor_into_1d_equal_chunks,
)
from .utils import divide, split_tensor_along_last_dim
__all__ = [
# cross_entropy.py
"vocab_parallel_cross_entropy",
# data.py
"broadcast_data",
# layers.py
"ColumnParallelLinear",
"RowParallelLinear",
"VocabParallelEmbedding",
"set_tensor_model_parallel_attributes",
"set_defaults_if_not_set_tensor_model_parallel_attributes",
"copy_tensor_model_parallel_attributes",
# mappings.py
"copy_to_tensor_model_parallel_region",
"gather_from_tensor_model_parallel_region",
"reduce_from_tensor_model_parallel_region",
"scatter_to_tensor_model_parallel_region",
# random.py
"checkpoint",
"get_cuda_rng_tracker",
"init_checkpointed_activations_memory_buffer",
"model_parallel_cuda_manual_seed",
"reset_checkpointed_activations_memory_buffer",
"gather_split_1d_tensor",
"split_tensor_into_1d_equal_chunks",
# utils.py
"divide",
"split_tensor_along_last_dim",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment