Commit 753c427a authored by Marek Kolodziej's avatar Marek Kolodziej Committed by mcarilli
Browse files

Fixed tensor core lookup for Turing (#534)

parent e87b5799
......@@ -12,8 +12,7 @@ import torch.cuda.profiler as profiler
import argparse
from apex import pyprof
from apex.optimizers import FusedAdam, FP16_Optimizer
import fused_adam_cuda
from apex.optimizers import FusedAdam
def parseArgs():
parser = argparse.ArgumentParser(prog=sys.argv[0], description="Run popular imagenet models.")
......@@ -91,7 +90,7 @@ def main():
args = parseArgs()
pyprof.nvtx.init()
pyprof.nvtx.wrap(fused_adam_cuda, 'adam')
# pyprof.nvtx.wrap(fused_adam_cuda, 'adam')
N = args.b
C = 3
......@@ -112,7 +111,6 @@ def main():
optimizer = torch.optim.SGD(net.parameters(), lr = 0.01, momentum=0.9)
elif (args.o == "adam"):
optimizer = FusedAdam(net.parameters())
optimizer = FP16_Optimizer(optimizer)
else:
assert False
......
......@@ -3,6 +3,8 @@ from .utility import Utility
from .base import OperatorLayerBase
import numpy as np
TC_GEMMS = ["884gemm", "1688gemm"]
class Addmm(OperatorLayerBase):
def __init__(self, d):
......@@ -59,7 +61,10 @@ class Addmm(OperatorLayerBase):
return
def tc(self):
return 1 if "884gemm" in self.name else 0
for s in TC_GEMMS:
if s in self.name:
return 1
return 0
def bytes(self):
m, n, k = self.m, self.n, self.k
......@@ -116,7 +121,10 @@ class Bmm(OperatorLayerBase):
self.name = d.name
def tc(self):
return 1 if "884gemm" in self.name else 0
for s in TC_GEMMS:
if s in self.name:
return 1
return 0
def params(self):
#p = OrderedDict([('A', A['shape']), ('B', B['shape']), ('type', t1)])
......@@ -248,7 +256,10 @@ class Matmul(OperatorLayerBase):
if self.name in Matmul.NON_TC:
return "-"
else:
return 1 if "884gemm" in self.name else 0
for s in TC_GEMMS:
if s in self.name:
return 1
return 0
def bytes(self):
# TODO: check bytes for non-GEMM cases
......@@ -310,7 +321,10 @@ class Mm(OperatorLayerBase):
return p
def tc(self):
return 1 if "884gemm" in self.name else 0
for s in TC_GEMMS:
if s in self.name:
return 1
return 0
def bytes(self):
m, n, k = self.m, self.n, self.k
......
......@@ -18,7 +18,7 @@ class Conv(OperatorLayerBase):
fftAuxList = ["compute_gemm_pointers", "flip_filter", "fft2d_r2c_", "fft2d_c2r_", "fft1d_r2c", "fft1d_c2r"]
miscAuxList = ["scaleTensor_kernel",]
convList = ["_s884cudnn_", "_scudnn_", "2d_grouped_direct_kernel", "cudnn::detail::implicit_convolve_sgemm", "cudnn::detail::dgrad2d_alg1_1", "cudnn::detail::wgrad_alg0_engine", "cudnn::detail::dgrad_engine", "dgrad_1x1_stride_2x2", "spatialDepthwiseConvolutionUpdateOutput"]
convList = ["_s884cudnn_", "_s1688cudnn_", "_scudnn_", "2d_grouped_direct_kernel", "cudnn::detail::implicit_convolve_sgemm", "cudnn::detail::dgrad2d_alg1_1", "cudnn::detail::wgrad_alg0_engine", "cudnn::detail::dgrad_engine", "dgrad_1x1_stride_2x2", "spatialDepthwiseConvolutionUpdateOutput"]
winoList = ["winograd3x3Kernel", "_sgemm_"]
fftList = ["fermiPlusCgemmLDS128_batched", "_gcgemm_",]
miscList = []
......@@ -224,7 +224,10 @@ class Conv(OperatorLayerBase):
return f
def tc(self):
return 1 if "884cudnn" in self.name else "-"
for s in ["884cudnn", "1688cudnn"]:
if s in self.name:
return 1
return "-"
def op(self):
return self.op_
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment