Commit 753c427a authored by Marek Kolodziej's avatar Marek Kolodziej Committed by mcarilli
Browse files

Fixed tensor core lookup for Turing (#534)

parent e87b5799
...@@ -12,8 +12,7 @@ import torch.cuda.profiler as profiler ...@@ -12,8 +12,7 @@ import torch.cuda.profiler as profiler
import argparse import argparse
from apex import pyprof from apex import pyprof
from apex.optimizers import FusedAdam, FP16_Optimizer from apex.optimizers import FusedAdam
import fused_adam_cuda
def parseArgs(): def parseArgs():
parser = argparse.ArgumentParser(prog=sys.argv[0], description="Run popular imagenet models.") parser = argparse.ArgumentParser(prog=sys.argv[0], description="Run popular imagenet models.")
...@@ -91,7 +90,7 @@ def main(): ...@@ -91,7 +90,7 @@ def main():
args = parseArgs() args = parseArgs()
pyprof.nvtx.init() pyprof.nvtx.init()
pyprof.nvtx.wrap(fused_adam_cuda, 'adam') # pyprof.nvtx.wrap(fused_adam_cuda, 'adam')
N = args.b N = args.b
C = 3 C = 3
...@@ -112,7 +111,6 @@ def main(): ...@@ -112,7 +111,6 @@ def main():
optimizer = torch.optim.SGD(net.parameters(), lr = 0.01, momentum=0.9) optimizer = torch.optim.SGD(net.parameters(), lr = 0.01, momentum=0.9)
elif (args.o == "adam"): elif (args.o == "adam"):
optimizer = FusedAdam(net.parameters()) optimizer = FusedAdam(net.parameters())
optimizer = FP16_Optimizer(optimizer)
else: else:
assert False assert False
......
...@@ -3,6 +3,8 @@ from .utility import Utility ...@@ -3,6 +3,8 @@ from .utility import Utility
from .base import OperatorLayerBase from .base import OperatorLayerBase
import numpy as np import numpy as np
TC_GEMMS = ["884gemm", "1688gemm"]
class Addmm(OperatorLayerBase): class Addmm(OperatorLayerBase):
def __init__(self, d): def __init__(self, d):
...@@ -59,7 +61,10 @@ class Addmm(OperatorLayerBase): ...@@ -59,7 +61,10 @@ class Addmm(OperatorLayerBase):
return return
def tc(self): def tc(self):
return 1 if "884gemm" in self.name else 0 for s in TC_GEMMS:
if s in self.name:
return 1
return 0
def bytes(self): def bytes(self):
m, n, k = self.m, self.n, self.k m, n, k = self.m, self.n, self.k
...@@ -116,7 +121,10 @@ class Bmm(OperatorLayerBase): ...@@ -116,7 +121,10 @@ class Bmm(OperatorLayerBase):
self.name = d.name self.name = d.name
def tc(self): def tc(self):
return 1 if "884gemm" in self.name else 0 for s in TC_GEMMS:
if s in self.name:
return 1
return 0
def params(self): def params(self):
#p = OrderedDict([('A', A['shape']), ('B', B['shape']), ('type', t1)]) #p = OrderedDict([('A', A['shape']), ('B', B['shape']), ('type', t1)])
...@@ -248,7 +256,10 @@ class Matmul(OperatorLayerBase): ...@@ -248,7 +256,10 @@ class Matmul(OperatorLayerBase):
if self.name in Matmul.NON_TC: if self.name in Matmul.NON_TC:
return "-" return "-"
else: else:
return 1 if "884gemm" in self.name else 0 for s in TC_GEMMS:
if s in self.name:
return 1
return 0
def bytes(self): def bytes(self):
# TODO: check bytes for non-GEMM cases # TODO: check bytes for non-GEMM cases
...@@ -310,7 +321,10 @@ class Mm(OperatorLayerBase): ...@@ -310,7 +321,10 @@ class Mm(OperatorLayerBase):
return p return p
def tc(self): def tc(self):
return 1 if "884gemm" in self.name else 0 for s in TC_GEMMS:
if s in self.name:
return 1
return 0
def bytes(self): def bytes(self):
m, n, k = self.m, self.n, self.k m, n, k = self.m, self.n, self.k
......
...@@ -18,7 +18,7 @@ class Conv(OperatorLayerBase): ...@@ -18,7 +18,7 @@ class Conv(OperatorLayerBase):
fftAuxList = ["compute_gemm_pointers", "flip_filter", "fft2d_r2c_", "fft2d_c2r_", "fft1d_r2c", "fft1d_c2r"] fftAuxList = ["compute_gemm_pointers", "flip_filter", "fft2d_r2c_", "fft2d_c2r_", "fft1d_r2c", "fft1d_c2r"]
miscAuxList = ["scaleTensor_kernel",] miscAuxList = ["scaleTensor_kernel",]
convList = ["_s884cudnn_", "_scudnn_", "2d_grouped_direct_kernel", "cudnn::detail::implicit_convolve_sgemm", "cudnn::detail::dgrad2d_alg1_1", "cudnn::detail::wgrad_alg0_engine", "cudnn::detail::dgrad_engine", "dgrad_1x1_stride_2x2", "spatialDepthwiseConvolutionUpdateOutput"] convList = ["_s884cudnn_", "_s1688cudnn_", "_scudnn_", "2d_grouped_direct_kernel", "cudnn::detail::implicit_convolve_sgemm", "cudnn::detail::dgrad2d_alg1_1", "cudnn::detail::wgrad_alg0_engine", "cudnn::detail::dgrad_engine", "dgrad_1x1_stride_2x2", "spatialDepthwiseConvolutionUpdateOutput"]
winoList = ["winograd3x3Kernel", "_sgemm_"] winoList = ["winograd3x3Kernel", "_sgemm_"]
fftList = ["fermiPlusCgemmLDS128_batched", "_gcgemm_",] fftList = ["fermiPlusCgemmLDS128_batched", "_gcgemm_",]
miscList = [] miscList = []
...@@ -224,7 +224,10 @@ class Conv(OperatorLayerBase): ...@@ -224,7 +224,10 @@ class Conv(OperatorLayerBase):
return f return f
def tc(self): def tc(self):
return 1 if "884cudnn" in self.name else "-" for s in ["884cudnn", "1688cudnn"]:
if s in self.name:
return 1
return "-"
def op(self): def op(self):
return self.op_ return self.op_
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment