Unverified Commit 96850dfa authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Merge pull request #80 from ROCmSoftwarePlatform/IFU-master-2022-07-29

IFU-master-2022-07-29
parents 87fc4125 cc5f83b5
from collections import OrderedDict
from .utility import Utility
from .base import OperatorLayerBase
#TODO: Add support for other optimizers.
class Adam(OperatorLayerBase):
def __init__(self, d):
marker = eval(d.argMarker[0])
mod = marker['mod']
op = marker['op']
args = marker['args']
self.marker = marker
self.mod_ = mod
self.op_ = op
self.args = args
assert(op == "adam")
assert (len(args) == 12) or (len(args) == 14)
w, hw, m, v, g = args[0:5]
assert (w['shape'] == m['shape'] == v['shape'] == g['shape'])
assert (hw['shape'] == w['shape']) or (hw['shape'] == (0,)) #hw could be null
assert (w['type'] == m['type'] == v['type'] == g['type'] == hw['type'] == "tensor")
assert (w['dtype'] == m['dtype'] == v['dtype'] == "float32")
self.w = w
self.g = g
def params(self):
p = OrderedDict([('T',self.w['shape']), ('wtype',self.w['dtype']), ('gtype',self.g['dtype'])])
return p
def flops(self):
return 0
def bytes(self):
wshape = self.w['shape']
wtype = self.w['dtype']
gtype = self.g['dtype']
b = 0
elems = Utility.numElems(wshape)
#Get time to stream read/write w, m, v
b += 6 * elems * Utility.typeToBytes(wtype)
#Get time to read "g"
b += elems * Utility.typeToBytes(gtype)
if wtype != gtype: #mixed precision
#Get time to write "hw
b += elems * Utility.typeToBytes(gtype)
return b
def tc(self):
return "-"
def op(self):
return self.op_
def mod(self):
return self.mod_
import errno, os, sys
class Output():
"""
This class handles printing of a columed output and a CSV.
"""
# The table below is organized as
# user_option: [output_header, attribute_in_Data_class, type, min_width_in_columed_output]
table = {
"idx": ["Idx", "index", int, 7],
"seq": ["SeqId", "seqId", str, 7],
"altseq": ["AltSeqId", "altSeqId", str, 7],
"tid": ["TId", "tid", int, 12],
"layer": ["Layer", "layer", str, 10],
"trace": ["Trace", "trace", str, 25],
"dir": ["Direction", "dir", str, 5],
"sub": ["Sub", "sub", int, 3],
"mod": ["Module", "mod", str, 15],
"op": ["Op", "op", str, 15],
"kernel": ["Kernel", "name", str, 0],
"params": ["Params", "params", str, 0],
"sil": ["Sil(ns)", "sil", int, 10],
"tc": ["TC", "tc", str, 2],
"device": ["Device", "device", int, 3],
"stream": ["Stream", "stream", int, 3],
"grid": ["Grid", "grid", str, 12],
"block": ["Block", "block", str, 12],
"flops": ["FLOPs", "flops", int, 12],
"bytes": ["Bytes", "bytes", int, 12]
}
def __init__(self, args):
self.cols = args.c
self.csv = args.csv
self.col = True if (args.w > 0) else False
self.width = args.w
w = 0
for col in self.cols:
assert col in Output.table.keys()
w += Output.table[col][3]
if ((self.col) and (w > self.width)):
print("Minimum width required to print {} = {}. Exiting.".format(",".join(self.cols), w))
sys.exit(1)
remainder = self.width - w
if ("kernel" in self.cols) and ("params" in self.cols):
Output.table["kernel"][3] = int(remainder/2)
Output.table["params"][3] = int(remainder/2)
elif ("kernel" in self.cols):
Output.table["kernel"][3] = remainder
elif ("params" in self.cols):
Output.table["params"][3] = remainder
#header format
cadena = ""
for col in self.cols:
_,_,t,w = Output.table[col]
cadena += "%-{}.{}s ".format(w,w)
self.hFormat = cadena
#data format
cadena = ""
for col in self.cols:
_,_,t,w = Output.table[col]
if (t == str):
cadena += "%-{}.{}s ".format(w,w)
elif (t == int):
cadena += "%{}d ".format(w)
self.dFormat = cadena
def foo(self, cadena, pformat):
if self.csv:
cadena = ",".join(map(lambda x : '"' + str(x) + '"', cadena))
elif self.col:
cadena = pformat % cadena
else:
cadena = " ".join(map(str,cadena))
try:
print(cadena)
except IOError as e:
#gracefully handle pipes
if e.errno == errno.EPIPE:
# Python flushes standard streams on exit; redirect remaining output
# to devnull to avoid another BrokenPipeError at shutdown
devnull = os.open(os.devnull, os.O_WRONLY)
os.dup2(devnull, sys.stdout.fileno())
sys.exit(0)
else:
sys.exit(-1)
def header(self):
cadena = ()
for col in self.cols:
h = Output.table[col][0]
cadena = cadena + (h,)
self.foo(cadena, self.hFormat)
def data(self, a):
if a.dir == "":
direc = "na"
else:
direc = a.dir
if a.op == "":
op = "na"
else:
op = a.op
if a.mod == "":
mod = "na"
else:
mod = a.mod
cadena = ()
for col in self.cols:
attr = Output.table[col][1]
val = getattr(a, attr)
if col == "layer":
assert(type(val) == list)
val = ":".join(val)
val = "-" if val == "" else val
if col == "trace":
assert(type(val) == list)
if self.col and len(val):
val = val[-1]
val = val.split("/")[-1]
else:
val = ",".join(val)
val = "-" if val == "" else val
if col in ["seq", "altseq"]:
assert(type(val) == list)
val = ",".join(map(str,val))
val = "-" if val == "" else val
cadena = cadena + (val,)
self.foo(cadena, self.dFormat)
import numpy as np
from collections import OrderedDict
from .utility import Utility
from .base import OperatorLayerBase
class Pointwise(OperatorLayerBase):
ops = []
ops += ["__abs__", "__neg__", "__invert__"]
ops += ["__add__", "__sub__", "__mul__", "__floordiv__", "__truediv__", "__pow__", "__mod__"]
ops += ["__radd__", "__rsub__", "__rmul__", "__rdiv__", "__rtruediv__", "__rfloordiv__", "__rpow__"]
ops += ["__iadd__", "__isub__", "__imul__", "__itruediv__",]
ops += ["__lt__", "__gt__", "__ge__", "__le__", "__eq__", "__ne__",]
ops += ["lt", "lt_", "gt", "gt_", "ge", "ge_", "le", "le_", "eq", "eq_", "ne", "ne_",]
ops += ["__and__", "__or__", "__xor__", "__lshift__", "__rshift__"]
ops += ["__iand__", "__ior__", "__ixor__", "__ilshift__", "__irshift__"]
ops += ["abs", "abs_", "neg", "neg_"]
ops += ["add", "add_", "div", "div_", "mul", "mul_", "reciprocal", "reciprocal_", "remainder", "remainder_", "sub", "sub_",]
ops += ["addcdiv", "addcdiv_", "addcmul", "addcmul_"]
ops += ["exp", "exp_", "exp1m", "exp1m_", "log", "log_", "log10", "log10_", "log1p", "log1p_", "log2", "log2_", "pow", "pow_", "rsqrt", "rsqrt_", "sqrt", "sqrt_",]
ops += ["ceil", "ceil_", "clamp", "clamp_", "floor", "floor_", "fmod", "fmod_", "frac", "frac_", "round", "round_", "sign", "sign_", "trunc", "trunc_"]
ops += ["acos", "acos_", "asin", "asin_", "atan", "atan_", "atan2", "atan2_", "cos", "cos_", "cosh", "cosh_", "sin", "sin_", "sinh", "sinh_", "tan", "tan_", "sigmoid", "sigmoid_", "tanh", "tanh_"]
ops += ["digamma", "erf", "erf_", "erfc", "erfc_", "erfinv", "erfinv_", "lerp", "lerp_", "mvlgamma",]
@staticmethod
def foo(d):
return d['name'],d['type'],d['shape'],d['dtype']
def __init__(self, d):
marker = eval(d.argMarker[0])
mod = marker['mod']
op = marker['op']
args = marker['args']
self.marker = marker
self.mod_ = mod
self.op_ = op
self.args = args
self.dir = d.dir
assert (d.dir in ["fprop", "bprop"])
assert (op in Pointwise.ops)
#Filter out all named parameters (kwargs).
#This might require revisiting in future.
args = list(filter(lambda x : x['name'] == "", args))
#Filter out non tensors
args = list(filter(lambda x : x['type'] == "tensor", args))
if (len(args) == 0):
self.shape = [(1,)]
self.type = "float32" #FIX
elif (len(args) == 1):
in0 = args[0]
_,t0,s0,dt0 = Pointwise.foo(in0)
assert (t0 == "tensor")
self.shape = [s0,]
self.type = dt0
elif (len(args) == 2):
in0,in1 = args
_,t0,s0,dt0 = Pointwise.foo(in0)
_,t1,s1,dt1 = Pointwise.foo(in1)
assert (t0 == t1 == "tensor")
assert (dt0 == dt1)
self.shape = [s0,s1]
self.type = dt0
elif (len(args) == 3):
in0,in1,in2 = args
_,t0,s0,dt0 = Pointwise.foo(in0)
_,t1,s1,dt1 = Pointwise.foo(in1)
_,t2,s2,dt2 = Pointwise.foo(in2)
assert (t0 == t1 == t2 == "tensor")
assert (dt0 == dt1 == dt2)
self.shape = [s0,s1,s2]
self.type = dt0
else:
assert False
return
def params(self):
p = OrderedDict([('T',self.shape), ('type', self.type)])
return p
def tc(self):
return "-"
def op(self):
return self.op_
def mod(self):
return self.mod_
def elems(self):
tensor = self.shape
t = self.type
if (len(tensor) == 1):
elems = 2 * Utility.numElems(tensor[0])
elif (len(tensor) == 2):
if (tensor[0] == tensor[1]): # same shape
elems = Utility.numElems(tensor[0])
if self.dir == "fprop":
elems *= 3
else:
if (self.op_ in ["add", "__add__", "sub", "__sub__", "__isub__"]):
elems *= 2
elif (self.op_ in ["__mul__", "__rmul__", "div", "__truediv__"]):
elems *= 3
else:
assert False
else: #check for broadcast conditions
array1 = np.empty(list(tensor[0]))
array2 = np.empty(list(tensor[1]))
try:
out = np.broadcast(array1, array2).shape
except:
assert False
elems = Utility.numElems(tensor[0])
elems += Utility.numElems(tensor[1])
elems += Utility.numElems(out)
#TODO bprop
elif (len(tensor) == 3):
if (tensor[0] == tensor[1] == tensor[2]): #same shape
elems = Utility.numElems(tensor[0])
elems *= 4
else:
assert False
else:
assert False
return elems
def bytes(self):
return self.elems() * Utility.typeToBytes(self.type)
def flops(self):
# Note: some cases may still be missing.
f = 0
if self.op_ in ["__abs__", "__neg__", "__add__", "__sub__", "__mul__",
"__radd__", "__rmul__", "__iadd__", "__isub__", "__imul__", "__itruediv__",
"abs", "abs_", "neg", "neg_", "add", "add_", "div", "div_", "mul", "mul_",
"sub", "sub_", "exp", "exp_", "sign", "sign_", "trunc", "trunc_",
"sin", "sin_", "cos", "cos_", "sinh", "sinh_", "cosh", "cosh_",
"sqrt", "sqrt_", "rsqrt", "rsqrt_", "__lt__", "__gt__", "__ge__", "__le__",
"__eq__", "__ne__", "lt", "lt_", "gt", "gt_", "ge", "ge_", "le", "le_",
"eq", "eq_", "ne", "ne_", "ceil", "ceil_", "clamp", "clamp_", "floor", "floor_",
"round", "sign", "sign_", "trunc", "trunc_"]:
# We're counting only one operand, not two (2 operands, 1 op)
f = self.elems() / 2
elif self.op_ in ["fmod", "fmod_"]:
f = self.elems()
elif self.op_ in ["tanh", "tanh_", "sigmoid", "sigmoid_", "log", "log_", "log2",
"log2_", "log10", "log10_"]:
f = self.elems() * 2
elif self.op_ in ["asin", "asin_", "acos", "acos_", "atan", "atan_"]:
# no intrinsic, hence slow execution
# surprisingly, asin/acos and atan were all the same (via nvprof measurement)
f = self.elems() * 10
return f
from .collections import OrderedDict
from .utility import Utility
# Work in progress.
#poolFuncs = ["max_pool2d_with_indices_forward", "max_pool2d_with_indices"]
class MaxPool2d(object):
def parse(marker):
def convert2Tuple(arg):
assert (arg['type'] in ["int", "tuple"])
if arg['type'] == "int":
return (arg['value'], arg['value'])
else:
return arg['value']
mod = marker['mod']
op = marker['op']
args = marker['args']
assert (mod == "torch.nn.functional")
assert (op == "max_pool2d")
assert (len(args) >= 2)
#input
assert (args[0]['name'] == "")
inp = args[0]
assert (inp['type'] == "tensor")
i = inp['shape']
t = inp['dtype']
assert (len(i) == 4) #nchw tensor
#kernel
if (args[1]['name'] == ""):
k = args[1]
else:
k = list(filter(lambda x : x['name'] == "kernel_size", args))[0]
k = convert2Tuple(k)
#stride
s = k #default value
if ((len(args) >= 3) and args[2] == ""):
s = args[2]
s = convert2Tuple(s)
elif any(x['name'] == "stride" for x in args):
s = list(filter(lambda x : x['name'] == "stride", args))[0]
s = convert2Tuple(s)
#padding
p = (0,0)
if ((len(args) >= 4) and args[3] == ""):
p = args[3]
p = convert2Tuple(p)
elif any(x['name'] == "padding" for x in args):
p = list(filter(lambda x : x['name'] == "padding", args))[0]
p = convert2Tuple(p)
params = OrderedDict([('T', i), ('K', k), ('s',s), ('p',p), ('type', t)])
return params
#!/usr/bin/env python3
"""
This script reads the output (Python dictionary) created by parse.py.
For every kernel (line) in the input it determines
module / class name e.g. torch.nn.functional
operator name e.g. linear
kernel parameters e.g. GEMM M, N, K, datatype
bytes
flops
tensor core usage
direction (fprop, bprop)
and other things. Please see the tool usage.
"""
from .usage import parseArgs
from .output import Output
from .utility import Utility
from .pointwise import Pointwise
from .convert import Convert
from .blas import *
from .embedding import Embedding
from .reduction import *
from .dropout import Dropout
from .softmax import *
#from pooling import * # work in progress
from .linear import Linear
from .optim import Adam
from .misc import *
from .conv import Conv
from .activation import Activation
from .index_slice_join_mutate import Cat, Reshape, MaskedScatter, Gather, Nonzero, IndexSelect, MaskedSelect
from .recurrentCell import RNNCell
from .normalization import BatchNorm
from .randomSample import RandPerm
from .loss import MSELoss
from .data import Data
def findFpropKernel(seq):
#Find the last fprop kernel with the same seqId
#First look at seqId and then at altSeqId
for idx in reversed(range(len(kernels))):
k = kernels[idx]
if (seq in k['seqId']) and (k['dir'] == "fprop"):
return idx
for idx in reversed(range(len(kernels))):
k = kernels[idx]
if (seq in k['altSeqId']) and (k['dir'] == "fprop"):
return idx
return -1
#print("Error: seqId {} not found.".format(seq), file=sys.stderr)
#assert False
def foo(mod, op, d):
if (op[0] == "linear"):
xx = Linear(d)
# rnncell, lstmcell, grucell
elif (mod[0] in["LSTMCell", "GRUCell"]) and (op[0] == "forward"):
xx = RNNCell(d)
elif op[0] in ["conv1d", "conv2d",]:
xx = Conv(d)
elif (op[0] in Pointwise.ops):
xx = Pointwise(d)
elif (op[0] in Convert.ops):
xx = Convert(d)
elif op[0] in ["__matmul__", "matmul"]:
xx = Matmul(d)
elif op[0] == "embedding":
xx = Embedding(d)
#reduction
elif op[0] == "sum":
xx = Sum(d)
elif op[0] == "mean":
xx = Mean(d)
elif op[0] == "norm":
xx = Norm(d)
elif op[0] == "dropout":
xx = Dropout(d)
#Index, Slice, Join, Mutate
elif (op[0] == "cat"):
xx = Cat(d)
elif (op[0] == "reshape"):
xx = Reshape(d)
elif (op[0] == "masked_scatter_"):
xx = MaskedScatter(d)
elif (op[0] == "gather"):
xx = Gather(d)
elif (op[0] == "nonzero"):
xx = Nonzero(d)
elif (op[0] == "index_select"):
xx = IndexSelect(d)
elif (op[0] == "masked_select"):
xx = MaskedSelect(d)
#blas
elif op[0] in ["addmm", "addmm_"]:
xx = Addmm(d)
elif op[0] == "mm":
xx = Mm(d)
elif op[0] == "bmm":
xx = Bmm(d)
#softmax
elif op[0] == "softmax":
xx = Softmax(d)
elif op[0] == "log_softmax":
xx = LogSoftmax(d)
#loss
elif op[0] == "mse_loss":
xx = MSELoss(d)
#optimizers
elif op[0] == "adam":
xx = Adam(d)
#normalization
elif op[0] == "batch_norm":
xx = BatchNorm(d)
#random
elif op[0] == "randperm":
xx = RandPerm(d)
#misc
elif op[0] == "copy_":
xx = Copy(d)
elif op[0] == "clone":
xx = Clone(d)
elif op[0] == "contiguous":
xx = Contiguous(d)
elif op[0] == "any":
xx = Any(d)
elif (op[0] in Activation.ops):
xx = Activation(d)
elif op[0] == "to":
xx = Convert(d)
else:
xx = Foo(d)
return xx
def main():
#Read cmd line arguments
cmdArgs = parseArgs()
output = Output(cmdArgs)
output.header()
idx = -1
#Read in all the kernel info
for line in cmdArgs.file:
idx += 1
kernel = eval(line)
assert(kernel)
kernels.append(kernel)
k = kernel
d = Data(k)
mod = k['mod']
op = k['op']
flops = 0
params = {"na":"na"}
tc = "na"
bytes = 0
if (d.dir == "bprop"):
d.seqMarker = k['seqMarker']
seq = k['seqId']
if len(seq) > 1:
pass
seq = k['seqId'][:1]
assert (len(seq) == 1), seq
#assert (seq[0] != 0)
assert (len(d.seqMarker) > 0)
#If there is no useful marker associated, use the
#sequence number to find the kernel from fprop
if len(d.argMarker) == 0:
index = findFpropKernel(seq[0])
if index >= 0:
d.argMarker = kernels[index]['marker']
d.modMarker = kernels[index]['reprMarkers']
mod = kernels[index]['mod']
op = kernels[index]['op']
d.layer = kernels[index]['layer']
d.trace = kernels[index]['trace']
# Check if marker has our annotations
if len(d.argMarker) and Utility.hasNVTX(d.argMarker[0]):
xx = foo(mod, op, d)
bytes = xx.bytes()
flops = xx.flops()
op = xx.op()
params = xx.params()
tc = xx.tc()
if type(op) is list:
if len(op):
op = op[0]
else:
op = ""
if type(mod) is list:
if len(mod):
mod = mod[0]
else:
mod = ""
d.index = idx+1
# The following 8 come from operator class functions.
d.setParams(params)
d.tc = tc
d.flops = flops
d.bytes = bytes
d.mod = mod
d.op = op
output.data(d)
kernels = []
if __name__ == '__main__':
main()
from collections import OrderedDict
from .utility import Utility
from .base import OperatorLayerBase
class RandPerm(OperatorLayerBase):
def __init__(self, d):
marker = eval(d.argMarker[0])
mod = marker['mod']
op = marker['op']
args = marker['args']
self.marker = marker
self.mod_ = mod
self.op_ = op
self.args = args
assert (mod == "torch")
assert (op == "randperm")
assert (len(args) == 1)
n = args[0]
assert n['type'] == "int"
self.n = n['value']
def params(self):
p = OrderedDict([('N', self.n)])
return p
def tc(self):
return "-"
def op(self):
return self.op_
def mod(self):
return self.mod_
def bytes(self):
return self.n * Utility.typeToBytes("int64")
def flops(self):
# Depends on RNG but this is probably a reasonable assumption.
return self.n * 3
from collections import OrderedDict
from .utility import Utility
from .base import OperatorLayerBase
def hasTileSize(name):
if ("sgemm" in name) or ("884gemm" in name) or ("hgemm" in name):
return True
else:
return False
def ctaTile(name):
name = name.split("_")
name = list(filter(lambda x : "x" in x, name))
name = list(filter(lambda x : "slice" not in x, name))
assert(len(name) == 1)
name = name[0].split("x")
assert(len(name) == 2)
name = list(map(int, name))
return name[0], name[1]
class RNNCell(OperatorLayerBase):
"""
This class supports RNNCell, LSTMCell and GRUCell.
"""
def __init__(self, d):
marker = eval(d.argMarker[0])
mod = marker['mod']
op = marker['op']
args = marker['args']
self.marker = marker
self.mod_ = mod
self.op_ = op
self.args = args
self.name = d.name
self.dir = d.dir
self.sub = d.sub
self.grid = d.grid
assert (op == "forward")
assert (mod in ["LSTMCell", "GRUCell", "RNNCell"])
assert (len(args) in [2,3])
x,h = args[0],args[1]
b1,ii = x['shape']
b2,hh = h['shape']
assert b1 == b2
assert x['dtype'] == h['dtype']
t = x['dtype']
self.cell = mod
self.inp = ii
self.hid = hh
self.b = b1
self.type = t
self.multiple = 1
if self.cell == "LSTMCell":
self.multiple = 4
elif self.cell == "GRUCell":
self.multiple = 3
self.gemm = None
self.m = None
self.n = None
self.k = None
self.elems = 0
self.bar()
def params(self):
if self.gemm is None:
p = OrderedDict([('cell', self.cell), ('X', self.inp), ('H', self.hid), ('B', self.b), ('type', self.type)])
else:
assert self.m is not None
assert self.n is not None
assert self.k is not None
p = OrderedDict([('gemm', self.gemm), ('M', self.m), ('N', self.n), ('K', self.k), ('type', self.type)])
return p
def tc(self):
if "gemm" in self.name:
return 1 if "884gemm" in self.name else 0
else:
return "-"
def op(self):
return self.op_
def mod(self):
return self.mod_
def bytes(self):
if self.gemm is not None:
m, n, k, t = self.m, self.n, self.k, self.type
b = (m*k + k*n + m*n) * Utility.typeToBytes(t)
elif self.elems != 0:
b = self.elems * Utility.typeToBytes(self.type)
else:
b = 0
return b
def flops(self):
if self.gemm is not None:
m, n, k = self.m, self.n, self.k
f = 2*m*n*k
elif self.elems != 0:
f = 0 #TODO
else:
f = 0
return f
def bar(self):
cell = self.cell
X = self.inp
H = self.hid
B = self.b
t = self.type
subseqId = self.sub
direc = self.dir
name = self.name
grid = self.grid
multiple = self.multiple
if direc == "fprop":
subseqId = subseqId % 3
if subseqId == 0: #layer gemm
self.gemm = "layer"
self.m = multiple*H
self.n = B
self.k = X
elif subseqId == 1: #recurrent gemm
self.gemm = "recur"
self.m = multiple*H
self.n = B
self.k = H
else:
layerGemmElems = multiple*H*B
recurGemmElems = multiple*H*B
cElems = H*B
hElems = H*B
totElems = layerGemmElems + recurGemmElems + 2*cElems + hElems
self.elems = totElems
else:
if ("gemm" in name) and hasTileSize(name): #gemm
#Get cta tile size
tileX, tileY = ctaTile(name)
#Get grid dimensions
grid = grid.split(",")
gridX,gridY,gridZ = map(lambda x : int(x), grid)
gemmM = tileX * gridX
gemmN = tileY * gridY
if name[-3:] == "_nn": # dgrad
if (gemmM == H): # recurrent dgrad
#Ideally gemmN = B, but we have a limited set of tile sizes.
gemmN = B
gemmK = multiple*H
self.gemm = "recur"
self.m = gemmM
self.n = gemmN
self.k = gemmK
elif (gemmM == X): # layer dgrad
#assert(gemmN % B == 0)
gemmK = multiple*H
self.gemm = "layer"
self.m = gemmM
self.n = gemmN
self.k = gemmK
else:
pass
elif name[-3:] == "_nt": #wgrad
if (gemmM == H): #recurrent wgrad
assert (gemmN == multiple*H)
gemmK = B
self.gemm = "recur"
self.m = gemmM
self.n = gemmN
self.k = gemmK
elif (gemmM == X): #layer wgrad
assert (gemmN == multiple*H)
gemmK = B
self.gemm = "layer"
self.m = gemmM
self.n = gemmN
self.k = gemmK
else:
pass
else:
pass
else:
pass
return
from collections import OrderedDict
from .utility import Utility
from .base import OperatorLayerBase
class Mean(OperatorLayerBase):
def __init__(self, d):
marker = eval(d.argMarker[0])
mod = marker['mod']
op = marker['op']
args = marker['args']
self.marker = marker
self.mod_ = mod
self.op_ = op
self.args = args
assert (mod in ["torch", "Tensor"])
assert (op == "mean")
#Filter out named parameters
args = list(filter(lambda x : x['name'] == '', args))
assert (len(args) <= 2)
i = args[0]
self.shape = i['shape']
self.type = i['dtype']
self.dir = d.dir
self.sub = d.sub
def params(self):
p = OrderedDict([('T', self.shape), ('type', self.type)])
return p
def tc(self):
return "-"
def op(self):
return self.op_
def mod(self):
return self.mod_
def elems(self):
return Utility.numElems(self.shape)
def bytes(self):
if self.sub == 0:
return self.elems() * Utility.typeToBytes(self.type)
else:
return 0
def flops(self):
if self.sub == 0:
return self.elems() + 1
else:
return 0
class Sum(OperatorLayerBase):
def __init__(self, d):
marker = eval(d.argMarker[0])
mod = marker['mod']
op = marker['op']
args = marker['args']
self.marker = marker
self.mod_ = mod
self.op_ = op
self.args = args
assert (mod in ["torch", "Tensor"])
assert (op == "sum")
assert (len(args) >= 1)
#Get input
if (args[0]['name'] == ""):
i = args[0]
else:
i = list(filter(lambda x : x['name'] == "input", args))[0]
self.shape = i['shape']
self.type = i['dtype']
def params(self):
p = OrderedDict([('T', self.shape), ('type', self.type)])
return p
def tc(self):
return "-"
def op(self):
return self.op_
def mod(self):
return self.mod_
def elems(self):
return Utility.numElems(self.shape)
def flops(self):
# Note: This is incorrect, need to calculate actual flops (say via nvprof)
return self.elems()
def bytes(self):
return self.elems() * Utility.typeToBytes(self.type)
class Norm(OperatorLayerBase):
def __init__(self, d):
marker = eval(d.argMarker[0])
mod = marker['mod']
op = marker['op']
args = marker['args']
self.marker = marker
self.mod_ = mod
self.op_ = op
self.args = args
assert (mod in ["torch", "Tensor"])
assert (op == "norm")
#assert (len(args) == 1)
i = args[0]
self.shape = i['shape']
self.type = i['dtype']
def params(self):
p = OrderedDict([('T', self.shape), ('type', self.type)])
return p
def elems(self):
return Utility.numElems(self.shape)
def bytes(self):
return self.elems() * Utility.typeToBytes(self.type)
def flops(self):
# square and add plus sqrt
return 2 * self.elems() + 1
def tc(self):
return "-"
def op(self):
return self.op_
def mod(self):
return self.mod_
from collections import OrderedDict
from .utility import Utility
from .base import OperatorLayerBase
class Softmax(OperatorLayerBase):
def __init__(self, d):
marker = eval(d.argMarker[0])
mod = marker['mod']
op = marker['op']
args = marker['args']
self.marker = marker
self.mod_ = mod
self.op_ = op
self.args = args
assert (mod == "torch.nn.functional")
assert (op == "softmax")
#Filter out named parameters
args = list(filter(lambda x : x['name'] == '', args))
assert (len(args) <= 2)
self.shape = args[0]['shape']
self.type = args[0]['dtype']
self.dir = d.dir
return
def op(self):
return self.op_
def mod(self):
return self.mod_
def tc(self):
return "-"
def params(self):
p = OrderedDict([('T', self.shape), ('type', self.type)])
return p
def elems(self):
return Utility.numElems(self.shape)
def flops(self):
# Note: exp, sum-reduce, divide
#flops = elems * 3
return 0
def bytes(self):
b = self.elems() * Utility.typeToBytes(self.type)
b *= 3 if self.dir == "fprop" else 5 #verify
return b
class LogSoftmax(OperatorLayerBase):
def __init__(self, d):
marker = eval(d.argMarker[0])
mod = marker['mod']
op = marker['op']
args = marker['args']
self.marker = marker
self.mod_ = mod
self.op_ = op
self.args = args
assert (mod == "torch.nn.functional")
assert (op == "log_softmax")
#Filter out named parameters
args = list(filter(lambda x : x['name'] == '', args))
assert (len(args) <= 2)
#Get input
if (args[0]['name'] == ""):
i = args[0]
else:
i = list(filter(lambda x : x['name'] == "input", args))[0]
t = i['dtype']
self.shape = i['shape']
self.type = i['dtype']
self.dir = d.dir
return
def op(self):
return self.op_
def mod(self):
return self.mod_
def tc(self):
return "-"
def params(self):
p = OrderedDict([('T', self.shape), ('type', self.type)])
return p
def elems(self):
return Utility.numElems(self.shape)
def flops(self):
# Note: exp, sum-reduce, divide, log
#flops = elems * 4
return 0
def bytes(self):
b = self.elems() * Utility.typeToBytes(self.type)
b *= 3 if self.dir == "fprop" else 5 #verify
return b
import sys
import argparse
def parseArgs():
"""
Print usage and parse arguments.
"""
def check_cols(value):
valid = ["idx", "seq", "altseq", "tid", "layer", "trace", "dir", "sub", "mod", "op", "kernel", "params", "sil", "tc", "device", "stream", "grid", "block", "flops", "bytes"]
cols = value.split(",")
for col in cols:
if col not in valid:
raise argparse.ArgumentTypeError("{} is not a valid column name. Valid column names are {}.".format(col, ",".join(valid)))
return cols
def openFile(f):
try:
d = open(f, "r")
return d
except IOError:
print("Error opening file {}. Exiting.".format(f), file=sys.stderr)
sys.exit(1)
parser = argparse.ArgumentParser(prog=sys.argv[0], description="PyTorch Profiler", formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("file",
nargs='?',
type=str,
default=None,
help="Output of parse.py (Python dictionary).")
parser.add_argument("-c",
type=check_cols,
default="idx,dir,sub,mod,op,kernel,params,sil",
help='''Comma seperated names of columns to print.
idx: Index
seq: PyTorch Sequence Id
altseq: PyTorch Alternate Sequence Id
tid: Thread Id
layer: User annotated NVTX string (can be nested)
trace: Function Call Trace
dir: Direction
sub: Sub Sequence Id
mod: Module
op: Operattion
kernel: Kernel Name
params: Parameters
sil: Silicon Time (in ns)
tc: Tensor Core Usage
device: GPU Device Id
stream: Stream Id
grid: Grid Dimensions
block: Block Dimensions
flops: Floating point ops (FMA = 2 FLOPs)
bytes: Number of bytes in and out of DRAM
e.g. -c idx,kernel,sil''')
group = parser.add_mutually_exclusive_group()
group.add_argument("--csv",
action="store_true",
default=False,
help="Print a CSV output.")
group.add_argument("-w",
type=int,
default=0,
help="Width of columnated output.")
args = parser.parse_args()
if args.file is None:
args.file = sys.stdin
else:
args.file = openFile(args.file)
return args
from functools import reduce
class Utility(object):
@staticmethod
def numElems(shape):
assert (type(shape) == tuple)
return reduce(lambda x,y: x*y, shape, 1)
@staticmethod
def typeToBytes(t):
if (t in ["uint8", "int8", "byte", "char", "bool"]):
return 1
elif (t in ["float16", "half", "int16", "short"]):
return 2
elif (t in ["float32", "float", "int32", "int"]):
return 4
elif (t in ["int64", "long", "float64", "double"]):
return 8
assert False
@staticmethod
def typeToString(t):
if (t in ["uint8", "byte", "char",]):
return "uint8"
elif (t in ["int8",]):
return "int8"
elif (t in ["int16", "short",]):
return "int16"
elif (t in ["float16", "half"]):
return "fp16"
elif (t in ["float32", "float"]):
return "fp32"
elif (t in ["int32", "int",]):
return "int32"
elif (t in ["int64", "long"]):
return "int64"
elif (t in ["float64", "double",]):
return "fp64"
elif (t in ["bool",]):
return "bool"
assert False
@staticmethod
def hasNVTX(marker):
if type(marker) is str:
try:
marker = eval(marker)
except:
return False
if type(marker) is dict:
keys = marker.keys()
return ("mod" in keys) and ("op" in keys) and ("args" in keys)
else:
return False
@staticmethod
def isscalar(t):
return (t in ["float", "int"])
from .weight_norm import WeightNorm
from .reparameterization import Reparameterization
def apply_weight_norm(module, name='', dim=0, hook_child=True):
r"""
Applies weight normalization to a parameter in the given module.
If no parameter is provided, applies weight normalization to all
parameters in model (except 1-d vectors and scalars).
.. math::
\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
Weight normalization is a reparameterization that decouples the magnitude
of a weight tensor from its direction. This replaces the parameter specified
by `name` (e.g. "weight") with two parameters: one specifying the magnitude
(e.g. "weight_g") and one specifying the direction (e.g. "weight_v").
Weight normalization is implemented via a hook that recomputes the weight
tensor from the magnitude and direction before every :meth:`~Module.forward`
call.
By default, with `dim=0`, the norm is computed independently per output
channel/plane. To compute a norm over the entire weight tensor, use
`dim=None`.
See https://arxiv.org/abs/1602.07868
Args:
module (nn.Module): containing module
name (str, optional): name of weight parameter
dim (int, optional): dimension over which to compute the norm
hook_child (boolean, optional): adds reparameterization hook to direct parent of the
parameters. If False, it's added to `module` instead. Default: True
Returns:
The original module with the weight norm hook
Example::
>>> m = apply_weight_norm(nn.Linear(20, 40), name='weight')
Linear (20 -> 40)
>>> m.weight_g.size()
torch.Size([40, 1])
>>> m.weight_v.size()
torch.Size([40, 20])
"""
return apply_reparameterization(module, reparameterization=WeightNorm, hook_child=hook_child,
name=name, dim=dim)
def remove_weight_norm(module, name='', remove_all=False):
"""
Removes the weight normalization reparameterization of a parameter from a module.
If no parameter is supplied then all weight norm parameterizations are removed.
Args:
module (nn.Module): containing module
name (str, optional): name of weight parameter
Example:
>>> m = apply_weight_norm(nn.Linear(20, 40))
>>> remove_weight_norm(m)
"""
return remove_reparameterization(module, reparameterization=WeightNorm,
name=name, remove_all=remove_all)
def apply_reparameterization(module, reparameterization=None, name='', dim=0, hook_child=True):
"""
Applies a given weight reparameterization (such as weight normalization) to
a parameter in the given module. If no parameter is given, applies the reparameterization
to all parameters in model (except 1-d vectors and scalars).
Args:
module (nn.Module): containing module
reparameterization (Reparameterization): reparamaterization class to apply
name (str, optional): name of weight parameter
dim (int, optional): dimension over which to perform reparameterization op
hook_child (boolean, optional): adds reparameterization hook to direct parent of the
parameters. If False, it's added to `module` instead. Default: True
Returns:
The original module with the reparameterization hook
Example::
>>> m = apply_reparameterization(nn.Linear(20, 40), WeightNorm)
Linear (20 -> 40)
"""
assert reparameterization is not None
if name != '':
Reparameterization.apply(module, name, dim, reparameterization, hook_child)
else:
names = list(module.state_dict().keys())
for name in names:
apply_reparameterization(module, reparameterization, name, dim, hook_child)
return module
def remove_reparameterization(module, reparameterization=Reparameterization,
name='', remove_all=False):
"""
Removes the given reparameterization of a parameter from a module.
If no parameter is supplied then all reparameterizations are removed.
Args:
module (nn.Module): containing module
reparameterization (Reparameterization): reparamaterization class to apply
name (str, optional): name of weight parameter
remove_all (bool, optional): if True, remove all reparamaterizations of given type. Default: False
Example:
>>> m = apply_reparameterization(nn.Linear(20, 40),WeightNorm)
>>> remove_reparameterization(m)
"""
if name != '' or remove_all:
to_remove = []
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, reparameterization) and (hook.name == name or remove_all):
hook.remove(module)
to_remove.append(k)
if len(to_remove) > 0:
for k in to_remove:
del module._forward_pre_hooks[k]
return module
if not remove_all:
raise ValueError("reparameterization of '{}' not found in {}"
.format(name, module))
else:
modules = [module]+[x for x in module.modules()]
for m in modules:
remove_reparameterization(m, reparameterization=reparameterization, remove_all=True)
return module
import torch
from torch.nn.parameter import Parameter
import sys
class Reparameterization(object):
"""
Class interface for performing weight reparameterizations
Arguments:
name (str): name of weight parameter
dim (int): dimension over which to compute the norm
module (nn.Module): parent module to which param `name` is registered to
retain_forward (bool, optional): if False deletes weight on call to
module.backward. Used to avoid memory leaks with DataParallel Default: True
Attributes:
reparameterization_names (list, str): contains names of all parameters
needed to compute reparameterization.
backward_hook_key (int): torch.utils.hooks.RemovableHandle.id for hook used in module backward pass.
"""
def __init__(self, name, dim, module, retain_forward=True):
self.name = name
self.dim = dim
self.evaluated = False
self.retain_forward = retain_forward
self.reparameterization_names = []
self.backward_hook_key = None
self.module = module
def compute_weight(self, module=None, name=None):
"""
Computes reparameterized weight value to assign value to module attribute
with name `name`.
See WeightNorm class for example.
Arguments:
module (nn.Module): module with weight we'd like to reparameterize
Returns:
w (Tensor): Tensor object containing value of reparameterized weight
"""
raise NotImplementedError
def reparameterize(self, name, weight, dim):
"""
Creates Parameters to be used for reparameterization and creates names that
for attributes for the module these Parameters will correspond to.
The parameters will be registered according to the names provided.
See WeightNorm class for example.
Arguments:
module (nn.Module): module with weight we'd like to reparameterize
name (str, optional): name of weight parameter
dim (int, optional): dimension over which to compute parameterization
Returns:
names (list, str): names of Parameters to be used for reparameterization
params (list, Parameter): Parameters to be used for reparameterization
"""
raise NotImplementedError
@staticmethod
def apply(module, name, dim, reparameterization=None, hook_child=True):
"""
Applies reparametrization to module's `name` parameter and modifies instance attributes as appropriate.
`hook_child` adds reparameterization hook to direct parent of the parameters. If False, it's added to `module` instead.
"""
if reparameterization is None:
reparameterization = Reparameterization
module2use, name2use = Reparameterization.get_module_and_name(module, name)
# does not work on sparse
if name2use is None or isinstance(module2use, (torch.nn.Embedding, torch.nn.EmbeddingBag)):
return
if hook_child:
fn = reparameterization(name2use, dim, module2use)
else:
fn = reparameterization(name, dim, module)
weight = getattr(module2use, name2use)
if weight.dim() <= 1:
return
# remove weight from parameter list
del module2use._parameters[name2use]
# add parameters of reparameterization of parameter to module
names, params = fn.reparameterize(name2use, weight, dim)
for n, p in zip(names, params):
module2use.register_parameter(n, p)
# add parameters to reparameterization so they can be removed later
fn.reparameterization_names = names
setattr(module2use, name2use, None)
hook_module = module2use
if not hook_child:
hook_module = module
# recompute weight before every forward()
hook_module.register_forward_pre_hook(fn)
# remove weight during backward
handle = hook_module.register_backward_hook(fn.backward_hook)
# get hook key so we can delete it later
fn.backward_hook_key = handle.id
return fn
@staticmethod
def get_module_and_name(module, name):
"""
recursively fetches (possible) child module and name of weight to be reparameterized
"""
name2use = None
module2use = None
names = name.split('.')
if len(names) == 1 and names[0] != '':
name2use = names[0]
module2use = module
elif len(names) > 1:
module2use = module
name2use = names[0]
for i in range(len(names)-1):
module2use = getattr(module2use, name2use)
name2use = names[i+1]
return module2use, name2use
def get_params(self, module):
"""gets params of reparameterization based on known attribute names"""
return [getattr(module, n) for n in self.reparameterization_names]
def remove(self, module):
"""removes reparameterization and backward hook (does not remove forward hook)"""
module2use, name2use = Reparameterization.get_module_and_name(module, self.name)
for p in self.get_params(module2use):
p.requires_grad = False
weight = self.compute_weight(module2use, name2use)
delattr(module2use, name2use)
for n in self.reparameterization_names:
del module2use._parameters[n]
module2use.register_parameter(name2use, Parameter(weight.data))
del module._backward_hooks[self.backward_hook_key]
def __call__(self, module, inputs):
"""callable hook for forward pass"""
module2use, name2use = Reparameterization.get_module_and_name(module, self.name)
_w = getattr(module2use, name2use)
if not self.evaluated or _w is None:
setattr(module2use, name2use, self.compute_weight(module2use, name2use))
self.evaluated = True
def backward_hook(self, module, grad_input, grad_output):
"""callable hook for backward pass"""
module2use, name2use = Reparameterization.get_module_and_name(module, self.name)
wn = getattr(module2use, name2use)
self.evaluated = False
import torch
from torch.nn.parameter import Parameter
from ..fp16_utils import Fused_Weight_Norm
import time
from .reparameterization import Reparameterization
def _norm(p, dim):
"""Computes the norm over all dimensions except dim"""
if dim is None:
return p.norm()
elif dim == 0:
output_size = (p.size(0),) + (1,) * (p.dim() - 1)
return p.contiguous().view(p.size(0), -1).norm(dim=1).view(*output_size)
elif dim == p.dim() - 1:
output_size = (1,) * (p.dim() - 1) + (p.size(-1),)
return p.contiguous().view(-1, p.size(-1)).norm(dim=0).view(*output_size)
return _norm(p.transpose(0, dim), 0).transpose(0, dim)
HALF_TYPES = (torch.cuda.HalfTensor, torch.HalfTensor)
class WeightNorm(Reparameterization):
r"""
Weight normalization is a reparameterization that decouples the magnitude
of a weight tensor from its direction. This replaces the parameter specified
by `name` (e.g. "weight") with two parameters: one specifying the magnitude
(e.g. "weight_g") and one specifying the direction (e.g. "weight_v").
Weight normalization is implemented via a hook that recomputes the weight
tensor from the magnitude and direction before every :meth:`~Module.forward`
call.
.. math::
\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
By default, with `dim=0`, the norm is computed independently per output
channel/plane. To compute a norm over the entire weight tensor, use
`dim=None`.
"""
def compute_weight(self, module=None, name=None):
"""
Computes weight normalized weight value to assign value to module attribute
with name `name`.
Arguments:
module (nn.Module): module with weight we'd like to reparameterize
Returns:
w (Tensor): Tensor object containing value of reparameterized weight
"""
if module is None:
module = self.module
if name is None:
name = self.name
module, name = Reparameterization.get_module_and_name(module, name)
g = getattr(module, name + '_g')
v = getattr(module, name + '_v')
fused_weight_norm = Fused_Weight_Norm.apply
v = v.contiguous()
w = fused_weight_norm(v, g, self.dim)
return w
def reparameterize(self, name, weight, dim):
"""
Creates Parameters v and gto be used for weight normalization
and creates names that for attributes for the module these Parameters
will correspond to. The parameters will be registered according to the names
provided.
Arguments:
module (nn.Module): module with weight we'd like to reparameterize
name (str, optional): name of weight parameter
dim (int, optional): dimension over which to compute parameterization
Returns:
names (list, str): names of Parameters to be used for reparameterization
params (list, Parameter): Parameters to be used for reparameterization
"""
names = [name + '_g', name + '_v']
params = [Parameter(_norm(weight, dim).data), Parameter(weight.data)]
return names, params
......@@ -10,6 +10,7 @@ import unittest
TEST_WITH_ROCM = os.getenv('APEX_TEST_WITH_ROCM', '0') == '1'
SKIP_FLAKY_TEST = os.getenv('APEX_SKIP_FLAKY_TEST', '0') == '1'
## Wrapper to skip the unit tests.
def skipIfRocm(fn):
......@@ -20,3 +21,13 @@ def skipIfRocm(fn):
else:
fn(*args, **kwargs)
return wrapper
## Wrapper to skip the flaky unit tests.
def skipFlakyTest(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if SKIP_FLAKY_TEST:
raise unittest.SkipTest("Test is flaky.")
else:
fn(*args, **kwargs)
return wrapper
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
import torch
......
......@@ -28,3 +28,8 @@ class AttnType(enum.Enum):
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
......@@ -31,7 +31,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import scaled_upper_triang_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0])
softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
inputs, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
......@@ -41,7 +43,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import scaled_upper_triang_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_upper_triang_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0])
input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None
......@@ -65,10 +69,10 @@ def scaled_upper_triang_masked_softmax(inputs, _, scale):
# 2. Apply the mask.
# 3. Perform softmax.
class ScaledMaskedSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, mask, scale):
import scaled_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
......@@ -81,7 +85,9 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0])
input_grads = scaled_masked_softmax_cuda.backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None, None
......@@ -120,7 +126,9 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
if self.input_in_fp16 and self.input_in_bf16:
raise RuntimeError("both fp16 and bf16 flags cannot be active at the same time.")
raise RuntimeError(
"both fp16 and bf16 flags cannot be active at the same time."
)
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
......@@ -154,9 +162,13 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None
and (
self.attn_mask_type == AttnMaskType.causal
or (self.attn_mask_type == AttnMaskType.padding and mask is not None)
)
and 16 < sk <= 2048 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and sk % 4 == 0 # sk must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 2048:
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from apex.transformer.layers.layer_norm import FastLayerNorm
from apex.transformer.layers.layer_norm import FusedLayerNorm
from apex.transformer.layers.layer_norm import MixedFusedLayerNorm
__all__ = [
"FastLayerNorm",
"FusedLayerNorm",
"MixedFusedLayerNorm",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment