Unverified Commit 96850dfa authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Merge pull request #80 from ROCmSoftwarePlatform/IFU-master-2022-07-29

IFU-master-2022-07-29
parents 87fc4125 cc5f83b5
---
name: Bug report
about: Create a report to help us improve apex
title: ''
labels: bug
assignees: ''
---
**Describe the Bug**
**Minimal Steps/Code to Reproduce the Bug**
<!--
Please list the *minimal* steps or provide a code snippet for us to be able to reproduce the bug.
A helpful guide on on how to craft a minimal bug report http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports.
-->
**Expected Behavior**
<!-- A clear and concise description of what you expected to happen. -->
**Environment**
<!-- OS, version of Python, CUDA, PyTorch; collect these via `python -m torch.utils.collect_env` -->
# Introduction # Introduction
This repository holds NVIDIA-maintained utilities to streamline This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch.
mixed precision and distributed training in Pytorch.
Some of the code here will be included in upstream Pytorch eventually. Some of the code here will be included in upstream Pytorch eventually.
The intention of Apex is to make up-to-date utilities available to The intent of Apex is to make up-to-date utilities available to users as quickly as possible.
users as quickly as possible.
## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex) ## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex)
...@@ -98,30 +96,22 @@ amp.load_state_dict(checkpoint['amp']) ...@@ -98,30 +96,22 @@ amp.load_state_dict(checkpoint['amp'])
Note that we recommend restoring the model using the same `opt_level`. Also note that we recommend calling the `load_state_dict` methods after `amp.initialize`. Note that we recommend restoring the model using the same `opt_level`. Also note that we recommend calling the `load_state_dict` methods after `amp.initialize`.
# Requirements # Installation
Python 3 ## Containers
NVIDIA PyTorch Containers are available on NGC: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch.
The containers come with all the custom extensions available at the moment.
CUDA 9 or newer See [the NGC documentation](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html) for details such as:
- how to pull a container
- how to run a pulled container
- release notes
PyTorch 0.4 or newer. The CUDA and C++ extensions require pytorch 1.0 or newer. ## From Source
We recommend the latest stable release, obtainable from To install Apex from source, we recommend using the nightly Pytorch obtainable from https://github.com/pytorch/pytorch.
[https://pytorch.org/](https://pytorch.org/). We also test against the latest master branch, obtainable from [https://github.com/pytorch/pytorch](https://github.com/pytorch/pytorch).
It's often convenient to use Apex in Docker containers. Compatible options include: The latest stable release obtainable from https://pytorch.org should also work.
* [NVIDIA Pytorch containers from NGC](https://ngc.nvidia.com/catalog/containers/nvidia%2Fpytorch), which come with Apex preinstalled. To use the latest Amp API, you may need to `pip uninstall apex` then reinstall Apex using the **Quick Start** commands below.
* [official Pytorch -devel Dockerfiles](https://hub.docker.com/r/pytorch/pytorch/tags), e.g. `docker pull pytorch/pytorch:nightly-devel-cuda10.0-cudnn7`, in which you can install Apex using the **Quick Start** commands.
See the [Docker example folder](https://github.com/NVIDIA/apex/tree/master/examples/docker) for details.
## On ROCm:
* Python 3.6
* Pytorch 1.5 or newer, The HIPExtensions require 1.5 or newer.
* We recommend follow the instructions from [ROCm-Pytorch](https://github.com/ROCmSoftwarePlatform/pytorch) to install pytorch on ROCm.
* Note: For pytorch versions < 1.8, building from source is no longer supported, please use the release package [ROCm-Apex v0.3](https://github.com/ROCmSoftwarePlatform/apex/releases/tag/v0.3) .
# Quick Start
### Rocm ### Rocm
Apex on ROCm supports both python only build and extension build. Apex on ROCm supports both python only build and extension build.
...@@ -145,29 +135,27 @@ pip install -v --install-option="--cpp_ext" --install-option="--cuda_ext" 'git+h ...@@ -145,29 +135,27 @@ pip install -v --install-option="--cpp_ext" --install-option="--cuda_ext" 'git+h
``` ```
### Linux ### Linux
For performance and full functionality, we recommend installing Apex with For performance and full functionality, we recommend installing Apex with
CUDA and C++ extensions via CUDA and C++ extensions via
``` ```bash
git clone https://github.com/NVIDIA/apex git clone https://github.com/NVIDIA/apex
cd apex cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
``` ```
Apex also supports a Python-only build (required with Pytorch 0.4) via Apex also supports a Python-only build via
``` ```bash
pip install -v --disable-pip-version-check --no-cache-dir ./ pip install -v --disable-pip-version-check --no-cache-dir ./
``` ```
A Python-only build omits: A Python-only build omits:
- Fused kernels required to use `apex.optimizers.FusedAdam`. - Fused kernels required to use `apex.optimizers.FusedAdam`.
- Fused kernels required to use `apex.normalization.FusedLayerNorm`. - Fused kernels required to use `apex.normalization.FusedLayerNorm` and `apex.normalization.FusedRMSNorm`.
- Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`. - Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`.
- Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`. - Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`.
`DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower. `DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower.
Pyprof support has been moved to its own [dedicated repository](https://github.com/NVIDIA/PyProf).
The codebase is deprecated in Apex and will be removed soon.
### Windows support ### [Experimental] Windows
Windows support is experimental, and Linux is recommended. `pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .` may work if you were able to build Pytorch from source `pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .` may work if you were able to build Pytorch from source
on your system. `pip install -v --no-cache-dir .` (without CUDA/C++ extensions) is more likely to work. If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment. on your system. A Python-only build via `pip install -v --no-cache-dir .` is more likely to work.
If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.
import logging import logging
import warnings
# May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten # May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten
import torch import torch
...@@ -19,14 +20,10 @@ from . import fp16_utils ...@@ -19,14 +20,10 @@ from . import fp16_utils
# load time) the error message is timely and visible. # load time) the error message is timely and visible.
from . import optimizers from . import optimizers
from . import normalization from . import normalization
from . import pyprof
#common utilties to run tests on ROCm.
from . import testing
from . import transformer from . import transformer
# Logging utilities mainly for apex.transformer module # Logging utilities for apex.transformer module
class RankInfoFormatter(logging.Formatter): class RankInfoFormatter(logging.Formatter):
def format(self, record): def format(self, record):
...@@ -37,6 +34,18 @@ class RankInfoFormatter(logging.Formatter): ...@@ -37,6 +34,18 @@ class RankInfoFormatter(logging.Formatter):
_library_root_logger = logging.getLogger(__name__) _library_root_logger = logging.getLogger(__name__)
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setFormatter(RankInfoFormatter("%(asctime)s - %(name)s - %(levelname)s - %(rank_info)s - %(message)s")) handler.setFormatter(RankInfoFormatter("%(asctime)s - PID:%(process)d - rank:%(rank_info)s - %(filename)s:%(lineno)d - %(levelname)s - %(message)s", "%y-%m-%d %H:%M:%S"))
_library_root_logger.addHandler(handler) _library_root_logger.addHandler(handler)
_library_root_logger.propagate = False _library_root_logger.propagate = False
def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool:
cudnn_available = torch.backends.cudnn.is_available()
cudnn_version = torch.backends.cudnn.version() if cudnn_available else None
if not (cudnn_available and (cudnn_version >= required_cudnn_version)):
warnings.warn(
f"`{global_option}` depends on cuDNN {required_cudnn_version} or later, "
f"but {'cuDNN is not available' if not cudnn_available else cudnn_version}"
)
return False
return True
from typing import Optional from typing import Optional, Sequence
import torch import torch
def _get_autocast_dtypes() -> Sequence[torch.dtype]:
if torch.cuda.is_bf16_supported():
return [torch.half, torch.bfloat16]
return [torch.half]
def _get_current_dtype(dtype: Optional[torch.dtype] = None) -> torch.dtype: def _get_current_dtype(dtype: Optional[torch.dtype] = None) -> torch.dtype:
if not torch.is_autocast_enabled(): if not torch.is_autocast_enabled():
return torch.float or dtype return torch.float or dtype
......
from .bottleneck import Bottleneck, SpatialBottleneck from .bottleneck import Bottleneck, SpatialBottleneck
from .halo_exchangers import HaloExchangerNoComm, HaloExchangerAllGather, HaloExchangerSendRecv, HaloExchangerPeer
import functools as func
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch import nn from torch import nn
from apex import check_cudnn_version_and_warn
import fast_bottleneck import fast_bottleneck
import nccl_p2p_cuda as inc
assert check_cudnn_version_and_warn(__name__, 8400)
def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'): def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
weight_tensor_nchw = tensor weight_tensor_nchw = tensor
nn.init.kaiming_uniform_(weight_tensor_nchw, a=a, mode=mode, nonlinearity=nonlinearity) nn.init.kaiming_uniform_(weight_tensor_nchw, a=a, mode=mode, nonlinearity=nonlinearity)
class FrozenBatchNorm2d(torch.nn.Module): def compute_scale_bias_one(nhwc, weight, bias, running_mean, running_var, w_scale, w_bias):
scale = weight * running_var.rsqrt()
bias = bias - running_mean * scale
w_scale.copy_(scale)
w_bias.copy_(bias)
def compute_scale_bias_method(nhwc, args):
for arg in args:
# arg is tuple of (weight, bias, running_mean, running_var, w_scale, w_bias)
compute_scale_bias_one(nhwc, *arg)
class FrozenBatchNorm2d(torch.jit.ScriptModule):
""" """
BatchNorm2d where the batch statistics and the affine parameters are fixed BatchNorm2d where the batch statistics and the affine parameters are fixed
""" """
...@@ -18,7 +38,9 @@ class FrozenBatchNorm2d(torch.nn.Module): ...@@ -18,7 +38,9 @@ class FrozenBatchNorm2d(torch.nn.Module):
self.register_buffer("running_mean", torch.zeros(n)) self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n)) self.register_buffer("running_var", torch.ones(n))
def get_scale_bias(self, nhwc=False): @torch.jit.script_method
def get_scale_bias(self, nhwc):
# type: (bool) -> List[torch.Tensor]
scale = self.weight * self.running_var.rsqrt() scale = self.weight * self.running_var.rsqrt()
bias = self.bias - self.running_mean * scale bias = self.bias - self.running_mean * scale
if nhwc: if nhwc:
...@@ -29,21 +51,21 @@ class FrozenBatchNorm2d(torch.nn.Module): ...@@ -29,21 +51,21 @@ class FrozenBatchNorm2d(torch.nn.Module):
bias = bias.reshape(1, -1, 1, 1) bias = bias.reshape(1, -1, 1, 1)
return scale, bias return scale, bias
@torch.jit.script_method
def forward(self, x): def forward(self, x):
scale, bias = self.get_scale_bias() scale, bias = self.get_scale_bias(False)
return x * scale + bias return x * scale + bias
@torch.jit.script @torch.jit.script
def drelu_dscale1(grad_o, output, scale1): def drelu_dscale1(grad_o, output, scale1):
relu_mask = (output>0).half() relu_mask = (output>0)
dx_relu = relu_mask * grad_o dx_relu = relu_mask * grad_o
g1 = dx_relu * scale1 g1 = dx_relu * scale1
return g1, dx_relu return g1, dx_relu
@torch.jit.script @torch.jit.script
def drelu_dscale2(grad_o, output, scale1, scale2): def drelu_dscale2(grad_o, output, scale1, scale2):
relu_mask = (output>0).half() relu_mask = (output>0)
dx_relu = relu_mask * grad_o dx_relu = relu_mask * grad_o
g1 = dx_relu * scale1 g1 = dx_relu * scale1
g2 = dx_relu * scale2 g2 = dx_relu * scale2
...@@ -147,6 +169,7 @@ class Bottleneck(torch.nn.Module): ...@@ -147,6 +169,7 @@ class Bottleneck(torch.nn.Module):
self.bn1 = norm_func(bottleneck_channels) self.bn1 = norm_func(bottleneck_channels)
self.bn2 = norm_func(bottleneck_channels) self.bn2 = norm_func(bottleneck_channels)
self.bn3 = norm_func(out_channels) self.bn3 = norm_func(out_channels)
self.w_scale = None
self.use_cudnn = use_cudnn self.use_cudnn = use_cudnn
...@@ -170,10 +193,33 @@ class Bottleneck(torch.nn.Module): ...@@ -170,10 +193,33 @@ class Bottleneck(torch.nn.Module):
for p in self.parameters(): for p in self.parameters():
with torch.no_grad(): with torch.no_grad():
p.data = p.data.permute(0,2,3,1).contiguous() p.data = p.data.permute(0,2,3,1).contiguous()
return return
# Returns single callable that recomputes scale and bias for all frozen batch-norms.
# This method must be called before cuda graphing.
# The callable it returns can be called anytime.
# Calling this method will prevent these from being computed every forward call.
def get_scale_bias_callable(self):
self.w_scale, self.w_bias, args = [], [], []
batch_norms = [self.bn1, self.bn2, self.bn3]
if self.downsample is not None:
batch_norms.append(self.downsample[1])
for bn in batch_norms:
s = torch.empty_like(bn.weight)
b = torch.empty_like(s)
args.append( (bn.weight, bn.bias, bn.running_mean, bn.running_var, s, b) )
if self.explicit_nhwc:
self.w_scale.append( s.reshape(1, 1, 1, -1) )
self.w_bias.append( b.reshape(1, 1, 1, -1) )
else:
self.w_scale.append( s.reshape(1, -1, 1, 1) )
self.w_bias.append( b.reshape(1, -1, 1, 1) )
return func.partial(compute_scale_bias_method, self.explicit_nhwc, args)
def forward(self, x): def forward(self, x):
if self.use_cudnn: if self.use_cudnn:
if self.w_scale is None:
# calculate scale/bias from registered buffers # calculate scale/bias from registered buffers
# TODO: make this better # TODO: make this better
s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc) s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)
...@@ -185,8 +231,9 @@ class Bottleneck(torch.nn.Module): ...@@ -185,8 +231,9 @@ class Bottleneck(torch.nn.Module):
s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc) s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)
w_scale.append(s4) w_scale.append(s4)
w_bias.append(b4) w_bias.append(b4)
out = bottleneck_function(self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv) out = bottleneck_function(self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv)
else:
out = bottleneck_function(self.explicit_nhwc, self.stride, self.w_scale, self.w_bias, x, *self.w_conv)
return out return out
if self.explicit_nhwc: if self.explicit_nhwc:
...@@ -217,7 +264,12 @@ class Bottleneck(torch.nn.Module): ...@@ -217,7 +264,12 @@ class Bottleneck(torch.nn.Module):
class SpatialBottleneckFunction(torch.autograd.Function): class SpatialBottleneckFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, spatial_group_size, local_rank, comm, stream1, nhwc, stride_1x1, scale, bias, x, *conv): def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method, use_delay_kernel, explicit_nhwc, stride_1x1, scale, bias, thresholdTop, thresholdBottom, x, *conv):
if spatial_group_size > 1:
stream1 = spatial_halo_exchanger.stream1
stream2 = spatial_halo_exchanger.stream2
stream3 = spatial_halo_exchanger.stream3
# TODO: clean up order of tensors # TODO: clean up order of tensors
args = [x, *conv[0:3], *scale[0:3], *bias[0:3]] args = [x, *conv[0:3], *scale[0:3], *bias[0:3]]
ctx.downsample = len(conv) > 3 ctx.downsample = len(conv) > 3
...@@ -226,59 +278,152 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -226,59 +278,152 @@ class SpatialBottleneckFunction(torch.autograd.Function):
args.append(scale[3]) args.append(scale[3])
args.append(bias[3]) args.append(bias[3])
# weight buffers are always in nhwc while shape can be nhwc or channels_last # weight buffers are always in explicit_nhwc while shape can be explicit_nhwc or channels_last
# here we pass in flag and let c++ handle it # here we pass in flag and let c++ handle it
# alternatively, we can put all sizes into a fixed format and pass it in # alternatively, we can put all sizes into a fixed format and pass it in
outputs = fast_bottleneck.forward_init(nhwc, stride_1x1, args) outputs = fast_bottleneck.forward_init(explicit_nhwc, stride_1x1, args)
fast_bottleneck.forward_out1(nhwc, stride_1x1, args, outputs) fast_bottleneck.forward_out1(explicit_nhwc, stride_1x1, args, outputs)
fast_bottleneck.forward_out2(nhwc, stride_1x1, args, outputs)
# do halo exchange for outputs[0] (out1)
# compute halo cells for outputs[1]
if spatial_group_size > 1: if spatial_group_size > 1:
out1 = outputs[0] out1 = outputs[0]
if explicit_nhwc:
N,Hs,W,C = list(out1.shape) N,Hs,W,C = list(out1.shape)
memory_format = torch.contiguous_format
out1_pad = torch.empty([N,Hs+2,W,C], dtype=out1.dtype, device='cuda')
else:
N,C,Hs,W = list(out1.shape)
memory_format = torch.channels_last if out1.is_contiguous(memory_format=torch.channels_last) else torch.contiguous_format
out1_pad = torch.empty([N,C,Hs+2,W], dtype=out1.dtype, device='cuda', memory_format=memory_format)
stream1.wait_stream(torch.cuda.current_stream()) stream1.wait_stream(torch.cuda.current_stream())
if spatial_method != 2: stream3.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream1): with torch.cuda.stream(stream1):
# copy halos to send buffer if explicit_nhwc:
send_halos = torch.empty((N,2,W,C),dtype=out1.dtype,device=out1.device) top_out1_halo = out1_pad[:,:1,:,:]
send_halos[:,:1,:,:].copy_(out1[:,:1,:,:]) btm_out1_halo = out1_pad[:,Hs+1:Hs+2,:,:]
send_halos[:,1:,:,:].copy_(out1[:,Hs-1:,:,:]) spatial_halo_exchanger.left_right_halo_exchange(out1[:,:1,:,:], out1[:,Hs-1:,:,:], top_out1_halo, btm_out1_halo)
all_halos = torch.empty((N,2*spatial_group_size,W,C),dtype=out1.dtype,device=out1.device) else:
all_halos = [all_halos[:,i*2:(i+1)*2,:,:] for i in range(spatial_group_size)] top_out1_halo = out1_pad[:,:,:1,:]
dist.all_gather(all_halos,send_halos,group=comm) btm_out1_halo = out1_pad[:,:,Hs+1:Hs+2,:]
fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device) spatial_halo_exchanger.left_right_halo_exchange(out1[:,:,:1,:], out1[:,:,Hs-1:,:], top_out1_halo, btm_out1_halo)
top_out1_halo = all_halos[(spatial_group_size+local_rank-1)%spatial_group_size][:,1:,:,:] if spatial_method == 1:
if local_rank > 0: # overlap mid convolution with halo transfer
fat_halo[:,:1,:,:].copy_(top_out1_halo) if spatial_group_rank < spatial_group_size-1:
fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:]) stream2.wait_stream(stream1)
top_out2 = fast_bottleneck.forward_out2_halo(nhwc, fat_halo, args) with torch.cuda.stream(stream2):
btm_out1_halo = all_halos[(local_rank+1)%spatial_group_size][:,:1,:,:] if explicit_nhwc:
if local_rank < spatial_group_size-1: btm_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:]) btm_fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:])
fat_halo[:,2:,:,:].copy_(btm_out1_halo) btm_fat_halo[:,2:,:,:].copy_(btm_out1_halo)
btm_out2 = fast_bottleneck.forward_out2_halo(nhwc, fat_halo, args) else:
btm_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device)
btm_fat_halo[:,:,0:2,:].copy_(out1[:,:,Hs-2:,:])
btm_fat_halo[:,:,2:,:].copy_(btm_out1_halo)
btm_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, btm_fat_halo, args)
if spatial_group_rank > 0:
with torch.cuda.stream(stream1):
if explicit_nhwc:
top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
else:
top_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:,:1,:].copy_(top_out1_halo)
top_fat_halo[:,:,1:3,:].copy_(out1[:,:,:2,:])
top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args)
if use_delay_kernel: inc.add_delay(10)
elif spatial_method != 2 and spatial_method != 3:
assert(False), "spatial_method must be 1, 2 or 3"
if spatial_group_size <= 1:
fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs)
elif spatial_method == 1:
fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs)
with torch.cuda.stream(stream3):
if explicit_nhwc:
out1_pad[:,1:Hs+1,:,:].copy_(out1)
else:
out1_pad[:,:,1:Hs+1,:].copy_(out1)
elif spatial_method == 2:
# wait for halo transfer to finish before doing a full convolution of padded x
if explicit_nhwc:
out1_pad[:,1:Hs+1,:,:].copy_(out1)
else:
out1_pad[:,:,1:Hs+1,:].copy_(out1)
torch.cuda.current_stream().wait_stream(stream1) torch.cuda.current_stream().wait_stream(stream1)
fast_bottleneck.forward_out2_pad(explicit_nhwc, stride_1x1, args, outputs, out1_pad)
elif spatial_method == 3:
fast_bottleneck.forward_out2_mask(explicit_nhwc, stride_1x1, args, outputs, thresholdTop, thresholdBottom)
with torch.cuda.stream(stream3):
if explicit_nhwc:
out1_pad[:,1:Hs+1,:,:].copy_(out1)
else:
out1_pad[:,:,1:Hs+1,:].copy_(out1)
# compute halo cells for outputs[1] (out2)
if spatial_group_size > 1:
out2 = outputs[1] out2 = outputs[1]
if local_rank > 0: if explicit_nhwc:
out2[:,:1,:,:].copy_(top_out2) top_out2_halo = out2[:,:1,:,:]
if local_rank < spatial_group_size-1: btm_out2_halo = out2[:,Hs-1:,:,:]
out2[:,Hs-1:,:,:].copy_(btm_out2) else:
top_out2_halo = out2[:,:,:1,:]
btm_out2_halo = out2[:,:,Hs-1:,:]
if spatial_method == 1:
if spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(stream1)
top_out2_halo.copy_(top_out2)
if spatial_group_rank < spatial_group_size-1:
torch.cuda.current_stream().wait_stream(stream2)
btm_out2_halo.copy_(btm_out2)
elif spatial_method == 3:
# Note
# out2 halo correction cannot overlap with anything since it has
# to wait for out2_mask to finish, but itself has to finish before
# the first kernel of _forward_rest can launch.
# At least we can overlap the two halo correction kernels.
if spatial_group_rank < spatial_group_size-1:
stream2.wait_stream(stream1) # wait for halo transfers to finish
stream2.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish
with torch.cuda.stream(stream2):
w1by3 = args[2][:,2:3,:,:].clone()
btm_out1_halo = btm_out1_halo.clone()
btm_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, btm_out1_halo, args, w1by3, btm_out2_halo.clone())
btm_out2_halo.copy_(btm_out2)
if spatial_group_rank > 0:
stream1.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish
with torch.cuda.stream(stream1):
w1by3 = args[2][:,:1,:,:].clone()
top_out1_halo = top_out1_halo.clone()
top_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, top_out1_halo, args, w1by3, top_out2_halo.clone())
top_out2_halo.copy_(top_out2)
if spatial_group_rank < spatial_group_size-1:
torch.cuda.current_stream().wait_stream(stream2)
if spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(stream1)
fast_bottleneck.forward_rest(nhwc, stride_1x1, args, outputs) fast_bottleneck.forward_rest(explicit_nhwc, stride_1x1, args, outputs)
# save halos for backward pass # save halos for backward pass
if spatial_group_size > 1: if spatial_group_size > 1:
ctx.save_for_backward(*(args+outputs+[top_out1_halo,btm_out1_halo])) if spatial_method != 2:
# make sure copy of mid-section of out1 into out1_pad is done before exiting
torch.cuda.current_stream().wait_stream(stream3)
ctx.save_for_backward(*(args+outputs+[out1_pad,]))
else: else:
ctx.save_for_backward(*(args+outputs)) ctx.save_for_backward(*(args+outputs))
# save relu outputs for drelu # save relu outputs for drelu
ctx.nhwc = nhwc ctx.explicit_nhwc = explicit_nhwc
ctx.stride_1x1 = stride_1x1 ctx.stride_1x1 = stride_1x1
ctx.spatial_group_size = spatial_group_size ctx.spatial_group_size = spatial_group_size
ctx.local_rank = local_rank if spatial_group_size > 1:
ctx.comm = comm ctx.spatial_group_rank = spatial_group_rank
ctx.spatial_halo_exchanger = spatial_halo_exchanger
ctx.spatial_method = spatial_method
ctx.use_delay_kernel = use_delay_kernel
ctx.thresholdTop = thresholdTop
ctx.thresholdBottom = thresholdBottom
ctx.stream1 = stream1 ctx.stream1 = stream1
ctx.stream2 = stream2
ctx.stream3 = stream3
return outputs[2] return outputs[2]
# backward relu is not exposed, MUL with mask used now # backward relu is not exposed, MUL with mask used now
...@@ -286,9 +431,8 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -286,9 +431,8 @@ class SpatialBottleneckFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_o): def backward(ctx, grad_o):
if ctx.spatial_group_size > 1: if ctx.spatial_group_size > 1:
top_out1_halo = ctx.saved_tensors[-2] out1_pad = ctx.saved_tensors[-1]
btm_out1_halo = ctx.saved_tensors[-1] outputs = ctx.saved_tensors[-4:-1]
outputs = ctx.saved_tensors[-5:-2]
else: else:
outputs = ctx.saved_tensors[-3:] outputs = ctx.saved_tensors[-3:]
...@@ -310,58 +454,79 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -310,58 +454,79 @@ class SpatialBottleneckFunction(torch.autograd.Function):
if ctx.downsample: if ctx.downsample:
t_list.append(ctx.saved_tensors[10]) t_list.append(ctx.saved_tensors[10])
grads = fast_bottleneck.backward_init(ctx.nhwc, ctx.stride_1x1, t_list) grads = fast_bottleneck.backward_init(ctx.explicit_nhwc, ctx.stride_1x1, t_list)
grad_out2 = fast_bottleneck.backward_grad_out2(ctx.nhwc, ctx.stride_1x1, t_list, grads) wgrad3_stream = torch.cuda.Stream()
wgrad3_stream.wait_stream(torch.cuda.current_stream())
# compute wgrad2 for internal cells grad_out2 = fast_bottleneck.backward_grad_out2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads)
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2) wgrad2_stream = torch.cuda.Stream()
wgrad2_stream.wait_stream(torch.cuda.current_stream())
# apply wgrad2 halos
if ctx.spatial_group_size > 1:
if ctx.local_rank > 0:
top_grad2_halo = grad_out2[:,:1,:,:]
top_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, top_out1_halo, top_grad2_halo)
wgrad2[:,:1,:,:].add_(top_wgrad2_halo)
if ctx.local_rank < ctx.spatial_group_size-1:
btm_grad2_halo = grad_out2[:,-1:,:,:]
btm_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, btm_out1_halo, btm_grad2_halo)
wgrad2[:,-1:,:,:].add_(btm_wgrad2_halo)
# do halo exchange of grad_out2 here # do halo exchange of grad_out2 here
# compute halo cells for grad_out1 # compute halo cells for grad_out1
if ctx.spatial_group_size > 1: if ctx.spatial_group_size > 1:
if ctx.explicit_nhwc:
N,Hs,W,C = list(grad_out2.shape) N,Hs,W,C = list(grad_out2.shape)
else:
N,C,Hs,W = list(grad_out2.shape)
relu1 = t_list[12]
ctx.stream1.wait_stream(torch.cuda.current_stream()) ctx.stream1.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(ctx.stream1): with torch.cuda.stream(ctx.stream1):
top_halo, btm_halo = ctx.spatial_halo_exchanger.left_right_halo_exchange(grad_out2[:,:1,:,:], grad_out2[:,Hs-1:,:,:])
# copy halos to send buffer # copy halos to send buffer
send_halos = torch.empty((N,2,W,C),dtype=grad_out2.dtype,device=grad_out2.device) if ctx.spatial_method == 1 or ctx.spatial_method == 2:
send_halos[:,:1,:,:].copy_(grad_out2[:,:1,:,:]) # 1 -> halo recompute approach
send_halos[:,1:,:,:].copy_(grad_out2[:,Hs-1:,:,:]) # 2 -> wait for concatenated halos, then do single conv on full input (not implemented yet for bprop)
all_halos = torch.empty((N,2*ctx.spatial_group_size,W,C),dtype=grad_out2.dtype,device=grad_out2.device) if ctx.spatial_group_rank < ctx.spatial_group_size-1:
all_halos = [all_halos[:,i*2:(i+1)*2,:,:] for i in range(ctx.spatial_group_size)] ctx.stream2.wait_stream(ctx.stream1)
dist.all_gather(all_halos,send_halos,group=ctx.comm) with torch.cuda.stream(ctx.stream2):
relu1 = t_list[12] if ctx.explicit_nhwc:
fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device) btm_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device) btm_fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:])
if ctx.local_rank > 0: btm_fat_halo[:,2:,:,:].copy_(btm_halo)
top_halo = all_halos[ctx.local_rank-1][:,1:,:,:] btm_fat_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
fat_halo[:,:1,:,:].copy_(top_halo) btm_fat_relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2:,:,:])
fat_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:]) btm_fat_relu_halo[:,2:,:,:].zero_()
relu_halo[:,:1,:,:].zero_() else:
relu_halo[:,1:,:,:].copy_(relu1[:,:2,:,:]) btm_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, fat_halo, relu_halo) btm_fat_halo[:,:,:2,:].copy_(grad_out2[:,:,Hs-2:,:])
top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:] btm_fat_halo[:,:,2:,:].copy_(btm_halo)
if ctx.local_rank < ctx.spatial_group_size-1: btm_fat_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
btm_halo = all_halos[ctx.local_rank+1][:,:1,:,:] btm_fat_relu_halo[:,:,:2,:].copy_(relu1[:,:,Hs-2:,:])
fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:]) btm_fat_relu_halo[:,:,2:,:].zero_()
fat_halo[:,2:,:,:].copy_(btm_halo) btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, btm_fat_halo, btm_fat_relu_halo)
relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2:,:,:]) if ctx.explicit_nhwc:
relu_halo[:,2:,:,:].zero_()
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, fat_halo, relu_halo)
btm_grad_out1_halo = btm_grad_out1_halo[:,1:2,:,:] btm_grad_out1_halo = btm_grad_out1_halo[:,1:2,:,:]
else:
btm_grad_out1_halo = btm_grad_out1_halo[:,:,1:2,:]
if ctx.spatial_group_rank > 0:
with torch.cuda.stream(ctx.stream1):
if ctx.explicit_nhwc:
top_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
top_fat_halo[:,:1,:,:].copy_(top_halo)
top_fat_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:])
top_fat_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
top_fat_relu_halo[:,:1,:,:].zero_()
top_fat_relu_halo[:,1:,:,:].copy_(relu1[:,:2,:,:])
else:
top_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
top_fat_halo[:,:,:1,:].copy_(top_halo)
top_fat_halo[:,:,1:,:].copy_(grad_out2[:,:,:2,:])
top_fat_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
top_fat_relu_halo[:,:,:1,:].zero_()
top_fat_relu_halo[:,:,1:,:].copy_(relu1[:,:,:2,:])
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, top_fat_halo, top_fat_relu_halo)
if ctx.explicit_nhwc:
top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:]
else:
top_grad_out1_halo = top_grad_out1_halo[:,:,1:2,:]
if ctx.use_delay_kernel: inc.add_delay(10)
elif ctx.spatial_method != 3:
assert(False), "spatial_method must be 1, 2 or 3"
# compute grad_out1 for internal cells # compute grad_out1 for internal cells
grad_out1 = fast_bottleneck.backward_grad_out1(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2) if ctx.spatial_group_size <= 1 or ctx.spatial_method == 1 or ctx.spatial_method == 2:
grad_out1 = fast_bottleneck.backward_grad_out1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
elif ctx.spatial_group_size > 1 and ctx.spatial_method == 3:
grad_out1 = fast_bottleneck.backward_grad_out1_mask(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, ctx.thresholdTop, ctx.thresholdBottom)
# apply halo cells to grad_out1 # apply halo cells to grad_out1
if ctx.spatial_group_size > 1: if ctx.spatial_group_size > 1:
...@@ -369,17 +534,69 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -369,17 +534,69 @@ class SpatialBottleneckFunction(torch.autograd.Function):
z = t_list[4] z = t_list[4]
relu1 = t_list[12] relu1 = t_list[12]
#print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape)))) #print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape))))
if ctx.spatial_method == 1 or ctx.spatial_method == 2:
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
torch.cuda.current_stream().wait_stream(ctx.stream2)
if ctx.explicit_nhwc:
grad_out1[:,Hs-1:,:,:].copy_(btm_grad_out1_halo)
else:
grad_out1[:,:,Hs-1:,:].copy_(btm_grad_out1_halo)
#print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
if ctx.spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(ctx.stream1) torch.cuda.current_stream().wait_stream(ctx.stream1)
if ctx.local_rank > 0: if ctx.explicit_nhwc:
grad_out1[:,:1,:,:].copy_(top_grad_out1_halo) grad_out1[:,:1,:,:].copy_(top_grad_out1_halo)
#print("ctx.local_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.local_rank, str(list(grad_out1.shape)))) else:
if ctx.local_rank < ctx.spatial_group_size-1: grad_out1[:,:,:1,:].copy_(top_grad_out1_halo)
grad_out1[:,Hs-1:,:,:].copy_(btm_grad_out1_halo) #print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
#print("ctx.local_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.local_rank, str(list(grad_out1.shape)))) elif ctx.spatial_method == 3:
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
if ctx.explicit_nhwc:
btm_relu_halo = relu1[:,Hs-1:,:,:].clone()
btm_grad_out1 = grad_out1[:,Hs-1:,:,:]
else:
btm_relu_halo = relu1[:,:,Hs-1:,:].clone()
btm_grad_out1 = grad_out1[:,:,Hs-1:,:]
w1by3 = w[:,:1,:,:].clone()
ctx.stream2.wait_stream(ctx.stream1) # wait for halo transfers to finish
ctx.stream2.wait_stream(torch.cuda.current_stream()) # wait for backward_grad_out1_mask to finish before launching halo correction kernel
with torch.cuda.stream(ctx.stream2):
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr(ctx.explicit_nhwc, ctx.stride_1x1, t_list, w1by3, grads, btm_halo, btm_relu_halo, btm_grad_out1.clone())
btm_grad_out1.copy_(btm_grad_out1_halo)
if ctx.spatial_group_rank > 0:
if ctx.explicit_nhwc:
top_relu_halo = relu1[:,:1,:,:].clone()
top_grad_out1 = grad_out1[:,:1,:,:]
else:
top_relu_halo = relu1[:,:,:1,:].clone()
top_grad_out1 = grad_out1[:,:,:1,:]
w1by3 = w[:,2:,:,:].clone()
ctx.stream1.wait_stream(torch.cuda.current_stream()) # wait for backward_grad_out1_mask to finish before launching halo correction kernel
with torch.cuda.stream(ctx.stream1):
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr(ctx.explicit_nhwc, ctx.stride_1x1, t_list, w1by3, grads, top_halo, top_relu_halo, top_grad_out1.clone())
top_grad_out1.copy_(top_grad_out1_halo)
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
torch.cuda.current_stream().wait_stream(ctx.stream2) # wait for halo correction to finish
if ctx.spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(ctx.stream1)
fast_bottleneck.backward_rest(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1, wgrad2) wgrad1_stream = torch.cuda.Stream()
wgrad1_stream.wait_stream(torch.cuda.current_stream())
fast_bottleneck.backward_rest(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1)
with torch.cuda.stream(wgrad3_stream):
fast_bottleneck.backward_wgrad3(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads)
with torch.cuda.stream(wgrad2_stream):
if ctx.spatial_group_size > 1:
fast_bottleneck.backward_wgrad2_pad(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, out1_pad, grad_out2)
else:
fast_bottleneck.backward_wgrad2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
with torch.cuda.stream(wgrad1_stream):
fast_bottleneck.backward_wgrad1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out1)
torch.cuda.current_stream().wait_stream(wgrad3_stream)
torch.cuda.current_stream().wait_stream(wgrad2_stream)
torch.cuda.current_stream().wait_stream(wgrad1_stream)
return (None, None, None, None, None, None, None, None, *grads) return (None, None, None, None, None, None, None, None, None, None, None, None, *grads)
spatial_bottleneck_function = SpatialBottleneckFunction.apply spatial_bottleneck_function = SpatialBottleneckFunction.apply
...@@ -393,7 +610,7 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -393,7 +610,7 @@ class SpatialBottleneck(torch.nn.Module):
def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1, def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1,
dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False, dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False,
spatial_group_size=1, communicator=None): spatial_parallel_args=None):
super(SpatialBottleneck, self).__init__() super(SpatialBottleneck, self).__init__()
if groups != 1: if groups != 1:
raise RuntimeError('Only support groups == 1') raise RuntimeError('Only support groups == 1')
...@@ -422,6 +639,7 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -422,6 +639,7 @@ class SpatialBottleneck(torch.nn.Module):
self.bn1 = norm_func(bottleneck_channels) self.bn1 = norm_func(bottleneck_channels)
self.bn2 = norm_func(bottleneck_channels) self.bn2 = norm_func(bottleneck_channels)
self.bn3 = norm_func(out_channels) self.bn3 = norm_func(out_channels)
self.w_scale = None
self.use_cudnn = use_cudnn self.use_cudnn = use_cudnn
...@@ -434,6 +652,8 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -434,6 +652,8 @@ class SpatialBottleneck(torch.nn.Module):
for w in self.w_conv: for w in self.w_conv:
kaiming_uniform_(w, a=1) kaiming_uniform_(w, a=1)
self.thresholdTop, self.thresholdBottom = None, None
# TODO: prevent unsupported case usage # TODO: prevent unsupported case usage
# support cases # support cases
# native cudnn # native cudnn
...@@ -447,30 +667,45 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -447,30 +667,45 @@ class SpatialBottleneck(torch.nn.Module):
p.data = p.data.permute(0,2,3,1).contiguous() p.data = p.data.permute(0,2,3,1).contiguous()
# spatial communicator # spatial communicator
self.spatial_group_size = spatial_group_size if spatial_parallel_args is None:
if spatial_group_size > 1: self.spatial_parallel_args = (1, 0, None, None, 0, False)
world_size = dist.get_world_size() else:
num_groups = world_size // spatial_group_size self.spatial_parallel_args = spatial_parallel_args
assert(num_groups*spatial_group_size == world_size), "torch.distributed.get_world_size() must be multiple of group_size"
rank = dist.get_rank()
self.local_rank = rank % spatial_group_size
if communicator is None:
for group in range(num_groups):
ranks = list(range(group*spatial_group_size,(group+1)*spatial_group_size))
comm = torch.distributed.new_group(ranks=ranks)
if rank in ranks:
self.communicator = comm
else:
self.communicator = communicator
self.stream1 = torch.cuda.Stream()
self.spatial_args = self.spatial_group_size, self.local_rank, self.communicator, self.stream1
else:
self.spatial_args = 1, 0, None, None
return return
# Returns single callable that recomputes scale and bias for all frozen batch-norms.
# This method must be called before cuda graphing.
# The callable it returns can be called anytime.
# Calling this method will prevent these from being computed every forward call.
def get_scale_bias_callable(self):
self.w_scale, self.w_bias, args = [], [], []
batch_norms = [self.bn1, self.bn2, self.bn3]
if self.downsample is not None:
batch_norms.append(self.downsample[1])
for bn in batch_norms:
s = torch.empty_like(bn.weight)
b = torch.empty_like(s)
args.append( (bn.weight, bn.bias, bn.running_mean, bn.running_var, s, b) )
if self.explicit_nhwc:
self.w_scale.append( s.reshape(1, 1, 1, -1) )
self.w_bias.append( b.reshape(1, 1, 1, -1) )
else:
self.w_scale.append( s.reshape(1, -1, 1, 1) )
self.w_bias.append( b.reshape(1, -1, 1, 1) )
return func.partial(compute_scale_bias_method, self.explicit_nhwc, args)
def forward(self, x): def forward(self, x):
if self.use_cudnn: if self.use_cudnn:
if self.thresholdTop is None:
spatial_group_size, spatial_group_rank, _, _, _, _ = self.spatial_parallel_args
if self.explicit_nhwc:
N,H,W,C = list(x.shape)
else:
N,C,H,W = list(x.shape)
self.thresholdTop = torch.tensor([1 if spatial_group_rank > 0 else 0], dtype=torch.int32, device='cuda')
self.thresholdBottom = torch.tensor([H-2 if spatial_group_rank < spatial_group_size - 1 else H-1], dtype=torch.int32, device='cuda')
if self.w_scale is None:
# calculate scale/bias from registered buffers # calculate scale/bias from registered buffers
# TODO: make this better # TODO: make this better
s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc) s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)
...@@ -482,8 +717,9 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -482,8 +717,9 @@ class SpatialBottleneck(torch.nn.Module):
s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc) s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)
w_scale.append(s4) w_scale.append(s4)
w_bias.append(b4) w_bias.append(b4)
out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, w_scale, w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv)
out = spatial_bottleneck_function(*self.spatial_args, self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv) else:
out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, self.w_scale, self.w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv)
return out return out
if self.explicit_nhwc: if self.explicit_nhwc:
...@@ -510,3 +746,4 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -510,3 +746,4 @@ class SpatialBottleneck(torch.nn.Module):
out = self.relu(out) out = self.relu(out)
return out return out
import os
import torch import torch
from maskrcnn_benchmark.modeling.backbone.resnet import Bottleneck from apex.contrib.bottleneck import Bottleneck, SpatialBottleneck
from maskrcnn_benchmark.layers.nhwc import nhwc_to_nchw_transform, nchw_to_nhwc_transform from apex.contrib.bottleneck import HaloExchangerNoComm, HaloExchangerAllGather, HaloExchangerSendRecv, HaloExchangerPeer
from maskrcnn_benchmark.layers.nhwc.batch_norm import FrozenBatchNorm2d_NHWC from apex.contrib.peer_memory import PeerMemoryPool
from apex.contrib.bottleneck import Bottleneck as FastBottleneck
from apex.contrib.bottleneck import SpatialBottleneck
def single_module_test(ref, rank, world_size, numtype, device, shape, fast, spatial_group_size, in_channels, bottleneck_channels, out_channels, num_groups, stride_in_1x1, stride, dilation, norm_func, nhwc): def ground_truth_bottleneck(C, dtype, explicit_nhwc):
# inputs + modules bottleneck = Bottleneck(C,C,C,use_cudnn=True,explicit_nhwc=explicit_nhwc)
bottleneck.to(dtype=dtype, device='cuda')
for p in bottleneck.parameters():
torch.distributed.broadcast(p, 0)
for b in bottleneck.buffers():
torch.distributed.broadcast(b, 0)
return bottleneck
def print_bottleneck_p_and_b(bottleneck):
with torch.no_grad(): with torch.no_grad():
input_shape = [1, in_channels] + list(shape) for n,p in bottleneck.named_parameters():
x = torch.randn(input_shape, dtype=numtype, device=device) print("%s :: %s" % (n, str(p.norm(p=2,dtype=torch.float32))))
if nhwc: for n,p in bottleneck.named_buffers():
x = nchw_to_nhwc_transform(x).contiguous() print("%s :: %s" % (n, str(p.norm(p=2,dtype=torch.float32))))
x.requires_grad = True
print(x.shape, x.stride())
def has_nan(x):
#if spatial_group_size > 1: if isinstance(x, list) or isinstance(x, tuple):
# fast = False # hack so fast bottleneck can be run against distributed bottleneck for xx in x:
#if spatial_group_size == 1: if torch.any(torch.isnan(xx)):
# fast = False return True
return False
if fast: elif isinstance(x, dict):
if spatial_group_size == 1: for k,v in x.items():
bottleneck = FastBottleneck( if torch.any(torch.isnan(v)):
in_channels=in_channels, return True
bottleneck_channels=bottleneck_channels,
out_channels=out_channels,
stride=stride,
dilation=dilation,
explicit_nhwc=nhwc,
use_cudnn=True)
else: else:
bottleneck = SpatialBottleneck( return torch.any(torch.isnan(x))
in_channels=in_channels,
bottleneck_channels=bottleneck_channels,
out_channels=out_channels, def rel_diff_t(xx1, xx2):
stride=stride, return ((xx1 - xx2).norm(p=2,dtype=torch.float32) / (xx1 + xx2).norm(p=2,dtype=torch.float32)).item()
dilation=dilation,
explicit_nhwc=nhwc,
use_cudnn=True, def rel_diff(x1, x2):
spatial_group_size=spatial_group_size) if isinstance(x1, list) or isinstance(x1, tuple):
return [rel_diff_t(xx1,xx2) for xx1,xx2 in zip(x1,x2)]
elif isinstance(x1, dict):
return [rel_diff_t(xx1, xx2) for (k1,xx1), (k2,xx2) in zip(x1.items(),x2.items())]
else: else:
bottleneck = Bottleneck( return rel_diff_t(x1,x2)
in_channels,
bottleneck_channels,
out_channels, def graph_it(bottleneck, x):
num_groups, print("Graphing")
stride_in_1x1,
stride,
dilation,
norm_func,
nhwc,
spatial_group_size)
bottleneck = bottleneck.to(dtype=numtype,device=device)
weights = dict(bottleneck.named_parameters())
if ref is not None:
ref_x, _, ref_weights = ref
Hs,H = x.shape[1], ref_x.shape[1]
assert(Hs*spatial_group_size == H), "Hs not a multiple of H"
ref_x = ref_x[:,rank*Hs:(rank+1)*Hs,:,:]
x.copy_(ref_x)
assert(len(weights) == len(ref_weights)), "Reference weights and weights don't match"
for k in weights.keys():
weights[k].copy_(ref_weights[k])
# forward
out = bottleneck(x)
# gradient output
with torch.no_grad(): with torch.no_grad():
grad_out = torch.randn_like(out) x = x.clone()
if ref is not None: x.grad = None
_, ref_grad_out, _ = ref x.requires_grad = True
Hs,H = grad_out.shape[1], ref_grad_out.shape[1] return torch.cuda.make_graphed_callables(bottleneck, (x,))
assert(Hs*spatial_group_size == H), "Hs not a multiple of H"
ref_grad_out = ref_grad_out[:,rank*Hs:(rank+1)*Hs,:,:]
grad_out.copy_(ref_grad_out)
# backward
out.backward(grad_out)
def clone_inputs(bottleneck, x, dy=None):
with torch.no_grad(): with torch.no_grad():
dgrad = x.grad.detach() x = x.clone()
x.grad = None
x.requires_grad = True
if dy is None:
y = bottleneck(x)
dy = torch.randn_like(y) / 1e2
torch.distributed.broadcast(dy, 0)
return x, dy
def fprop_and_bprop(bottleneck, x, dy):
y = bottleneck(x)
y.backward(dy)
dgrad = x.grad.detach()
wgrad = {} wgrad = {}
for n,p in bottleneck.named_parameters(): for n,p in bottleneck.named_parameters():
wgrad[n] = p.grad.detach() wgrad[n] = p.grad.detach()
return x, y, dy, dgrad, wgrad
if world_size > 1: def ground_truth(N, C, H, W, dtype, memory_format, bottleneck):
if spatial_group_size == 1: if memory_format == 1:
# broadcast x, grad_out and weights from rank 0 # 1 -> explicit nhwc
explicit_nhwc = True
with torch.no_grad(): with torch.no_grad():
torch.distributed.broadcast(x,0) x = torch.randn([N,H,W,C], dtype=dtype, device='cuda')
torch.distributed.broadcast(grad_out,0) torch.distributed.broadcast(x, 0)
for k in weights.keys(): x, dy = clone_inputs(bottleneck, x)
torch.distributed.broadcast(weights[k],0) return fprop_and_bprop(bottleneck, x, dy)
else: else:
# gather dgrad (x.grad), sum wgrad (weights) and out # 2 -> native nhwc
N,Hs,W,C = dgrad.shape # 3 -> nchw
H = Hs * spatial_group_size explicit_nhwc = False
dgrad_gathered = torch.empty((N,H,W,C),dtype=dgrad.dtype,device=dgrad.device) assert(False), "Not implemented yet"
dgrad_tensors = [dgrad_gathered[:,i*Hs:(i+1)*Hs,:,:] for i in range(spatial_group_size)]
torch.distributed.all_gather(dgrad_tensors, dgrad)
dgrad = dgrad_gathered
N,Hs,W,C = list(out.shape)
H = Hs * spatial_group_size
out_gathered = torch.empty((N,H,W,C),dtype=dgrad.dtype,device=dgrad.device)
out_tensors= [out_gathered[:,i*Hs:(i+1)*Hs,:,:] for i in range(spatial_group_size)]
torch.distributed.all_gather(out_tensors, out)
out = out_gathered
for k in wgrad.keys():
w = wgrad[k].to(dtype=torch.float64)
torch.distributed.all_reduce(w)
wgrad[k].copy_(w.to(dtype=wgrad[k].dtype))
#torch.distributed.all_reduce(wgrad[k])
return x, out, grad_out, weights, dgrad, wgrad
def module_tests(rank, world_size, numtype, device, fast, spatial_group_sizes, init_args):
r = []
for ia in init_args:
shape = ia[0:4]
args = ia[4:]
rr = []
ref = None
for spatial_group_size in spatial_group_sizes:
N,H,W,C = shape
H = H//spatial_group_size
x, out, grad_out, weights, dgrad, wgrad = single_module_test(ref, rank, world_size, numtype, device, [H,W], fast, spatial_group_size, *args)
if ref is None:
assert(spatial_group_size == 1), "Wrong reference weights"
ref = x, grad_out, weights
if rank == 0:
rr.append( (out, dgrad, wgrad) )
if world_size > 1: torch.distributed.barrier()
r.append(rr)
return r
def main(): def print_ground_truth(gt):
total_num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 x, y, dy, dgrad, wgrad = gt
distributed = total_num_gpus > 1 if has_nan(y) or has_nan(dgrad) or has_nan(wgrad):
ngpus = torch.cuda.device_count() print("Error! Ground truth has NAN")
else:
print("Ok! No NAN found in ground truth")
if distributed:
torch.distributed.init_process_group("nccl") def apply_to_different_bottleneck(gt, bottleneck):
rank, world_size = torch.distributed.get_rank(), torch.distributed.get_world_size() with torch.no_grad():
is_master = True if rank == 0 else False x, _, dy, _, _ = gt
local_rank = rank % ngpus x, dy = clone_inputs(bottleneck, x, dy)
torch.cuda.set_device(local_rank) return fprop_and_bprop(bottleneck, x, dy)
spatial_group_size = total_num_gpus
def compare_single_field(results, f1, f2, l0, l1, l2):
if has_nan(f1) and has_nan(f2):
results[l0] = "both NAN"
elif has_nan(f1):
results[l0] = "%s.%s NAN" % (l1, l0)
elif has_nan(f2):
results[l0] = "%s.%s NAN" % (l2, l0)
else: else:
rank, local_rank, is_master, world_size, spatial_group_size = 0, 0, True, 1, 1 results[l0] = "%s" % (str(rel_diff(f1,f2)))
def compare(gt, bt):
x1, y1, dy1, dgrad1, wgrad1 = gt
x2, y2, dy2, dgrad2, wgrad2 = bt
results = {}
compare_single_field(results, y1, y2, "y", "gt", "bt")
compare_single_field(results, dy1, dy2, "dy", "gt", "bt")
compare_single_field(results, dgrad1, dgrad2, "dgrad", "gt", "bt")
compare_single_field(results, wgrad1, wgrad2, "wgrad", "gt", "bt")
for i in range(torch.distributed.get_world_size()):
if i == torch.distributed.get_rank():
print(i,results)
torch.distributed.barrier()
def spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_parallel_args):
spatial_bottleneck = SpatialBottleneck(C,C,C,use_cudnn=True,explicit_nhwc=explicit_nhwc,spatial_parallel_args=spatial_parallel_args)
spatial_bottleneck.to(dtype=dtype, device='cuda')
with torch.no_grad():
sp = {}
for n,p in spatial_bottleneck.named_parameters():
sp[n] = p
for n,p in gt_bottleneck.named_parameters():
sp[n].copy_(p)
sb = {}
for n,b in spatial_bottleneck.named_buffers():
sb[n] = b
for n,b in gt_bottleneck.named_buffers():
sb[n].copy_(b)
return spatial_bottleneck
def n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp32_reduce=False):
assert(explicit_nhwc), "Only tested for explicit nhwc"
x, _, dy, _, _ = gt
N, H, W, C = list(x.shape) # Tensor is already shaped properly for n-way parallel
dtype = x.dtype
spatial_group_size = world_size
spatial_group_rank = rank
spatial_communicator = None
spatial_halo_exchanger = halex
spatial_method = 1 # 1 -> overlap halo and main conv, 2 -> wait for halo, conv on padded x
use_delay_kernel = False
spatial_parallel_args = (spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method, use_delay_kernel)
spatial_bottleneck = spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_parallel_args)
torch.use_deterministic_algorithms(True) with torch.no_grad():
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
norm_func = FrozenBatchNorm2d_NHWC
init_args = [
(1, 200, 336, 64, 64, 64, 256, 1, True, 1, 1, norm_func, True),
(1, 200, 336, 256, 256, 64, 256, 1, True, 1, 1, norm_func, True),
(1, 200, 336, 256, 256, 128, 512, 1, True, 2, 1, norm_func, True),
(1, 100, 168, 512, 512, 128, 512, 1, True, 1, 1, norm_func, True),
(1, 100, 168, 512, 512, 256, 1024, 1, True, 2, 1, norm_func, True),
(1, 50, 84, 1024, 1024, 256, 1024, 1, True, 1, 1, norm_func, True),
(1, 50, 84, 1024, 1024, 512, 2048, 1, True, 2, 1, norm_func, True),
(1, 25, 42, 2048, 2048, 512, 2048, 1, True, 1, 1, norm_func, True),
(1, 336, 200, 64, 64, 64, 256, 1, True, 1, 1, norm_func, True),
(1, 336, 200, 256, 256, 64, 256, 1, True, 1, 1, norm_func, True),
(1, 336, 200, 256, 256, 128, 512, 1, True, 2, 1, norm_func, True),
(1, 168, 100, 512, 512, 128, 512, 1, True, 1, 1, norm_func, True),
(1, 168, 100, 512, 512, 256, 1024, 1, True, 2, 1, norm_func, True),
(1, 84, 50, 1024, 1024, 256, 1024, 1, True, 1, 1, norm_func, True),
(1, 84, 50, 1024, 1024, 512, 2048, 1, True, 2, 1, norm_func, True),
(1, 42, 25, 2048, 2048, 512, 2048, 1, True, 1, 1, norm_func, True),
]
init_args = init_args[0:1]
# pad H to account for spatial distribution
padded_init_args = []
for ia in init_args:
N,H,W,C = ia[0:4]
m = spatial_group_size * H // (25 if H < W else 42)
H = ((H + m - 1) // m) * m
args = tuple( [N,H,W,C] + list(ia[4:]) )
padded_init_args.append(args)
init_args = padded_init_args
if rank == 0:
for ia in init_args:
print(ia)
spatial_group_sizes = [1]
if spatial_group_size > 1:
spatial_group_sizes.append(spatial_group_size)
numtype, device, fast = torch.float16, 'cuda', True
r = module_tests(rank, world_size, numtype, device, fast, spatial_group_sizes, init_args)
if world_size > 1: torch.distributed.barrier()
if rank == 0:
for rr in r:
print("***")
for out, dgrad, wgrad in rr:
gr = [("out",out.norm(p=2,dtype=torch.float64).item())]
gr = gr + [("dgrad",dgrad.norm(p=2,dtype=torch.float64).item())]
gr = gr + [(k+".wgrad",wgrad[k].norm(p=2,dtype=torch.float64).item()) for k in wgrad.keys()]
print(gr)
if len(rr) == 2:
out1, dgrad1, wgrad1 = rr[0]
out2, dgrad2, wgrad2 = rr[1]
rtol = 1e-1
out_atol = out1.abs().max().item() * rtol
dgrad_atol = dgrad1.abs().max().item() * rtol
wgrad_atol = {}
for k in wgrad1.keys():
wgrad_atol[k] = wgrad1[k].abs().max().item() * rtol
gr = [("out",torch.allclose(out1,out2,rtol,out_atol,equal_nan=True))]
gr = gr + [("dgrad",torch.allclose(dgrad1,dgrad2,rtol,dgrad_atol,equal_nan=True))]
gr = gr + [(k+".wgrad",torch.allclose(wgrad1[k],wgrad2[k],rtol,wgrad_atol[k],equal_nan=True)) for k in wgrad1.keys()]
print(gr)
gr = [("out",(out1-out2).norm(p=2,dtype=torch.float64).item())]
gr = gr + [("dgrad",(dgrad1-dgrad2).norm(p=2,dtype=torch.float64).item())]
gr = gr + [(k+".wgrad",(wgrad1[k]-wgrad2[k]).norm(p=2,dtype=torch.float64).item()) for k in wgrad1.keys()]
print(gr)
N,H,W,C = out1.shape
Hs = H // spatial_group_size
Ht = Hs-2
print("out1@%d:%d=%s" % (Ht,H,str(out1[0,Ht,:8,:5])))
print("out2@%d:%d=%s" % (Ht,H,str(out2[0,Ht,:8,:5])))
Ht = Hs-1
print("out1@%d:%d=%s" % (Ht,H,str(out1[0,Ht,:8,:5])))
print("out2@%d:%d=%s" % (Ht,H,str(out2[0,Ht,:8,:5])))
Ht = Hs
print("out1@%d:%d=%s" % (Ht,H,str(out1[0,Ht,:8,:5])))
print("out2@%d:%d=%s" % (Ht,H,str(out2[0,Ht,:8,:5])))
Ht = Hs+1
print("out1@%d:%d=%s" % (Ht,H,str(out1[0,Ht,:8,:5])))
print("out2@%d:%d=%s" % (Ht,H,str(out2[0,Ht,:8,:5])))
N,H,W,C = dgrad1.shape
Hs = H // spatial_group_size Hs = H // spatial_group_size
Ht = Hs-2 xs = x[:,spatial_group_rank*Hs:(spatial_group_rank+1)*Hs,:,:].clone()
print("dgrad1@%d:%d=%s" % (Ht,H,str(dgrad1[0,Ht,:8,:5]))) dys = dy[:,spatial_group_rank*Hs:(spatial_group_rank+1)*Hs,:,:].clone()
print("dgrad2@%d:%d=%s" % (Ht,H,str(dgrad2[0,Ht,:8,:5]))) xs.requires_grad = True
Ht = Hs-1
print("dgrad1@%d:%d=%s" % (Ht,H,str(dgrad1[0,Ht,:8,:5]))) spatial_bottleneck = graph_it(spatial_bottleneck, xs)
print("dgrad2@%d:%d=%s" % (Ht,H,str(dgrad2[0,Ht,:8,:5]))) _, y, _, dgrad, wgrad = fprop_and_bprop(spatial_bottleneck, xs, dys)
Ht = Hs
print("dgrad1@%d:%d=%s" % (Ht,H,str(dgrad1[0,Ht,:8,:5]))) # gather output pieces
print("dgrad2@%d:%d=%s" % (Ht,H,str(dgrad2[0,Ht,:8,:5]))) for n,p in wgrad.items():
Ht = Hs+1 if fp32_reduce:
print("dgrad1@%d:%d=%s" % (Ht,H,str(dgrad1[0,Ht,:8,:5]))) p32 = p.float()
print("dgrad2@%d:%d=%s" % (Ht,H,str(dgrad2[0,Ht,:8,:5]))) torch.distributed.all_reduce(p32)
p.copy_(p32.half())
else:
if world_size > 1: torch.distributed.barrier() torch.distributed.all_reduce(p)
ys = [torch.empty_like(y) for _ in range(spatial_group_size)]
torch.distributed.all_gather(ys,y)
y = torch.cat(ys,dim=1)
dgrads = [torch.empty_like(dgrad) for _ in range(spatial_group_size)]
torch.distributed.all_gather(dgrads,dgrad)
dgrad = torch.cat(dgrads,dim=1)
return x, y, dy, dgrad, wgrad
def main():
torch.use_deterministic_algorithms(True)
torch.distributed.init_process_group("nccl")
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(rank)
explicit_nhwc = True
dtype = torch.float16
N, C, H, W = 1, 64, 200, 336
Hs = ((H+8*world_size-1) // (8*world_size)) * 8
H = Hs*world_size
gt_bottleneck = ground_truth_bottleneck(C, dtype, explicit_nhwc)
gt = ground_truth(N, C, H, W, dtype, 1, gt_bottleneck)
# verify that spatial bottleneck with group_size 1 produces same results as ground truth bottleneck
spatial_bottleneck = spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, None)
bt = apply_to_different_bottleneck(gt, spatial_bottleneck)
compare(gt, bt)
#print_bottleneck_p_and_b(gt_bottleneck)
#print_bottleneck_p_and_b(spatial_bottleneck)
group_size = world_size
group = rank // group_size
ranks = [group*group_size+i for i in range(group_size)]
rank_in_group = rank % group_size
spatial_group_size = world_size
spatial_communicator = None
peer_pool = PeerMemoryPool(64*1024*1024, 2*1024*1024, ranks)
#class HaloExchangerNoComm(HaloExchanger):
# def __init__(self, ranks, rank_in_group):
#class HaloExchangerAllGather(HaloExchanger):
# def __init__(self, ranks, rank_in_group, comm):
#class HaloExchangerSendRecv(HaloExchanger):
# def __init__(self, ranks, rank_in_group):
#class HaloExchangerPeer(HaloExchanger):
# def __init__(self, ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=1):
#halex = HaloExchangerAllGather(ranks, rank_in_group)
#halex = HaloExchangerSendRecv(ranks, rank_in_group)
halex = HaloExchangerPeer(ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=1)
#print("halex.signals = %s" % (str(halex.signals)))
# Make sure peer memory halo exchanger has finished initializing flags on all ranks before proceeding
#torch.cuda.synchronize()
#torch.distributed.barrier()
bt2 = n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp32_reduce=True)
compare(gt, bt2)
if __name__ == "__main__": if __name__ == "__main__":
......
import torch
import torch.distributed as dist
from torch import nn
import nccl_p2p_cuda as inc
import peer_memory_cuda as pm
# Communication free halo exchanger.
# NB! This halo exchanger does not exchange halos with neighbors as it should, it merely swaps the inputs
# NB! This is only useful for performance testing.
# NB! Do not use for actual production runs
class HaloExchanger(object):
def __init__(self, ranks, rank_in_group):
self.stream1 = torch.cuda.Stream()
self.stream2 = torch.cuda.Stream()
self.stream3 = torch.cuda.Stream()
self.group_size = len(ranks)
self.ranks = ranks
self.rank_in_group = rank_in_group
self.wrap_around_left_rank_in_group = (rank_in_group + self.group_size - 1) % self.group_size
self.wrap_around_right_rank_in_group = (rank_in_group + 1) % self.group_size
self.left_rank = ranks[rank_in_group-1] if rank_in_group > 0 else -1
self.left_zero = True if rank_in_group == 0 else False
self.right_rank = ranks[rank_in_group+1] if rank_in_group < self.group_size - 1 else -1
self.right_zero = True if rank_in_group == self.group_size - 1 else False
class HaloExchangerNoComm(HaloExchanger):
def __init__(self, ranks, rank_in_group):
super(HaloExchangerNoComm, self).__init__(ranks, rank_in_group)
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
if left_input_halo is None:
return right_output_halo, left_output_halo
else:
left_input_halo.copy_(right_output_halo)
right_input_halo.copy_(left_output_halo)
class HaloExchangerAllGather(HaloExchanger):
def __init__(self, ranks, rank_in_group, comm):
super(HaloExchangerAllGather, self).__init__(ranks, rank_in_group)
# self.comm must be NCCL process_group created with torch.distributed.new_group(ranks=ranks)
self.comm = comm
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
N,Hh,W,C = list(left_output_halo.shape)
send_halos = torch.empty((N,2*Hh,W,C),dtype=left_output_halo.dtype,device=left_output_halo.device)
send_halos[:,:Hh,:,:].copy_(left_output_halo)
send_halos[:,Hh:,:,:].copy_(right_output_halo)
all_halos = torch.empty((N,2*Hh*self.group_size,W,C),dtype=left_output_halo.dtype,device=left_output_halo.device)
all_halos = [all_halos[:,i*2*Hh:(i+1)*2*Hh,:,:] for i in range(self.group_size)]
torch.distributed.all_gather(all_halos,send_halos,group=self.comm,no_copy=True)
ag_left_input_halo = all_halos[self.wrap_around_left_rank_in_group][:,Hh:,:,:]
ag_right_input_halo = all_halos[self.wrap_around_right_rank_in_group][:,:Hh,:,:]
if left_input_halo is None:
if self.left_zero:
ag_left_input_halo.zero_()
if self.right_zero:
ag_right_input_halo.zero_()
return ag_left_input_halo, ag_right_input_halo
else:
if self.left_zero:
left_input_halo.zero_()
else:
left_input_halo.copy_(ag_left_input_halo)
if self.right_zero:
right_input_halo.zero_()
else:
right_input_halo.copy_(ag_right_input_halo)
class HaloExchangerSendRecv(HaloExchanger):
def __init__(self, ranks, rank_in_group):
super(HaloExchangerSendRecv, self).__init__(ranks, rank_in_group)
nccl_id = inc.get_unique_nccl_id(1).cuda()
torch.distributed.broadcast(nccl_id, 0)
nccl_id = nccl_id.cpu()
print("%d :: nccl_id = %s" % (torch.distributed.get_rank(), str(nccl_id)))
# Create another global nccl communicator in addition to the one created by torch.distributed.init_process_group("nccl")
# This is unavoidable because the underlying NCCL communicator torch.distributed creates is a protected variable, hence
# it cannot be accessed from another class.
# TODO: Figure out a way to avoid creating a second global communicator
assert(torch.distributed.get_rank() == self.ranks[self.rank_in_group]), "ranks[%d](%d) != torch.distributed.get_rank()(%d)" % (self.rank_in_group, self.ranks[self.rank_in_group], torch.distributed.get_rank())
self.handle = inc.init_nccl_comm(nccl_id, torch.distributed.get_rank(), torch.distributed.get_world_size())
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
if left_input_halo is None:
left_input_halo, right_input_halo = inc.left_right_halo_exchange(self.handle, self.left_rank, self.right_rank , left_output_halo, right_output_halo)
return left_input_halo, right_input_halo
else:
inc.left_right_halo_exchange_inplace(self.handle, self.left_rank, self.right_rank, left_output_halo, right_output_halo, left_input_halo, right_input_halo)
class HaloExchangerPeer(HaloExchanger):
def __init__(self, ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=1):
super(HaloExchangerPeer, self).__init__(ranks, rank_in_group)
self.diagnostics = False
self.explicit_nhwc = explicit_nhwc
self.numSM = numSM
self.peer_pool = peer_pool
self.signals = peer_pool.allocate_peer_tensors([2,4], torch.int32, False, False)
self.signals[self.rank_in_group].zero_()
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
inplace = False if left_input_halo is None and right_input_halo is None else True
if not inplace:
left_input_halo = torch.empty_like(right_output_halo)
right_input_halo = torch.empty_like(left_output_halo)
channels_last = left_output_halo.is_contiguous(memory_format=torch.channels_last) and not self.explicit_nhwc
left_tx = self.peer_pool.allocate_peer_tensors(list(left_output_halo.shape), left_output_halo.dtype, channels_last, True)
right_tx = self.peer_pool.allocate_peer_tensors(list(right_output_halo.shape), right_output_halo.dtype, channels_last, True)
pm.push_pull_halos_1d(
self.diagnostics, self.explicit_nhwc, self.numSM,
left_output_halo, left_tx[self.rank_in_group], right_tx[self.wrap_around_left_rank_in_group], left_input_halo,
right_output_halo, right_tx[self.rank_in_group], left_tx[self.wrap_around_right_rank_in_group], right_input_halo,
self.signals[self.wrap_around_left_rank_in_group], self.signals[self.wrap_around_right_rank_in_group], self.signals[self.rank_in_group]
)
# TODO: Add to push_pull_halos_1d kernel
if self.left_zero:
left_input_halo.zero_()
if self.right_zero:
right_input_halo.zero_()
if not inplace:
return left_input_halo, right_input_halo
# Class that combines input volume with halos from neighbors (1d).
class HaloPadder:
def __init__(self, halo_ex):
self.halo_ex = halo_ex
self.stream1 = torch.cuda.Stream()
self.stream2 = torch.cuda.Stream()
def __call__(self, y, half_halo, explicit_nhwc, H_split):
channels_last = not explicit_nhwc and y.is_contiguous(memory_format=torch.channels_last)
if explicit_nhwc:
N,H,W,C = list(y.shape)
if H_split:
padded_shape = [N,H+2*half_halo,W,C]
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.contiguous_format)
yleft = ypad[:,:half_halo,:,:]
ymid = ypad[:,half_halo:H+half_halo,:,:]
yright = ypad[:,H+half_halo:H+2*half_halo,:,:]
oleft = y[:,:half_halo,:,:]
oright = y[:,H-half_halo:,:,:]
else:
padded_shape = [N,H,W+2*half_halo,C]
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.contiguous_format)
yleft = ypad[:,:,:half_halo,:]
ymid = ypad[:,:,half_halo:W+half_halo,:]
yright = ypad[:,:,W+half_halo:W+2*half_halo,:]
oleft = y[:,:,:half_halo,:]
oright = y[:,:,W-half_halo:,:]
else:
N,C,H,W = list(y.shape)
if H_split:
padded_shape = [N,C,H+2*half_halo,W]
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.channels_last)
yleft = ypad[:,:,:half_halo,:]
ymid = ypad[:,:,half_halo:H+half_halo,:]
yright = ypad[:,:,H+half_halo:H+2*half_halo,:]
oleft = y[:,:,:half_halo,:]
oright = y[:,:,H-half_halo:,:]
else:
padded_shape = [N,C,H,W+2*half_halo]
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.channels_last)
yleft = ypad[:,:,:,:half_halo]
ymid = ypad[:,:,:,half_halo:W+half_halo]
yright = ypad[:,:,:,W+half_halo:W+2*half_halo]
oleft = y[:,:,:,:half_halo]
oright = y[:,:,:,W-half_halo:]
with torch.cuda.stream(self.stream1):
self.halo_ex(oleft, oright, yleft, yright)
with torch.cuda.stream(self.stream2):
ymid.copy_(y)
return ypad
def wait(self):
current_stream = torch.cuda.current_stream()
current_stream.wait_stream(self.stream1)
current_stream.wait_stream(self.stream2)
from .clip_grad import clip_grad_norm_
import torch
from torch._six import inf
from typing import Union, Iterable
_kernel_import_succeeded = False
try:
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
_kernel_import_succeeded = True
except:
_kernel_import_succeeded = False
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
def clip_grad_norm_(
parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
error_if_nonfinite: bool = False) -> torch.Tensor:
r"""Clips gradient norm of an iterable of parameters.
The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
This is identical to torch.nn.utils.clip_grad_norm_, except it
uses a fused CUDA kernel when computing the 2-norm of GPU tensors
in float32 and float16.
Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
error_if_nonfinite (bool): if True, an error is thrown if the total
norm of the gradients from :attr:`parameters` is ``nan``,
``inf``, or ``-inf``. Default: False (will switch to True in the future)
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
max_norm = float(max_norm)
norm_type = float(norm_type)
# Trivial case
if len(parameters) == 0:
return torch.tensor(0.)
# Fallback implementation
if not (_kernel_import_succeeded
and norm_type == 2.0
and any(p.is_cuda for p in parameters)):
return torch.nn.utils.clip_grad_norm_(
parameters,
max_norm,
norm_type=norm_type,
error_if_nonfinite = error_if_nonfinite,
)
# Find fp32 and fp16 gradients on GPU
device = next(p.device for p in parameters if p.is_cuda)
grads_fp32, grads_fp16, grads_misc = [], [], []
for p in parameters:
grad = p.grad.detach()
if p.dtype == torch.float32 and p.device == device:
grads_fp32.append(grad)
elif p.dtype == torch.float16 and p.device == device:
grads_fp16.append(grad)
else:
grads_misc.append(grad)
# Compute gradient L2 norms
norms = []
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device=device)
if grads_fp32:
norms.append(
multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads_fp32],
False,
)[0]
)
if grads_fp16:
norms.append(
multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads_fp16],
False,
)[0],
)
for g in grads_misc:
norms.append(torch.linalg.norm(g).unsqueeze(0).to(device))
total_norm = torch.linalg.norm(torch.cat(norms))
# Check for non-finite values
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
raise RuntimeError(
f'The total norm of order {norm_type} for gradients from '
'`parameters` is non-finite, so it cannot be clipped. To disable '
'this error and scale the gradients by the non-finite norm anyway, '
'set `error_if_nonfinite=False`')
# Scale gradients
clip_coef = max_norm / (total_norm + 1e-6)
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
if grads_fp32:
multi_tensor_applier(
amp_C.multi_tensor_scale,
dummy_overflow_buf,
[grads_fp32, grads_fp32],
clip_coef_clamped,
)
if grads_fp16:
multi_tensor_applier(
amp_C.multi_tensor_scale,
dummy_overflow_buf,
[grads_fp16, grads_fp16],
clip_coef_clamped,
)
for g in grads_misc:
g.mul_(clip_coef_clamped.to(g.device))
return total_norm
from .conv_bias_relu import ConvBiasReLU, ConvBias, ConvBiasMaskReLU
import pdb
import torch
from torch.autograd import gradcheck
from apex import check_cudnn_version_and_warn
import fused_conv_bias_relu
check_cudnn_version_and_warn(__name__, 8400)
class ConvBiasReLU_(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.half)
def forward(ctx, x, weight, bias, padding, stride):
outputs = fused_conv_bias_relu.forward([x, weight, bias], padding, stride)
ctx.save_for_backward(x, weight, outputs[0])
ctx.padding = padding
ctx.stride = stride
return outputs[0]
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, grad_output):
bwd_args = [*ctx.saved_tensors, grad_output]
padding = ctx.padding
stride = ctx.stride
grads = fused_conv_bias_relu.backward(bwd_args, padding, stride)
return grads[0], grads[1], grads[2], None, None
class ConvBiasMaskReLU_(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.half)
def forward(ctx, x, weight, bias, mask, padding, stride):
outputs = fused_conv_bias_relu.forward_mask([x, weight, bias, mask], padding, stride)
ctx.save_for_backward(x, weight, outputs[0])
ctx.padding = padding
ctx.stride = stride
return outputs[0]
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, grad_output):
bwd_args = [*ctx.saved_tensors, grad_output]
padding = ctx.padding
stride = ctx.stride
grads = fused_conv_bias_relu.backward(bwd_args, padding, stride)
return grads[0], grads[1], grads[2], None, None, None
class ConvBias_(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.half)
def forward(ctx, x, weight, bias, padding, stride):
outputs = fused_conv_bias_relu.forward_no_relu([x, weight, bias], padding, stride)
ctx.save_for_backward(x, weight)
ctx.padding = padding
ctx.stride = stride
return outputs[0]
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, grad_output):
bwd_args = [*ctx.saved_tensors, grad_output]
padding = ctx.padding
stride = ctx.stride
grads = fused_conv_bias_relu.backward_no_relu(bwd_args, padding, stride)
return grads[0], grads[1], grads[2], None, None
ConvBiasReLU = ConvBiasReLU_.apply
ConvBiasMaskReLU = ConvBiasMaskReLU_.apply
ConvBias = ConvBias_.apply
...@@ -237,7 +237,7 @@ create_conv_bias_add_act_descriptors(int64_t* x_dim_padded, ...@@ -237,7 +237,7 @@ create_conv_bias_add_act_descriptors(int64_t* x_dim_padded,
.setVirtual() .setVirtual()
.setId('A') // after add .setId('A') // after add
.setAlignment(16) .setAlignment(16)
.setDataType(dataType) .setDataType(CUDNN_DATA_FLOAT)
.build(), .build(),
cudnn_frontend::TensorBuilder() cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded) .setDim(4, y_dim_padded)
...@@ -245,7 +245,7 @@ create_conv_bias_add_act_descriptors(int64_t* x_dim_padded, ...@@ -245,7 +245,7 @@ create_conv_bias_add_act_descriptors(int64_t* x_dim_padded,
.setVirtual() .setVirtual()
.setId('B') // after bias .setId('B') // after bias
.setAlignment(16) .setAlignment(16)
.setDataType(dataType) .setDataType(CUDNN_DATA_FLOAT)
.build(), .build(),
cudnn_frontend::TensorBuilder() cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded) .setDim(4, y_dim_padded)
...@@ -253,7 +253,7 @@ create_conv_bias_add_act_descriptors(int64_t* x_dim_padded, ...@@ -253,7 +253,7 @@ create_conv_bias_add_act_descriptors(int64_t* x_dim_padded,
.setId('C') // after conv .setId('C') // after conv
.setAlignment(16) .setAlignment(16)
.setVirtual() .setVirtual()
.setDataType(dataType) .setDataType(CUDNN_DATA_FLOAT)
.build(), .build(),
cudnn_frontend::TensorBuilder() cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded) .setDim(4, y_dim_padded)
...@@ -268,7 +268,7 @@ create_conv_bias_add_act_descriptors(int64_t* x_dim_padded, ...@@ -268,7 +268,7 @@ create_conv_bias_add_act_descriptors(int64_t* x_dim_padded,
.setId('D') // after optional add .setId('D') // after optional add
.setAlignment(16) .setAlignment(16)
.setVirtual() .setVirtual()
.setDataType(dataType) .setDataType(CUDNN_DATA_FLOAT)
.build()); .build());
} }
...@@ -358,7 +358,7 @@ create_dconv_descriptors(int64_t* x_dim_padded, ...@@ -358,7 +358,7 @@ create_dconv_descriptors(int64_t* x_dim_padded,
.setVirtual() .setVirtual()
.setId('A') // after dconv .setId('A') // after dconv
.setAlignment(16) .setAlignment(16)
.setDataType(dataType) .setDataType(CUDNN_DATA_FLOAT)
.build(), .build(),
cudnn_frontend::TensorBuilder() cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded) .setDim(4, x_dim_padded)
...@@ -366,7 +366,7 @@ create_dconv_descriptors(int64_t* x_dim_padded, ...@@ -366,7 +366,7 @@ create_dconv_descriptors(int64_t* x_dim_padded,
.setVirtual() .setVirtual()
.setId('B') // after drelu .setId('B') // after drelu
.setAlignment(16) .setAlignment(16)
.setDataType(dataType) .setDataType(CUDNN_DATA_FLOAT)
.build()); .build());
} }
...@@ -621,7 +621,7 @@ run_conv_scale_bias_add_activation(int64_t* x_dim_padded, ...@@ -621,7 +621,7 @@ run_conv_scale_bias_add_activation(int64_t* x_dim_padded,
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status); checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error"); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) { } catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
} }
...@@ -749,7 +749,7 @@ run_conv_scale_bias(int64_t* x_dim_padded, ...@@ -749,7 +749,7 @@ run_conv_scale_bias(int64_t* x_dim_padded,
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status); checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error"); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) { } catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
} }
...@@ -877,7 +877,7 @@ run_dconv_drelu_dscale(int64_t* x_dim_padded, ...@@ -877,7 +877,7 @@ run_dconv_drelu_dscale(int64_t* x_dim_padded,
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status); checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error"); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) { } catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
} }
...@@ -983,7 +983,7 @@ run_dconv(int64_t* x_dim_padded, ...@@ -983,7 +983,7 @@ run_dconv(int64_t* x_dim_padded,
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status); checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error"); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) { } catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
} }
...@@ -1093,7 +1093,7 @@ run_dconv_add(int64_t* x_dim_padded, ...@@ -1093,7 +1093,7 @@ run_dconv_add(int64_t* x_dim_padded,
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status); checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error"); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) { } catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
} }
...@@ -1608,76 +1608,1387 @@ std::vector<at::Tensor> bottleneck_backward(bool explicit_nhwc, int stride_1X1, ...@@ -1608,76 +1608,1387 @@ std::vector<at::Tensor> bottleneck_backward(bool explicit_nhwc, int stride_1X1,
namespace { namespace {
struct bottleneck_forward_status { enum {
X_TENSOR,
Y_TENSOR,
W_TENSOR,
Z_TENSOR,
B_TENSOR,
AFTERADD_TENSOR,
AFTERBIAS_TENSOR,
AFTERCONV_TENSOR,
OPTIONAL,
AFTEROPT_TENSOR,
AFTERACT_TENSOR,
GEN_INDEX_TENSOR,
MASK_TOP_TENSOR,
MASK_BOTTOM_TENSOR,
MASK_TENSOR,
THRESHOLD_TOP_TENSOR,
THRESHOLD_BOTTOM_TENSOR,
};
int64_t dimA[4]; using masked_convbias_descriptors = std::tuple<cudnn_frontend::Tensor,
int64_t filterdimA1[4]; cudnn_frontend::Tensor,
int64_t filterdimA2[4]; cudnn_frontend::Tensor,
int64_t filterdimA3[4]; cudnn_frontend::Tensor,
int64_t filterdimA4[4]; cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor>;
int axis[4]; masked_convbias_descriptors
create_conv_bias_add_act_mask_descriptors(int64_t* x_dim_padded,
int64_t* padA,
int64_t* convstrideA,
int64_t* dilationA,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
int64_t* threshold_dim,
cudnnDataType_t dataType) {
const int convDim = 2;
int64_t outdimA0[4]; int64_t b_dim_padded[4];
int64_t outdimA1[4]; b_dim_padded[0] = 1;
int64_t outdimA2[4]; b_dim_padded[1] = y_dim_padded[1];
int64_t outdimA3[4]; b_dim_padded[2] = 1;
int64_t outdimA4[4]; b_dim_padded[3] = 1;
int64_t padA[2]; int64_t x_stride_padded[4];
int64_t padA1[2]; int64_t y_stride_padded[4];
int64_t padA2[2]; // halo padding int64_t w_stride_padded[4];
int64_t dilationA[2]; int64_t b_stride_padded[4];
int64_t convstrideA[2]; int64_t threshold_stride[4];
int64_t convstride1X1[2];
int64_t outdim0[4]; // halo input shape generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);
int64_t outdim1[4]; generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);
int64_t outdim2[4]; generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);
int64_t outdim3[4]; generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);
int64_t outdim4[4]; // halo output shape generateStrides(threshold_dim, threshold_stride, 4, CUDNN_TENSOR_NHWC);
void init(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) { return masked_convbias_descriptors(cudnn_frontend::TensorBuilder()
dimA[0] = dimA[1] = dimA[2] = dimA[3] = 0; .setDim(4, x_dim_padded)
filterdimA1[0] = filterdimA1[1] = filterdimA1[2] = filterdimA1[3] = 0; .setStrides(4, x_stride_padded)
filterdimA2[0] = filterdimA2[1] = filterdimA2[2] = filterdimA2[3] = 0; .setId('x')
filterdimA3[0] = filterdimA3[1] = filterdimA3[2] = filterdimA3[3] = 0; .setAlignment(16)
filterdimA4[0] = filterdimA4[1] = filterdimA4[2] = filterdimA4[3] = 0; .setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, w_dim_padded)
.setStrides(4, w_stride_padded)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('z')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('b')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setVirtual()
.setId('A') // after add
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setVirtual()
.setId('B') // after bias
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('C') // after conv
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('i')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('D') // after optional add
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('E') // after act for masked
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('I') // output of the gen index operation
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_INT32)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('m') // top half of the mask created after the less than
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('n') // bottom half of the mask
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('M') // OR of the top and bottom masks
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, threshold_dim)
.setStrides(4, threshold_stride)
.setId('t') // threshold for creating the top mask
.setAlignment(16)
.setDataType(CUDNN_DATA_INT32)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, threshold_dim)
.setStrides(4, threshold_stride)
.setId('u') // threshold for creating the bottom mask
.setAlignment(16)
.setDataType(CUDNN_DATA_INT32)
.build());
}
// All dim calculation after this order of n,c,h,w // tensor descriptors used for dgrad
if (explicit_nhwc) { enum {
axis[0] = 0; X_OR_DX_TENSOR,
axis[1] = 3; DY_TENSOR,
axis[2] = 1; W_OR_DW_TENSOR,
axis[3] = 2; SCALE_TENSOR,
} else { RELU_TENSOR,
axis[0] = 0; AFTER_DCONV_TENSOR,
axis[1] = 1; AFTER_DRELU_TENSOR,
axis[2] = 2; DGRAD_INPUT_TENSOR,
axis[3] = 3; DGRAD_OPTIONAL_TENSOR,
} DGRAD_GEN_INDEX_TENSOR,
DGRAD_MASK_TOP_TENSOR,
DGRAD_MASK_BOTTOM_TENSOR,
DGRAD_MASK_TENSOR,
DGRAD_THRESHOLD_TOP_TENSOR,
DGRAD_THRESHOLD_BOTTOM_TENSOR,
};
for (int dim=0;dim<4;dim++) { using dconv_add_descriptors = std::tuple<cudnn_frontend::Tensor,
dimA[dim] = inputs[0].size(axis[dim]); cudnn_frontend::Tensor,
filterdimA1[dim] = inputs[1].size(axis[dim]); cudnn_frontend::Tensor,
filterdimA2[dim] = inputs[2].size(axis[dim]); cudnn_frontend::Tensor,
filterdimA3[dim] = inputs[3].size(axis[dim]); cudnn_frontend::Tensor,
} cudnn_frontend::Tensor,
if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { cudnn_frontend::Tensor,
for (int dim=0;dim<4;dim++) { cudnn_frontend::Tensor,
filterdimA4[dim] = inputs[10].size(axis[dim]); cudnn_frontend::Tensor>;
}
}
// output dim in n,c,h,w used by backend dconv_add_descriptors
outdimA0[0] = outdimA0[1] = outdimA0[2] = outdimA0[3] = 0; create_dconv_add_descriptors(int64_t* x_dim_padded,
outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0; int64_t* padA,
outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0; int64_t* convstrideA,
outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0; int64_t* dilationA,
outdimA4[0] = outdimA4[1] = outdimA4[2] = outdimA4[3] = 0; int64_t* w_dim_padded,
int64_t* y_dim_padded,
cudnnDataType_t dataType) {
const int convDim = 2;
// use these fixed value for test run int64_t b_dim_padded[4];
padA[0] = 0; padA[1] = 0; b_dim_padded[0] = 1;
b_dim_padded[1] = x_dim_padded[1];
b_dim_padded[2] = 1;
b_dim_padded[3] = 1;
int64_t x_stride_padded[4];
int64_t y_stride_padded[4];
int64_t w_stride_padded[4];
int64_t b_stride_padded[4];
generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);
return dconv_add_descriptors(cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, w_dim_padded)
.setStrides(4, w_stride_padded)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('s')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('r')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('A') // after dconv
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('B') // after drelu
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('i')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('D') // after optional add
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build());
}
using dconv_mask_descriptors = std::tuple<cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor>;
dconv_mask_descriptors
create_dconv_mask_descriptors(int64_t* x_dim_padded,
int64_t* padA,
int64_t* convstrideA,
int64_t* dilationA,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
int64_t* threshold_dim,
cudnnDataType_t dataType) {
const int convDim = 2;
int64_t b_dim_padded[4];
b_dim_padded[0] = 1;
b_dim_padded[1] = x_dim_padded[1];
b_dim_padded[2] = 1;
b_dim_padded[3] = 1;
int64_t x_stride_padded[4];
int64_t y_stride_padded[4];
int64_t w_stride_padded[4];
int64_t b_stride_padded[4];
int64_t threshold_stride[4];
generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(threshold_dim, threshold_stride, 4, CUDNN_TENSOR_NHWC);
return dconv_mask_descriptors(cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, w_dim_padded)
.setStrides(4, w_stride_padded)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('s')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('r')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('A') // after dconv
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('B') // after drelu
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('i')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('D') // after optional add
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('I') // output of the gen index operation
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_INT32)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('m') // top half of the mask created after the less than
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('n') // bottom half of the mask
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('M') // OR of the top and bottom masks
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, threshold_dim)
.setStrides(4, threshold_stride)
.setId('t') // threshold for creating the top mask
.setAlignment(16)
.setDataType(CUDNN_DATA_INT32)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, threshold_dim)
.setStrides(4, threshold_stride)
.setId('u') // threshold for creating the bottom mask
.setAlignment(16)
.setDataType(CUDNN_DATA_INT32)
.build());
}
void
run_conv_add_scale_bias_activation(int64_t* x_dim_padded,
int64_t* pad,
int64_t* convstride,
int64_t* dilation,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrY,
at::Half* devPtrZ,
at::Half* devPtrB,
at::Half* devPtrI) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors(
x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Y_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Z_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<B_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERADD_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERBIAS_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<OPTIONAL>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTEROPT_TENSOR>(tensors).describe());
// Define the add operation
auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_MUL)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
// Define the bias operation
auto biasDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
// optional add
auto addDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, addDesc.describe());
// Define the activation operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_FWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.setxDesc(std::get<X_TENSOR>(tensors))
.setwDesc(std::get<W_TENSOR>(tensors))
.setyDesc(std::get<AFTERCONV_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// create an add node.
auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(conv_op.getOutputTensor())
.setbDesc(std::get<OPTIONAL>(tensors))
.setyDesc(std::get<AFTEROPT_TENSOR>(tensors))
.setpwDesc(addDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, add_op.describe());
// Create a Add Node with scaling parameters.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(add_op.getOutputTensor())
.setbDesc(std::get<Z_TENSOR>(tensors))
.setyDesc(std::get<AFTERADD_TENSOR>(tensors))
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create a Bias Node.
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(scale_op.getOutputTensor())
.setbDesc(std::get<B_TENSOR>(tensors))
.setyDesc(std::get<AFTERBIAS_TENSOR>(tensors))
.setpwDesc(biasDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
// Create an Activation Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(bias_op.getOutputTensor())
.setyDesc(std::get<Y_TENSOR>(tensors))
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 5> ops = {&conv_op, &add_op, &scale_op, &bias_op, &act_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI};
int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(6, data_ptrs)
.setUids(6, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_conv_scale_bias_add_activation_mask(int64_t* x_dim_padded,
int64_t* pad,
int64_t* convstride,
int64_t* dilation,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
int64_t* threshold_dim,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrY,
at::Half* devPtrZ,
at::Half* devPtrB,
at::Half* devPtrI,
int* devPtrT,
int* devPtrU,
int axis) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
masked_convbias_descriptors tensors = create_conv_bias_add_act_mask_descriptors(
x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, threshold_dim, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Y_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Z_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<B_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERADD_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERBIAS_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<OPTIONAL>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERACT_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<GEN_INDEX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<MASK_TOP_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<MASK_BOTTOM_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<MASK_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<THRESHOLD_TOP_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<THRESHOLD_BOTTOM_TENSOR>(tensors).describe());
// Define the add operation
auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_MUL)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
// Define the bias operation
auto biasDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
// optional add
auto addDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, addDesc.describe());
// Define the activation operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_FWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// Define the genIndex descriptor
auto genIndexDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setMathPrecision(CUDNN_DATA_FLOAT)
.setAxis(axis)
.build();
DEBUG_CUDNN_MSG(log_buf, genIndexDesc.describe());
// Define the lessThan descriptor
auto lessThanDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_CMP_LT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, lessThanDesc.describe());
// Define the greaterThan descriptor
auto greaterThanDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_CMP_GT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, greaterThanDesc.describe());
// Define the logical_or descriptor
auto logicalOrDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_LOGICAL_OR)
.setMathPrecision(CUDNN_DATA_BOOLEAN)
.build();
DEBUG_CUDNN_MSG(log_buf, logicalOrDesc.describe());
// Define the binary_selection descriptor
auto selectionDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_BINARY_SELECT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, selectionDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.setxDesc(std::get<X_TENSOR>(tensors))
.setwDesc(std::get<W_TENSOR>(tensors))
.setyDesc(std::get<AFTERCONV_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create a Add Node with scaling parameters.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(conv_op.getOutputTensor())
.setbDesc(std::get<Z_TENSOR>(tensors))
.setyDesc(std::get<AFTERADD_TENSOR>(tensors))
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create a Bias Node.
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(scale_op.getOutputTensor())
.setbDesc(std::get<B_TENSOR>(tensors))
.setyDesc(std::get<AFTERBIAS_TENSOR>(tensors))
.setpwDesc(biasDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
// Create a optional add Node.
auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(bias_op.getOutputTensor())
.setbDesc(std::get<OPTIONAL>(tensors))
.setyDesc(std::get<AFTEROPT_TENSOR>(tensors))
.setpwDesc(addDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, add_op.describe());
// Create an Activation Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(devPtrI ? add_op.getOutputTensor() : bias_op.getOutputTensor())
.setyDesc(std::get<AFTERACT_TENSOR>(tensors))
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create a Gen_Index Node.
auto genIndex_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTERACT_TENSOR>(tensors))
.setyDesc(std::get<GEN_INDEX_TENSOR>(tensors))
.setpwDesc(genIndexDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, genIndex_op.describe());
// Create a LessThan Node.
auto lessThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<GEN_INDEX_TENSOR>(tensors))
.setbDesc(std::get<THRESHOLD_TOP_TENSOR>(tensors))
.setyDesc(std::get<MASK_TOP_TENSOR>(tensors))
.setpwDesc(lessThanDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, lessThan_op.describe());
// Create a GreaterThan Node.
auto greaterThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<GEN_INDEX_TENSOR>(tensors))
.setbDesc(std::get<THRESHOLD_BOTTOM_TENSOR>(tensors))
.setyDesc(std::get<MASK_BOTTOM_TENSOR>(tensors))
.setpwDesc(greaterThanDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, greaterThan_op.describe());
// Create a LogicalOr Node.
auto logicalOr_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<MASK_TOP_TENSOR>(tensors))
.setbDesc(std::get<MASK_BOTTOM_TENSOR>(tensors))
.setyDesc(std::get<MASK_TENSOR>(tensors))
.setpwDesc(logicalOrDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, logicalOr_op.describe());
// Create a Binary_Selection Node.
auto selection_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTERCONV_TENSOR>(tensors))
.setbDesc(std::get<AFTERACT_TENSOR>(tensors))
.settDesc(std::get<MASK_TENSOR>(tensors))
.setyDesc(std::get<Y_TENSOR>(tensors))
.setpwDesc(selectionDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, selection_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
if (devPtrI) {
std::array<cudnn_frontend::Operation const*, 10> ops = {&conv_op, &scale_op, &bias_op, &add_op, &act_op, &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI, devPtrT, devPtrU};
int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i', 't', 'u'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(8, data_ptrs)
.setUids(8, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} else {
std::array<cudnn_frontend::Operation const*, 9> ops = {&conv_op, &scale_op, &bias_op, &act_op, &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrT, devPtrU};
int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 't', 'u'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(7, data_ptrs)
.setUids(7, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
}
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_dconv_add_drelu_dscale(int64_t* x_dim_padded,
int64_t* pad,
int64_t* convstride,
int64_t* dilation,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrY,
at::Half* devPtrZ,
at::Half* devPtrR,
at::Half* devPtrI) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
dconv_add_descriptors tensors = create_dconv_add_descriptors(
x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_INPUT_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_OPTIONAL_TENSOR>(tensors).describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// optional add
auto addDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, addDesc.describe());
// Define the activation backward operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_BWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the scale backward operation
auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_MUL)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)
.setdxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setwDesc(std::get<W_OR_DW_TENSOR>(tensors))
.setdyDesc(std::get<DY_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create add Node.
auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_INPUT_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.setpwDesc(addDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, add_op.describe());
// TODO: do we need getOutputTensor(), and what it returns in backward case?
// Create an relu backward Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setdyDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.setxDesc(std::get<RELU_TENSOR>(tensors))
.setdxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create a Scale Node.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setbDesc(std::get<SCALE_TENSOR>(tensors))
.setyDesc(std::get<X_OR_DX_TENSOR>(tensors))
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 4> ops = {&conv_op, &add_op, &act_op, &scale_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR, devPtrI};
int64_t uids[] = {'x', 'y', 'w', 's', 'r', 'i'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(6, data_ptrs)
.setUids(6, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_dconv_drelu_dscale_mask(int64_t* x_dim_padded,
int64_t* pad,
int64_t* convstride,
int64_t* dilation,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
int64_t* threshold_dim,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrY,
at::Half* devPtrZ,
at::Half* devPtrR,
int* devPtrT,
int* devPtrU,
int axis) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
dconv_mask_descriptors tensors = create_dconv_mask_descriptors(
x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, threshold_dim, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_OPTIONAL_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_GEN_INDEX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_MASK_TOP_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_MASK_BOTTOM_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_MASK_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_THRESHOLD_TOP_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_THRESHOLD_BOTTOM_TENSOR>(tensors).describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// Define the activation backward operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_BWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the scale backward operation
auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_MUL)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
// Define the genIndex descriptor
auto genIndexDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setMathPrecision(CUDNN_DATA_FLOAT)
.setAxis(axis)
.build();
DEBUG_CUDNN_MSG(log_buf, genIndexDesc.describe());
// Define the lessThan descriptor
auto lessThanDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_CMP_LT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, lessThanDesc.describe());
// Define the greaterThan descriptor
auto greaterThanDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_CMP_GT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, greaterThanDesc.describe());
// Define the logical_or descriptor
auto logicalOrDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_LOGICAL_OR)
.setMathPrecision(CUDNN_DATA_BOOLEAN)
.build();
DEBUG_CUDNN_MSG(log_buf, logicalOrDesc.describe());
// Define the binary_selection descriptor
auto selectionDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_BINARY_SELECT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, selectionDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)
.setdxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setwDesc(std::get<W_OR_DW_TENSOR>(tensors))
.setdyDesc(std::get<DY_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// TODO: do we need getOutputTensor(), and what it returns in backward case?
// Create an relu backward Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setdyDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setxDesc(std::get<RELU_TENSOR>(tensors))
.setdxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create a Scale Node.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setbDesc(std::get<SCALE_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create a Gen_Index Node.
auto genIndex_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_GEN_INDEX_TENSOR>(tensors))
.setpwDesc(genIndexDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, genIndex_op.describe());
// Create a LessThan Node.
auto lessThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<DGRAD_GEN_INDEX_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_THRESHOLD_TOP_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_MASK_TOP_TENSOR>(tensors))
.setpwDesc(lessThanDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, lessThan_op.describe());
// Create a GreaterThan Node.
auto greaterThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<DGRAD_GEN_INDEX_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_THRESHOLD_BOTTOM_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_MASK_BOTTOM_TENSOR>(tensors))
.setpwDesc(greaterThanDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, greaterThan_op.describe());
// Create a LogicalOr Node.
auto logicalOr_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<DGRAD_MASK_TOP_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_MASK_BOTTOM_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_MASK_TENSOR>(tensors))
.setpwDesc(logicalOrDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, logicalOr_op.describe());
// Create a Binary_Selection Node.
auto selection_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.settDesc(std::get<DGRAD_MASK_TENSOR>(tensors))
.setyDesc(std::get<X_OR_DX_TENSOR>(tensors))
.setpwDesc(selectionDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, selection_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 8> ops = {&conv_op, &act_op, &scale_op, &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR, devPtrT, devPtrU};
int64_t uids[] = {'x', 'y', 'w', 's', 'r', 't', 'u'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(7, data_ptrs)
.setUids(7, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
struct bottleneck_forward_status {
int64_t dimA[4];
int64_t filterdimA1[4];
int64_t filterdimA2[4];
int64_t filterdimA2hh[4];
int64_t filterdimA3[4];
int64_t filterdimA4[4];
int64_t threshdim[4];
int axis[4];
int64_t outdimA0[4];
int64_t outdimA1[4];
int64_t outdimA1b[4]; // out1_pad
int64_t outdimA2[4];
int64_t outdimA3[4];
int64_t outdimA4[4];
int64_t padA[2];
int64_t padA1[2];
int64_t padA2[2]; // halo padding
int64_t dilationA[2];
int64_t convstrideA[2];
int64_t convstride1X1[2];
int64_t outdim0[4]; // halo input shape
int64_t outdim1[4];
int64_t outdim1b[4];
int64_t outdim2[4];
int64_t outdim3[4];
int64_t outdim4[4]; // halo output shape
void init(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {
dimA[0] = dimA[1] = dimA[2] = dimA[3] = 0;
filterdimA1[0] = filterdimA1[1] = filterdimA1[2] = filterdimA1[3] = 0;
filterdimA2[0] = filterdimA2[1] = filterdimA2[2] = filterdimA2[3] = 0;
filterdimA2hh[0] = filterdimA2hh[1] = filterdimA2hh[2] = filterdimA2hh[3] = 0;
filterdimA3[0] = filterdimA3[1] = filterdimA3[2] = filterdimA3[3] = 0;
filterdimA4[0] = filterdimA4[1] = filterdimA4[2] = filterdimA4[3] = 0;
threshdim[0] = threshdim[1] = threshdim[2] = threshdim[3] = 1;
// All dim calculation after this order of n,c,h,w
if (explicit_nhwc) {
axis[0] = 0;
axis[1] = 3;
axis[2] = 1;
axis[3] = 2;
} else {
axis[0] = 0;
axis[1] = 1;
axis[2] = 2;
axis[3] = 3;
}
for (int dim=0;dim<4;dim++) {
dimA[dim] = inputs[0].size(axis[dim]);
filterdimA1[dim] = inputs[1].size(axis[dim]);
filterdimA2[dim] = inputs[2].size(axis[dim]);
filterdimA3[dim] = inputs[3].size(axis[dim]);
}
if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {
for (int dim=0;dim<4;dim++) {
filterdimA4[dim] = inputs[10].size(axis[dim]);
}
}
for (int dim=0;dim<4;dim++) {
if (dim == 2) {
filterdimA2hh[dim] = 1;
} else {
filterdimA2hh[dim] = filterdimA2[dim];
}
}
// output dim in n,c,h,w used by backend
outdimA0[0] = outdimA0[1] = outdimA0[2] = outdimA0[3] = 0;
outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0;
outdimA1b[0] = outdimA1b[1] = outdimA1b[2] = outdimA1b[3] = 0;
outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0;
outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0;
outdimA4[0] = outdimA4[1] = outdimA4[2] = outdimA4[3] = 0;
// use these fixed value for test run
padA[0] = 0; padA[1] = 0;
padA1[0] = 1; padA1[1] = 1; padA1[0] = 1; padA1[1] = 1;
padA2[0] = 0; padA2[1] = 1; padA2[0] = 0; padA2[1] = 1;
dilationA[0] = 1; dilationA[1] = 1; dilationA[0] = 1; dilationA[1] = 1;
...@@ -1690,6 +3001,13 @@ struct bottleneck_forward_status { ...@@ -1690,6 +3001,13 @@ struct bottleneck_forward_status {
for (int dim = 0; dim < 2; dim++) { for (int dim = 0; dim < 2; dim++) {
outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);
} }
for (int dim = 0; dim < 4; dim++) {
if (dim == 2) {
outdimA1b[dim] = outdimA1[dim] + 2;
} else {
outdimA1b[dim] = outdimA1[dim];
}
}
outdimA2[0] = outdimA1[0]; outdimA2[0] = outdimA1[0];
outdimA2[1] = filterdimA2[0]; outdimA2[1] = filterdimA2[0];
...@@ -1715,6 +3033,7 @@ struct bottleneck_forward_status { ...@@ -1715,6 +3033,7 @@ struct bottleneck_forward_status {
// Create output tensor in the correct shape in pytorch's view // Create output tensor in the correct shape in pytorch's view
outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0; outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0;
outdim1b[0] = outdim1b[1] = outdim1b[2] = outdim1b[3] = 0;
outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0; outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0;
outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0; outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0;
if (explicit_nhwc) { if (explicit_nhwc) {
...@@ -1726,6 +3045,7 @@ struct bottleneck_forward_status { ...@@ -1726,6 +3045,7 @@ struct bottleneck_forward_status {
for (int dim=0;dim<4;dim++) { for (int dim=0;dim<4;dim++) {
outdim0[dim] = outdimA0[axis[dim]]; outdim0[dim] = outdimA0[axis[dim]];
outdim1[dim] = outdimA1[axis[dim]]; outdim1[dim] = outdimA1[axis[dim]];
outdim1b[dim] = outdimA1b[axis[dim]];
outdim2[dim] = outdimA2[axis[dim]]; outdim2[dim] = outdimA2[axis[dim]];
outdim3[dim] = outdimA3[axis[dim]]; outdim3[dim] = outdimA3[axis[dim]];
outdim4[dim] = outdimA4[axis[dim]]; outdim4[dim] = outdimA4[axis[dim]];
...@@ -1821,6 +3141,41 @@ at::Tensor bottleneck_forward_out2_halo(bool explicit_nhwc, at::Tensor fat_halo_ ...@@ -1821,6 +3141,41 @@ at::Tensor bottleneck_forward_out2_halo(bool explicit_nhwc, at::Tensor fat_halo_
return halo_y2; return halo_y2;
} }
// compute halo correction term (top or bottom) from slim halo input (N,C,1,W).
// slim halo input is 1 pixel wide in H.
at::Tensor bottleneck_forward_out2_halo_corr(bool explicit_nhwc, at::Tensor slim_halo_y1, std::vector<at::Tensor> inputs, at::Tensor w1by3, at::Tensor out2_part_halo) {
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// run
at::Half* w = w1by3.data_ptr<at::Half>(); // C,C,1,3
at::Half* z = inputs[5].data_ptr<at::Half>();
at::Half* b = inputs[8].data_ptr<at::Half>();
at::Half* y1 = slim_halo_y1.data_ptr<at::Half>();
at::Half* prev_out2 = out2_part_halo.data_ptr<at::Half>();
auto halo_y2 = at::empty(forward_state.outdim4, inputs[0].type(), output_format);
at::Half* y2 = halo_y2.data_ptr<at::Half>();
run_conv_add_scale_bias_activation(forward_state.outdimA4,
forward_state.padA2,
forward_state.convstrideA,
forward_state.dilationA,
forward_state.filterdimA2hh,
forward_state.outdimA4,
CUDNN_DATA_HALF,
y1,
w,
y2,
z,
b,
prev_out2);
return halo_y2;
}
void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) { void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
std::cout << std::fixed; std::cout << std::fixed;
...@@ -1859,6 +3214,86 @@ void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector<at: ...@@ -1859,6 +3214,86 @@ void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector<at:
DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item<float>()); DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item<float>());
} }
void bottleneck_forward_out2_mask(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor thresholdTop, at::Tensor thresholdBottom) {
std::cout << std::fixed;
// from _out1 method
at::Half* x = inputs[0].data_ptr<at::Half>();
auto out1 = outputs[0];
at::Half* y1 = out1.data_ptr<at::Half>();
// run
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[5].data_ptr<at::Half>();
at::Half* b = inputs[8].data_ptr<at::Half>();
auto out2 = outputs[1];
at::Half* y2 = out2.data_ptr<at::Half>();
//printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]);
//printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]);
//printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]);
//printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]);
//printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]);
//printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]);
run_conv_scale_bias_add_activation_mask(forward_state.outdimA1,
forward_state.padA1,
forward_state.convstrideA,
forward_state.dilationA,
forward_state.filterdimA2,
forward_state.outdimA2,
forward_state.threshdim,
CUDNN_DATA_HALF,
y1,
w,
y2,
z,
b,
nullptr,
thresholdTop.data_ptr<int>(),
thresholdBottom.data_ptr<int>(),
2); // axis == 1 -> Does this assume explicit NHWC?
DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item<float>());
}
void bottleneck_forward_out2_pad(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor out1_pad) {
std::cout << std::fixed;
// from _out1 method
at::Half* x = inputs[0].data_ptr<at::Half>();
auto out1 = outputs[0];
at::Half* y1 = out1_pad.data_ptr<at::Half>();
// run
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[5].data_ptr<at::Half>();
at::Half* b = inputs[8].data_ptr<at::Half>();
auto out2 = outputs[1];
at::Half* y2 = out2.data_ptr<at::Half>();
//printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]);
//printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]);
//printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]);
//printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]);
//printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]);
//printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]);
run_conv_scale_bias_add_activation(forward_state.outdimA1b,
forward_state.padA2,
forward_state.convstrideA,
forward_state.dilationA,
forward_state.filterdimA2,
forward_state.outdimA2,
CUDNN_DATA_HALF,
y1,
w,
y2,
z,
b,
nullptr);
DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item<float>());
}
void bottleneck_forward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) { void bottleneck_forward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
std::cout << std::fixed; std::cout << std::fixed;
...@@ -1932,10 +3367,12 @@ struct bottleneck_backward_state { ...@@ -1932,10 +3367,12 @@ struct bottleneck_backward_state {
int64_t filterdimA3[4]; int64_t filterdimA3[4];
int64_t filterdimA4[4]; int64_t filterdimA4[4];
int64_t filterdimA2hh[4]; // Cin,Cout,1,3 int64_t filterdimA2hh[4]; // Cin,Cout,1,3
int64_t threshdim[4];
int axis[4]; int axis[4];
int64_t outdimA1[4]; // grad_out1 int64_t outdimA1[4]; // grad_out1
int64_t outdimA1b[4]; // out1_pad
int64_t outdimA2[4]; // grad_out2 int64_t outdimA2[4]; // grad_out2
int64_t outdimA3[4]; int64_t outdimA3[4];
int64_t outdimA1h[4]; // output: grad_out1 halo (H=3) int64_t outdimA1h[4]; // output: grad_out1 halo (H=3)
...@@ -1953,9 +3390,11 @@ struct bottleneck_backward_state { ...@@ -1953,9 +3390,11 @@ struct bottleneck_backward_state {
int64_t filterdim2hh[4]; // Cin,1,3,Cout int64_t filterdim2hh[4]; // Cin,1,3,Cout
int64_t outdim1[4]; int64_t outdim1[4];
int64_t outdim1b[4];
int64_t outdim2[4]; int64_t outdim2[4];
int64_t outdim3[4]; int64_t outdim3[4];
int64_t outdim1h[4]; int64_t outdim1h[4];
int64_t outdim1hh[4];
void init(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) { void init(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {
// setup dimensions // setup dimensions
...@@ -1965,6 +3404,7 @@ struct bottleneck_backward_state { ...@@ -1965,6 +3404,7 @@ struct bottleneck_backward_state {
filterdimA3[0] = filterdimA3[1] = filterdimA3[2] = filterdimA3[3] = 0; filterdimA3[0] = filterdimA3[1] = filterdimA3[2] = filterdimA3[3] = 0;
filterdimA4[0] = filterdimA4[1] = filterdimA4[2] = filterdimA4[3] = 0; filterdimA4[0] = filterdimA4[1] = filterdimA4[2] = filterdimA4[3] = 0;
filterdimA2hh[0] = filterdimA2hh[1] = filterdimA2hh[2] = filterdimA2hh[3] = 0; filterdimA2hh[0] = filterdimA2hh[1] = filterdimA2hh[2] = filterdimA2hh[3] = 0;
threshdim[0] = threshdim[1] = threshdim[2] = threshdim[3] = 1;
// All dim calculation after this order of n,c,h,w // All dim calculation after this order of n,c,h,w
if (explicit_nhwc) { if (explicit_nhwc) {
...@@ -2001,6 +3441,7 @@ struct bottleneck_backward_state { ...@@ -2001,6 +3441,7 @@ struct bottleneck_backward_state {
// output dim in n,c,h,w used by backend // output dim in n,c,h,w used by backend
outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0; outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0;
outdimA1b[0] = outdimA1b[1] = outdimA1b[2] = outdimA1b[3] = 0;
outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0; outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0;
outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0; outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0;
outdimA1h[0] = outdimA1h[1] = outdimA1h[2] = outdimA1h[3] = 0; outdimA1h[0] = outdimA1h[1] = outdimA1h[2] = outdimA1h[3] = 0;
...@@ -2022,6 +3463,13 @@ struct bottleneck_backward_state { ...@@ -2022,6 +3463,13 @@ struct bottleneck_backward_state {
for (int dim = 0; dim < 2; dim++) { for (int dim = 0; dim < 2; dim++) {
outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);
} }
for (int dim = 0; dim < 4; dim++) {
if (dim == 2) {
outdimA1b[dim] = outdimA1[dim] + 2;
} else {
outdimA1b[dim] = outdimA1[dim];
}
}
outdimA2[0] = outdimA1[0]; outdimA2[0] = outdimA1[0];
outdimA2[1] = filterdimA2[0]; outdimA2[1] = filterdimA2[0];
...@@ -2051,9 +3499,11 @@ struct bottleneck_backward_state { ...@@ -2051,9 +3499,11 @@ struct bottleneck_backward_state {
// Create output tensor in the correct shape in pytorch's view // Create output tensor in the correct shape in pytorch's view
outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0; outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0;
outdim1b[0] = outdim1b[1] = outdim1b[2] = outdim1b[3] = 0;
outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0; outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0;
outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0; outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0;
outdim1h[0] = outdim1h[1] = outdim1h[2] = outdim1h[3] = 0; outdim1h[0] = outdim1h[1] = outdim1h[2] = outdim1h[3] = 0;
outdim1hh[0] = outdim1hh[1] = outdim1hh[2] = outdim1hh[3] = 0;
filterdim2hh[0] = filterdim2hh[1] = filterdim2hh[2] = filterdim2hh[3] = 0; filterdim2hh[0] = filterdim2hh[1] = filterdim2hh[2] = filterdim2hh[3] = 0;
if (explicit_nhwc) { if (explicit_nhwc) {
axis[0] = 0; axis[0] = 0;
...@@ -2063,9 +3513,11 @@ struct bottleneck_backward_state { ...@@ -2063,9 +3513,11 @@ struct bottleneck_backward_state {
} }
for (int dim=0;dim<4;dim++) { for (int dim=0;dim<4;dim++) {
outdim1[dim] = outdimA1[axis[dim]]; outdim1[dim] = outdimA1[axis[dim]];
outdim1b[dim] = outdimA1b[axis[dim]];
outdim2[dim] = outdimA2[axis[dim]]; outdim2[dim] = outdimA2[axis[dim]];
outdim3[dim] = outdimA3[axis[dim]]; outdim3[dim] = outdimA3[axis[dim]];
outdim1h[dim] = outdimA1h[axis[dim]]; outdim1h[dim] = outdimA1h[axis[dim]];
outdim1hh[dim] = outdimA1hh[axis[dim]];
filterdim2hh[dim] = filterdimA2hh[axis[dim]]; filterdim2hh[dim] = filterdimA2hh[axis[dim]];
} }
} }
...@@ -2102,19 +3554,12 @@ std::vector<at::Tensor> bottleneck_backward_init(bool explicit_nhwc, int stride_ ...@@ -2102,19 +3554,12 @@ std::vector<at::Tensor> bottleneck_backward_init(bool explicit_nhwc, int stride_
return outputs; return outputs;
} }
at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) { void bottleneck_backward_wgrad3(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dconv3+drelu2+dscale2 // dconv3+drelu2+dscale2
at::Half* conv_in = inputs[13].data_ptr<at::Half>(); at::Half* conv_in = inputs[13].data_ptr<at::Half>();
at::Half* dy3 = inputs[10].data_ptr<at::Half>(); at::Half* dy3 = inputs[10].data_ptr<at::Half>();
DEBUG_MSG("[DEBUG] new dconv3 : " << inputs[10].to(at::kFloat).sum().item<float>());
// wgrad // wgrad
auto wgrad3 = outputs[3]; auto wgrad3 = outputs[3];
at::Half* dw3 = wgrad3.data_ptr<at::Half>(); at::Half* dw3 = wgrad3.data_ptr<at::Half>();
...@@ -2129,6 +3574,22 @@ at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std ...@@ -2129,6 +3574,22 @@ at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std
dw3, dw3,
dy3, dy3,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
DEBUG_MSG("[DEBUG] new wgrad3 : " << wgrad3.to(at::kFloat).sum().item<float>());
}
at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dconv3+drelu2+dscale2
at::Half* conv_in = inputs[13].data_ptr<at::Half>();
at::Half* dy3 = inputs[10].data_ptr<at::Half>();
DEBUG_MSG("[DEBUG] new dconv3 : " << inputs[10].to(at::kFloat).sum().item<float>());
// dgrad // dgrad
auto grad_out2 = at::empty(backward_state.outdim2, inputs[0].type(), output_format); auto grad_out2 = at::empty(backward_state.outdim2, inputs[0].type(), output_format);
...@@ -2178,6 +3639,7 @@ at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std ...@@ -2178,6 +3639,7 @@ at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std
//printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3)); //printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3));
// fused dgrad // fused dgrad
//printf("backward_state.outdim1 = {%d,%d,%d,%d}\n",backward_state.outdim1[0],backward_state.outdim1[1],backward_state.outdim1[2],backward_state.outdim1[3]);
run_dconv_drelu_dscale(backward_state.outdimA1, run_dconv_drelu_dscale(backward_state.outdimA1,
backward_state.padA1, backward_state.padA1,
backward_state.convstrideA, backward_state.convstrideA,
...@@ -2194,6 +3656,88 @@ at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std ...@@ -2194,6 +3656,88 @@ at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std
return grad_out1; return grad_out1;
} }
at::Tensor bottleneck_backward_grad_out1_mask(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2, at::Tensor thresholdTop, at::Tensor thresholdBottom) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2 = grad_out2.data_ptr<at::Half>();
// dgrad
auto grad_out1 = at::empty(backward_state.outdim1, inputs[0].type(), output_format);
at::Half* dy1 = grad_out1.data_ptr<at::Half>();
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[4].data_ptr<at::Half>();
at::Half* relu1 = inputs[12].data_ptr<at::Half>();
//printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3));
// fused dgrad
run_dconv_drelu_dscale_mask(backward_state.outdimA1,
backward_state.padA1,
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2,
backward_state.outdimA2,
backward_state.threshdim,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
z,
relu1,
thresholdTop.data_ptr<int>(),
thresholdBottom.data_ptr<int>(),
2);
return grad_out1;
}
// perform backward data 1x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,1,W,C] with padding=(0,1) to produce output of shape [N,1,W,C]
at::Tensor bottleneck_backward_grad_out1_halo_corr(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, at::Tensor w1by3, std::vector<at::Tensor> outputs, at::Tensor grad_out2_halo, at::Tensor relu1_halo, at::Tensor part_grad_out1) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2h = grad_out2_halo.data_ptr<at::Half>();
// dgrad
auto grad_out1_halo = at::empty(backward_state.outdim1hh, inputs[0].type(), output_format);
at::Half* dy1h = grad_out1_halo.data_ptr<at::Half>();
//at::Half* w = inputs[2].data_ptr<at::Half>(); // use w1by3 instead, which is a sliced version of inputs[2]
at::Half* w = w1by3.data_ptr<at::Half>();
at::Half* z = inputs[4].data_ptr<at::Half>();
at::Half* relu1h = relu1_halo.data_ptr<at::Half>();
at::Half* pdy1h = part_grad_out1.data_ptr<at::Half>();
//printf("relu.shape = [%d,%d,%d,%d]\n",relu1_halo.size(0),relu1_halo.size(1),relu1_halo.size(2),relu1_halo.size(3));
// fused dgrad
//printf("backward_state.outdimA1h = {%d,%d,%d,%d}\n",backward_state.outdimA1h[0],backward_state.outdimA1h[1],backward_state.outdimA1h[2],backward_state.outdimA1h[3]);
//printf("backward_state.outdimA2h = {%d,%d,%d,%d}\n",backward_state.outdimA2h[0],backward_state.outdimA2h[1],backward_state.outdimA2h[2],backward_state.outdimA2h[3]);
//printf("backward_state.filterdimA2 = {%d,%d,%d,%d}\n",backward_state.filterdimA2[0],backward_state.filterdimA2[1],backward_state.filterdimA2[2],backward_state.filterdimA2[3]);
run_dconv_add_drelu_dscale(backward_state.outdimA1hh,
backward_state.padA2, // 0,1
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2hh, // C,1,3,C
backward_state.outdimA2hh,
CUDNN_DATA_HALF,
dy1h,
w,
dy2h,
z,
relu1h,
pdy1h);
return grad_out1_halo;
}
// perform backward data 3x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,3,W,C] with padding=(1,1) to produce output of shape [N,3,W,C] // perform backward data 3x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,3,W,C] with padding=(1,1) to produce output of shape [N,3,W,C]
at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2_halo, at::Tensor relu1_halo) { at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2_halo, at::Tensor relu1_halo) {
...@@ -2233,7 +3777,38 @@ at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1 ...@@ -2233,7 +3777,38 @@ at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1
return grad_out1_halo; return grad_out1_halo;
} }
at::Tensor bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2) { void bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor input, at::Tensor grad_out2) {
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2 = grad_out2.data_ptr<at::Half>();
// dconv2+drelu1+dscale1
at::Half* conv_in = input.data_ptr<at::Half>();
// wgrad
auto wgrad2 = outputs[2];
at::Half* dw2 = wgrad2.data_ptr<at::Half>();
//printf("outdimA1b = (%d,%d,%d,%d)\n",backward_state.outdimA1b[0],backward_state.outdimA1b[1],backward_state.outdimA1b[2],backward_state.outdimA1b[3]);
//printf("backward_state.padA2 = {%d,%d}\n",backward_state.padA2[0],backward_state.padA2[1]);
run_dconv(backward_state.outdimA1b, // conv_in.shape (including H halos)
backward_state.padA2, // 0, 1
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2, // dw2.shape
backward_state.outdimA2, // dy2.shape
CUDNN_DATA_HALF,
conv_in,
dw2,
dy2,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item<float>());
}
void bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2) {
bool requires_grad = inputs[0].requires_grad(); bool requires_grad = inputs[0].requires_grad();
...@@ -2262,8 +3837,7 @@ at::Tensor bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::v ...@@ -2262,8 +3837,7 @@ at::Tensor bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::v
dw2, dw2,
dy2, dy2,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item<float>());
return wgrad2;
} }
// compute halo cells for input volume of dimension [N,1,W,C] with padding=(0,1) to produce output volume of dimension [N,1,W,C] // compute halo cells for input volume of dimension [N,1,W,C] with padding=(0,1) to produce output volume of dimension [N,1,W,C]
...@@ -2306,7 +3880,30 @@ at::Tensor bottleneck_backward_wgrad2_halo(bool explicit_nhwc, int stride_1X1, s ...@@ -2306,7 +3880,30 @@ at::Tensor bottleneck_backward_wgrad2_halo(bool explicit_nhwc, int stride_1X1, s
return wgrad2_halo; return wgrad2_halo;
} }
void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2, at::Tensor grad_out1, at::Tensor wgrad2) { void bottleneck_backward_wgrad1(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out1) {
at::Half* x = inputs[0].data_ptr<at::Half>();
at::Half* dy1 = grad_out1.data_ptr<at::Half>();
// dconv1+add
// wgrad
auto wgrad1 = outputs[1];
at::Half* dw1 = wgrad1.data_ptr<at::Half>();
run_dconv(backward_state.dimA,
backward_state.padA,
backward_state.convstride1X1,
backward_state.dilationA,
backward_state.filterdimA1,
backward_state.outdimA1,
CUDNN_DATA_HALF,
x,
dw1,
dy1,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
}
void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2, at::Tensor grad_out1) {
bool requires_grad = inputs[0].requires_grad(); bool requires_grad = inputs[0].requires_grad();
...@@ -2404,22 +4001,6 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at ...@@ -2404,22 +4001,6 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at
dx_conv4 = inputs[11].data_ptr<at::Half>(); dx_conv4 = inputs[11].data_ptr<at::Half>();
} }
// dconv1+add
// wgrad
auto wgrad1 = outputs[1];
at::Half* dw1 = wgrad1.data_ptr<at::Half>();
run_dconv(backward_state.dimA,
backward_state.padA,
backward_state.convstride1X1,
backward_state.dilationA,
backward_state.filterdimA1,
backward_state.outdimA1,
CUDNN_DATA_HALF,
x,
dw1,
dy1,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
// dgrad // dgrad
w = inputs[1].data_ptr<at::Half>(); w = inputs[1].data_ptr<at::Half>();
auto grad_x = outputs[0]; auto grad_x = outputs[0];
...@@ -2460,8 +4041,6 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at ...@@ -2460,8 +4041,6 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at
DEBUG_MSG("[DEBUG] new dx : " << grad_x.to(at::kFloat).sum().item<float>()); DEBUG_MSG("[DEBUG] new dx : " << grad_x.to(at::kFloat).sum().item<float>());
DEBUG_MSG("[DEBUG] new wgrad1 : " << wgrad1.to(at::kFloat).sum().item<float>()); DEBUG_MSG("[DEBUG] new wgrad1 : " << wgrad1.to(at::kFloat).sum().item<float>());
DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item<float>());
DEBUG_MSG("[DEBUG] new wgrad3 : " << wgrad3.to(at::kFloat).sum().item<float>());
if (stride_1X1 != 1 || backward_state.filterdimA3[0] != backward_state.dimA[1]) { if (stride_1X1 != 1 || backward_state.filterdimA3[0] != backward_state.dimA[1]) {
DEBUG_MSG("[DEBUG] new wgrad4 : " << wgrad4.to(at::kFloat).sum().item<float>()); DEBUG_MSG("[DEBUG] new wgrad4 : " << wgrad4.to(at::kFloat).sum().item<float>());
...@@ -2474,13 +4053,21 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -2474,13 +4053,21 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_init", &bottleneck_forward_init, "Bottleneck block init"); m.def("forward_init", &bottleneck_forward_init, "Bottleneck block init");
m.def("forward_out1", &bottleneck_forward_out1, "Bottleneck block forward"); m.def("forward_out1", &bottleneck_forward_out1, "Bottleneck block forward");
m.def("forward_out2", &bottleneck_forward_out2, "Bottleneck block forward"); m.def("forward_out2", &bottleneck_forward_out2, "Bottleneck block forward");
m.def("forward_out2_mask", &bottleneck_forward_out2_mask, "Bottleneck block forward");
m.def("forward_out2_halo", &bottleneck_forward_out2_halo, "Bottleneck block forward"); m.def("forward_out2_halo", &bottleneck_forward_out2_halo, "Bottleneck block forward");
m.def("forward_out2_halo_corr", &bottleneck_forward_out2_halo_corr, "Bottleneck block forward");
m.def("forward_out2_pad", &bottleneck_forward_out2_pad, "Bottleneck block forward");
m.def("forward_rest", &bottleneck_forward_rest, "Bottleneck block forward"); m.def("forward_rest", &bottleneck_forward_rest, "Bottleneck block forward");
m.def("backward_init", &bottleneck_backward_init, "Bottleneck block backward init"); m.def("backward_init", &bottleneck_backward_init, "Bottleneck block backward init");
m.def("backward_grad_out2", &bottleneck_backward_grad_out2, "Bottleneck block backward"); m.def("backward_grad_out2", &bottleneck_backward_grad_out2, "Bottleneck block backward");
m.def("backward_grad_out1", &bottleneck_backward_grad_out1, "Bottleneck block backward"); m.def("backward_grad_out1", &bottleneck_backward_grad_out1, "Bottleneck block backward");
m.def("backward_grad_out1_mask", &bottleneck_backward_grad_out1_mask, "Bottleneck block backward");
m.def("backward_grad_out1_halo", &bottleneck_backward_grad_out1_halo, "Bottleneck block backward"); m.def("backward_grad_out1_halo", &bottleneck_backward_grad_out1_halo, "Bottleneck block backward");
m.def("backward_grad_out1_halo_corr", &bottleneck_backward_grad_out1_halo_corr, "Bottleneck block backward");
m.def("backward_wgrad2_pad", &bottleneck_backward_wgrad2_pad, "Bottleneck block backward");
m.def("backward_wgrad2", &bottleneck_backward_wgrad2, "Bottleneck block backward"); m.def("backward_wgrad2", &bottleneck_backward_wgrad2, "Bottleneck block backward");
m.def("backward_wgrad2_halo", &bottleneck_backward_wgrad2_halo, "Bottleneck block backward"); m.def("backward_wgrad2_halo", &bottleneck_backward_wgrad2_halo, "Bottleneck block backward");
m.def("backward_wgrad3", &bottleneck_backward_wgrad3, "Bottleneck block backward");
m.def("backward_wgrad1", &bottleneck_backward_wgrad1, "Bottleneck block backward");
m.def("backward_rest", &bottleneck_backward_rest, "Bottleneck block backward"); m.def("backward_rest", &bottleneck_backward_rest, "Bottleneck block backward");
} }
#include <ATen/ATen.h>
#include <ATen/cudnn/Handle.h> // for getcudnnhandle
#include <torch/extension.h>
#include <torch/torch.h>
#include <vector>
#include <cudnn_frontend.h>
#include <iostream>
#ifdef DEBUG
#define DEBUG_MSG(str) do { std::cout << str << std::endl; } while( false )
#else
#define DEBUG_MSG(str) do { } while ( false )
#endif
#ifdef DEBUG_CUDNN
#define DEBUG_CUDNN_MSG(buf, str) do { buf << str << std::endl; } while( false )
#else
#define DEBUG_CUDNN_MSG(buf, str) do { } while ( false )
#endif
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(at::MemoryFormat::ChannelsLast), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define checkCudnnErr(...) \
do { \
int err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \
if (err) { \
return; \
} \
} while (0)
int checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line) {
if (code) {
printf("CUDNN error at %s:%d, code=%d (%s) in '%s'\n", file, line, (int)code, cudnnGetErrorString(code), expr);
return 1;
}
return 0;
}
void checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort = true);
#define checkCUDAError(val) { checkError((val), #val, __FILE__, __LINE__); } // in-line regular function
void checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort) {
if (code != cudaSuccess)
{
const char * errorMessage = cudaGetErrorString(code);
fprintf(stderr, "CUDA error returned from \"%s\" at %s:%d, Error code: %d (%s)\n", func, file, line, code, errorMessage);
if (abort){
cudaDeviceReset();
exit(code);
}
}
}
void generateStrides(const int64_t* dimA, int64_t* strideA, int nbDims, cudnnTensorFormat_t filterFormat) {
// For INT8x4 and INT8x32 we still compute standard strides here to input
// into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref.
if (filterFormat == CUDNN_TENSOR_NCHW) {
strideA[nbDims - 1] = 1;
for (int64_t d = nbDims - 2; d >= 0; d--) {
strideA[d] = strideA[d + 1] * dimA[d + 1];
}
} else {
// Here we assume that the format is CUDNN_TENSOR_NHWC
strideA[1] = 1;
strideA[nbDims - 1] = strideA[1] * dimA[1];
for (int64_t d = nbDims - 2; d >= 2; d--) {
strideA[d] = strideA[d + 1] * dimA[d + 1];
}
strideA[0] = strideA[2] * dimA[2];
}
}
int getFwdConvDilatedFilterDim(int filterDim, int dilation) {
return ((filterDim - 1) * dilation) + 1;
}
int getFwdConvPaddedImageDim(int tensorDim, int pad) {
return tensorDim + (2 * pad);
}
int getFwdConvOutputDim(int tensorDim,
int pad,
int filterDim,
int stride,
int dilation) {
int p = (getFwdConvPaddedImageDim(tensorDim, pad) - getFwdConvDilatedFilterDim(filterDim, dilation)) / stride + 1;
return (p);
}
// create a cache for plan
std::unordered_map<std::string, cudnn_frontend::ExecutionPlan> plan_cache;
std::string getConvFusionString(int64_t* x_dim_padded,
int64_t* padA,
int64_t* convstrideA,
int64_t* dilationA,
int64_t* w_dim_padded,
cudnnDataType_t dataType,
std::string fusion_string) {
for(int i=0;i<4;i++) {
fusion_string += 'X';
fusion_string += std::to_string(x_dim_padded[i]);
}
for(int i=0;i<4;i++) {
fusion_string += 'W';
fusion_string += std::to_string(w_dim_padded[i]);
}
for(int i=0;i<2;i++) {
fusion_string += 'P';
fusion_string += std::to_string(padA[i]);
}
for(int i=0;i<2;i++) {
fusion_string += 'S';
fusion_string += std::to_string(convstrideA[i]);
}
for(int i=0;i<2;i++) {
fusion_string += 'D';
fusion_string += std::to_string(dilationA[i]);
}
fusion_string += 'T';
fusion_string += std::to_string(dataType);
return fusion_string;
}
cudnn_frontend::ExecutionPlan& getOrCreatePlan(cudnnHandle_t handle_,
std::stringstream& log_buf,
cudnn_frontend::OperationGraph& opGraph,
std::string cache_string,
bool use_heuristic = true){
auto it = plan_cache.find(cache_string);
if (it != plan_cache.end()) {
DEBUG_CUDNN_MSG(log_buf, "Found plan in cache");
return it->second;
} else {
if (use_heuristic){
// TODO: confirm which mode to use
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(opGraph)
.setHeurMode(CUDNN_HEUR_MODE_INSTANT)
.build();
// try 3 times for now as WAR for no heuristic training
int max_tries = 3, count = 0;
auto& engine_configs = heuristics.getEngineConfig(max_tries);
while(true) {
try {
plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle_)
.setEngineConfig(engine_configs[count], opGraph.getTag())
.build()));
break;
} catch (cudnn_frontend::cudnnException e) {
if (++count == max_tries) throw e;
}
}
}else{
DEBUG_CUDNN_MSG(log_buf, "No plan in cache");
// How many engines support this operation graph ?
auto total_engines = opGraph.getEngineCount();
DEBUG_CUDNN_MSG(log_buf, opGraph.describe() << " has " << total_engines << " engines.");
// We have to randomly pick one engine from [0, total_engines)
// Selecting "0" by default
auto engine = cudnn_frontend::EngineBuilder().setGlobalEngineIdx(0).setOperationGraph(opGraph).build();
DEBUG_CUDNN_MSG(log_buf, engine.describe());
auto& knobs = engine.getSupportedKnobs();
for (auto it = std::begin(knobs); it != std::end(knobs); ++it) {
DEBUG_CUDNN_MSG(log_buf, it->describe());
}
if (knobs.begin() != knobs.end()) {
DEBUG_CUDNN_MSG(log_buf, "Updated knob choice");
knobs.begin()->setChoice(knobs.begin()->getMinValue() + 1);
DEBUG_CUDNN_MSG(log_buf, knobs.begin()->describe());
}
// Createmplacee the requisite engine config
auto engine_config = cudnn_frontend::EngineConfigBuilder().setEngine(engine).build();
DEBUG_CUDNN_MSG(log_buf, engine_config.describe());
plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder().setHandle(handle_).setEngineConfig(engine_config).build()));
}
return plan_cache.find(cache_string)->second;
}
}
void
run_conv_bias(int64_t* x_dim,
int64_t* w_dim,
int64_t* y_dim,
int64_t* conv_pad,
int64_t* convstride,
int64_t* dilation,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrB,
at::Half* devPtrY) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
float alpha = 1.0f;
float beta = 0.0f;
int64_t b_dim[] = {1, y_dim[1], 1, 1};
// Creates the necessary tensor descriptors
int64_t stride[4];
generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto xTensor = cudnn_frontend::TensorBuilder()
.setDim(4, x_dim)
.setStrides(4, stride)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, xTensor.describe());
generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto wTensor = cudnn_frontend::TensorBuilder()
.setDim(4, w_dim)
.setStrides(4, stride)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, wTensor.describe());
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto afterConvTensor = cudnn_frontend::TensorBuilder()
.setDim(4, y_dim)
.setStrides(4, stride)
.setId('c')
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual()
.build();
DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe());
generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto bTensor = cudnn_frontend::TensorBuilder()
.setDim(4, b_dim)
.setStrides(4, stride)
.setId('b')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, bTensor.describe());
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto afterBiasTensor = cudnn_frontend::TensorBuilder()
.setDim(4, y_dim)
.setStrides(4, stride)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe());
// Define the bias operation
auto biasDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, conv_pad)
.setPostPadding(convDim, conv_pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.setxDesc(xTensor)
.setwDesc(wTensor)
.setyDesc(afterConvTensor)
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create a Bias Node.
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(conv_op.getOutputTensor())
.setbDesc(bTensor)
.setyDesc(afterBiasTensor)
.setpwDesc(biasDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
// Create an Operation Graph. In this case it is convolution bias activation
std::array<cudnn_frontend::Operation const*, 2> ops = {&conv_op, &bias_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(2, ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim, conv_pad, convstride, dilation, w_dim, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrY};
int64_t uids[] = {'x', 'w', 'b', 'y'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(4, data_ptrs)
.setUids(4, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_conv_bias_mask_relu(int64_t* x_dim,
int64_t* w_dim,
int64_t* y_dim,
int64_t* conv_pad,
int64_t* conv_stride,
int64_t* conv_dilation,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrB,
int8_t* devPtrM,
at::Half* devPtrY) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int conv_dim = 2;
float alpha = 1.0f;
float beta = 0.0f;
int64_t b_dim[] = {1, y_dim[1], 1, 1};
// Creates the necessary tensor descriptors
int64_t stride[4];
generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto xTensor = cudnn_frontend::TensorBuilder()
.setDim(4, x_dim)
.setStrides(4, stride)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, xTensor.describe());
generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto wTensor = cudnn_frontend::TensorBuilder()
.setDim(4, w_dim)
.setStrides(4, stride)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, wTensor.describe());
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto mTensor = cudnn_frontend::TensorBuilder()
.setDim(4, y_dim)
.setStrides(4, stride)
.setId('m')
.setAlignment(16)
.setDataType(CUDNN_DATA_INT8)
.build();
DEBUG_CUDNN_MSG(log_buf, wTensor.describe());
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto afterConvTensor = cudnn_frontend::TensorBuilder()
.setDim(4, y_dim)
.setStrides(4, stride)
.setId('c')
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual()
.build();
DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe());
generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto bTensor = cudnn_frontend::TensorBuilder()
.setDim(4, b_dim)
.setStrides(4, stride)
.setId('b')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, bTensor.describe());
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto afterBiasTensor = cudnn_frontend::TensorBuilder()
.setDim(4, y_dim)
.setStrides(4, stride)
.setId('B')
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual()
.build();
DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe());
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto afterMaskTensor = cudnn_frontend::TensorBuilder()
.setDim(4, y_dim)
.setStrides(4, stride)
.setId('M')
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual()
.build();
DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe());
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto afterReLUTensor = cudnn_frontend::TensorBuilder()
.setDim(4, y_dim)
.setStrides(4, stride)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, afterReLUTensor.describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(conv_dim)
.setStrides(conv_dim, conv_stride)
.setPrePadding(conv_dim, conv_pad)
.setPostPadding(conv_dim, conv_pad)
.setDilation(conv_dim, conv_dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// Define the bias operation
auto biasDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
// Define the mask operation
auto maskDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_MUL)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
// Define the activation operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_FWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.setxDesc(xTensor)
.setwDesc(wTensor)
.setyDesc(afterConvTensor)
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create a Bias Node
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(conv_op.getOutputTensor())
.setbDesc(bTensor)
.setyDesc(afterBiasTensor)
.setpwDesc(biasDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
// create a Mask Node
auto mask_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(bias_op.getOutputTensor())
.setbDesc(mTensor)
.setyDesc(afterMaskTensor)
.setpwDesc(maskDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, mask_op.describe());
// Create an Activation Node
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(mask_op.getOutputTensor())
.setyDesc(afterReLUTensor)
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create an Operation Graph. In this case it is convolution bias activation
std::array<cudnn_frontend::Operation const*, 4> ops = {&conv_op, &bias_op, &mask_op, &act_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(4, ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrM, devPtrY};
int64_t uids[] = {'x', 'w', 'b', 'm', 'y'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(5, data_ptrs)
.setUids(5, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_conv_bias_relu(int64_t* x_dim,
int64_t* w_dim,
int64_t* y_dim,
int64_t* conv_pad,
int64_t* conv_stride,
int64_t* conv_dilation,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrB,
at::Half* devPtrY) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int conv_dim = 2;
float alpha = 1.0f;
float beta = 0.0f;
int64_t b_dim[] = {1, y_dim[1], 1, 1};
// Creates the necessary tensor descriptors
int64_t stride[4];
generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto xTensor = cudnn_frontend::TensorBuilder()
.setDim(4, x_dim)
.setStrides(4, stride)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, xTensor.describe());
generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto wTensor = cudnn_frontend::TensorBuilder()
.setDim(4, w_dim)
.setStrides(4, stride)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, wTensor.describe());
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto afterConvTensor = cudnn_frontend::TensorBuilder()
.setDim(4, y_dim)
.setStrides(4, stride)
.setId('c')
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual()
.build();
DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe());
generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto bTensor = cudnn_frontend::TensorBuilder()
.setDim(4, b_dim)
.setStrides(4, stride)
.setId('b')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, bTensor.describe());
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto afterBiasTensor = cudnn_frontend::TensorBuilder()
.setDim(4, y_dim)
.setStrides(4, stride)
.setId('B')
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual()
.build();
DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe());
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto afterReLUTensor = cudnn_frontend::TensorBuilder()
.setDim(4, y_dim)
.setStrides(4, stride)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, afterReLUTensor.describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(conv_dim)
.setStrides(conv_dim, conv_stride)
.setPrePadding(conv_dim, conv_pad)
.setPostPadding(conv_dim, conv_pad)
.setDilation(conv_dim, conv_dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// Define the bias operation
auto biasDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
// Define the activation operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_FWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.setxDesc(xTensor)
.setwDesc(wTensor)
.setyDesc(afterConvTensor)
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create a Bias Node.
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(conv_op.getOutputTensor())
.setbDesc(bTensor)
.setyDesc(afterBiasTensor)
.setpwDesc(biasDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
// Create an Activation Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(bias_op.getOutputTensor())
.setyDesc(afterReLUTensor)
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create an Operation Graph. In this case it is convolution bias activation
std::array<cudnn_frontend::Operation const*, 3> ops = {&conv_op, &bias_op, &act_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(3, ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrY};
int64_t uids[] = {'x', 'w', 'b', 'y'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(4, data_ptrs)
.setUids(4, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_drelu_dbias(int64_t* dy_dim,
cudnnDataType_t dataType,
at::Half* devPtrDY,
at::Half* devPtrR,
at::Half* devPtrDR,
float* devPtrDB) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
float alpha = 1.0f;
float beta = 0.0f;
int64_t b_dim[] = {1, dy_dim[1], 1, 1};
// Creates the necessary tensor descriptors
int64_t stride[4];
generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto dyTensor = cudnn_frontend::TensorBuilder()
.setDim(4, dy_dim)
.setStrides(4, stride)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, dyTensor.describe());
generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto rTensor = cudnn_frontend::TensorBuilder()
.setDim(4, dy_dim)
.setStrides(4, stride)
.setId('r')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, rTensor.describe());
generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto inActGradTensor = cudnn_frontend::TensorBuilder()
.setDim(4, dy_dim)
.setStrides(4, stride)
.setId('R')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, inActGradTensor.describe());
generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto biasGradTensor = cudnn_frontend::TensorBuilder()
.setDim(4, b_dim)
.setStrides(4, stride)
.setId('y')
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, biasGradTensor.describe());
// Define the activation backward operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_BWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the bias backward operation
auto biasDesc = cudnn_frontend::ReductionDescBuilder()
.setMathPrecision(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
.build();
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
// Create an relu backward Node
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setdyDesc(dyTensor)
.setxDesc(rTensor)
.setdxDesc(inActGradTensor)
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create bias node
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(inActGradTensor)
.setyDesc(biasGradTensor)
.setreductionDesc(biasDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
// Create an Operation Graph. In this case it is bias only
std::array<cudnn_frontend::Operation const*, 2> ops = {&act_op, &bias_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
// creating unique dummy values
int64_t pad_dummy[] = {20, 20};
int64_t stride_dummy[] = {20, 20};
int64_t dilation_dummy[] = {20, 20};
auto cache_string = getConvFusionString(dy_dim, pad_dummy, stride_dummy, dilation_dummy, b_dim, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrDY, devPtrR, devPtrDR, devPtrDB};
int64_t uids[] = {'x', 'r', 'R', 'y'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(4, data_ptrs)
.setUids(4, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_dconv_drelu_dbias(int64_t* x_dim,
int64_t* w_dim,
int64_t* y_dim,
int64_t* pad,
int64_t* convstride,
int64_t* dilation,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrR,
at::Half* devPtrRg,
float* devPtrY) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
float alpha = 1.0f;
float beta = 0.0f;
int64_t b_dim[] = {1, x_dim[1], 1, 1};
int64_t stride[4];
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto outConvGradTensor = cudnn_frontend::TensorBuilder()
.setDim(4, y_dim)
.setStrides(4, stride)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, outConvGradTensor.describe());
generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto wTensor = cudnn_frontend::TensorBuilder()
.setDim(4, w_dim)
.setStrides(4, stride)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, wTensor.describe());
generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto inConvGradTensor = cudnn_frontend::TensorBuilder()
.setDim(4, x_dim)
.setStrides(4, stride)
.setId('A')
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual()
.build();
DEBUG_CUDNN_MSG(log_buf, inConvGradTensor.describe());
generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto rTensor = cudnn_frontend::TensorBuilder()
.setDim(4, x_dim)
.setStrides(4, stride)
.setId('r')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, rTensor.describe());
generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto inReLUGradTensor = cudnn_frontend::TensorBuilder()
.setDim(4, x_dim)
.setStrides(4, stride)
.setId('R')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, inReLUGradTensor.describe());
generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto inBiasGradTensor = cudnn_frontend::TensorBuilder()
.setDim(4, b_dim)
.setStrides(4, stride)
.setId('y')
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, inBiasGradTensor.describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// Define the activation backward operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_BWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the bias backward operation
auto biasDesc = cudnn_frontend::ReductionDescBuilder()
.setMathPrecision(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
.build();
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)
.setdyDesc(outConvGradTensor)
.setwDesc(wTensor)
.setdxDesc(inConvGradTensor)
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create an relu backward Node
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setdyDesc(inConvGradTensor)
.setxDesc(rTensor)
.setdxDesc(inReLUGradTensor)
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create bias node
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(inReLUGradTensor)
.setyDesc(inBiasGradTensor)
.setreductionDesc(biasDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
// Create an Operation Graph. In this case it is bias only
std::array<cudnn_frontend::Operation const*, 3> ops = {&conv_op, &act_op, &bias_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim, pad, convstride, dilation, w_dim, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrW, devPtrR, devPtrRg, devPtrY};
int64_t uids[] = {'x', 'w', 'r', 'R', 'y'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(5, data_ptrs)
.setUids(5, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_dconv(int64_t* x_dim,
int64_t* w_dim,
int64_t* y_dim,
int64_t* conv_pad,
int64_t* conv_stride,
int64_t* conv_dilation,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrY,
cudnnBackendDescriptorType_t mode) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int conv_dim = 2;
float alpha = 1.0f;
float beta = 0.0f;
// Define the convolution problem
int64_t stride[4];
generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto xTensor = cudnn_frontend::TensorBuilder()
.setDim(4, x_dim)
.setStrides(4, stride)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, xTensor.describe());
generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto wTensor = cudnn_frontend::TensorBuilder()
.setDim(4, w_dim)
.setStrides(4, stride)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, wTensor.describe());
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto yTensor = cudnn_frontend::TensorBuilder()
.setDim(4, y_dim)
.setStrides(4, stride)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, yTensor.describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(conv_dim)
.setStrides(conv_dim, conv_stride)
.setPrePadding(conv_dim, conv_pad)
.setPostPadding(conv_dim, conv_pad)
.setDilation(conv_dim, conv_dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// Create a convolution node
// mode should be one of following
// CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
// CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
auto conv_op_builder = cudnn_frontend::OperationBuilder(mode);
if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) {
conv_op_builder.setdxDesc(xTensor)
.setwDesc(wTensor)
.setdyDesc(yTensor)
.setcDesc(convDesc);
}
else {
conv_op_builder.setxDesc(xTensor)
.setdwDesc(wTensor)
.setdyDesc(yTensor)
.setcDesc(convDesc);
}
auto conv_op = conv_op_builder
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 1> ops = {&conv_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrW, devPtrY};
int64_t uids[] = {'x', 'w', 'y'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(3, data_ptrs)
.setUids(3, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_dbias(int64_t* x_dim,
cudnnDataType_t dataType,
at::Half* devPtrX,
float* devPtrY) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
int64_t b_dim[] = {1, x_dim[1], 1, 1};
int64_t stride[4];
generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto xTensor = cudnn_frontend::TensorBuilder()
.setDim(4, x_dim)
.setStrides(4, stride)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, xTensor.describe());
generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto yTensor = cudnn_frontend::TensorBuilder()
.setDim(4, b_dim)
.setStrides(4, stride)
.setId('y')
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, yTensor.describe());
// Define the bias backward operation
auto biasDesc = cudnn_frontend::ReductionDescBuilder()
.setMathPrecision(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
.build();
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
// Create bias node
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(xTensor)
.setyDesc(yTensor)
.setreductionDesc(biasDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
// Create an Operation Graph. In this case it is bias only
std::array<cudnn_frontend::Operation const*, 1> ops = {&bias_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
int64_t pad_dummy[] = {10, 10};
int64_t stride_dummy[] = {10, 10};
int64_t dilation_dummy[] = {10, 10};
auto cache_string = getConvFusionString(x_dim, pad_dummy, stride_dummy, dilation_dummy, b_dim, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY};
int64_t uids[] = {'x', 'y'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(2, data_ptrs)
.setUids(2, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
std::vector<at::Tensor> conv_bias_mask_relu_forward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {
std::cout << std::fixed;
// create output vector
std::vector<at::Tensor> outputs;
auto output_format = at::MemoryFormat::ChannelsLast;
// setup dimensions
int64_t x_dim[] = {0, 0, 0, 0};
int64_t w_dim[] = {0, 0, 0, 0};
// All dim calculation after this order of n,c,h,w
int axis[] = {0, 1, 2, 3};
for (int dim = 0; dim < 4; dim++) {
x_dim[dim] = inputs[0].size(axis[dim]);
w_dim[dim] = inputs[1].size(axis[dim]);
}
// output dim in n,c,h,w used by backend
int64_t y_dim[] = {0, 0, 0, 0};
// use these fixed values
int64_t conv_pad[] = {padding, padding};
int64_t conv_stride[] = {stride, stride};
int64_t conv_dilation[] = {1, 1};
// compute output from pad/stride/dilation
y_dim[0] = x_dim[0];
y_dim[1] = w_dim[0];
for (int dim = 0; dim < 2; dim++) {
y_dim[dim + 2] = getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]);
}
// run
at::Half* x = inputs[0].data_ptr<at::Half>();
at::Half* w = inputs[1].data_ptr<at::Half>();
at::Half* b = inputs[2].data_ptr<at::Half>();
int8_t* m = inputs[3].data_ptr<int8_t>();
auto out = at::empty(y_dim, inputs[0].type(), output_format);
at::Half* y = out.data_ptr<at::Half>();
run_conv_bias_mask_relu(x_dim,
w_dim,
y_dim,
conv_pad,
conv_stride,
conv_dilation,
CUDNN_DATA_HALF,
x,
w,
b,
m,
y);
DEBUG_MSG("[DEBUG] conv-bias-mask-relu : " << y.to(at::kFloat).sum().item<float>());
outputs.push_back(out);
return outputs;
}
std::vector<at::Tensor> conv_bias_relu_forward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {
std::cout << std::fixed;
// create output vector
std::vector<at::Tensor> outputs;
auto output_format = at::MemoryFormat::ChannelsLast;
// setup dimensions
int64_t x_dim[] = {0, 0, 0, 0};
int64_t w_dim[] = {0, 0, 0, 0};
// All dim calculation after this order of n,c,h,w
int axis[] = {0, 1, 2, 3};
for (int dim = 0; dim < 4; dim++) {
x_dim[dim] = inputs[0].size(axis[dim]);
w_dim[dim] = inputs[1].size(axis[dim]);
}
// output dim in n,c,h,w used by backend
int64_t y_dim[] = {0, 0, 0, 0};
// use these fixed values
int64_t conv_pad[] = {padding, padding};
int64_t conv_stride[] = {stride, stride};
int64_t conv_dilation[] = {1, 1};
// compute output from pad/stride/dilation
y_dim[0] = x_dim[0];
y_dim[1] = w_dim[0];
for (int dim = 0; dim < 2; dim++) {
y_dim[dim + 2] = getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]);
}
// run
at::Half* x = inputs[0].data_ptr<at::Half>();
at::Half* w = inputs[1].data_ptr<at::Half>();
at::Half* b = inputs[2].data_ptr<at::Half>();
auto out = at::empty(y_dim, inputs[0].type(), output_format);
at::Half* y = out.data_ptr<at::Half>();
run_conv_bias_relu(x_dim,
w_dim,
y_dim,
conv_pad,
conv_stride,
conv_dilation,
CUDNN_DATA_HALF,
x,
w,
b,
y);
DEBUG_MSG("[DEBUG] conv-bias-relu : " << y.to(at::kFloat).sum().item<float>());
outputs.push_back(out);
return outputs;
}
std::vector<at::Tensor> conv_bias_relu_backward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {
bool requires_grad = inputs[0].requires_grad();
for (int i = 0; i <= 3; i++) {
CHECK_INPUT(inputs[i]);
}
std::cout << std::fixed;
// create output vector
std::vector<at::Tensor> outputs;
auto output_format = at::MemoryFormat::ChannelsLast;
// setup dimensions
int64_t x_dim[] = {0, 0, 0, 0};
int64_t w_dim[] = {0, 0, 0, 0};
int64_t y_dim[] = {0, 0, 0, 0};
// All dim calculation after this order of n,c,h,w
int axis[] = {0, 1, 2, 3};
for (int dim = 0; dim < 4; dim++) {
x_dim[dim] = inputs[0].size(axis[dim]);
w_dim[dim] = inputs[1].size(axis[dim]);
y_dim[dim] = inputs[3].size(axis[dim]);
}
int64_t b_dim[] = {1, y_dim[1], 1, 1};
int64_t conv_pad[] = {padding, padding};
int64_t conv_stride[] = {stride, stride};
int64_t conv_dilation[] = {1, 1};
// run
// drelu-dbias
at::Half* dy = inputs[3].data_ptr<at::Half>();
at::Half* r = inputs[2].data_ptr<at::Half>();
auto drelu = at::empty_like(inputs[2]);
at::Half* dr = drelu.data_ptr<at::Half>();
auto options = at::TensorOptions().dtype(at::kFloat).layout(inputs[0].layout()).device(inputs[0].device()).requires_grad(false);
auto bgrad = at::empty(b_dim, options, output_format);
float* db = bgrad.data_ptr<float>();
run_drelu_dbias(y_dim,
CUDNN_DATA_HALF,
dy,
r,
dr,
db);
// conv wgrad
at::Half* x = inputs[0].data_ptr<at::Half>();
auto wgrad = at::empty_like(inputs[1]);
at::Half* dw = wgrad.data_ptr<at::Half>();
run_dconv(x_dim,
w_dim,
y_dim,
conv_pad,
conv_stride,
conv_dilation,
CUDNN_DATA_HALF,
x,
dw,
dr,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
// conv dgrad
at::Half* w = inputs[1].data_ptr<at::Half>();
auto dgrad = at::empty_like(inputs[0]);
at::Half* dx = dgrad.data_ptr<at::Half>();
run_dconv(x_dim,
w_dim,
y_dim,
conv_pad,
conv_stride,
conv_dilation,
CUDNN_DATA_HALF,
dx,
w,
dr,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
outputs.push_back(dgrad);
outputs.push_back(wgrad);
outputs.push_back(bgrad);
return outputs;
}
std::vector<at::Tensor> conv_bias_forward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {
std::cout << std::fixed;
// create output vector
std::vector<at::Tensor> outputs;
auto output_format = at::MemoryFormat::ChannelsLast;
// setup dimensions
int64_t x_dim[] = {0, 0, 0, 0};
int64_t w_dim[] = {0, 0, 0, 0};
// All dim calculation after this order of n,c,h,w
int axis[] = {0, 1, 2, 3};
for (int dim = 0; dim < 4; dim++) {
x_dim[dim] = inputs[0].size(axis[dim]);
w_dim[dim] = inputs[1].size(axis[dim]);
}
// output dim in n,c,h,w used by backend
int64_t y_dim[] = {0, 0, 0, 0};
// use these fixed values
int64_t conv_pad[] = {padding, padding};
int64_t conv_stride[] = {stride, stride};
int64_t conv_dilation[] = {1, 1};
// compute output from pad/stride/dilation
y_dim[0] = x_dim[0];
y_dim[1] = w_dim[0];
for (int dim = 0; dim < 2; dim++) {
y_dim[dim + 2] = getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]);
}
// run
at::Half* x = inputs[0].data_ptr<at::Half>();
at::Half* w = inputs[1].data_ptr<at::Half>();
at::Half* b = inputs[2].data_ptr<at::Half>();
auto out = at::empty(y_dim, inputs[0].type(), output_format);
at::Half* y = out.data_ptr<at::Half>();
run_conv_bias(x_dim,
w_dim,
y_dim,
conv_pad,
conv_stride,
conv_dilation,
CUDNN_DATA_HALF,
x,
w,
b,
y);
DEBUG_MSG("[DEBUG] conv-bias : " << y.to(at::kFloat).sum().item<float>());
outputs.push_back(out);
return outputs;
}
std::vector<at::Tensor> conv_bias_backward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {
bool requires_grad = inputs[0].requires_grad();
for (int i = 0; i <= 2; i++) {
CHECK_INPUT(inputs[i]);
}
std::cout << std::fixed;
// create output vector
std::vector<at::Tensor> outputs;
auto output_format = at::MemoryFormat::ChannelsLast;
// setup dimensions
int64_t x_dim[] = {0, 0, 0, 0};
int64_t w_dim[] = {0, 0, 0, 0};
int64_t y_dim[] = {0, 0, 0, 0};
// All dim calculation after this order of n,c,h,w
int axis[] = {0, 1, 2, 3};
for (int dim = 0; dim < 4; dim++) {
x_dim[dim] = inputs[0].size(axis[dim]);
w_dim[dim] = inputs[1].size(axis[dim]);
y_dim[dim] = inputs[2].size(axis[dim]);
}
int64_t b_dim[] = {1, y_dim[1], 1, 1};
int64_t conv_pad[] = {padding, padding};
int64_t conv_stride[] = {stride, stride};
int64_t conv_dilation[] = {1, 1};
// run
// dbias
at::Half* dy = inputs[2].data_ptr<at::Half>();
auto options = at::TensorOptions().dtype(at::kFloat).layout(inputs[0].layout()).device(inputs[0].device()).requires_grad(false);
auto bgrad = at::empty(b_dim, options, output_format);
float* db = bgrad.data_ptr<float>();
run_dbias(y_dim,
CUDNN_DATA_HALF,
dy,
db);
// conv wgrad
at::Half* x = inputs[0].data_ptr<at::Half>();
auto wgrad = at::empty_like(inputs[1]);
at::Half* dw = wgrad.data_ptr<at::Half>();
run_dconv(x_dim,
w_dim,
y_dim,
conv_pad,
conv_stride,
conv_dilation,
CUDNN_DATA_HALF,
x,
dw,
dy,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
// conv dgrad
at::Half* w = inputs[1].data_ptr<at::Half>();
auto dgrad = at::empty_like(inputs[0]);
at::Half* dx = dgrad.data_ptr<at::Half>();
run_dconv(x_dim,
w_dim,
y_dim,
conv_pad,
conv_stride,
conv_dilation,
CUDNN_DATA_HALF,
dx,
w,
dy,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
outputs.push_back(dgrad);
outputs.push_back(wgrad);
outputs.push_back(bgrad);
return outputs;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &conv_bias_relu_forward, "Fused Conv-Bias-ReLU forward");
m.def("backward", &conv_bias_relu_backward, "Fused Conv-Bias-ReLU backward");
m.def("forward_no_relu", &conv_bias_forward, "Fused Conv-Bias forward");
m.def("backward_no_relu", &conv_bias_backward, "Fused Conv-Bias backward");
m.def("forward_mask", &conv_bias_mask_relu_forward, "Fused Conv-Bias-Mask-ReLU forward");
}
Subproject commit b4e1ad9613b89199982c9baf6ee91f6f98f5606d Subproject commit fa611998a360cbabaa2dcc7c9859748144114fc0
...@@ -72,7 +72,7 @@ void set_params(Fused_multihead_attention_fprop_params &params, ...@@ -72,7 +72,7 @@ void set_params(Fused_multihead_attention_fprop_params &params,
constexpr float scale_softmax = 1.f; constexpr float scale_softmax = 1.f;
constexpr float scale_bmm2 = 1.f; constexpr float scale_bmm2 = 1.f;
set_alpha(params.scale_bmm1, scale_bmm1, acc_type); set_alpha(params.scale_bmm1, scale_bmm1, data_type);
set_alpha(params.scale_softmax, scale_softmax, acc_type); set_alpha(params.scale_softmax, scale_softmax, acc_type);
set_alpha(params.scale_bmm2, scale_bmm2, data_type); set_alpha(params.scale_bmm2, scale_bmm2, data_type);
...@@ -89,9 +89,15 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \ ...@@ -89,9 +89,15 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
const float p_dropout, const float p_dropout,
const int max_seq_len, const int max_seq_len,
const bool is_training, const bool is_training,
const bool is_nl,
const bool zero_tensors,
c10::optional<at::Generator> gen_) { c10::optional<at::Generator> gen_) {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major == 8 && dprops->minor == 0); TORCH_CHECK(dprops->major == 8 && dprops->minor == 0);
auto stream = at::cuda::getCurrentCUDAStream().stream();
Launch_params<Fused_multihead_attention_fprop_params> launch_params(dprops, stream, is_training, is_nl);
int seq_len = 512; int seq_len = 512;
auto launch = &run_fmha_fp16_512_64_sm80; auto launch = &run_fmha_fp16_512_64_sm80;
if( max_seq_len <= 128 ) { if( max_seq_len <= 128 ) {
...@@ -110,18 +116,6 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \ ...@@ -110,18 +116,6 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
TORCH_CHECK(false); TORCH_CHECK(false);
} }
constexpr int warps_m = 1;
constexpr int warps_n = 4; // this leads to an upper bound
const int mmas_m = seq_len / 16 / warps_m;
const int mmas_n = seq_len / 16 / warps_n;
const int elts_per_thread = 8 * mmas_m * mmas_n;
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(qkv.dtype() == torch::kFloat16);
TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);
TORCH_CHECK(qkv.is_cuda()) TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda()) TORCH_CHECK(cu_seqlens.is_cuda())
...@@ -147,12 +141,16 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \ ...@@ -147,12 +141,16 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts); auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);
if( zero_tensors ) {
ctx.zero_();
s.zero_();
}
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator()); gen_, at::cuda::detail::getDefaultCUDAGenerator());
Fused_multihead_attention_fprop_params params;
set_params(params, set_params(launch_params.params,
batch_size, batch_size,
seq_len, seq_len,
num_heads, num_heads,
...@@ -163,29 +161,32 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \ ...@@ -163,29 +161,32 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
s.data_ptr(), s.data_ptr(),
p_dropout); p_dropout);
// number of times random will be generated per thread, to offset philox counter in the random launch(launch_params, /*configure=*/ true);
// number of times random will be generated per thread, to offset philox counter in thc random
// state // state
int64_t counter_offset = elts_per_thread; int64_t counter_offset = launch_params.elts_per_thread;
at::PhiloxCudaState rng_engine_inputs; at::PhiloxCudaState rng_engine_inputs;
if( is_training ) { if( is_training ) {
// See Note [Acquire lock when using random generators] // See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_); std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset); launch_params.params.philox_args = gen->philox_cuda_state(counter_offset);
} }
launch(params, is_training, stream); launch(launch_params, /*configure=*/ false);
return { ctx, s }; return { ctx, s };
} }
std::vector<at::Tensor> std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
const at::Tensor &cu_seqlens, // b+1 const at::Tensor &cu_seqlens, // b+1
const float p_dropout, // probability to drop const float p_dropout, // probability to drop
const int max_seq_len // max sequence length to choose the kernel const int max_seq_len, // max sequence length to choose the kernel
const bool zero_tensors
) { ) {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major == 8 && dprops->minor == 0); TORCH_CHECK(dprops->major == 8 && dprops->minor == 0);
...@@ -235,6 +236,10 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size ...@@ -235,6 +236,10 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
auto dqkv = torch::empty_like(qkv); auto dqkv = torch::empty_like(qkv);
if( zero_tensors ) {
dqkv.zero_();
}
Fused_multihead_attention_fprop_params params; Fused_multihead_attention_fprop_params params;
set_params(params, set_params(params,
...@@ -259,92 +264,13 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size ...@@ -259,92 +264,13 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
return { dqkv, softmax }; return { dqkv, softmax };
} }
std::vector<at::Tensor> mha_fwd_nl(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens, // b+1
const float p_dropout,
const int max_seq_len,
const bool is_training,
c10::optional<at::Generator> gen_) {
int seq_len = 512;
auto launch = &run_fmha_fp16_512_64_sm80_nl;
TORCH_CHECK(max_seq_len == seq_len);
constexpr int warps_m = 1;
constexpr int warps_n = 4; // this leads to an upper bound
const int mmas_m = seq_len / 16 / warps_m;
const int mmas_n = seq_len / 16 / warps_n;
// static_assert( mmas_m == 32 );
// static_assert( mmas_n == 4 );
const int elts_per_thread = 8 * mmas_m * mmas_n;
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda())
TORCH_CHECK(qkv.is_contiguous())
TORCH_CHECK(cu_seqlens.is_contiguous())
TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);
const auto sizes = qkv.sizes();
TORCH_CHECK(sizes[THREE_DIM] == 3);
const int batch_size = cu_seqlens.numel() - 1;
const int total = sizes[TOTAL_DIM];
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 64);
auto opts = qkv.options();
auto ctx = torch::empty({ total, num_heads, head_size }, opts);
auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());
Fused_multihead_attention_fprop_params params;
set_params(params,
batch_size,
seq_len,
num_heads,
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
ctx.data_ptr(),
s.data_ptr(),
p_dropout);
// number of times random will be generated per thread, to offset philox counter in the random
// state
int64_t counter_offset = elts_per_thread;
at::PhiloxCudaState rng_engine_inputs;
if( is_training ) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
}
int num_chunks = 3;
if(batch_size == 3) {
num_chunks = 2;
}
launch(params, is_training, num_chunks, stream);
return { ctx, s };
}
std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num_heads, x head_size std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num_heads, x head_size
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
const at::Tensor &cu_seqlens, // b+1 const at::Tensor &cu_seqlens, // b+1
const float p_dropout, // probability to drop const float p_dropout, // probability to drop
const int max_seq_len // max sequence length to choose the kernel const int max_seq_len, // max sequence length to choose the kernel
const bool zero_tensors
) { ) {
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
...@@ -378,6 +304,10 @@ std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num ...@@ -378,6 +304,10 @@ std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num
auto dqkv = torch::empty_like(qkv); auto dqkv = torch::empty_like(qkv);
if( zero_tensors ) {
dqkv.zero_();
}
int num_chunks = 2; int num_chunks = 2;
if( batch_size == 1 ) { if( batch_size == 1 ) {
num_chunks = 4; num_chunks = 4;
...@@ -427,6 +357,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -427,6 +357,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "Fused Multi-head Self-attention for BERT"; m.doc() = "Fused Multi-head Self-attention for BERT";
m.def("fwd", &mha_fwd, "Forward pass"); m.def("fwd", &mha_fwd, "Forward pass");
m.def("bwd", &mha_bwd, "Backward pass"); m.def("bwd", &mha_bwd, "Backward pass");
m.def("fwd_nl", &mha_fwd_nl, "Forward pass (small-batch)");
m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)"); m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)");
} }
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
#include <cuda.h> #include <cuda.h>
#include <vector> #include <vector>
#if !defined(NEW_GENERATOR_PATH) #ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h> #include <ATen/CUDAGeneratorImpl.h>
#else #else
#include <ATen/cuda/CUDAGeneratorImpl.h> #include <ATen/cuda/CUDAGeneratorImpl.h>
...@@ -50,7 +50,7 @@ constexpr int D_DIM = 3; ...@@ -50,7 +50,7 @@ constexpr int D_DIM = 3;
struct Qkv_params { struct Qkv_params {
// The QKV matrices. // The QKV matrices.
void *qkv_ptr; void * __restrict__ qkv_ptr;
// The stride between rows of the Q, K and V matrices. // The stride between rows of the Q, K and V matrices.
size_t qkv_stride_in_bytes; size_t qkv_stride_in_bytes;
...@@ -64,19 +64,19 @@ struct Qkv_params { ...@@ -64,19 +64,19 @@ struct Qkv_params {
struct Fused_multihead_attention_fprop_params : public Qkv_params { struct Fused_multihead_attention_fprop_params : public Qkv_params {
// The dQKV matrices. // The dQKV matrices.
void *dqkv_ptr; void * __restrict__ dqkv_ptr;
// Temporary for dKV. // Temporary for dKV.
void *dkv_ptr; void * __restrict__ dkv_ptr;
// The O matrix (output). // The O matrix (output).
void *o_ptr; void * __restrict__ o_ptr;
// The stride between rows of O. // The stride between rows of O.
int64_t o_stride_in_bytes; int64_t o_stride_in_bytes;
// The pointer to the S matrix, overwritten by the dP matrix (bwd). // The pointer to the S matrix, overwritten by the dP matrix (bwd).
void *s_ptr; void * __restrict__ s_ptr;
// The stride between rows of the S matrix. // The stride between rows of the S matrix.
int64_t s_stride_in_bytes; int64_t s_stride_in_bytes;
...@@ -87,7 +87,7 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params { ...@@ -87,7 +87,7 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
uint32_t scale_bmm1, scale_softmax, scale_bmm2; uint32_t scale_bmm1, scale_softmax, scale_bmm2;
// array of length b+1 holding starting offset of each sequence. // array of length b+1 holding starting offset of each sequence.
int *cu_seqlens; int * __restrict__ cu_seqlens;
// The dropout probability (probability of keeping an activation). // The dropout probability (probability of keeping an activation).
float p_dropout; float p_dropout;
...@@ -104,10 +104,43 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params { ...@@ -104,10 +104,43 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
void run_fmha_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream); template<typename Kernel_params>
void run_fmha_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream); struct Launch_params{
void run_fmha_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream); Launch_params(cudaDeviceProp * props_,
void run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream); cudaStream_t stream_,
bool is_training_,
bool is_nl_)
: elts_per_thread(0)
, props(props_)
, stream(stream_)
, is_training(is_training_)
, is_nl(is_nl_) {
}
size_t elts_per_thread;
cudaDeviceProp * props;
cudaStream_t stream;
bool is_training;
Kernel_params params;
int num_full_heads;
int num_main_groups;
int heads_last_wave;
int main_steps;
int rest_steps;
bool is_nl;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_fmha_fp16_128_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
void run_fmha_fp16_256_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
void run_fmha_fp16_384_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
void run_fmha_fp16_512_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream); void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream); void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
......
...@@ -210,9 +210,6 @@ struct Clear_accumulator<float, WARPS_K> { ...@@ -210,9 +210,6 @@ struct Clear_accumulator<float, WARPS_K> {
} }
}; };
////////////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Acc, typename A, typename B, int M, int N> template<typename Acc, typename A, typename B, int M, int N>
......
...@@ -60,7 +60,7 @@ struct Gmem_tile_qkv { ...@@ -60,7 +60,7 @@ struct Gmem_tile_qkv {
// Ctor. // Ctor.
template< typename Params, typename BInfo > template< typename Params, typename BInfo >
inline __device__ Gmem_tile_qkv(const Params &params, int qkv_offset, const BInfo &binfo, int tidx) inline __device__ Gmem_tile_qkv(const Params &params, const int qkv_offset, const BInfo &binfo, const int tidx)
: params_qkv_stride_in_bytes_(params.qkv_stride_in_bytes) : params_qkv_stride_in_bytes_(params.qkv_stride_in_bytes)
, actual_seqlen(binfo.actual_seqlen) , actual_seqlen(binfo.actual_seqlen)
, qkv_ptr_(reinterpret_cast<char *>(params.qkv_ptr)) { , qkv_ptr_(reinterpret_cast<char *>(params.qkv_ptr)) {
...@@ -125,6 +125,11 @@ struct Gmem_tile_qkv { ...@@ -125,6 +125,11 @@ struct Gmem_tile_qkv {
actual_seqlen -= ROWS; actual_seqlen -= ROWS;
} }
inline __device__ void move(int steps) {
qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_ * steps;
actual_seqlen -= ROWS * steps;
}
// The stride between rows for the QKV matrice. // The stride between rows for the QKV matrice.
int64_t params_qkv_stride_in_bytes_; int64_t params_qkv_stride_in_bytes_;
// The pointer. // The pointer.
...@@ -224,6 +229,11 @@ struct Gmem_tile_o { ...@@ -224,6 +229,11 @@ struct Gmem_tile_o {
o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_; o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_;
} }
inline __device__ void move(const int steps) {
row_ += ROWS * steps;
o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_ * steps;
}
// The stride between rows for the QKV matrice. // The stride between rows for the QKV matrice.
int64_t params_o_stride_in_bytes_; int64_t params_o_stride_in_bytes_;
// The pointer. // The pointer.
...@@ -270,13 +280,9 @@ struct Gmem_tile_mma_sd { ...@@ -270,13 +280,9 @@ struct Gmem_tile_mma_sd {
// Ctor. // Ctor.
template<typename Params> template<typename Params>
inline __device__ Gmem_tile_mma_sd(void *ptr, const Params &params, const int tidx) inline __device__ Gmem_tile_mma_sd(void *ptr, const Params &params, const int bidb, const int bidh, const int tidx)
: ptr_(static_cast<char *>(ptr)) { : ptr_(static_cast<char *>(ptr)) {
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The block index. // The block index.
size_t bidx = bidb * params.h + bidh; size_t bidx = bidb * params.h + bidh;
...@@ -300,6 +306,9 @@ struct Gmem_tile_mma_sd { ...@@ -300,6 +306,9 @@ struct Gmem_tile_mma_sd {
inline __device__ void move() { inline __device__ void move() {
ptr_ += LOOP_STRIDE_BYTES; ptr_ += LOOP_STRIDE_BYTES;
} }
inline __device__ void move(const int steps) {
ptr_ += LOOP_STRIDE_BYTES * steps;
}
// The pointer in global memory. // The pointer in global memory.
char *ptr_; char *ptr_;
...@@ -318,9 +327,9 @@ struct Gmem_tile_mma_s : public Base { ...@@ -318,9 +327,9 @@ struct Gmem_tile_mma_s : public Base {
using Type = typename Base::Type; using Type = typename Base::Type;
// Ctor. // Ctor.
template< typename Params > template< typename Params, typename Block_info >
inline __device__ Gmem_tile_mma_s(void *ptr, const Params &params, const int tidx) inline __device__ Gmem_tile_mma_s(const Params &params, const Block_info& binfo, const int tidx)
: Base(ptr, params, tidx) { : Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) {
} }
// Store to global memory. // Store to global memory.
...@@ -353,6 +362,25 @@ struct Gmem_tile_mma_s : public Base { ...@@ -353,6 +362,25 @@ struct Gmem_tile_mma_s : public Base {
} }
} }
// Store to global memory.
template<typename Mask, typename Fragment>
inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
uint4 dst;
dst.x = frag[ni][mi].reg(0);
dst.y = frag[ni][mi].reg(2);
dst.z = frag[ni][mi].reg(1);
dst.w = frag[ni][mi].reg(3);
if( mask.any_valid(mi, ni) ) {
Base::store(dst, mi, ni);
}
}
}
}
// Load from global memory. // Load from global memory.
template<typename Mask> template<typename Mask>
inline __device__ void load(uint4 (&regs)[M][N], const Mask &mask) { inline __device__ void load(uint4 (&regs)[M][N], const Mask &mask) {
...@@ -361,7 +389,7 @@ struct Gmem_tile_mma_s : public Base { ...@@ -361,7 +389,7 @@ struct Gmem_tile_mma_s : public Base {
#pragma unroll #pragma unroll
for( int ni = 0; ni < N; ni++ ) { for( int ni = 0; ni < N; ni++ ) {
regs[mi][ni] = make_uint4(0, 0, 0, 0); regs[mi][ni] = make_uint4(0, 0, 0, 0);
if( mask.is_valid(mi, ni, 0, 0) ) { if( mask.any_valid(mi, ni) ) {
Base::load(regs[mi][ni], mi, ni); Base::load(regs[mi][ni], mi, ni);
} }
} }
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<int S, int D, int STEP, int WARPS_M, int WARPS_N, uint32_t FLAGS = 0x8u> template<int S, int D, int STEP, int WARPS_M, int WARPS_N, uint32_t FLAGS = 0x08u>
struct FMHA_kernel_traits { struct FMHA_kernel_traits {
// The CTA description for the 1st GEMM. // The CTA description for the 1st GEMM.
...@@ -38,7 +38,9 @@ struct FMHA_kernel_traits { ...@@ -38,7 +38,9 @@ struct FMHA_kernel_traits {
using Cta_tile_o = fmha::Cta_tile_extd<STEP, D, S, WARPS_M, 1, WARPS_N>; using Cta_tile_o = fmha::Cta_tile_extd<STEP, D, S, WARPS_M, 1, WARPS_N>;
// Do we use one buffer for K and V. // Do we use one buffer for K and V.
enum { SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x8u) != 0u }; enum { SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x08u) != 0u };
// Do we keep K in registers.
enum { K_IN_REGS = (FLAGS & 0x10u) == 0u };
// The global memory tile to load Q. // The global memory tile to load Q.
using Gmem_tile_q = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_A, STEP, D>; using Gmem_tile_q = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_A, STEP, D>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment