Unverified Commit 23cfb576 authored by Gil Shomron's avatar Gil Shomron Committed by GitHub
Browse files

Conv-Bias-ReLU fusion (#1332)



* Enabled Conv-Bias-ReLU fusion

The following modules are enabled using cuDNN runtime fusion:
1) Conv-Bias-ReLU (+backward)
2) Conv-Bias (+backward)
3) Conv-Bias-Mask-ReLU (+backward)

* Casts cleanup and autocast in unittest

- Remove redundant dtype casts
- Simulate the usage in the unittest by using torch.cuda.amp.autocast
Co-authored-by: default avatarMasaki Kozuki <mkozuki@nvidia.com>

* Fixed save_for_backward
Co-authored-by: default avatarMasaki Kozuki <mkozuki@nvidia.com>
Co-authored-by: default avatarroot <root@luna-0277.selene.nvidia.com>
parent 3c88451a
from .conv_bias_relu import ConvBiasReLU, ConvBias, ConvBiasMaskReLU
import torch
import pdb
from torch.autograd import gradcheck
import fused_conv_bias_relu
class ConvBiasReLU_(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.half)
def forward(ctx, x, weight, bias, padding, stride):
outputs = fused_conv_bias_relu.forward([x, weight, bias], padding, stride)
ctx.save_for_backward(x, weight, outputs[0])
ctx.padding = padding
ctx.stride = stride
return outputs[0]
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, grad_output):
bwd_args = [*ctx.saved_tensors, grad_output]
padding = ctx.padding
stride = ctx.stride
grads = fused_conv_bias_relu.backward(bwd_args, padding, stride)
return grads[0], grads[1], grads[2], None, None
class ConvBiasMaskReLU_(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.half)
def forward(ctx, x, weight, bias, mask, padding, stride):
outputs = fused_conv_bias_relu.forward_mask([x, weight, bias, mask], padding, stride)
ctx.save_for_backward(x, weight, outputs[0])
ctx.padding = padding
ctx.stride = stride
return outputs[0]
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, grad_output):
bwd_args = [*ctx.saved_tensors, grad_output]
padding = ctx.padding
stride = ctx.stride
grads = fused_conv_bias_relu.backward(bwd_args, padding, stride)
return grads[0], grads[1], grads[2], None, None, None
class ConvBias_(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.half)
def forward(ctx, x, weight, bias, padding, stride):
outputs = fused_conv_bias_relu.forward_no_relu([x, weight, bias], padding, stride)
ctx.save_for_backward(x, weight)
ctx.padding = padding
ctx.stride = stride
return outputs[0]
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, grad_output):
bwd_args = [*ctx.saved_tensors, grad_output]
padding = ctx.padding
stride = ctx.stride
grads = fused_conv_bias_relu.backward_no_relu(bwd_args, padding, stride)
return grads[0], grads[1], grads[2], None, None
ConvBiasReLU = ConvBiasReLU_.apply
ConvBiasMaskReLU = ConvBiasMaskReLU_.apply
ConvBias = ConvBias_.apply
This diff is collapsed.
import torch
import torch.nn.functional as F
import unittest
import copy
import random
import math
from apex.contrib.conv_bias_relu import ConvBiasReLU, ConvBias, ConvBiasMaskReLU
class FusedDenseTest(unittest.TestCase):
def setUp(self, seed=0):
torch.manual_seed(seed)
self.batch_size = random.randint(1, 64)
self.in_channels = random.randint(1, 64) * 8
self.out_channels = random.randint(1, 64) * 8
self.in_height = self.in_width = random.randint(5, 100)
self.conv_kernel_size = random.randint(1, 5)
self.conv_pad = random.randint(0, int(self.conv_kernel_size / 2))
self.conv_stride = random.randint(1, 5)
self.conv_dilation = 1
self.out_height = self.out_width = \
math.floor((self.in_height + 2 * self.conv_pad - \
self.conv_dilation * (self.conv_kernel_size - 1) - 1) / self.conv_stride + 1)
self.x = torch.randint(low=-16, high=16,
size=[self.batch_size, self.in_channels, self.in_height, self.in_width]) \
.cuda().to(memory_format=torch.channels_last).float()
self.x_ = self.x.clone()
self.x.requires_grad_()
self.x_.requires_grad_()
self.mask = torch.randn([self.batch_size, self.out_channels, self.out_height, self.out_width]).cuda().to(memory_format=torch.channels_last)
self.mask = (self.mask > 0).to(torch.int8)
self.mask_ = self.mask.clone()
self.conv1 = torch.nn.Conv2d(self.in_channels, self.out_channels, self.conv_kernel_size,
stride=self.conv_stride, padding=self.conv_pad).cuda().to(memory_format=torch.channels_last)
self.conv1_ = copy.deepcopy(self.conv1)
print()
print('> input=[{}, {}, {}, {}]'.format(self.batch_size, self.in_channels, self.in_height, self.in_width))
print('> kernel=[{}, {}, {}, {}], stride={}, pad={}'.format(self.out_channels, self.in_channels,
self.conv_kernel_size, self.conv_kernel_size,
self.conv_stride, self.conv_pad))
def test_conv_bias_relu(self):
with torch.cuda.amp.autocast(dtype=torch.half):
out = ConvBiasReLU(self.x, self.conv1.weight, self.conv1.bias.reshape(1, -1, 1, 1), self.conv_pad, self.conv_stride)
loss = (out.float()**2).sum() / out.numel()
loss.backward()
with torch.cuda.amp.autocast(dtype=torch.half):
out_ = F.relu(self.conv1_(self.x_))
loss_ = (out_**2).sum() / out_.numel()
loss_.backward()
self.assertTrue(torch.allclose(self.x_, self.x, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(self.conv1_.bias.grad, self.conv1.bias.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(self.conv1_.weight.grad, self.conv1.weight.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(self.x_.grad, self.x.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
def test_conv_bias(self):
with torch.cuda.amp.autocast(dtype=torch.half):
out = ConvBias(self.x, self.conv1.weight, self.conv1.bias.reshape(1, -1, 1, 1), self.conv_pad, self.conv_stride)
loss = (out.float()**2).sum() / out.numel()
loss.backward()
with torch.cuda.amp.autocast(dtype=torch.half):
out_ = self.conv1_(self.x_)
loss_ = (out_**2).sum() / out_.numel()
loss_.backward()
self.assertTrue(torch.allclose(self.x_, self.x, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(self.conv1_.bias.grad, self.conv1.bias.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(self.conv1_.weight.grad, self.conv1.weight.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(self.x_.grad, self.x.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
def test_conv_bias_mask_relu(self):
with torch.cuda.amp.autocast(dtype=torch.half):
out = ConvBiasMaskReLU(self.x, self.conv1.weight, self.conv1.bias.reshape(1, -1, 1, 1), self.mask, self.conv_pad, self.conv_stride)
loss = (out.float()**2).sum() / out.numel()
loss.backward()
with torch.cuda.amp.autocast(dtype=torch.half):
out_ = F.relu(self.conv1_(self.x_) * self.mask_)
loss_ = (out_**2).sum() / out_.numel()
loss_.backward()
self.assertTrue(torch.allclose(self.x_, self.x, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(self.conv1_.bias.grad, self.conv1.bias.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(self.conv1_.weight.grad, self.conv1.weight.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(self.x_.grad, self.x.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
if __name__ == '__main__':
unittest.main()
...@@ -646,6 +646,20 @@ if "--fast_bottleneck" in sys.argv: ...@@ -646,6 +646,20 @@ if "--fast_bottleneck" in sys.argv:
) )
if "--fused_conv_bias_relu" in sys.argv:
sys.argv.remove("--fused_conv_bias_relu")
raise_if_cuda_home_none("--fused_conv_bias_relu")
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"])
ext_modules.append(
CUDAExtension(
name="fused_conv_bias_relu",
sources=["apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp"],
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/cudnn-frontend/include")],
extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
)
)
setup( setup(
name="apex", name="apex",
version="0.1", version="0.1",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment