Unverified Commit 96850dfa authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Merge pull request #80 from ROCmSoftwarePlatform/IFU-master-2022-07-29

IFU-master-2022-07-29
parents 87fc4125 cc5f83b5
---
name: Bug report
about: Create a report to help us improve apex
title: ''
labels: bug
assignees: ''
---
**Describe the Bug**
**Minimal Steps/Code to Reproduce the Bug**
<!--
Please list the *minimal* steps or provide a code snippet for us to be able to reproduce the bug.
A helpful guide on on how to craft a minimal bug report http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports.
-->
**Expected Behavior**
<!-- A clear and concise description of what you expected to happen. -->
**Environment**
<!-- OS, version of Python, CUDA, PyTorch; collect these via `python -m torch.utils.collect_env` -->
# Introduction # Introduction
This repository holds NVIDIA-maintained utilities to streamline This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch.
mixed precision and distributed training in Pytorch.
Some of the code here will be included in upstream Pytorch eventually. Some of the code here will be included in upstream Pytorch eventually.
The intention of Apex is to make up-to-date utilities available to The intent of Apex is to make up-to-date utilities available to users as quickly as possible.
users as quickly as possible.
## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex) ## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex)
...@@ -98,30 +96,22 @@ amp.load_state_dict(checkpoint['amp']) ...@@ -98,30 +96,22 @@ amp.load_state_dict(checkpoint['amp'])
Note that we recommend restoring the model using the same `opt_level`. Also note that we recommend calling the `load_state_dict` methods after `amp.initialize`. Note that we recommend restoring the model using the same `opt_level`. Also note that we recommend calling the `load_state_dict` methods after `amp.initialize`.
# Requirements # Installation
Python 3 ## Containers
NVIDIA PyTorch Containers are available on NGC: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch.
The containers come with all the custom extensions available at the moment.
CUDA 9 or newer See [the NGC documentation](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html) for details such as:
- how to pull a container
- how to run a pulled container
- release notes
PyTorch 0.4 or newer. The CUDA and C++ extensions require pytorch 1.0 or newer. ## From Source
We recommend the latest stable release, obtainable from To install Apex from source, we recommend using the nightly Pytorch obtainable from https://github.com/pytorch/pytorch.
[https://pytorch.org/](https://pytorch.org/). We also test against the latest master branch, obtainable from [https://github.com/pytorch/pytorch](https://github.com/pytorch/pytorch).
It's often convenient to use Apex in Docker containers. Compatible options include: The latest stable release obtainable from https://pytorch.org should also work.
* [NVIDIA Pytorch containers from NGC](https://ngc.nvidia.com/catalog/containers/nvidia%2Fpytorch), which come with Apex preinstalled. To use the latest Amp API, you may need to `pip uninstall apex` then reinstall Apex using the **Quick Start** commands below.
* [official Pytorch -devel Dockerfiles](https://hub.docker.com/r/pytorch/pytorch/tags), e.g. `docker pull pytorch/pytorch:nightly-devel-cuda10.0-cudnn7`, in which you can install Apex using the **Quick Start** commands.
See the [Docker example folder](https://github.com/NVIDIA/apex/tree/master/examples/docker) for details.
## On ROCm:
* Python 3.6
* Pytorch 1.5 or newer, The HIPExtensions require 1.5 or newer.
* We recommend follow the instructions from [ROCm-Pytorch](https://github.com/ROCmSoftwarePlatform/pytorch) to install pytorch on ROCm.
* Note: For pytorch versions < 1.8, building from source is no longer supported, please use the release package [ROCm-Apex v0.3](https://github.com/ROCmSoftwarePlatform/apex/releases/tag/v0.3) .
# Quick Start
### Rocm ### Rocm
Apex on ROCm supports both python only build and extension build. Apex on ROCm supports both python only build and extension build.
...@@ -145,29 +135,27 @@ pip install -v --install-option="--cpp_ext" --install-option="--cuda_ext" 'git+h ...@@ -145,29 +135,27 @@ pip install -v --install-option="--cpp_ext" --install-option="--cuda_ext" 'git+h
``` ```
### Linux ### Linux
For performance and full functionality, we recommend installing Apex with For performance and full functionality, we recommend installing Apex with
CUDA and C++ extensions via CUDA and C++ extensions via
``` ```bash
git clone https://github.com/NVIDIA/apex git clone https://github.com/NVIDIA/apex
cd apex cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
``` ```
Apex also supports a Python-only build (required with Pytorch 0.4) via Apex also supports a Python-only build via
``` ```bash
pip install -v --disable-pip-version-check --no-cache-dir ./ pip install -v --disable-pip-version-check --no-cache-dir ./
``` ```
A Python-only build omits: A Python-only build omits:
- Fused kernels required to use `apex.optimizers.FusedAdam`. - Fused kernels required to use `apex.optimizers.FusedAdam`.
- Fused kernels required to use `apex.normalization.FusedLayerNorm`. - Fused kernels required to use `apex.normalization.FusedLayerNorm` and `apex.normalization.FusedRMSNorm`.
- Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`. - Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`.
- Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`. - Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`.
`DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower. `DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower.
Pyprof support has been moved to its own [dedicated repository](https://github.com/NVIDIA/PyProf).
The codebase is deprecated in Apex and will be removed soon.
### Windows support ### [Experimental] Windows
Windows support is experimental, and Linux is recommended. `pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .` may work if you were able to build Pytorch from source `pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .` may work if you were able to build Pytorch from source
on your system. `pip install -v --no-cache-dir .` (without CUDA/C++ extensions) is more likely to work. If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment. on your system. A Python-only build via `pip install -v --no-cache-dir .` is more likely to work.
If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.
import logging import logging
import warnings
# May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten # May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten
import torch import torch
...@@ -19,14 +20,10 @@ from . import fp16_utils ...@@ -19,14 +20,10 @@ from . import fp16_utils
# load time) the error message is timely and visible. # load time) the error message is timely and visible.
from . import optimizers from . import optimizers
from . import normalization from . import normalization
from . import pyprof
#common utilties to run tests on ROCm.
from . import testing
from . import transformer from . import transformer
# Logging utilities mainly for apex.transformer module # Logging utilities for apex.transformer module
class RankInfoFormatter(logging.Formatter): class RankInfoFormatter(logging.Formatter):
def format(self, record): def format(self, record):
...@@ -37,6 +34,18 @@ class RankInfoFormatter(logging.Formatter): ...@@ -37,6 +34,18 @@ class RankInfoFormatter(logging.Formatter):
_library_root_logger = logging.getLogger(__name__) _library_root_logger = logging.getLogger(__name__)
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setFormatter(RankInfoFormatter("%(asctime)s - %(name)s - %(levelname)s - %(rank_info)s - %(message)s")) handler.setFormatter(RankInfoFormatter("%(asctime)s - PID:%(process)d - rank:%(rank_info)s - %(filename)s:%(lineno)d - %(levelname)s - %(message)s", "%y-%m-%d %H:%M:%S"))
_library_root_logger.addHandler(handler) _library_root_logger.addHandler(handler)
_library_root_logger.propagate = False _library_root_logger.propagate = False
def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool:
cudnn_available = torch.backends.cudnn.is_available()
cudnn_version = torch.backends.cudnn.version() if cudnn_available else None
if not (cudnn_available and (cudnn_version >= required_cudnn_version)):
warnings.warn(
f"`{global_option}` depends on cuDNN {required_cudnn_version} or later, "
f"but {'cuDNN is not available' if not cudnn_available else cudnn_version}"
)
return False
return True
from typing import Optional from typing import Optional, Sequence
import torch import torch
def _get_autocast_dtypes() -> Sequence[torch.dtype]:
if torch.cuda.is_bf16_supported():
return [torch.half, torch.bfloat16]
return [torch.half]
def _get_current_dtype(dtype: Optional[torch.dtype] = None) -> torch.dtype: def _get_current_dtype(dtype: Optional[torch.dtype] = None) -> torch.dtype:
if not torch.is_autocast_enabled(): if not torch.is_autocast_enabled():
return torch.float or dtype return torch.float or dtype
......
from .bottleneck import Bottleneck, SpatialBottleneck from .bottleneck import Bottleneck, SpatialBottleneck
from .halo_exchangers import HaloExchangerNoComm, HaloExchangerAllGather, HaloExchangerSendRecv, HaloExchangerPeer
This diff is collapsed.
import torch
import torch.distributed as dist
from torch import nn
import nccl_p2p_cuda as inc
import peer_memory_cuda as pm
# Communication free halo exchanger.
# NB! This halo exchanger does not exchange halos with neighbors as it should, it merely swaps the inputs
# NB! This is only useful for performance testing.
# NB! Do not use for actual production runs
class HaloExchanger(object):
def __init__(self, ranks, rank_in_group):
self.stream1 = torch.cuda.Stream()
self.stream2 = torch.cuda.Stream()
self.stream3 = torch.cuda.Stream()
self.group_size = len(ranks)
self.ranks = ranks
self.rank_in_group = rank_in_group
self.wrap_around_left_rank_in_group = (rank_in_group + self.group_size - 1) % self.group_size
self.wrap_around_right_rank_in_group = (rank_in_group + 1) % self.group_size
self.left_rank = ranks[rank_in_group-1] if rank_in_group > 0 else -1
self.left_zero = True if rank_in_group == 0 else False
self.right_rank = ranks[rank_in_group+1] if rank_in_group < self.group_size - 1 else -1
self.right_zero = True if rank_in_group == self.group_size - 1 else False
class HaloExchangerNoComm(HaloExchanger):
def __init__(self, ranks, rank_in_group):
super(HaloExchangerNoComm, self).__init__(ranks, rank_in_group)
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
if left_input_halo is None:
return right_output_halo, left_output_halo
else:
left_input_halo.copy_(right_output_halo)
right_input_halo.copy_(left_output_halo)
class HaloExchangerAllGather(HaloExchanger):
def __init__(self, ranks, rank_in_group, comm):
super(HaloExchangerAllGather, self).__init__(ranks, rank_in_group)
# self.comm must be NCCL process_group created with torch.distributed.new_group(ranks=ranks)
self.comm = comm
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
N,Hh,W,C = list(left_output_halo.shape)
send_halos = torch.empty((N,2*Hh,W,C),dtype=left_output_halo.dtype,device=left_output_halo.device)
send_halos[:,:Hh,:,:].copy_(left_output_halo)
send_halos[:,Hh:,:,:].copy_(right_output_halo)
all_halos = torch.empty((N,2*Hh*self.group_size,W,C),dtype=left_output_halo.dtype,device=left_output_halo.device)
all_halos = [all_halos[:,i*2*Hh:(i+1)*2*Hh,:,:] for i in range(self.group_size)]
torch.distributed.all_gather(all_halos,send_halos,group=self.comm,no_copy=True)
ag_left_input_halo = all_halos[self.wrap_around_left_rank_in_group][:,Hh:,:,:]
ag_right_input_halo = all_halos[self.wrap_around_right_rank_in_group][:,:Hh,:,:]
if left_input_halo is None:
if self.left_zero:
ag_left_input_halo.zero_()
if self.right_zero:
ag_right_input_halo.zero_()
return ag_left_input_halo, ag_right_input_halo
else:
if self.left_zero:
left_input_halo.zero_()
else:
left_input_halo.copy_(ag_left_input_halo)
if self.right_zero:
right_input_halo.zero_()
else:
right_input_halo.copy_(ag_right_input_halo)
class HaloExchangerSendRecv(HaloExchanger):
def __init__(self, ranks, rank_in_group):
super(HaloExchangerSendRecv, self).__init__(ranks, rank_in_group)
nccl_id = inc.get_unique_nccl_id(1).cuda()
torch.distributed.broadcast(nccl_id, 0)
nccl_id = nccl_id.cpu()
print("%d :: nccl_id = %s" % (torch.distributed.get_rank(), str(nccl_id)))
# Create another global nccl communicator in addition to the one created by torch.distributed.init_process_group("nccl")
# This is unavoidable because the underlying NCCL communicator torch.distributed creates is a protected variable, hence
# it cannot be accessed from another class.
# TODO: Figure out a way to avoid creating a second global communicator
assert(torch.distributed.get_rank() == self.ranks[self.rank_in_group]), "ranks[%d](%d) != torch.distributed.get_rank()(%d)" % (self.rank_in_group, self.ranks[self.rank_in_group], torch.distributed.get_rank())
self.handle = inc.init_nccl_comm(nccl_id, torch.distributed.get_rank(), torch.distributed.get_world_size())
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
if left_input_halo is None:
left_input_halo, right_input_halo = inc.left_right_halo_exchange(self.handle, self.left_rank, self.right_rank , left_output_halo, right_output_halo)
return left_input_halo, right_input_halo
else:
inc.left_right_halo_exchange_inplace(self.handle, self.left_rank, self.right_rank, left_output_halo, right_output_halo, left_input_halo, right_input_halo)
class HaloExchangerPeer(HaloExchanger):
def __init__(self, ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=1):
super(HaloExchangerPeer, self).__init__(ranks, rank_in_group)
self.diagnostics = False
self.explicit_nhwc = explicit_nhwc
self.numSM = numSM
self.peer_pool = peer_pool
self.signals = peer_pool.allocate_peer_tensors([2,4], torch.int32, False, False)
self.signals[self.rank_in_group].zero_()
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
inplace = False if left_input_halo is None and right_input_halo is None else True
if not inplace:
left_input_halo = torch.empty_like(right_output_halo)
right_input_halo = torch.empty_like(left_output_halo)
channels_last = left_output_halo.is_contiguous(memory_format=torch.channels_last) and not self.explicit_nhwc
left_tx = self.peer_pool.allocate_peer_tensors(list(left_output_halo.shape), left_output_halo.dtype, channels_last, True)
right_tx = self.peer_pool.allocate_peer_tensors(list(right_output_halo.shape), right_output_halo.dtype, channels_last, True)
pm.push_pull_halos_1d(
self.diagnostics, self.explicit_nhwc, self.numSM,
left_output_halo, left_tx[self.rank_in_group], right_tx[self.wrap_around_left_rank_in_group], left_input_halo,
right_output_halo, right_tx[self.rank_in_group], left_tx[self.wrap_around_right_rank_in_group], right_input_halo,
self.signals[self.wrap_around_left_rank_in_group], self.signals[self.wrap_around_right_rank_in_group], self.signals[self.rank_in_group]
)
# TODO: Add to push_pull_halos_1d kernel
if self.left_zero:
left_input_halo.zero_()
if self.right_zero:
right_input_halo.zero_()
if not inplace:
return left_input_halo, right_input_halo
# Class that combines input volume with halos from neighbors (1d).
class HaloPadder:
def __init__(self, halo_ex):
self.halo_ex = halo_ex
self.stream1 = torch.cuda.Stream()
self.stream2 = torch.cuda.Stream()
def __call__(self, y, half_halo, explicit_nhwc, H_split):
channels_last = not explicit_nhwc and y.is_contiguous(memory_format=torch.channels_last)
if explicit_nhwc:
N,H,W,C = list(y.shape)
if H_split:
padded_shape = [N,H+2*half_halo,W,C]
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.contiguous_format)
yleft = ypad[:,:half_halo,:,:]
ymid = ypad[:,half_halo:H+half_halo,:,:]
yright = ypad[:,H+half_halo:H+2*half_halo,:,:]
oleft = y[:,:half_halo,:,:]
oright = y[:,H-half_halo:,:,:]
else:
padded_shape = [N,H,W+2*half_halo,C]
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.contiguous_format)
yleft = ypad[:,:,:half_halo,:]
ymid = ypad[:,:,half_halo:W+half_halo,:]
yright = ypad[:,:,W+half_halo:W+2*half_halo,:]
oleft = y[:,:,:half_halo,:]
oright = y[:,:,W-half_halo:,:]
else:
N,C,H,W = list(y.shape)
if H_split:
padded_shape = [N,C,H+2*half_halo,W]
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.channels_last)
yleft = ypad[:,:,:half_halo,:]
ymid = ypad[:,:,half_halo:H+half_halo,:]
yright = ypad[:,:,H+half_halo:H+2*half_halo,:]
oleft = y[:,:,:half_halo,:]
oright = y[:,:,H-half_halo:,:]
else:
padded_shape = [N,C,H,W+2*half_halo]
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.channels_last)
yleft = ypad[:,:,:,:half_halo]
ymid = ypad[:,:,:,half_halo:W+half_halo]
yright = ypad[:,:,:,W+half_halo:W+2*half_halo]
oleft = y[:,:,:,:half_halo]
oright = y[:,:,:,W-half_halo:]
with torch.cuda.stream(self.stream1):
self.halo_ex(oleft, oright, yleft, yright)
with torch.cuda.stream(self.stream2):
ymid.copy_(y)
return ypad
def wait(self):
current_stream = torch.cuda.current_stream()
current_stream.wait_stream(self.stream1)
current_stream.wait_stream(self.stream2)
from .clip_grad import clip_grad_norm_
import torch
from torch._six import inf
from typing import Union, Iterable
_kernel_import_succeeded = False
try:
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
_kernel_import_succeeded = True
except:
_kernel_import_succeeded = False
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
def clip_grad_norm_(
parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
error_if_nonfinite: bool = False) -> torch.Tensor:
r"""Clips gradient norm of an iterable of parameters.
The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
This is identical to torch.nn.utils.clip_grad_norm_, except it
uses a fused CUDA kernel when computing the 2-norm of GPU tensors
in float32 and float16.
Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
error_if_nonfinite (bool): if True, an error is thrown if the total
norm of the gradients from :attr:`parameters` is ``nan``,
``inf``, or ``-inf``. Default: False (will switch to True in the future)
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
max_norm = float(max_norm)
norm_type = float(norm_type)
# Trivial case
if len(parameters) == 0:
return torch.tensor(0.)
# Fallback implementation
if not (_kernel_import_succeeded
and norm_type == 2.0
and any(p.is_cuda for p in parameters)):
return torch.nn.utils.clip_grad_norm_(
parameters,
max_norm,
norm_type=norm_type,
error_if_nonfinite = error_if_nonfinite,
)
# Find fp32 and fp16 gradients on GPU
device = next(p.device for p in parameters if p.is_cuda)
grads_fp32, grads_fp16, grads_misc = [], [], []
for p in parameters:
grad = p.grad.detach()
if p.dtype == torch.float32 and p.device == device:
grads_fp32.append(grad)
elif p.dtype == torch.float16 and p.device == device:
grads_fp16.append(grad)
else:
grads_misc.append(grad)
# Compute gradient L2 norms
norms = []
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device=device)
if grads_fp32:
norms.append(
multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads_fp32],
False,
)[0]
)
if grads_fp16:
norms.append(
multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads_fp16],
False,
)[0],
)
for g in grads_misc:
norms.append(torch.linalg.norm(g).unsqueeze(0).to(device))
total_norm = torch.linalg.norm(torch.cat(norms))
# Check for non-finite values
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
raise RuntimeError(
f'The total norm of order {norm_type} for gradients from '
'`parameters` is non-finite, so it cannot be clipped. To disable '
'this error and scale the gradients by the non-finite norm anyway, '
'set `error_if_nonfinite=False`')
# Scale gradients
clip_coef = max_norm / (total_norm + 1e-6)
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
if grads_fp32:
multi_tensor_applier(
amp_C.multi_tensor_scale,
dummy_overflow_buf,
[grads_fp32, grads_fp32],
clip_coef_clamped,
)
if grads_fp16:
multi_tensor_applier(
amp_C.multi_tensor_scale,
dummy_overflow_buf,
[grads_fp16, grads_fp16],
clip_coef_clamped,
)
for g in grads_misc:
g.mul_(clip_coef_clamped.to(g.device))
return total_norm
from .conv_bias_relu import ConvBiasReLU, ConvBias, ConvBiasMaskReLU
import pdb
import torch
from torch.autograd import gradcheck
from apex import check_cudnn_version_and_warn
import fused_conv_bias_relu
check_cudnn_version_and_warn(__name__, 8400)
class ConvBiasReLU_(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.half)
def forward(ctx, x, weight, bias, padding, stride):
outputs = fused_conv_bias_relu.forward([x, weight, bias], padding, stride)
ctx.save_for_backward(x, weight, outputs[0])
ctx.padding = padding
ctx.stride = stride
return outputs[0]
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, grad_output):
bwd_args = [*ctx.saved_tensors, grad_output]
padding = ctx.padding
stride = ctx.stride
grads = fused_conv_bias_relu.backward(bwd_args, padding, stride)
return grads[0], grads[1], grads[2], None, None
class ConvBiasMaskReLU_(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.half)
def forward(ctx, x, weight, bias, mask, padding, stride):
outputs = fused_conv_bias_relu.forward_mask([x, weight, bias, mask], padding, stride)
ctx.save_for_backward(x, weight, outputs[0])
ctx.padding = padding
ctx.stride = stride
return outputs[0]
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, grad_output):
bwd_args = [*ctx.saved_tensors, grad_output]
padding = ctx.padding
stride = ctx.stride
grads = fused_conv_bias_relu.backward(bwd_args, padding, stride)
return grads[0], grads[1], grads[2], None, None, None
class ConvBias_(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.half)
def forward(ctx, x, weight, bias, padding, stride):
outputs = fused_conv_bias_relu.forward_no_relu([x, weight, bias], padding, stride)
ctx.save_for_backward(x, weight)
ctx.padding = padding
ctx.stride = stride
return outputs[0]
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, grad_output):
bwd_args = [*ctx.saved_tensors, grad_output]
padding = ctx.padding
stride = ctx.stride
grads = fused_conv_bias_relu.backward_no_relu(bwd_args, padding, stride)
return grads[0], grads[1], grads[2], None, None
ConvBiasReLU = ConvBiasReLU_.apply
ConvBiasMaskReLU = ConvBiasMaskReLU_.apply
ConvBias = ConvBias_.apply
This diff is collapsed.
This diff is collapsed.
Subproject commit b4e1ad9613b89199982c9baf6ee91f6f98f5606d Subproject commit fa611998a360cbabaa2dcc7c9859748144114fc0
...@@ -72,7 +72,7 @@ void set_params(Fused_multihead_attention_fprop_params &params, ...@@ -72,7 +72,7 @@ void set_params(Fused_multihead_attention_fprop_params &params,
constexpr float scale_softmax = 1.f; constexpr float scale_softmax = 1.f;
constexpr float scale_bmm2 = 1.f; constexpr float scale_bmm2 = 1.f;
set_alpha(params.scale_bmm1, scale_bmm1, acc_type); set_alpha(params.scale_bmm1, scale_bmm1, data_type);
set_alpha(params.scale_softmax, scale_softmax, acc_type); set_alpha(params.scale_softmax, scale_softmax, acc_type);
set_alpha(params.scale_bmm2, scale_bmm2, data_type); set_alpha(params.scale_bmm2, scale_bmm2, data_type);
...@@ -89,9 +89,15 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \ ...@@ -89,9 +89,15 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
const float p_dropout, const float p_dropout,
const int max_seq_len, const int max_seq_len,
const bool is_training, const bool is_training,
const bool is_nl,
const bool zero_tensors,
c10::optional<at::Generator> gen_) { c10::optional<at::Generator> gen_) {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major == 8 && dprops->minor == 0); TORCH_CHECK(dprops->major == 8 && dprops->minor == 0);
auto stream = at::cuda::getCurrentCUDAStream().stream();
Launch_params<Fused_multihead_attention_fprop_params> launch_params(dprops, stream, is_training, is_nl);
int seq_len = 512; int seq_len = 512;
auto launch = &run_fmha_fp16_512_64_sm80; auto launch = &run_fmha_fp16_512_64_sm80;
if( max_seq_len <= 128 ) { if( max_seq_len <= 128 ) {
...@@ -110,18 +116,6 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \ ...@@ -110,18 +116,6 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
TORCH_CHECK(false); TORCH_CHECK(false);
} }
constexpr int warps_m = 1;
constexpr int warps_n = 4; // this leads to an upper bound
const int mmas_m = seq_len / 16 / warps_m;
const int mmas_n = seq_len / 16 / warps_n;
const int elts_per_thread = 8 * mmas_m * mmas_n;
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(qkv.dtype() == torch::kFloat16);
TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);
TORCH_CHECK(qkv.is_cuda()) TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda()) TORCH_CHECK(cu_seqlens.is_cuda())
...@@ -147,12 +141,16 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \ ...@@ -147,12 +141,16 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts); auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);
if( zero_tensors ) {
ctx.zero_();
s.zero_();
}
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator()); gen_, at::cuda::detail::getDefaultCUDAGenerator());
Fused_multihead_attention_fprop_params params;
set_params(params, set_params(launch_params.params,
batch_size, batch_size,
seq_len, seq_len,
num_heads, num_heads,
...@@ -163,29 +161,32 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \ ...@@ -163,29 +161,32 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
s.data_ptr(), s.data_ptr(),
p_dropout); p_dropout);
// number of times random will be generated per thread, to offset philox counter in the random launch(launch_params, /*configure=*/ true);
// number of times random will be generated per thread, to offset philox counter in thc random
// state // state
int64_t counter_offset = elts_per_thread; int64_t counter_offset = launch_params.elts_per_thread;
at::PhiloxCudaState rng_engine_inputs; at::PhiloxCudaState rng_engine_inputs;
if( is_training ) { if( is_training ) {
// See Note [Acquire lock when using random generators] // See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_); std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset); launch_params.params.philox_args = gen->philox_cuda_state(counter_offset);
} }
launch(params, is_training, stream); launch(launch_params, /*configure=*/ false);
return { ctx, s }; return { ctx, s };
} }
std::vector<at::Tensor> std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
const at::Tensor &cu_seqlens, // b+1 const at::Tensor &cu_seqlens, // b+1
const float p_dropout, // probability to drop const float p_dropout, // probability to drop
const int max_seq_len // max sequence length to choose the kernel const int max_seq_len, // max sequence length to choose the kernel
const bool zero_tensors
) { ) {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major == 8 && dprops->minor == 0); TORCH_CHECK(dprops->major == 8 && dprops->minor == 0);
...@@ -235,6 +236,10 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size ...@@ -235,6 +236,10 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
auto dqkv = torch::empty_like(qkv); auto dqkv = torch::empty_like(qkv);
if( zero_tensors ) {
dqkv.zero_();
}
Fused_multihead_attention_fprop_params params; Fused_multihead_attention_fprop_params params;
set_params(params, set_params(params,
...@@ -259,92 +264,13 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size ...@@ -259,92 +264,13 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
return { dqkv, softmax }; return { dqkv, softmax };
} }
std::vector<at::Tensor> mha_fwd_nl(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens, // b+1
const float p_dropout,
const int max_seq_len,
const bool is_training,
c10::optional<at::Generator> gen_) {
int seq_len = 512;
auto launch = &run_fmha_fp16_512_64_sm80_nl;
TORCH_CHECK(max_seq_len == seq_len);
constexpr int warps_m = 1;
constexpr int warps_n = 4; // this leads to an upper bound
const int mmas_m = seq_len / 16 / warps_m;
const int mmas_n = seq_len / 16 / warps_n;
// static_assert( mmas_m == 32 );
// static_assert( mmas_n == 4 );
const int elts_per_thread = 8 * mmas_m * mmas_n;
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda())
TORCH_CHECK(qkv.is_contiguous())
TORCH_CHECK(cu_seqlens.is_contiguous())
TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);
const auto sizes = qkv.sizes();
TORCH_CHECK(sizes[THREE_DIM] == 3);
const int batch_size = cu_seqlens.numel() - 1;
const int total = sizes[TOTAL_DIM];
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 64);
auto opts = qkv.options();
auto ctx = torch::empty({ total, num_heads, head_size }, opts);
auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());
Fused_multihead_attention_fprop_params params;
set_params(params,
batch_size,
seq_len,
num_heads,
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
ctx.data_ptr(),
s.data_ptr(),
p_dropout);
// number of times random will be generated per thread, to offset philox counter in the random
// state
int64_t counter_offset = elts_per_thread;
at::PhiloxCudaState rng_engine_inputs;
if( is_training ) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
}
int num_chunks = 3;
if(batch_size == 3) {
num_chunks = 2;
}
launch(params, is_training, num_chunks, stream);
return { ctx, s };
}
std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num_heads, x head_size std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num_heads, x head_size
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
const at::Tensor &cu_seqlens, // b+1 const at::Tensor &cu_seqlens, // b+1
const float p_dropout, // probability to drop const float p_dropout, // probability to drop
const int max_seq_len // max sequence length to choose the kernel const int max_seq_len, // max sequence length to choose the kernel
const bool zero_tensors
) { ) {
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
...@@ -378,6 +304,10 @@ std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num ...@@ -378,6 +304,10 @@ std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num
auto dqkv = torch::empty_like(qkv); auto dqkv = torch::empty_like(qkv);
if( zero_tensors ) {
dqkv.zero_();
}
int num_chunks = 2; int num_chunks = 2;
if( batch_size == 1 ) { if( batch_size == 1 ) {
num_chunks = 4; num_chunks = 4;
...@@ -427,6 +357,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -427,6 +357,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "Fused Multi-head Self-attention for BERT"; m.doc() = "Fused Multi-head Self-attention for BERT";
m.def("fwd", &mha_fwd, "Forward pass"); m.def("fwd", &mha_fwd, "Forward pass");
m.def("bwd", &mha_bwd, "Backward pass"); m.def("bwd", &mha_bwd, "Backward pass");
m.def("fwd_nl", &mha_fwd_nl, "Forward pass (small-batch)");
m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)"); m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)");
} }
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
#include <cuda.h> #include <cuda.h>
#include <vector> #include <vector>
#if !defined(NEW_GENERATOR_PATH) #ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h> #include <ATen/CUDAGeneratorImpl.h>
#else #else
#include <ATen/cuda/CUDAGeneratorImpl.h> #include <ATen/cuda/CUDAGeneratorImpl.h>
...@@ -50,7 +50,7 @@ constexpr int D_DIM = 3; ...@@ -50,7 +50,7 @@ constexpr int D_DIM = 3;
struct Qkv_params { struct Qkv_params {
// The QKV matrices. // The QKV matrices.
void *qkv_ptr; void * __restrict__ qkv_ptr;
// The stride between rows of the Q, K and V matrices. // The stride between rows of the Q, K and V matrices.
size_t qkv_stride_in_bytes; size_t qkv_stride_in_bytes;
...@@ -64,19 +64,19 @@ struct Qkv_params { ...@@ -64,19 +64,19 @@ struct Qkv_params {
struct Fused_multihead_attention_fprop_params : public Qkv_params { struct Fused_multihead_attention_fprop_params : public Qkv_params {
// The dQKV matrices. // The dQKV matrices.
void *dqkv_ptr; void * __restrict__ dqkv_ptr;
// Temporary for dKV. // Temporary for dKV.
void *dkv_ptr; void * __restrict__ dkv_ptr;
// The O matrix (output). // The O matrix (output).
void *o_ptr; void * __restrict__ o_ptr;
// The stride between rows of O. // The stride between rows of O.
int64_t o_stride_in_bytes; int64_t o_stride_in_bytes;
// The pointer to the S matrix, overwritten by the dP matrix (bwd). // The pointer to the S matrix, overwritten by the dP matrix (bwd).
void *s_ptr; void * __restrict__ s_ptr;
// The stride between rows of the S matrix. // The stride between rows of the S matrix.
int64_t s_stride_in_bytes; int64_t s_stride_in_bytes;
...@@ -87,7 +87,7 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params { ...@@ -87,7 +87,7 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
uint32_t scale_bmm1, scale_softmax, scale_bmm2; uint32_t scale_bmm1, scale_softmax, scale_bmm2;
// array of length b+1 holding starting offset of each sequence. // array of length b+1 holding starting offset of each sequence.
int *cu_seqlens; int * __restrict__ cu_seqlens;
// The dropout probability (probability of keeping an activation). // The dropout probability (probability of keeping an activation).
float p_dropout; float p_dropout;
...@@ -104,10 +104,43 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params { ...@@ -104,10 +104,43 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
void run_fmha_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream); template<typename Kernel_params>
void run_fmha_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream); struct Launch_params{
void run_fmha_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream); Launch_params(cudaDeviceProp * props_,
void run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream); cudaStream_t stream_,
bool is_training_,
bool is_nl_)
: elts_per_thread(0)
, props(props_)
, stream(stream_)
, is_training(is_training_)
, is_nl(is_nl_) {
}
size_t elts_per_thread;
cudaDeviceProp * props;
cudaStream_t stream;
bool is_training;
Kernel_params params;
int num_full_heads;
int num_main_groups;
int heads_last_wave;
int main_steps;
int rest_steps;
bool is_nl;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_fmha_fp16_128_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
void run_fmha_fp16_256_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
void run_fmha_fp16_384_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
void run_fmha_fp16_512_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream); void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream); void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
......
...@@ -210,9 +210,6 @@ struct Clear_accumulator<float, WARPS_K> { ...@@ -210,9 +210,6 @@ struct Clear_accumulator<float, WARPS_K> {
} }
}; };
////////////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Acc, typename A, typename B, int M, int N> template<typename Acc, typename A, typename B, int M, int N>
......
...@@ -60,7 +60,7 @@ struct Gmem_tile_qkv { ...@@ -60,7 +60,7 @@ struct Gmem_tile_qkv {
// Ctor. // Ctor.
template< typename Params, typename BInfo > template< typename Params, typename BInfo >
inline __device__ Gmem_tile_qkv(const Params &params, int qkv_offset, const BInfo &binfo, int tidx) inline __device__ Gmem_tile_qkv(const Params &params, const int qkv_offset, const BInfo &binfo, const int tidx)
: params_qkv_stride_in_bytes_(params.qkv_stride_in_bytes) : params_qkv_stride_in_bytes_(params.qkv_stride_in_bytes)
, actual_seqlen(binfo.actual_seqlen) , actual_seqlen(binfo.actual_seqlen)
, qkv_ptr_(reinterpret_cast<char *>(params.qkv_ptr)) { , qkv_ptr_(reinterpret_cast<char *>(params.qkv_ptr)) {
...@@ -125,6 +125,11 @@ struct Gmem_tile_qkv { ...@@ -125,6 +125,11 @@ struct Gmem_tile_qkv {
actual_seqlen -= ROWS; actual_seqlen -= ROWS;
} }
inline __device__ void move(int steps) {
qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_ * steps;
actual_seqlen -= ROWS * steps;
}
// The stride between rows for the QKV matrice. // The stride between rows for the QKV matrice.
int64_t params_qkv_stride_in_bytes_; int64_t params_qkv_stride_in_bytes_;
// The pointer. // The pointer.
...@@ -224,6 +229,11 @@ struct Gmem_tile_o { ...@@ -224,6 +229,11 @@ struct Gmem_tile_o {
o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_; o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_;
} }
inline __device__ void move(const int steps) {
row_ += ROWS * steps;
o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_ * steps;
}
// The stride between rows for the QKV matrice. // The stride between rows for the QKV matrice.
int64_t params_o_stride_in_bytes_; int64_t params_o_stride_in_bytes_;
// The pointer. // The pointer.
...@@ -270,13 +280,9 @@ struct Gmem_tile_mma_sd { ...@@ -270,13 +280,9 @@ struct Gmem_tile_mma_sd {
// Ctor. // Ctor.
template<typename Params> template<typename Params>
inline __device__ Gmem_tile_mma_sd(void *ptr, const Params &params, const int tidx) inline __device__ Gmem_tile_mma_sd(void *ptr, const Params &params, const int bidb, const int bidh, const int tidx)
: ptr_(static_cast<char *>(ptr)) { : ptr_(static_cast<char *>(ptr)) {
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The block index. // The block index.
size_t bidx = bidb * params.h + bidh; size_t bidx = bidb * params.h + bidh;
...@@ -300,6 +306,9 @@ struct Gmem_tile_mma_sd { ...@@ -300,6 +306,9 @@ struct Gmem_tile_mma_sd {
inline __device__ void move() { inline __device__ void move() {
ptr_ += LOOP_STRIDE_BYTES; ptr_ += LOOP_STRIDE_BYTES;
} }
inline __device__ void move(const int steps) {
ptr_ += LOOP_STRIDE_BYTES * steps;
}
// The pointer in global memory. // The pointer in global memory.
char *ptr_; char *ptr_;
...@@ -318,9 +327,9 @@ struct Gmem_tile_mma_s : public Base { ...@@ -318,9 +327,9 @@ struct Gmem_tile_mma_s : public Base {
using Type = typename Base::Type; using Type = typename Base::Type;
// Ctor. // Ctor.
template< typename Params > template< typename Params, typename Block_info >
inline __device__ Gmem_tile_mma_s(void *ptr, const Params &params, const int tidx) inline __device__ Gmem_tile_mma_s(const Params &params, const Block_info& binfo, const int tidx)
: Base(ptr, params, tidx) { : Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) {
} }
// Store to global memory. // Store to global memory.
...@@ -353,6 +362,25 @@ struct Gmem_tile_mma_s : public Base { ...@@ -353,6 +362,25 @@ struct Gmem_tile_mma_s : public Base {
} }
} }
// Store to global memory.
template<typename Mask, typename Fragment>
inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
uint4 dst;
dst.x = frag[ni][mi].reg(0);
dst.y = frag[ni][mi].reg(2);
dst.z = frag[ni][mi].reg(1);
dst.w = frag[ni][mi].reg(3);
if( mask.any_valid(mi, ni) ) {
Base::store(dst, mi, ni);
}
}
}
}
// Load from global memory. // Load from global memory.
template<typename Mask> template<typename Mask>
inline __device__ void load(uint4 (&regs)[M][N], const Mask &mask) { inline __device__ void load(uint4 (&regs)[M][N], const Mask &mask) {
...@@ -361,7 +389,7 @@ struct Gmem_tile_mma_s : public Base { ...@@ -361,7 +389,7 @@ struct Gmem_tile_mma_s : public Base {
#pragma unroll #pragma unroll
for( int ni = 0; ni < N; ni++ ) { for( int ni = 0; ni < N; ni++ ) {
regs[mi][ni] = make_uint4(0, 0, 0, 0); regs[mi][ni] = make_uint4(0, 0, 0, 0);
if( mask.is_valid(mi, ni, 0, 0) ) { if( mask.any_valid(mi, ni) ) {
Base::load(regs[mi][ni], mi, ni); Base::load(regs[mi][ni], mi, ni);
} }
} }
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<int S, int D, int STEP, int WARPS_M, int WARPS_N, uint32_t FLAGS = 0x8u> template<int S, int D, int STEP, int WARPS_M, int WARPS_N, uint32_t FLAGS = 0x08u>
struct FMHA_kernel_traits { struct FMHA_kernel_traits {
// The CTA description for the 1st GEMM. // The CTA description for the 1st GEMM.
...@@ -38,7 +38,9 @@ struct FMHA_kernel_traits { ...@@ -38,7 +38,9 @@ struct FMHA_kernel_traits {
using Cta_tile_o = fmha::Cta_tile_extd<STEP, D, S, WARPS_M, 1, WARPS_N>; using Cta_tile_o = fmha::Cta_tile_extd<STEP, D, S, WARPS_M, 1, WARPS_N>;
// Do we use one buffer for K and V. // Do we use one buffer for K and V.
enum { SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x8u) != 0u }; enum { SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x08u) != 0u };
// Do we keep K in registers.
enum { K_IN_REGS = (FLAGS & 0x10u) == 0u };
// The global memory tile to load Q. // The global memory tile to load Q.
using Gmem_tile_q = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_A, STEP, D>; using Gmem_tile_q = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_A, STEP, D>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment