"llama/llama.cpp/src/llama-model-loader.cpp" did not exist on "57f038ec7b4b71113924b7513319a1005c476f2d"
Unverified Commit cc92a4b4 authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Merge pull request #55 from ROCmSoftwarePlatform/IFU-master-2021-10-15

IFU-2021-10-15 (+ remove redundant defines + C10_CUDA_CHECK)
parents 1e0f9bc6 fec3141c
###############################################################################
# Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
###############################################################################
import sys
import torch
import numpy as np
import unittest
import math
import fmhalib as mha
def py_mha(qkv, amask, b, s, h, d):
qkv = qkv.view(b, s, h, 3, d)
q = qkv[:, :, :, 0, :].permute(0,2,1,3)
k = qkv[:, :, :, 1, :].permute(0,2,1,3)
v = qkv[:, :, :, 2, :].permute(0,2,1,3)
p = torch.matmul(q.float(), k.permute(0,1,3,2).float())
p_masked = p / math.sqrt(d) + (1.0 - amask) * -10000.0
s = torch.softmax(p_masked, -1).to(qkv.dtype)
ctx = torch.matmul(s, v)
ctx = ctx.permute(0,2,1,3).contiguous()
ctx.retain_grad()
return ctx
class TestFMHA(unittest.TestCase):
def run_test(self, s, b):
print(f'Test s={s} b={b}')
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
dtype = torch.float16
device = torch.device('cuda')
h = 16
d = 64
slens = [s] * b
a = torch.tensor(np.array([0] + slens), dtype=torch.int32)
amask = torch.ones(b,h,s,s, dtype=dtype, device=device)
seqlens = torch.tensor(slens, dtype=torch.int32, device=device)
cu_seqlens = torch.cumsum(a, 0).to(dtype=torch.int32, device=device)
total = cu_seqlens[-1].item()
qkv = torch.randn((b,s,h,3,d), device=device, dtype=dtype)
qkv_vs = qkv.permute(0,1,3,2,4).contiguous().view(b*s, 3, h,d)
qkv.requires_grad = True
if b < 4:
ctx, S_ = mha.fwd_nl(qkv_vs, cu_seqlens, 0.0, s, True, None)
else:
ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, None)
ctx = ctx.view(b,s,h,d)
ctx_ref = py_mha(qkv, amask, b,s,h,d)
self.assertTrue(torch.allclose(ctx_ref.float(), ctx.float(), atol=1e-3))
labels = torch.randn_like(ctx_ref)
diff = ctx_ref - labels
l = (diff * diff).sum() / b
l.backward()
dw = ctx_ref.grad.permute(0,2,1,3)
dw2 = dw.permute(0,2,1,3).clone().detach().contiguous()
if b < 4:
dqkv2, _, _ = mha.bwd_nl(dw2, qkv_vs, S_, cu_seqlens, 0.0, s)
else:
dqkv2, _ = mha.bwd(dw2, qkv_vs, S_, cu_seqlens, 0.0, s)
dqkv2 = dqkv2.permute(0,2,1,3).view(b,s, h,3,d)
self.assertTrue(torch.allclose(qkv.grad.float(), dqkv2.float(), atol=1e-3))
def test_128(self):
self.run_test(128, 32)
def test_256(self):
self.run_test(256, 32)
def test_384(self):
self.run_test(384, 32)
def test_512(self):
self.run_test(512, 32)
self.run_test(512, 2)
self.run_test(512, 3)
if __name__ == '__main__':
unittest.main()
import torch
import unittest
import torch.nn.functional as F
from apex import fused_dense
from torch import nn
from apex import amp
class FusedDenseTest(unittest.TestCase):
def setUp(self, seed=0):
torch.manual_seed(seed)
#torch.cuda.manual_seed_all(seed)
self.seq_length = 512
self.sequences = 3
self.hidden_dim = 1024
self.ref_inputs = torch.randn(self.sequences*self.seq_length, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).int().half().requires_grad_(True)
self.tst_inputs = self.ref_inputs.clone().detach().requires_grad_(True)
self.dense = fused_dense.FusedDense(1024, 3072)
self.dense.half()
self.dense.cuda()
def test_fused_dense(self) :
y_tst = self.dense(self.tst_inputs)
y_ref = torch.matmul(self.ref_inputs,self.dense.weight.t())+self.dense.bias
dy = torch.randn_like(y_tst).half()
y_tst.backward(dy)
dw_ref = torch.matmul(dy.t(), self.ref_inputs)
dx_ref = torch.matmul(dy, self.dense.weight.clone())
db_ref = dy.sum(0, False)
self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(y_ref, y_tst, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(dw_ref, self.dense.weight.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(dx_ref, self.tst_inputs.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(db_ref, self.dense.bias.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
if __name__ == '__main__':
unittest.main()
import torch
import unittest
from apex.contrib.transducer import TransducerJoint
import transducer_ref
class TransducerJointTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def gen_input(self, for_vector_kernel):
self.B = 4
T_min = 51
T_max = 101
U_min = 12
U_max = 25
if for_vector_kernel:
H = 512
else:
H = 509
dtype = torch.float16
device = "cuda"
self.f_tst = torch.randn((self.B, T_max, H), dtype=dtype, requires_grad=True, device=device)
self.g_tst = torch.randn((self.B, U_max, H), dtype=dtype, requires_grad=True, device=device)
self.h_grad = torch.randn(self.B, T_max, U_max, H, dtype=dtype, device=device)
self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device)
self.g_len = torch.randint(U_min, U_max+1, (self.B,), dtype=torch.int, device=device)
self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max
self.g_len[torch.randint(0, self.B, (1,)).item()] = U_max
self.dropout_prob = 0.5
# Make sure gradients from out-of-bound locations are zero. This should be guaranteed by
# the loss function
for b in range(self.B):
self.h_grad[b, self.f_len[b]:, :, :] = 0
self.h_grad[b, :, self.g_len[b]:, :] = 0
self.h_grad_packed = self._pack(self.h_grad, self.f_len, self.g_len)
def _pack(self, x, f_len, g_len):
B = x.size(0)
list_x = []
for b in range(B):
list_x_row = [x[b, t, :g_len[b]] for t in range(f_len[b])]
x_row = torch.cat(list_x_row)
list_x.append(x_row)
x_packed = torch.cat(list_x).data.clone()
x_packed.requires_grad = True
batch_offset = torch.cumsum(f_len * g_len, dim=0)
return x_packed
def _unpack(self, x, f_len, g_len):
batch_offset = torch.cumsum(f_len * g_len, dim=0)
x_unpacked = torch.zeros_like(self.h_grad, dtype=torch.uint8)
B = self.h_grad.size(0)
H = self.h_grad.size(-1)
for b in range(B):
my_batch_offset = 0 if b == 0 else batch_offset[b-1]
my_f_len = f_len[b]
my_g_len = g_len[b]
for t in range(my_f_len):
x_unpacked[b, t, :my_g_len] = x[my_batch_offset + t*my_g_len :
my_batch_offset + t*my_g_len + my_g_len]
return x_unpacked
def run_transducer_joint(self, for_vector_kernel, pack_output, relu, dropout):
self.gen_input(for_vector_kernel=for_vector_kernel)
# Generate reference
f_ref = self.f_tst.data.clone()
g_ref = self.g_tst.data.clone()
f_ref.requires_grad = True
g_ref.requires_grad = True
my_joint = TransducerJoint(pack_output=pack_output, relu=relu, dropout=dropout,
dropout_prob=self.dropout_prob, probe_mask=True)
if not pack_output:
h_tst = my_joint( f=self.f_tst,
g=self.g_tst,
f_len=self.f_len,
g_len=self.g_len)
h_tst.backward(self.h_grad)
if dropout:
mask = my_joint.mask_probe[0]
else:
batch_offset = torch.cumsum(self.f_len * self.g_len, dim=0)
h_tst = my_joint( f=self.f_tst,
g=self.g_tst,
f_len=self.f_len,
g_len=self.g_len,
batch_offset=batch_offset,
packed_batch=batch_offset[-1])
h_tst.backward(self.h_grad_packed)
if dropout:
mask_packed = my_joint.mask_probe[0]
mask = self._unpack(mask_packed, self.f_len, self.g_len)
# reference
h_ref, f_grad_ref, g_grad_ref \
= transducer_ref.transducer_joint_reference(f=f_ref,
g=g_ref,
h_grad=self.h_grad,
f_len=self.f_len,
g_len=self.g_len,
pack_output=pack_output,
relu=relu,
dropout=dropout,
dropout_prob=self.dropout_prob,
mask=mask if dropout else None)
f_grad_tst = self.f_tst.grad
g_grad_tst = self.g_tst.grad
self.assertTrue(torch.allclose(h_ref, h_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(f_grad_ref, f_grad_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(g_grad_ref, g_grad_tst, atol=1e-4, rtol=1e-4))
def test_transducer_joint(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False)
def test_transducer_joint_vec(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=False, dropout=False)
def test_transducer_joint_pack(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=False, dropout=False)
def test_transducer_joint_vec_pack(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False)
def test_transducer_joint_relu(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False)
def test_transducer_joint_vec_relu(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=False)
def test_transducer_joint_pack_relu(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=False)
def test_transducer_joint_vec_pack_relu(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False)
def test_transducer_joint_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)
def test_transducer_joint_vec_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=True)
def test_transducer_joint_pack_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=True)
def test_transducer_joint_vec_pack_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
import torch
import unittest
from apex.contrib.transducer import TransducerLoss
import transducer_ref
class TransducerLossTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def gen_input(self, scalar_t, for_vector_kernel):
self.B = 5
T_min = 23
T_max = 51
U_min = 12
U_max = 25
V = 16 if for_vector_kernel else 14
self.blank_idx = V - 1
device = "cuda"
self.x_tst = torch.randn((self.B, T_max, U_max, V), dtype=scalar_t, requires_grad=True,
device=device)
self.y = torch.randint(0, self.blank_idx, (self.B, U_max-1), dtype=torch.int, device=device)
self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device)
self.y_len = torch.randint(U_min-1, U_max, (self.B,), dtype=torch.int, device=device)
self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max
self.y_len[torch.randint(0, self.B, (1,)).item()] = U_max-1
self.x_tst_packed, self.batch_offset = self._pack(self.x_tst)
# Generate reference
x_ref = self.x_tst.data.clone()
x_ref.requires_grad = True
loss_grad = torch.ones(x_ref.size(0), dtype=x_ref.dtype, device=x_ref.device)/x_ref.size(0)
_, _, self.grad_ref, self.loss_ref \
= transducer_ref.transducer_loss_reference( x=x_ref,
label=self.y,
f_len=self.f_len,
y_len=self.y_len,
blank_idx=self.blank_idx,
loss_grad=loss_grad)
def _pack(self, x):
list_x = []
for b in range(self.B):
list_x_row = [x[b, t, : self.y_len[b]+1] for t in range(self.f_len[b])]
x_row = torch.cat(list_x_row)
list_x.append(x_row)
x_packed = torch.cat(list_x).data.clone()
x_packed.requires_grad = True
batch_offset = torch.cumsum(self.f_len * (self.y_len+1), dim=0)
return x_packed, batch_offset
def _unpack(self, x):
x_unpacked = torch.zeros(self.B, self.f_len.max(), self.y_len.max()+1, x.size(-1),
dtype=x.dtype, device=x.device)
for b in range(self.B):
my_batch_offset = 0 if b == 0 else self.batch_offset[b-1]
my_f_len = self.f_len[b]
my_g_len = self.y_len[b] + 1
for t in range(my_f_len):
for u in range(my_g_len):
x_unpacked[b, t, u] = x[my_batch_offset + t*my_g_len + u]
return x_unpacked
def run_transducer_loss(self, scalar_t, fuse_softmax_backward, packed_input, for_vector_kernel):
self.gen_input(scalar_t, for_vector_kernel)
my_loss = TransducerLoss( fuse_softmax_backward=fuse_softmax_backward,
packed_input=packed_input)
if not packed_input:
loss_tst = my_loss( x=self.x_tst,
label=self.y,
f_len=self.f_len,
y_len=self.y_len,
blank_idx=self.blank_idx)
loss_tst.mean().backward()
grad_tst = self.x_tst.grad
else:
loss_tst = my_loss( x=self.x_tst_packed,
label=self.y,
f_len=self.f_len,
y_len=self.y_len,
blank_idx=self.blank_idx,
batch_offset=self.batch_offset,
max_f_len=max(self.f_len))
loss_tst.mean().backward()
grad_tst_packed = self.x_tst_packed.grad
grad_tst = self._unpack(grad_tst_packed)
return loss_tst, grad_tst
def test_transducer_loss_fp32(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float32,
fuse_softmax_backward=False,
packed_input=False,
for_vector_kernel=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-5, rtol=1e-5))
def test_transducer_loss_fp16(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=False,
packed_input=False,
for_vector_kernel=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
def test_transducer_loss_fp16_backward_fusion(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=True,
packed_input=False,
for_vector_kernel=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
def test_transducer_loss_fp16_backward_fusion_packed(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=True,
packed_input=True,
for_vector_kernel=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
def test_transducer_loss_fp16_backward_fusion_packed_vec(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=True,
packed_input=True,
for_vector_kernel=True)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
import torch
import numpy as np
import pdb
def transducer_loss_reference(x, label, f_len, y_len, blank_idx, loss_grad):
def log_sum_exp(a, b):
if (a >= b):
return a + torch.log(1 + torch.exp(b-a))
else:
return b + torch.log(1 + torch.exp(a-b))
def forward_alpha(x, label, f_len, y_len, blank_idx):
B, T, U, V = x.size()
acc_t = torch.float32 if x.dtype in [torch.float16, torch.float32] else x.dtype
alpha = torch.zeros((B, T, U), dtype=acc_t, device=x.device)
for b in range(B):
alpha[b, 0, 0] = 0
for t in range(1, f_len[b]):
alpha[b, t, 0] = alpha[b, t-1, 0] + x[b, t-1, 0, blank_idx]
for u in range(1, y_len[b]+1):
alpha[b, 0, u] = alpha[b, 0, u-1] + x[b, 0, u-1, label[b, u-1]]
for t in range(1, f_len[b]):
for u in range(1, y_len[b]+1):
curr_ = alpha[b, t-1, u] + x[b, t-1, u, blank_idx]
next_ = alpha[b, t, u-1] + x[b, t, u-1, label[b, u-1]]
alpha[b, t, u] = log_sum_exp(curr_, next_)
return alpha
def forward_beta(x, label, f_len, y_len, blank_idx):
B, T, U, V = x.shape
acc_t = torch.float32 if x.dtype in [torch.float16, torch.float32] else x.dtype
beta = torch.zeros((B, T, U), dtype=acc_t, device=x.device)
for b in range(B):
beta[b, f_len[b]-1, y_len[b]] = x[b, f_len[b]-1, y_len[b], blank_idx]
for t in range(f_len[b]-2, -1, -1):
beta[b, t, y_len[b]] = beta[b, t+1, y_len[b]] + x[b, t, y_len[b], blank_idx]
for u in range(y_len[b]-1, -1, -1):
beta[b, f_len[b]-1, u] = beta[b, f_len[b]-1, u+1] + x[b, f_len[b]-1, u, label[b, u]]
for t in range(f_len[b]-2, -1, -1):
for u in range(y_len[b]-1, -1, -1):
curr_ = beta[b, t+1, u] + x[b, t, u, blank_idx]
next_ = beta[b, t, u+1] + x[b, t, u, label[b, u]]
beta[b, t, u] = log_sum_exp(curr_, next_)
return beta
def backward(x, label, f_len, y_len, alpha, beta, loss_grad, blank_idx):
grad = torch.zeros_like(x)
B, T, U, V = x.size()
for b in range(B):
common_factor = torch.log(loss_grad[b]) + alpha - beta[b, 0, 0]
# next
for u in range(y_len[b]):
grad[b, :f_len[b], u, label[b, u]] = -torch.exp(common_factor[b, :f_len[b], u]
+ beta[b, :f_len[b], u+1]
+ x[b, :f_len[b], u, label[b, u]])
# current
grad[b, :f_len[b]-1, :y_len[b]+1, blank_idx] \
= -torch.exp(common_factor[b, :f_len[b]-1, :y_len[b]+1]
+ beta[b, 1:f_len[b], :y_len[b]+1]
+ x[b, :f_len[b]-1, :y_len[b]+1, blank_idx])
grad[b, f_len[b]-1, y_len[b], blank_idx] = -torch.exp(common_factor[b, f_len[b]-1, y_len[b]]
+ x[b, f_len[b]-1, y_len[b], blank_idx])
return grad
x_log = torch.nn.functional.log_softmax(x, dim=-1)
alpha = forward_alpha(x_log, label, f_len, y_len, blank_idx)
beta = forward_beta(x_log, label, f_len, y_len, blank_idx)
grad = backward(x_log, label, f_len, y_len, alpha, beta,
loss_grad, blank_idx)
x_log.backward(grad)
loss = -beta[:, 0, 0]
loss = loss.to(x.dtype)
return alpha, beta, x.grad, loss
def transducer_joint_reference(f, g, h_grad, f_len, g_len, pack_output, relu, dropout,
dropout_prob=0, mask=None):
if dropout and mask == None:
raise NotImplementedError("mask needs to supplied to test dropout.")
B, T, H = f.size()
U = g.size(1)
f_expand = f.unsqueeze(dim=2)
g_expand = g.unsqueeze(dim=1)
h = f_expand + g_expand
if relu:
h = torch.nn.functional.relu(h)
if dropout:
h *= mask
scale = 1/(1-dropout_prob)
h *= scale
h.backward(h_grad)
if pack_output == False:
# intentionally set don't-care region to -1 to test if transducer joint
# write these regions to avoid NaN and inf
for b in range(B):
h[b, f_len[b]:] = -1
h[b, :, g_len[b]:] = -1
return h, f.grad, g.grad
# packing
list_to_pack = []
for b in range(B):
list_to_pack.append(h[b, :f_len[b], :g_len[b], :].reshape(-1, H))
h_packed = torch.cat(list_to_pack)
return h_packed, f.grad, g.grad
from .transducer import TransducerJoint
from .transducer import TransducerLoss
\ No newline at end of file
import torch
import transducer_loss_cuda
import transducer_joint_cuda
class TransducerJoint(torch.nn.Module):
"""Transducer joint
Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural
Networks
Arguments:
pack_output (bool, optional): whether to pack the output in a compact form with don't-care
data being removed. (default: False)
relu (bool, optional): apply ReLU to the output of the joint operation. Requires opt=1
(default: False)
dropout (bool, optional): apply dropout to the output of the joint operation. Requires opt=1
(default: False)
opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a tiled algorithm.
(default: 1)
fwd_tile_size (int, optional): tile size used in forward operation. This argument will be
ignored if opt != 1. (default: 4)
dropout_prob (float, optional): dropout probability. (default: 0.0)
probe_mask (bool, optional): a flag used to probe the mask generated by ReLU and/or dropout
operation. When this argument is set to True, the mask can be accessed through
self.mask_probe. (default: false)
"""
def __init__(self, pack_output=False, relu=False, dropout=False, opt=1, fwd_tile_size=4,
dropout_prob=0, probe_mask=False):
super(TransducerJoint, self).__init__()
self.pack_output = pack_output
self.relu = relu
self.dropout = dropout
self.dropout_prob = dropout_prob
self.opt = opt
self.fwd_tile_size = fwd_tile_size
self.dummy_batch_offset = torch.empty(0)
masked = self.relu or self.dropout
self.mask_probe = [] if masked and probe_mask else None
if masked and opt != 1:
raise NotImplementedError("ReLU and dropout fusion is only supported with opt=1")
def forward(self, f, g, f_len, g_len, batch_offset=None, packed_batch=0):
"""Forward operation of transducer joint
Arguments:
f (tensor): transcription vector from encode block of shape (B, T, H).
g (tensor): prediction vector form predict block of shape (B, U, H).
f_len (tensor): length of transcription vector for each batch.
g_len (tensor): length of prediction vector minus 1 for each batch.
batch_offset (tensor, optional): tensor containing the offset of each batch
in the results. For example, batch offset can be obtained from:
batch_offset = torch.cumsum(f_len*g_len, dim=0)
This argument is required if pack_output == True, and is ignored if
pack_output == False. (default: None)
packed_batch (int, optional): the batch size after packing. This argument is
ignored if pack_output == False. (default: 0)
"""
my_batch_offset = batch_offset if self.pack_output else self.dummy_batch_offset
if self.pack_output and (batch_offset is None or packed_batch == 0):
raise Exception("Please specify batch_offset and packed_batch when packing is enabled")
dropout = self.dropout and self.training # only dropout for training
return TransducerJointFunc.apply(f, g, f_len, g_len, self.pack_output, self.relu, dropout,
my_batch_offset, packed_batch, self.opt,
self.fwd_tile_size, self.dropout_prob, self.mask_probe)
class TransducerLoss(torch.nn.Module):
"""Transducer loss
Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural
Networks
Arguments:
fuse_softmax_backward (bool, optional) whether to fuse the backward of transducer loss with
softmax. (default: True)
opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a more optimized
algorithm. In some cases, opt=1 might fall back to opt=0. (default: 1)
packed_input (bool, optional): whether to pack the output in a compact form with don't-care
data being removed. (default: False)
"""
def __init__(self, fuse_softmax_backward=True, opt=1, packed_input=False):
super(TransducerLoss, self).__init__()
self.fuse_softmax_backward = fuse_softmax_backward
self.opt = opt
self.packed_input = packed_input
self.dummy_batch_offset = torch.empty(0)
def forward(self, x, label, f_len, y_len, blank_idx, batch_offset=None, max_f_len=None,
debug_list=None):
"""Forward operation of transducer joint
Arguments:
x (tensor): input tensor to the loss function with a shape of (B, T, U, H).
label (tensor): labels for the input data.
f_len (tensor): lengths of the inputs in the time dimension for each batch.
y_len (tensor): lengths of the labels for each batch.
blank_idx (int): index for the null symbol.
batch_offset (tensor, optional): tensor containing the offset of each batch
in the input. For example, batch offset can be obtained from:
batch_offset = torch.cumsum(f_len*(y_len+1), dim=0)
This argument is required if packed_input == True, and is ignored if
packed_input == False. (default: None)
max_f_len (int, optional): maximum length of the input in the time dimension.
For example, it can be obtained as
max_f_len = max(f_len)
This argument is required if packed_input == True, and is ignored if
packed_input == False. (default: None)
(default: None)
debug_list (list, optional): when an empty list is supplied, Alpha and Beta generated
in the forward operation will be attached to this list for debug purpose.
(default: None)
"""
if self.packed_input:
if batch_offset is None or max_f_len is None:
raise Exception("Please specify batch_offset and max_f_len when packing is \
enabled")
my_batch_offset = batch_offset
my_max_f_len = max_f_len
else:
my_batch_offset = self.dummy_batch_offset
my_max_f_len = x.size(1)
return TransducerLossFunc.apply(x, label, f_len, y_len, my_batch_offset, my_max_f_len,
blank_idx, self.fuse_softmax_backward, debug_list,
self.opt, self.packed_input)
class TransducerLossFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, x, label, f_len, y_len, batch_offset, max_f_len, blank_idx,
fuse_softmax_backward, debug_list, opt, packed_input):
if fuse_softmax_backward == False:
with torch.enable_grad():
x = torch.nn.functional.log_softmax(x, dim=-1)
else:
x = torch.nn.functional.log_softmax(x, dim=-1)
alpha, beta, loss = transducer_loss_cuda.forward( x, label, f_len, y_len, batch_offset,
max_f_len, blank_idx, opt, packed_input)
if debug_list == []:
debug_list += [alpha, beta]
ctx.save_for_backward(x, alpha, beta, f_len, y_len, label, batch_offset)
ctx.blank_idx = blank_idx
ctx.fuse_softmax_backward = fuse_softmax_backward
ctx.opt = opt
ctx.packed_input = packed_input
ctx.max_f_len = max_f_len
return loss
@staticmethod
def backward(ctx, loss_grad):
x, alpha, beta, f_len, y_len, label, batch_offset = ctx.saved_tensors
x_grad = transducer_loss_cuda.backward( x, loss_grad, alpha, beta, f_len, y_len, label,
batch_offset, ctx.max_f_len, ctx.blank_idx, ctx.opt,
ctx.fuse_softmax_backward, ctx.packed_input)
if ctx.fuse_softmax_backward == False:
x_grad = x.backward(x_grad)
return x_grad, None, None, None, None, None, None, None, None, None, None
class TransducerJointFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, f, g, f_len, g_len, pack_output, relu, dropout, batch_offset, packed_batch,
opt, fwd_tile_size, dropout_prob, mask_probe):
h = transducer_joint_cuda.forward(f, g, f_len, g_len, batch_offset, packed_batch, opt,
pack_output, relu, dropout, dropout_prob, fwd_tile_size)
masked = relu or dropout
if masked:
ctx.save_for_backward(h[1], f_len, g_len, batch_offset)
if mask_probe is not None:
mask_probe.append(h[1])
else:
ctx.save_for_backward(f_len, g_len, batch_offset)
ctx.pack_output = pack_output
ctx.masked = relu or dropout
ctx.max_f_len = f.size(1)
ctx.max_g_len = g.size(1)
ctx.scale = 1 / (1-dropout_prob) if dropout and dropout_prob != 1 else 1
return h[0]
@staticmethod
def backward(ctx, loss_grad):
if ctx.masked:
mask, f_len, g_len, batch_offset = ctx.saved_tensors
inp = [loss_grad, mask]
else:
f_len, g_len, batch_offset = ctx.saved_tensors
inp = [loss_grad]
f_grad, g_grad = transducer_joint_cuda.backward( inp, f_len, g_len, batch_offset,
ctx.max_f_len, ctx.max_g_len,
ctx.pack_output, ctx.scale)
return f_grad, g_grad, None, None, None, None, None, None, None, None, None, None, None, \
None, None, None
from .fused_dense import *
import torch
from torch import nn
import fused_dense_cuda
from .. import amp
#implements fused GEMM+bias in forward pass using mlp_cuda from apex
class FusedDenseFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias):
ctx.save_for_backward(input, weight)
output = fused_dense_cuda.linear_bias_forward(input, weight, bias)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_input, grad_weight, grad_bias = fused_dense_cuda.linear_bias_backward(input, weight, grad_output)
return grad_input, grad_weight, grad_bias
class DenseNoBiasFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight):
ctx.save_for_backward(input, weight)
output = torch.matmul(input, weight.t())
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_input = grad_output.mm(weight)
grad_weight = grad_output.t().mm(input)
return grad_input, grad_weight
class FusedDenseGeluDenseFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight1, bias1, weight2, bias2):
ctx.save_for_backward(input, weight1, weight2)
output1, output2, gelu_in = fused_dense_cuda.linear_gelu_linear_forward(input, weight1, bias1, weight2, bias2)
ctx.save_for_backward(input, weight1, weight2, gelu_in, output1)
return output2
@staticmethod
def backward(ctx, grad_output):
input, weight1, weight2, gelu_in, output1 = ctx.saved_tensors
grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_gelu_linear_backward(input, gelu_in, output1, weight1, weight2, grad_output)
return grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2
fused_dense_function = amp.half_function(FusedDenseFunc.apply)
dense_no_bias_function = amp.half_function(DenseNoBiasFunc.apply)
fused_dense_gelu_dense_function = amp.half_function(FusedDenseGeluDenseFunc.apply)
class FusedDense(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super(FusedDense, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_features))
else:
#assert False, "no-bias option not added yet"
self.register_parameter('bias', None)
def forward(self, input):
if self.bias is not None:
return fused_dense_function(input, self.weight, self.bias)
else:
return dense_no_bias_function(input, self.weight)
class FusedDenseGeluDense(nn.Module):
def __init__(self, in_features, intermediate_features, out_features, bias=True):
super(FusedDenseGeluDense, self).__init__()
assert bias == True, "DenseGeluDense module without bias is currently not supported"
self.in_features = in_features
self.intermediate_features = intermediate_features
self.out_features = out_features
self.weight1 = nn.Parameter(torch.Tensor(intermediate_features, in_features))
self.bias1 = nn.Parameter(torch.Tensor(intermediate_features))
self.weight2 = nn.Parameter(torch.Tensor(out_features, intermediate_features))
self.bias2 = nn.Parameter(torch.Tensor(out_features))
def forward(self, input):
return fused_dense_gelu_dense_function(input, self.weight1, self.bias1, self.weight2, self.bias2)
from .fused_layer_norm import FusedLayerNorm
from .fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm
import math
import torch
import importlib
import numbers
import torch
from torch.nn.parameter import Parameter
from torch.nn import init
from torch.nn import functional as F
import importlib
from apex._autocast_utils import _cast_if_autocast_enabled
global fused_layer_norm_cuda
fused_layer_norm_cuda = None
class FusedLayerNormAffineFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward_affine(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps
)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps
)
return grad_input, grad_weight, grad_bias, None, None
class FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFunction):
@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward_affine_mixed_dtypes(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps
)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward_affine(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None
class FusedLayerNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward(input_, ctx.normalized_shape, ctx.eps)
ctx.save_for_backward(input_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, mean, invvar = ctx.saved_tensors
grad_input = None
grad_input = fused_layer_norm_cuda.backward(
grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, ctx.eps
)
return grad_input, None, None
def fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6):
args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps)
with torch.cuda.amp.autocast(enabled=False):
return FusedLayerNormAffineFunction.apply(*args)
@staticmethod
def forward(ctx, input, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward(
input_, ctx.normalized_shape, ctx.eps)
ctx.save_for_backward(input_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, mean, invvar = ctx.saved_tensors
grad_input = None
grad_input = fused_layer_norm_cuda.backward(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
ctx.eps)
return grad_input, None, None
def fused_layer_norm_affine(input, normalized_shape, weight, bias, eps=1e-6):
return FusedLayerNormAffineFunction.apply(input, weight, bias, normalized_shape, eps)
def fused_layer_norm(input, normalized_shape, eps=1e-6):
return FusedLayerNormFunction.apply(input, normalized_shape, eps)
args = _cast_if_autocast_enabled(input, normalized_shape, eps)
with torch.cuda.amp.autocast(enabled=False):
return FusedLayerNormFunction.apply(*args)
def mixed_dtype_fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6):
args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps)
with torch.cuda.amp.autocast(enabled=False):
return FusedLayerNormAffineMixedDtypesFunction.apply(*args)
class FusedLayerNorm(torch.nn.Module):
r"""Applies Layer Normalization over a mini-batch of inputs as described in
......@@ -126,8 +158,9 @@ class FusedLayerNorm(torch.nn.Module):
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
"""
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super(FusedLayerNorm, self).__init__()
super().__init__()
global fused_layer_norm_cuda
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
......@@ -141,8 +174,8 @@ class FusedLayerNorm(torch.nn.Module):
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.register_parameter("weight", None)
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self):
......@@ -152,14 +185,34 @@ class FusedLayerNorm(torch.nn.Module):
def forward(self, input):
if not input.is_cuda:
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps)
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
if self.elementwise_affine:
return FusedLayerNormAffineFunction.apply(
input, self.weight, self.bias, self.normalized_shape,self.eps)
return fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps)
else:
return FusedLayerNormFunction.apply(input, self.normalized_shape, self.eps)
return fused_layer_norm(input, self.normalized_shape, self.eps)
def extra_repr(self):
return '{normalized_shape}, eps={eps}, ' \
'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__)
# NOTE (mkozuki): Why "mixed"?
# MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype
# as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype.
# See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in "csrc/layer_norm_cuda.cpp"
class MixedFusedLayerNorm(FusedLayerNorm):
def __init__(self, normalized_shape, eps=1e-5, **kwargs):
if "elementwise_affine" in kwargs:
import warnings
warnings.warn("MixedFusedLayerNorm does not support `elementwise_affine` argument")
elementwise_affine = kwargs.pop("elementwise_affine")
if not elementwise_affine:
raise RuntimeError("MixedFusedLayerNorm does not support `elementwise_affine = False`")
super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=True)
def forward(self, input: torch.Tensor):
# NOTE (mkozuki): CPU path is here mainly for unittest sake.
if not input.is_cuda:
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
return mixed_dtype_fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps)
......@@ -13,7 +13,8 @@ class FusedAdam(torch.optim.Optimizer):
* Fusion of the Adam update's elementwise operations
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
:class:`apex.optimizers.FusedAdam` may be used as a drop-in replacement for ``torch.optim.Adam``::
:class:`apex.optimizers.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,
or ``torch.optim.Adam`` with ``adam_w_mode=False``::
opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....)
...
......
......@@ -79,7 +79,9 @@ class FusedNovoGrad(torch.optim.Optimizer):
if multi_tensor_applier.available:
import amp_C
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
# Creating the overflow buffer on the same device as the params tensors.
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
self.multi_tensor_novograd = amp_C.multi_tensor_novograd
else:
raise RuntimeError('apex.optimizers.FusedNovoGrad requires cuda extensions')
......@@ -158,8 +160,9 @@ class FusedNovoGrad(torch.optim.Optimizer):
if 'exp_avg_sq' not in group:
group['exp_avg_sq'] = [None, None]
if group['init_zero']:
group['exp_avg_sq'][0] = torch.cuda.FloatTensor(len(g_16)).contiguous().fill_(0)
group['exp_avg_sq'][1] = torch.cuda.FloatTensor(len(g_32)).contiguous().fill_(0)
# Creating the following parameters on the same device as the params tensors.
group['exp_avg_sq'][0] = torch.cuda.FloatTensor(len(g_16), device=self.param_groups[0]["params"][0].device).contiguous().fill_(0)
group['exp_avg_sq'][1] = torch.cuda.FloatTensor(len(g_32), device=self.param_groups[0]["params"][0].device).contiguous().fill_(0)
else: # init with first step norm, so first blend have no effect
if group['norm_type'] == 0:
v_16 = [torch.max(torch.abs(g.to(torch.float32))).item() for g in g_16]
......@@ -169,8 +172,9 @@ class FusedNovoGrad(torch.optim.Optimizer):
v_32 = [torch.sum(torch.pow(g, 2)).sqrt().item() for g in g_32]
else:
raise RuntimeError('FusedNovoGrad only support l2/inf norm now.')
group['exp_avg_sq'][0] = torch.cuda.FloatTensor(v_16)
group['exp_avg_sq'][1] = torch.cuda.FloatTensor(v_32)
# Creating the following parameters on the same device as the params tensors.
group['exp_avg_sq'][0] = torch.cuda.FloatTensor(v_16, device=self.param_groups[0]["params"][0].device)
group['exp_avg_sq'][1] = torch.cuda.FloatTensor(v_32, device=self.param_groups[0]["params"][0].device)
else:
assert(len(g_16) == group['exp_avg_sq'][0].numel())
assert(len(g_32) == group['exp_avg_sq'][1].numel())
......
# apex.transformer
`apex.transformer` is a module which enables efficient large Transformer models at scale.
`apex.transformer.tensor_parallel` is based on [NVIDIA/Megatron-LM](https://github.com/NVIDIA/Megatron-LM)'s `megatron.mpu` module.
from . import tensor_parallel
from . import functional
from .enums import LayerType
from .enums import AttnType
from .enums import AttnMaskType
from .parallel_state import (
is_unitialized,
destroy_model_parallel,
get_data_parallel_group,
get_data_parallel_rank,
get_data_parallel_world_size,
get_embedding_group,
get_model_parallel_group,
get_tensor_model_parallel_group,
get_pipeline_model_parallel_group,
get_tensor_model_parallel_rank,
set_tensor_model_parallel_rank,
get_pipeline_model_parallel_rank,
set_pipeline_model_parallel_rank,
is_pipeline_first_stage,
is_pipeline_last_stage,
get_tensor_model_parallel_src_rank,
get_pipeline_model_parallel_first_rank,
get_pipeline_model_parallel_last_rank,
get_pipeline_model_parallel_next_rank,
get_pipeline_model_parallel_prev_rank,
get_tensor_model_parallel_world_size,
set_tensor_model_parallel_world_size,
get_pipeline_model_parallel_world_size,
set_pipeline_model_parallel_world_size,
get_virtual_pipeline_model_parallel_rank,
set_virtual_pipeline_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
class LayerType(enum.Enum):
encoder = 1
decoder = 2
class AttnType(enum.Enum):
self_attn = 1
cross_attn = 2
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
from .fused_softmax import FusedScaleMaskSoftmax
__all__ = [
"FusedScaleMaskSoftmax",
]
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from apex._autocast_utils import _cast_if_autocast_enabled
from ..enums import AttnMaskType
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, scale):
import scaled_upper_triang_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
import scaled_upper_triang_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_upper_triang_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None
def scaled_upper_triang_masked_softmax(inputs, _, scale):
b, np, sq, sk = inputs.size()
assert sq == sk, "causal mask is only for self attention"
# Reshaping input to 3D tensor (attn_batches, sq, sk)
inputs = inputs.view(-1, sq, sk)
args = _cast_if_autocast_enabled(inputs, scale)
with torch.cuda.amp.autocast(enabled=False):
probs = ScaledUpperTriangMaskedSoftmax.apply(*args)
return probs.view(b, np, sq, sk)
# NOTE (mkozuki): `ScaledMaskedSoftmax` somehow doesn't work well with `torch.cuda.amp.custom_fwd`.
# Without `cast_inputs` kwarg, somehow inputs are not cast to dtype used in the autocast context.
# So I needed to manually write two `torch.autograd.Function` inheritances.
# Fused operation which performs following three operations in sequence
# 1. Scale the tensor.
# 2. Apply the mask.
# 3. Perform softmax.
class ScaledMaskedSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, mask, scale):
import scaled_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
import scaled_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None
def scaled_masked_softmax(inputs, mask, scale):
# input is 4D tensor (b, np, sq, sk)
args = _cast_if_autocast_enabled(inputs, mask, scale)
with torch.cuda.amp.autocast(enabled=False):
return ScaledMaskedSoftmax.apply(*args)
class FusedScaleMaskSoftmax(torch.nn.Module):
"""
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal)
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def __init__(
self,
input_in_fp16,
input_in_bf16,
attn_mask_type,
scaled_masked_softmax_fusion,
mask_func,
softmax_in_fp32,
scale,
):
super().__init__()
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
if self.input_in_fp16 and self.input_in_bf16:
raise RuntimeError("both fp16 and bf16 flags cannot be active at the same time.")
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
if not (self.scale is None or softmax_in_fp32):
raise RuntimeError("softmax should be in fp32 when scaled")
if self.scaled_masked_softmax_fusion:
if self.attn_mask_type == AttnMaskType.causal:
self.fused_softmax_func = scaled_upper_triang_masked_softmax
elif self.attn_mask_type == AttnMaskType.padding:
self.fused_softmax_func = scaled_masked_softmax
else:
raise ValueError("Invalid attn_mask_type.")
def forward(self, input, mask):
# [b, np, sq, sk]
assert input.dim() == 4
if self.is_kernel_available(mask, *input.size()):
return self.forward_fused_softmax(input, mask)
else:
return self.forward_torch_softmax(input, mask)
def is_kernel_available(self, mask, b, np, sq, sk):
attn_batches = b * np
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None
and 16 < sk <= 2048 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
if self.attn_mask_type == AttnMaskType.causal:
if attn_batches % batch_per_block == 0:
return True
else:
if sq % batch_per_block == 0:
return True
return False
def forward_fused_softmax(self, input, mask):
# input.shape = [b, np, sq, sk]
scale = self.scale if self.scale is not None else 1.0
return self.fused_softmax_func(input, mask, scale)
def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs
@staticmethod
def get_batch_per_block(sq, sk, b, np):
import scaled_masked_softmax_cuda
return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model and data parallel groups."""
import torch
# TODO (mkozuki): Consider dissecting utils as this utils import is here
# only for ensure_divisibility
from .tensor_parallel import utils
# Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# Inter-layer model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
_MODEL_PARALLEL_GROUP = None
# Embedding group.
_EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
# These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage
_PIPELINE_GLOBAL_RANKS = None
def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization"""
return _DATA_PARALLEL_GROUP is None
def initialize_model_parallel(
tensor_model_parallel_size_=1, pipeline_model_parallel_size_=1, virtual_pipeline_model_parallel_size_=None
):
"""
Initialize model data parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used to parallelize model tensor.
pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
and 8 data-parallel groups as:
8 data_parallel groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
8 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
4 pipeline model-parallel groups:
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
if torch.distributed.get_rank() == 0:
print("> initializing tensor model parallel with size {}".format(tensor_model_parallel_size_))
print("> initializing pipeline model parallel with size {}".format(pipeline_model_parallel_size_))
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)
# TODO (mkozuki): Consider moving `ensure_divisibility` to this file.
utils.ensure_divisibility(world_size, tensor_model_parallel_size * pipeline_model_parallel_size)
data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
num_data_parallel_groups = world_size // data_parallel_size
if virtual_pipeline_model_parallel_size_ is not None:
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_
rank = torch.distributed.get_rank()
# Build the data-parallel groups.
global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"
all_data_parallel_group_ranks = []
for i in range(pipeline_model_parallel_size):
start_rank = i * num_pipeline_model_parallel_groups
end_rank = (i + 1) * num_pipeline_model_parallel_groups
for j in range(tensor_model_parallel_size):
ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
all_data_parallel_group_ranks.append(list(ranks))
group = torch.distributed.new_group(ranks)
if rank in ranks:
_DATA_PARALLEL_GROUP = group
# Build the model-parallel groups.
global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized"
for i in range(data_parallel_size):
ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks]
group = torch.distributed.new_group(ranks)
if rank in ranks:
_MODEL_PARALLEL_GROUP = group
# Build the tensor model-parallel groups.
global _TENSOR_MODEL_PARALLEL_GROUP
assert _TENSOR_MODEL_PARALLEL_GROUP is None, "tensor model parallel group is already initialized"
for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group
# Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group).
global _PIPELINE_MODEL_PARALLEL_GROUP
global _PIPELINE_GLOBAL_RANKS
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, "pipeline model parallel group is already initialized"
global _EMBEDDING_GROUP
assert _EMBEDDING_GROUP is None, "embedding group is already initialized"
for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group
_PIPELINE_GLOBAL_RANKS = ranks
# Setup embedding group (to exchange gradients between
# first and last stages).
if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]]
else:
embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks)
if rank in embedding_ranks:
_EMBEDDING_GROUP = group
def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized."""
if _TENSOR_MODEL_PARALLEL_GROUP is None or _PIPELINE_MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None:
return False
return True
def get_model_parallel_group():
"""Get the model parallel group the caller rank belongs to."""
assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized"
return _MODEL_PARALLEL_GROUP
def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, "intra_layer_model parallel group is not initialized"
return _TENSOR_MODEL_PARALLEL_GROUP
def get_pipeline_model_parallel_group():
"""Get the pipeline model parallel group the caller rank belongs to."""
assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, "pipeline_model parallel group is not initialized"
return _PIPELINE_MODEL_PARALLEL_GROUP
def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized"
return _DATA_PARALLEL_GROUP
def get_embedding_group():
"""Get the embedding group the caller rank belongs to."""
assert _EMBEDDING_GROUP is not None, "embedding group is not initialized"
return _EMBEDDING_GROUP
def set_tensor_model_parallel_world_size(world_size):
"""Set the tensor model parallel size"""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
def set_pipeline_model_parallel_world_size(world_size):
"""Set the pipeline model parallel size"""
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
def get_pipeline_model_parallel_world_size():
"""Return world size for the pipeline model parallel group."""
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
def set_tensor_model_parallel_rank(rank):
"""Set tensor model parallel rank."""
global _MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK = rank
def set_pipeline_model_parallel_rank(rank):
"""Set pipeline model parallel rank."""
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_RANK
if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
return _MPU_TENSOR_MODEL_PARALLEL_RANK
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
def get_pipeline_model_parallel_rank():
"""Return my rank for the pipeline model parallel group."""
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
return _MPU_PIPELINE_MODEL_PARALLEL_RANK
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
if (
get_virtual_pipeline_model_parallel_world_size() is not None
and get_virtual_pipeline_model_parallel_rank() != 0
):
return False
return get_pipeline_model_parallel_rank() == 0
def is_pipeline_last_stage(ignore_virtual=False):
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
virtual_pipeline_model_parallel_world_size = get_virtual_pipeline_model_parallel_world_size()
if virtual_pipeline_model_parallel_world_size is not None and get_virtual_pipeline_model_parallel_rank() != (
virtual_pipeline_model_parallel_world_size - 1
):
return False
return get_pipeline_model_parallel_rank() == (get_pipeline_model_parallel_world_size() - 1)
def get_virtual_pipeline_model_parallel_rank():
"""Return the virtual pipeline-parallel rank."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
def set_virtual_pipeline_model_parallel_rank(rank):
"""Set the virtual pipeline-parallel rank."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank
def get_virtual_pipeline_model_parallel_world_size():
"""Return the virtual pipeline-parallel world size."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank = torch.distributed.get_rank()
local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size
def get_pipeline_model_parallel_first_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
return _PIPELINE_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_last_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_next_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
def get_pipeline_model_parallel_prev_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
def get_data_parallel_world_size():
"""Return world size for the data parallel group."""
return torch.distributed.get_world_size(group=get_data_parallel_group())
def get_data_parallel_rank():
"""Return my rank for the data parallel group."""
return torch.distributed.get_rank(group=get_data_parallel_group())
def destroy_model_parallel():
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = None
global _TENSOR_MODEL_PARALLEL_GROUP
_TENSOR_MODEL_PARALLEL_GROUP = None
global _PIPELINE_MODEL_PARALLEL_GROUP
_PIPELINE_MODEL_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None
global _EMBEDDING_GROUP
_EMBEDDING_GROUP = None
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model parallel utility interface."""
from .cross_entropy import vocab_parallel_cross_entropy
from .data import broadcast_data
from .layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes,
)
from .mappings import (
copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region,
scatter_to_tensor_model_parallel_region,
)
from .random import (
checkpoint,
get_cuda_rng_tracker,
init_checkpointed_activations_memory_buffer,
model_parallel_cuda_manual_seed,
reset_checkpointed_activations_memory_buffer,
gather_split_1d_tensor,
split_tensor_into_1d_equal_chunks,
)
from .utils import divide, split_tensor_along_last_dim
__all__ = [
# cross_entropy.py
"vocab_parallel_cross_entropy",
# data.py
"broadcast_data",
# layers.py
"ColumnParallelLinear",
"RowParallelLinear",
"VocabParallelEmbedding",
"set_tensor_model_parallel_attributes",
"set_defaults_if_not_set_tensor_model_parallel_attributes",
"copy_tensor_model_parallel_attributes",
# mappings.py
"copy_to_tensor_model_parallel_region",
"gather_from_tensor_model_parallel_region",
"reduce_from_tensor_model_parallel_region",
"scatter_to_tensor_model_parallel_region",
# random.py
"checkpoint",
"get_cuda_rng_tracker",
"init_checkpointed_activations_memory_buffer",
"model_parallel_cuda_manual_seed",
"reset_checkpointed_activations_memory_buffer",
"gather_split_1d_tensor",
"split_tensor_into_1d_equal_chunks",
# utils.py
"divide",
"split_tensor_along_last_dim",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment