Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
f79993d9
Commit
f79993d9
authored
Oct 15, 2021
by
hubertlu-tw
Browse files
Merge remote-tracking branch 'upstream/master' into IFU-master-2021-10-15
parents
297ab210
1d5f7e55
Changes
117
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
7748 additions
and
0 deletions
+7748
-0
.gitignore
.gitignore
+141
-0
.gitmodules
.gitmodules
+3
-0
apex/__init__.py
apex/__init__.py
+1
-0
apex/_autocast_utils.py
apex/_autocast_utils.py
+8
-0
apex/contrib/bottleneck/__init__.py
apex/contrib/bottleneck/__init__.py
+1
-0
apex/contrib/bottleneck/bottleneck.py
apex/contrib/bottleneck/bottleneck.py
+512
-0
apex/contrib/bottleneck/bottleneck_module_test.py
apex/contrib/bottleneck/bottleneck_module_test.py
+272
-0
apex/contrib/bottleneck/test.py
apex/contrib/bottleneck/test.py
+71
-0
apex/contrib/csrc/bottleneck/bottleneck.cpp
apex/contrib/csrc/bottleneck/bottleneck.cpp
+2486
-0
apex/contrib/csrc/cudnn-frontend
apex/contrib/csrc/cudnn-frontend
+1
-0
apex/contrib/csrc/fmha/fmha_api.cpp
apex/contrib/csrc/fmha/fmha_api.cpp
+432
-0
apex/contrib/csrc/fmha/src/fmha.h
apex/contrib/csrc/fmha/src/fmha.h
+125
-0
apex/contrib/csrc/fmha/src/fmha/gemm.h
apex/contrib/csrc/fmha/src/fmha/gemm.h
+317
-0
apex/contrib/csrc/fmha/src/fmha/gmem_tile.h
apex/contrib/csrc/fmha/src/fmha/gmem_tile.h
+428
-0
apex/contrib/csrc/fmha/src/fmha/kernel_traits.h
apex/contrib/csrc/fmha/src/fmha/kernel_traits.h
+95
-0
apex/contrib/csrc/fmha/src/fmha/mask.h
apex/contrib/csrc/fmha/src/fmha/mask.h
+76
-0
apex/contrib/csrc/fmha/src/fmha/smem_tile.h
apex/contrib/csrc/fmha/src/fmha/smem_tile.h
+1288
-0
apex/contrib/csrc/fmha/src/fmha/softmax.h
apex/contrib/csrc/fmha/src/fmha/softmax.h
+478
-0
apex/contrib/csrc/fmha/src/fmha/utils.h
apex/contrib/csrc/fmha/src/fmha/utils.h
+953
-0
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu
+60
-0
No files found.
.gitignore
View file @
f79993d9
...
...
@@ -4,3 +4,144 @@ build
docs/build
*~
__pycache__
.vscode
# Copied from https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
.gitmodules
View file @
f79993d9
...
...
@@ -2,3 +2,6 @@
path = apex/contrib/csrc/multihead_attn/cutlass
url = https://github.com/NVIDIA/cutlass.git
branch = v1.2.0
[submodule "apex/contrib/csrc/cudnn-frontend"]
path = apex/contrib/csrc/cudnn-frontend
url = https://github.com/NVIDIA/cudnn-frontend.git
apex/__init__.py
View file @
f79993d9
...
...
@@ -21,3 +21,4 @@ from . import pyprof
#common utilties to run tests on ROCm.
from
.
import
testing
from
.
import
transformer
apex/_autocast_utils.py
0 → 100644
View file @
f79993d9
import
torch
def
_cast_if_autocast_enabled
(
*
args
):
if
not
torch
.
is_autocast_enabled
():
return
args
else
:
return
torch
.
cuda
.
amp
.
autocast_mode
.
_cast
(
args
,
torch
.
get_autocast_gpu_dtype
())
apex/contrib/bottleneck/__init__.py
0 → 100644
View file @
f79993d9
from
.bottleneck
import
Bottleneck
,
SpatialBottleneck
apex/contrib/bottleneck/bottleneck.py
0 → 100644
View file @
f79993d9
import
torch
import
torch.distributed
as
dist
from
torch
import
nn
import
fast_bottleneck
def
kaiming_uniform_
(
tensor
,
a
=
0
,
mode
=
'fan_in'
,
nonlinearity
=
'leaky_relu'
):
weight_tensor_nchw
=
tensor
nn
.
init
.
kaiming_uniform_
(
weight_tensor_nchw
,
a
=
a
,
mode
=
mode
,
nonlinearity
=
nonlinearity
)
class
FrozenBatchNorm2d
(
torch
.
nn
.
Module
):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed
"""
def
__init__
(
self
,
n
):
super
(
FrozenBatchNorm2d
,
self
).
__init__
()
self
.
register_buffer
(
"weight"
,
torch
.
ones
(
n
))
self
.
register_buffer
(
"bias"
,
torch
.
zeros
(
n
))
self
.
register_buffer
(
"running_mean"
,
torch
.
zeros
(
n
))
self
.
register_buffer
(
"running_var"
,
torch
.
ones
(
n
))
def
get_scale_bias
(
self
,
nhwc
=
False
):
scale
=
self
.
weight
*
self
.
running_var
.
rsqrt
()
bias
=
self
.
bias
-
self
.
running_mean
*
scale
if
nhwc
:
scale
=
scale
.
reshape
(
1
,
1
,
1
,
-
1
)
bias
=
bias
.
reshape
(
1
,
1
,
1
,
-
1
)
else
:
scale
=
scale
.
reshape
(
1
,
-
1
,
1
,
1
)
bias
=
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
return
scale
,
bias
def
forward
(
self
,
x
):
scale
,
bias
=
self
.
get_scale_bias
()
return
x
*
scale
+
bias
@
torch
.
jit
.
script
def
drelu_dscale1
(
grad_o
,
output
,
scale1
):
relu_mask
=
(
output
>
0
).
half
()
dx_relu
=
relu_mask
*
grad_o
g1
=
dx_relu
*
scale1
return
g1
,
dx_relu
@
torch
.
jit
.
script
def
drelu_dscale2
(
grad_o
,
output
,
scale1
,
scale2
):
relu_mask
=
(
output
>
0
).
half
()
dx_relu
=
relu_mask
*
grad_o
g1
=
dx_relu
*
scale1
g2
=
dx_relu
*
scale2
return
g1
,
g2
class
BottleneckFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
nhwc
,
stride_1x1
,
scale
,
bias
,
x
,
*
conv
):
# TODO: clean up order of tensors
args
=
[
x
,
*
conv
[
0
:
3
],
*
scale
[
0
:
3
],
*
bias
[
0
:
3
]]
ctx
.
downsample
=
len
(
conv
)
>
3
if
ctx
.
downsample
:
args
.
append
(
conv
[
3
])
args
.
append
(
scale
[
3
])
args
.
append
(
bias
[
3
])
# weight buffers are always in nhwc while shape can be nhwc or channels_last
# here we pass in flag and let c++ handle it
# alternatively, we can put all sizes into a fixed format and pass it in
outputs
=
fast_bottleneck
.
forward
(
nhwc
,
stride_1x1
,
args
)
ctx
.
save_for_backward
(
*
(
args
+
outputs
))
# save relu outputs for drelu
ctx
.
nhwc
=
nhwc
ctx
.
stride_1x1
=
stride_1x1
return
outputs
[
2
]
# backward relu is not exposed, MUL with mask used now
# only support dgrad
@
staticmethod
def
backward
(
ctx
,
grad_o
):
outputs
=
ctx
.
saved_tensors
[
-
3
:]
if
ctx
.
downsample
:
grad_conv3
,
grad_conv4
=
drelu_dscale2
(
grad_o
,
outputs
[
2
],
ctx
.
saved_tensors
[
6
],
ctx
.
saved_tensors
[
11
])
else
:
grad_conv3
,
grad_conv4
=
drelu_dscale1
(
grad_o
,
outputs
[
2
],
ctx
.
saved_tensors
[
6
])
# create input vector for backward
t_list
=
[
*
ctx
.
saved_tensors
[
0
:
10
]]
t_list
.
append
(
grad_conv3
)
t_list
.
append
(
grad_conv4
)
# outputs used for wgrad and generating drelu mask
t_list
.
append
(
outputs
[
0
])
t_list
.
append
(
outputs
[
1
])
# in case there is downsample
if
ctx
.
downsample
:
t_list
.
append
(
ctx
.
saved_tensors
[
10
])
grads
=
fast_bottleneck
.
backward
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
)
return
(
None
,
None
,
None
,
None
,
*
grads
)
bottleneck_function
=
BottleneckFunction
.
apply
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
groups
=
1
,
dilation
=
1
):
"""3x3 convolution with padding"""
return
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
dilation
,
groups
=
groups
,
bias
=
False
,
dilation
=
dilation
)
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
):
"""1x1 convolution"""
return
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
bias
=
False
)
class
Bottleneck
(
torch
.
nn
.
Module
):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
# here we put it at 1x1
def
__init__
(
self
,
in_channels
,
bottleneck_channels
,
out_channels
,
stride
=
1
,
groups
=
1
,
dilation
=
1
,
norm_func
=
None
,
use_cudnn
=
False
,
explicit_nhwc
=
False
):
super
(
Bottleneck
,
self
).
__init__
()
if
groups
!=
1
:
raise
RuntimeError
(
'Only support groups == 1'
)
if
dilation
!=
1
:
raise
RuntimeError
(
'Only support dilation == 1'
)
if
norm_func
==
None
:
norm_func
=
FrozenBatchNorm2d
else
:
raise
RuntimeError
(
'Only support frozen BN now.'
)
if
stride
!=
1
or
in_channels
!=
out_channels
:
self
.
downsample
=
nn
.
Sequential
(
conv1x1
(
in_channels
,
out_channels
,
stride
),
norm_func
(
out_channels
),
)
else
:
self
.
downsample
=
None
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self
.
conv1
=
conv1x1
(
in_channels
,
bottleneck_channels
,
stride
)
self
.
conv2
=
conv3x3
(
bottleneck_channels
,
bottleneck_channels
)
self
.
conv3
=
conv1x1
(
bottleneck_channels
,
out_channels
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
stride
=
stride
self
.
bn1
=
norm_func
(
bottleneck_channels
)
self
.
bn2
=
norm_func
(
bottleneck_channels
)
self
.
bn3
=
norm_func
(
out_channels
)
self
.
use_cudnn
=
use_cudnn
# setup conv weights
self
.
w_conv
=
[
self
.
conv1
.
weight
,
self
.
conv2
.
weight
,
self
.
conv3
.
weight
]
if
self
.
downsample
is
not
None
:
self
.
w_conv
.
append
(
self
.
downsample
[
0
].
weight
)
# init weight in nchw format before possible transpose
for
w
in
self
.
w_conv
:
kaiming_uniform_
(
w
,
a
=
1
)
# TODO: prevent unsupported case usage
# support cases
# native cudnn
# normal yes no
# channel_last yes yes
# explicit_nhwc no yes
self
.
explicit_nhwc
=
explicit_nhwc
if
self
.
explicit_nhwc
:
for
p
in
self
.
parameters
():
with
torch
.
no_grad
():
p
.
data
=
p
.
data
.
permute
(
0
,
2
,
3
,
1
).
contiguous
()
return
def
forward
(
self
,
x
):
if
self
.
use_cudnn
:
# calculate scale/bias from registered buffers
# TODO: make this better
s1
,
b1
=
self
.
bn1
.
get_scale_bias
(
self
.
explicit_nhwc
)
s2
,
b2
=
self
.
bn2
.
get_scale_bias
(
self
.
explicit_nhwc
)
s3
,
b3
=
self
.
bn3
.
get_scale_bias
(
self
.
explicit_nhwc
)
w_scale
=
[
s1
,
s2
,
s3
]
w_bias
=
[
b1
,
b2
,
b3
]
if
self
.
downsample
is
not
None
:
s4
,
b4
=
self
.
downsample
[
1
].
get_scale_bias
(
self
.
explicit_nhwc
)
w_scale
.
append
(
s4
)
w_bias
.
append
(
b4
)
out
=
bottleneck_function
(
self
.
explicit_nhwc
,
self
.
stride
,
w_scale
,
w_bias
,
x
,
*
self
.
w_conv
)
return
out
if
self
.
explicit_nhwc
:
raise
RuntimeError
(
'explicit nhwc with native ops is not supported.'
)
# fallback to native ops
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv3
(
out
)
out
=
self
.
bn3
(
out
)
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
+=
identity
out
=
self
.
relu
(
out
)
return
out
class
SpatialBottleneckFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
spatial_group_size
,
local_rank
,
comm
,
stream1
,
nhwc
,
stride_1x1
,
scale
,
bias
,
x
,
*
conv
):
# TODO: clean up order of tensors
args
=
[
x
,
*
conv
[
0
:
3
],
*
scale
[
0
:
3
],
*
bias
[
0
:
3
]]
ctx
.
downsample
=
len
(
conv
)
>
3
if
ctx
.
downsample
:
args
.
append
(
conv
[
3
])
args
.
append
(
scale
[
3
])
args
.
append
(
bias
[
3
])
# weight buffers are always in nhwc while shape can be nhwc or channels_last
# here we pass in flag and let c++ handle it
# alternatively, we can put all sizes into a fixed format and pass it in
outputs
=
fast_bottleneck
.
forward_init
(
nhwc
,
stride_1x1
,
args
)
fast_bottleneck
.
forward_out1
(
nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_out2
(
nhwc
,
stride_1x1
,
args
,
outputs
)
# do halo exchange for outputs[0] (out1)
# compute halo cells for outputs[1]
if
spatial_group_size
>
1
:
out1
=
outputs
[
0
]
N
,
Hs
,
W
,
C
=
list
(
out1
.
shape
)
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
stream1
):
# copy halos to send buffer
send_halos
=
torch
.
empty
((
N
,
2
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
send_halos
[:,:
1
,:,:].
copy_
(
out1
[:,:
1
,:,:])
send_halos
[:,
1
:,:,:].
copy_
(
out1
[:,
Hs
-
1
:,:,:])
all_halos
=
torch
.
empty
((
N
,
2
*
spatial_group_size
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
all_halos
=
[
all_halos
[:,
i
*
2
:(
i
+
1
)
*
2
,:,:]
for
i
in
range
(
spatial_group_size
)]
dist
.
all_gather
(
all_halos
,
send_halos
,
group
=
comm
)
fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
top_out1_halo
=
all_halos
[(
spatial_group_size
+
local_rank
-
1
)
%
spatial_group_size
][:,
1
:,:,:]
if
local_rank
>
0
:
fat_halo
[:,:
1
,:,:].
copy_
(
top_out1_halo
)
fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
fat_halo
,
args
)
btm_out1_halo
=
all_halos
[(
local_rank
+
1
)
%
spatial_group_size
][:,:
1
,:,:]
if
local_rank
<
spatial_group_size
-
1
:
fat_halo
[:,
0
:
2
,:,:].
copy_
(
out1
[:,
Hs
-
2
:,:,:])
fat_halo
[:,
2
:,:,:].
copy_
(
btm_out1_halo
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
fat_halo
,
args
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
out2
=
outputs
[
1
]
if
local_rank
>
0
:
out2
[:,:
1
,:,:].
copy_
(
top_out2
)
if
local_rank
<
spatial_group_size
-
1
:
out2
[:,
Hs
-
1
:,:,:].
copy_
(
btm_out2
)
fast_bottleneck
.
forward_rest
(
nhwc
,
stride_1x1
,
args
,
outputs
)
# save halos for backward pass
if
spatial_group_size
>
1
:
ctx
.
save_for_backward
(
*
(
args
+
outputs
+
[
top_out1_halo
,
btm_out1_halo
]))
else
:
ctx
.
save_for_backward
(
*
(
args
+
outputs
))
# save relu outputs for drelu
ctx
.
nhwc
=
nhwc
ctx
.
stride_1x1
=
stride_1x1
ctx
.
spatial_group_size
=
spatial_group_size
ctx
.
local_rank
=
local_rank
ctx
.
comm
=
comm
ctx
.
stream1
=
stream1
return
outputs
[
2
]
# backward relu is not exposed, MUL with mask used now
# only support dgrad
@
staticmethod
def
backward
(
ctx
,
grad_o
):
if
ctx
.
spatial_group_size
>
1
:
top_out1_halo
=
ctx
.
saved_tensors
[
-
2
]
btm_out1_halo
=
ctx
.
saved_tensors
[
-
1
]
outputs
=
ctx
.
saved_tensors
[
-
5
:
-
2
]
else
:
outputs
=
ctx
.
saved_tensors
[
-
3
:]
if
ctx
.
downsample
:
grad_conv3
,
grad_conv4
=
drelu_dscale2
(
grad_o
,
outputs
[
2
],
ctx
.
saved_tensors
[
6
],
ctx
.
saved_tensors
[
11
])
else
:
grad_conv3
,
grad_conv4
=
drelu_dscale1
(
grad_o
,
outputs
[
2
],
ctx
.
saved_tensors
[
6
])
# create input vector for backward
t_list
=
[
*
ctx
.
saved_tensors
[
0
:
10
]]
t_list
.
append
(
grad_conv3
)
t_list
.
append
(
grad_conv4
)
# outputs used for wgrad and generating drelu mask
t_list
.
append
(
outputs
[
0
])
t_list
.
append
(
outputs
[
1
])
# in case there is downsample
if
ctx
.
downsample
:
t_list
.
append
(
ctx
.
saved_tensors
[
10
])
grads
=
fast_bottleneck
.
backward_init
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
)
grad_out2
=
fast_bottleneck
.
backward_grad_out2
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
)
# compute wgrad2 for internal cells
wgrad2
=
fast_bottleneck
.
backward_wgrad2
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
# apply wgrad2 halos
if
ctx
.
spatial_group_size
>
1
:
if
ctx
.
local_rank
>
0
:
top_grad2_halo
=
grad_out2
[:,:
1
,:,:]
top_wgrad2_halo
=
fast_bottleneck
.
backward_wgrad2_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
top_out1_halo
,
top_grad2_halo
)
wgrad2
[:,:
1
,:,:].
add_
(
top_wgrad2_halo
)
if
ctx
.
local_rank
<
ctx
.
spatial_group_size
-
1
:
btm_grad2_halo
=
grad_out2
[:,
-
1
:,:,:]
btm_wgrad2_halo
=
fast_bottleneck
.
backward_wgrad2_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
btm_out1_halo
,
btm_grad2_halo
)
wgrad2
[:,
-
1
:,:,:].
add_
(
btm_wgrad2_halo
)
# do halo exchange of grad_out2 here
# compute halo cells for grad_out1
if
ctx
.
spatial_group_size
>
1
:
N
,
Hs
,
W
,
C
=
list
(
grad_out2
.
shape
)
ctx
.
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
ctx
.
stream1
):
# copy halos to send buffer
send_halos
=
torch
.
empty
((
N
,
2
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
send_halos
[:,:
1
,:,:].
copy_
(
grad_out2
[:,:
1
,:,:])
send_halos
[:,
1
:,:,:].
copy_
(
grad_out2
[:,
Hs
-
1
:,:,:])
all_halos
=
torch
.
empty
((
N
,
2
*
ctx
.
spatial_group_size
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
all_halos
=
[
all_halos
[:,
i
*
2
:(
i
+
1
)
*
2
,:,:]
for
i
in
range
(
ctx
.
spatial_group_size
)]
dist
.
all_gather
(
all_halos
,
send_halos
,
group
=
ctx
.
comm
)
relu1
=
t_list
[
12
]
fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
relu_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
if
ctx
.
local_rank
>
0
:
top_halo
=
all_halos
[
ctx
.
local_rank
-
1
][:,
1
:,:,:]
fat_halo
[:,:
1
,:,:].
copy_
(
top_halo
)
fat_halo
[:,
1
:,:,:].
copy_
(
grad_out2
[:,:
2
,:,:])
relu_halo
[:,:
1
,:,:].
zero_
()
relu_halo
[:,
1
:,:,:].
copy_
(
relu1
[:,:
2
,:,:])
top_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
fat_halo
,
relu_halo
)
top_grad_out1_halo
=
top_grad_out1_halo
[:,
1
:
2
,:,:]
if
ctx
.
local_rank
<
ctx
.
spatial_group_size
-
1
:
btm_halo
=
all_halos
[
ctx
.
local_rank
+
1
][:,:
1
,:,:]
fat_halo
[:,:
2
,:,:].
copy_
(
grad_out2
[:,
Hs
-
2
:,:,:])
fat_halo
[:,
2
:,:,:].
copy_
(
btm_halo
)
relu_halo
[:,:
2
,:,:].
copy_
(
relu1
[:,
Hs
-
2
:,:,:])
relu_halo
[:,
2
:,:,:].
zero_
()
btm_grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
fat_halo
,
relu_halo
)
btm_grad_out1_halo
=
btm_grad_out1_halo
[:,
1
:
2
,:,:]
# compute grad_out1 for internal cells
grad_out1
=
fast_bottleneck
.
backward_grad_out1
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
# apply halo cells to grad_out1
if
ctx
.
spatial_group_size
>
1
:
w
=
t_list
[
2
]
z
=
t_list
[
4
]
relu1
=
t_list
[
12
]
#print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape))))
torch
.
cuda
.
current_stream
().
wait_stream
(
ctx
.
stream1
)
if
ctx
.
local_rank
>
0
:
grad_out1
[:,:
1
,:,:].
copy_
(
top_grad_out1_halo
)
#print("ctx.local_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.local_rank, str(list(grad_out1.shape))))
if
ctx
.
local_rank
<
ctx
.
spatial_group_size
-
1
:
grad_out1
[:,
Hs
-
1
:,:,:].
copy_
(
btm_grad_out1_halo
)
#print("ctx.local_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.local_rank, str(list(grad_out1.shape))))
fast_bottleneck
.
backward_rest
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
,
grad_out1
,
wgrad2
)
return
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
grads
)
spatial_bottleneck_function
=
SpatialBottleneckFunction
.
apply
class
SpatialBottleneck
(
torch
.
nn
.
Module
):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
# here we put it at 1x1
def
__init__
(
self
,
in_channels
,
bottleneck_channels
,
out_channels
,
stride
=
1
,
groups
=
1
,
dilation
=
1
,
norm_func
=
None
,
use_cudnn
=
False
,
explicit_nhwc
=
False
,
spatial_group_size
=
1
,
communicator
=
None
):
super
(
SpatialBottleneck
,
self
).
__init__
()
if
groups
!=
1
:
raise
RuntimeError
(
'Only support groups == 1'
)
if
dilation
!=
1
:
raise
RuntimeError
(
'Only support dilation == 1'
)
if
norm_func
==
None
:
norm_func
=
FrozenBatchNorm2d
else
:
raise
RuntimeError
(
'Only support frozen BN now.'
)
if
stride
!=
1
or
in_channels
!=
out_channels
:
self
.
downsample
=
nn
.
Sequential
(
conv1x1
(
in_channels
,
out_channels
,
stride
),
norm_func
(
out_channels
),
)
else
:
self
.
downsample
=
None
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self
.
conv1
=
conv1x1
(
in_channels
,
bottleneck_channels
,
stride
)
self
.
conv2
=
conv3x3
(
bottleneck_channels
,
bottleneck_channels
)
self
.
conv3
=
conv1x1
(
bottleneck_channels
,
out_channels
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
stride
=
stride
self
.
bn1
=
norm_func
(
bottleneck_channels
)
self
.
bn2
=
norm_func
(
bottleneck_channels
)
self
.
bn3
=
norm_func
(
out_channels
)
self
.
use_cudnn
=
use_cudnn
# setup conv weights
self
.
w_conv
=
[
self
.
conv1
.
weight
,
self
.
conv2
.
weight
,
self
.
conv3
.
weight
]
if
self
.
downsample
is
not
None
:
self
.
w_conv
.
append
(
self
.
downsample
[
0
].
weight
)
# init weight in nchw format before possible transpose
for
w
in
self
.
w_conv
:
kaiming_uniform_
(
w
,
a
=
1
)
# TODO: prevent unsupported case usage
# support cases
# native cudnn
# normal yes no
# channel_last yes yes
# explicit_nhwc no yes
self
.
explicit_nhwc
=
explicit_nhwc
if
self
.
explicit_nhwc
:
for
p
in
self
.
parameters
():
with
torch
.
no_grad
():
p
.
data
=
p
.
data
.
permute
(
0
,
2
,
3
,
1
).
contiguous
()
# spatial communicator
self
.
spatial_group_size
=
spatial_group_size
if
spatial_group_size
>
1
:
world_size
=
dist
.
get_world_size
()
num_groups
=
world_size
//
spatial_group_size
assert
(
num_groups
*
spatial_group_size
==
world_size
),
"torch.distributed.get_world_size() must be multiple of group_size"
rank
=
dist
.
get_rank
()
self
.
local_rank
=
rank
%
spatial_group_size
if
communicator
is
None
:
for
group
in
range
(
num_groups
):
ranks
=
list
(
range
(
group
*
spatial_group_size
,(
group
+
1
)
*
spatial_group_size
))
comm
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
rank
in
ranks
:
self
.
communicator
=
comm
else
:
self
.
communicator
=
communicator
self
.
stream1
=
torch
.
cuda
.
Stream
()
self
.
spatial_args
=
self
.
spatial_group_size
,
self
.
local_rank
,
self
.
communicator
,
self
.
stream1
else
:
self
.
spatial_args
=
1
,
0
,
None
,
None
return
def
forward
(
self
,
x
):
if
self
.
use_cudnn
:
# calculate scale/bias from registered buffers
# TODO: make this better
s1
,
b1
=
self
.
bn1
.
get_scale_bias
(
self
.
explicit_nhwc
)
s2
,
b2
=
self
.
bn2
.
get_scale_bias
(
self
.
explicit_nhwc
)
s3
,
b3
=
self
.
bn3
.
get_scale_bias
(
self
.
explicit_nhwc
)
w_scale
=
[
s1
,
s2
,
s3
]
w_bias
=
[
b1
,
b2
,
b3
]
if
self
.
downsample
is
not
None
:
s4
,
b4
=
self
.
downsample
[
1
].
get_scale_bias
(
self
.
explicit_nhwc
)
w_scale
.
append
(
s4
)
w_bias
.
append
(
b4
)
out
=
spatial_bottleneck_function
(
*
self
.
spatial_args
,
self
.
explicit_nhwc
,
self
.
stride
,
w_scale
,
w_bias
,
x
,
*
self
.
w_conv
)
return
out
if
self
.
explicit_nhwc
:
raise
RuntimeError
(
'explicit nhwc with native ops is not supported.'
)
# fallback to native ops
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv3
(
out
)
out
=
self
.
bn3
(
out
)
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
+=
identity
out
=
self
.
relu
(
out
)
return
out
apex/contrib/bottleneck/bottleneck_module_test.py
0 → 100644
View file @
f79993d9
import
os
import
torch
from
maskrcnn_benchmark.modeling.backbone.resnet
import
Bottleneck
from
maskrcnn_benchmark.layers.nhwc
import
nhwc_to_nchw_transform
,
nchw_to_nhwc_transform
from
maskrcnn_benchmark.layers.nhwc.batch_norm
import
FrozenBatchNorm2d_NHWC
from
apex.contrib.bottleneck
import
Bottleneck
as
FastBottleneck
from
apex.contrib.bottleneck
import
SpatialBottleneck
def
single_module_test
(
ref
,
rank
,
world_size
,
numtype
,
device
,
shape
,
fast
,
spatial_group_size
,
in_channels
,
bottleneck_channels
,
out_channels
,
num_groups
,
stride_in_1x1
,
stride
,
dilation
,
norm_func
,
nhwc
):
# inputs + modules
with
torch
.
no_grad
():
input_shape
=
[
1
,
in_channels
]
+
list
(
shape
)
x
=
torch
.
randn
(
input_shape
,
dtype
=
numtype
,
device
=
device
)
if
nhwc
:
x
=
nchw_to_nhwc_transform
(
x
).
contiguous
()
x
.
requires_grad
=
True
print
(
x
.
shape
,
x
.
stride
())
#if spatial_group_size > 1:
# fast = False # hack so fast bottleneck can be run against distributed bottleneck
#if spatial_group_size == 1:
# fast = False
if
fast
:
if
spatial_group_size
==
1
:
bottleneck
=
FastBottleneck
(
in_channels
=
in_channels
,
bottleneck_channels
=
bottleneck_channels
,
out_channels
=
out_channels
,
stride
=
stride
,
dilation
=
dilation
,
explicit_nhwc
=
nhwc
,
use_cudnn
=
True
)
else
:
bottleneck
=
SpatialBottleneck
(
in_channels
=
in_channels
,
bottleneck_channels
=
bottleneck_channels
,
out_channels
=
out_channels
,
stride
=
stride
,
dilation
=
dilation
,
explicit_nhwc
=
nhwc
,
use_cudnn
=
True
,
spatial_group_size
=
spatial_group_size
)
else
:
bottleneck
=
Bottleneck
(
in_channels
,
bottleneck_channels
,
out_channels
,
num_groups
,
stride_in_1x1
,
stride
,
dilation
,
norm_func
,
nhwc
,
spatial_group_size
)
bottleneck
=
bottleneck
.
to
(
dtype
=
numtype
,
device
=
device
)
weights
=
dict
(
bottleneck
.
named_parameters
())
if
ref
is
not
None
:
ref_x
,
_
,
ref_weights
=
ref
Hs
,
H
=
x
.
shape
[
1
],
ref_x
.
shape
[
1
]
assert
(
Hs
*
spatial_group_size
==
H
),
"Hs not a multiple of H"
ref_x
=
ref_x
[:,
rank
*
Hs
:(
rank
+
1
)
*
Hs
,:,:]
x
.
copy_
(
ref_x
)
assert
(
len
(
weights
)
==
len
(
ref_weights
)),
"Reference weights and weights don't match"
for
k
in
weights
.
keys
():
weights
[
k
].
copy_
(
ref_weights
[
k
])
# forward
out
=
bottleneck
(
x
)
# gradient output
with
torch
.
no_grad
():
grad_out
=
torch
.
randn_like
(
out
)
if
ref
is
not
None
:
_
,
ref_grad_out
,
_
=
ref
Hs
,
H
=
grad_out
.
shape
[
1
],
ref_grad_out
.
shape
[
1
]
assert
(
Hs
*
spatial_group_size
==
H
),
"Hs not a multiple of H"
ref_grad_out
=
ref_grad_out
[:,
rank
*
Hs
:(
rank
+
1
)
*
Hs
,:,:]
grad_out
.
copy_
(
ref_grad_out
)
# backward
out
.
backward
(
grad_out
)
with
torch
.
no_grad
():
dgrad
=
x
.
grad
.
detach
()
wgrad
=
{}
for
n
,
p
in
bottleneck
.
named_parameters
():
wgrad
[
n
]
=
p
.
grad
.
detach
()
if
world_size
>
1
:
if
spatial_group_size
==
1
:
# broadcast x, grad_out and weights from rank 0
with
torch
.
no_grad
():
torch
.
distributed
.
broadcast
(
x
,
0
)
torch
.
distributed
.
broadcast
(
grad_out
,
0
)
for
k
in
weights
.
keys
():
torch
.
distributed
.
broadcast
(
weights
[
k
],
0
)
else
:
# gather dgrad (x.grad), sum wgrad (weights) and out
N
,
Hs
,
W
,
C
=
dgrad
.
shape
H
=
Hs
*
spatial_group_size
dgrad_gathered
=
torch
.
empty
((
N
,
H
,
W
,
C
),
dtype
=
dgrad
.
dtype
,
device
=
dgrad
.
device
)
dgrad_tensors
=
[
dgrad_gathered
[:,
i
*
Hs
:(
i
+
1
)
*
Hs
,:,:]
for
i
in
range
(
spatial_group_size
)]
torch
.
distributed
.
all_gather
(
dgrad_tensors
,
dgrad
)
dgrad
=
dgrad_gathered
N
,
Hs
,
W
,
C
=
list
(
out
.
shape
)
H
=
Hs
*
spatial_group_size
out_gathered
=
torch
.
empty
((
N
,
H
,
W
,
C
),
dtype
=
dgrad
.
dtype
,
device
=
dgrad
.
device
)
out_tensors
=
[
out_gathered
[:,
i
*
Hs
:(
i
+
1
)
*
Hs
,:,:]
for
i
in
range
(
spatial_group_size
)]
torch
.
distributed
.
all_gather
(
out_tensors
,
out
)
out
=
out_gathered
for
k
in
wgrad
.
keys
():
w
=
wgrad
[
k
].
to
(
dtype
=
torch
.
float64
)
torch
.
distributed
.
all_reduce
(
w
)
wgrad
[
k
].
copy_
(
w
.
to
(
dtype
=
wgrad
[
k
].
dtype
))
#torch.distributed.all_reduce(wgrad[k])
return
x
,
out
,
grad_out
,
weights
,
dgrad
,
wgrad
def
module_tests
(
rank
,
world_size
,
numtype
,
device
,
fast
,
spatial_group_sizes
,
init_args
):
r
=
[]
for
ia
in
init_args
:
shape
=
ia
[
0
:
4
]
args
=
ia
[
4
:]
rr
=
[]
ref
=
None
for
spatial_group_size
in
spatial_group_sizes
:
N
,
H
,
W
,
C
=
shape
H
=
H
//
spatial_group_size
x
,
out
,
grad_out
,
weights
,
dgrad
,
wgrad
=
single_module_test
(
ref
,
rank
,
world_size
,
numtype
,
device
,
[
H
,
W
],
fast
,
spatial_group_size
,
*
args
)
if
ref
is
None
:
assert
(
spatial_group_size
==
1
),
"Wrong reference weights"
ref
=
x
,
grad_out
,
weights
if
rank
==
0
:
rr
.
append
(
(
out
,
dgrad
,
wgrad
)
)
if
world_size
>
1
:
torch
.
distributed
.
barrier
()
r
.
append
(
rr
)
return
r
def
main
():
total_num_gpus
=
int
(
os
.
environ
[
"WORLD_SIZE"
])
if
"WORLD_SIZE"
in
os
.
environ
else
1
distributed
=
total_num_gpus
>
1
ngpus
=
torch
.
cuda
.
device_count
()
if
distributed
:
torch
.
distributed
.
init_process_group
(
"nccl"
)
rank
,
world_size
=
torch
.
distributed
.
get_rank
(),
torch
.
distributed
.
get_world_size
()
is_master
=
True
if
rank
==
0
else
False
local_rank
=
rank
%
ngpus
torch
.
cuda
.
set_device
(
local_rank
)
spatial_group_size
=
total_num_gpus
else
:
rank
,
local_rank
,
is_master
,
world_size
,
spatial_group_size
=
0
,
0
,
True
,
1
,
1
torch
.
use_deterministic_algorithms
(
True
)
torch
.
backends
.
cudnn
.
benchmark
=
False
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cudnn
.
allow_tf32
=
False
norm_func
=
FrozenBatchNorm2d_NHWC
init_args
=
[
(
1
,
200
,
336
,
64
,
64
,
64
,
256
,
1
,
True
,
1
,
1
,
norm_func
,
True
),
(
1
,
200
,
336
,
256
,
256
,
64
,
256
,
1
,
True
,
1
,
1
,
norm_func
,
True
),
(
1
,
200
,
336
,
256
,
256
,
128
,
512
,
1
,
True
,
2
,
1
,
norm_func
,
True
),
(
1
,
100
,
168
,
512
,
512
,
128
,
512
,
1
,
True
,
1
,
1
,
norm_func
,
True
),
(
1
,
100
,
168
,
512
,
512
,
256
,
1024
,
1
,
True
,
2
,
1
,
norm_func
,
True
),
(
1
,
50
,
84
,
1024
,
1024
,
256
,
1024
,
1
,
True
,
1
,
1
,
norm_func
,
True
),
(
1
,
50
,
84
,
1024
,
1024
,
512
,
2048
,
1
,
True
,
2
,
1
,
norm_func
,
True
),
(
1
,
25
,
42
,
2048
,
2048
,
512
,
2048
,
1
,
True
,
1
,
1
,
norm_func
,
True
),
(
1
,
336
,
200
,
64
,
64
,
64
,
256
,
1
,
True
,
1
,
1
,
norm_func
,
True
),
(
1
,
336
,
200
,
256
,
256
,
64
,
256
,
1
,
True
,
1
,
1
,
norm_func
,
True
),
(
1
,
336
,
200
,
256
,
256
,
128
,
512
,
1
,
True
,
2
,
1
,
norm_func
,
True
),
(
1
,
168
,
100
,
512
,
512
,
128
,
512
,
1
,
True
,
1
,
1
,
norm_func
,
True
),
(
1
,
168
,
100
,
512
,
512
,
256
,
1024
,
1
,
True
,
2
,
1
,
norm_func
,
True
),
(
1
,
84
,
50
,
1024
,
1024
,
256
,
1024
,
1
,
True
,
1
,
1
,
norm_func
,
True
),
(
1
,
84
,
50
,
1024
,
1024
,
512
,
2048
,
1
,
True
,
2
,
1
,
norm_func
,
True
),
(
1
,
42
,
25
,
2048
,
2048
,
512
,
2048
,
1
,
True
,
1
,
1
,
norm_func
,
True
),
]
init_args
=
init_args
[
0
:
1
]
# pad H to account for spatial distribution
padded_init_args
=
[]
for
ia
in
init_args
:
N
,
H
,
W
,
C
=
ia
[
0
:
4
]
m
=
spatial_group_size
*
H
//
(
25
if
H
<
W
else
42
)
H
=
((
H
+
m
-
1
)
//
m
)
*
m
args
=
tuple
(
[
N
,
H
,
W
,
C
]
+
list
(
ia
[
4
:])
)
padded_init_args
.
append
(
args
)
init_args
=
padded_init_args
if
rank
==
0
:
for
ia
in
init_args
:
print
(
ia
)
spatial_group_sizes
=
[
1
]
if
spatial_group_size
>
1
:
spatial_group_sizes
.
append
(
spatial_group_size
)
numtype
,
device
,
fast
=
torch
.
float16
,
'cuda'
,
True
r
=
module_tests
(
rank
,
world_size
,
numtype
,
device
,
fast
,
spatial_group_sizes
,
init_args
)
if
world_size
>
1
:
torch
.
distributed
.
barrier
()
if
rank
==
0
:
for
rr
in
r
:
print
(
"***"
)
for
out
,
dgrad
,
wgrad
in
rr
:
gr
=
[(
"out"
,
out
.
norm
(
p
=
2
,
dtype
=
torch
.
float64
).
item
())]
gr
=
gr
+
[(
"dgrad"
,
dgrad
.
norm
(
p
=
2
,
dtype
=
torch
.
float64
).
item
())]
gr
=
gr
+
[(
k
+
".wgrad"
,
wgrad
[
k
].
norm
(
p
=
2
,
dtype
=
torch
.
float64
).
item
())
for
k
in
wgrad
.
keys
()]
print
(
gr
)
if
len
(
rr
)
==
2
:
out1
,
dgrad1
,
wgrad1
=
rr
[
0
]
out2
,
dgrad2
,
wgrad2
=
rr
[
1
]
rtol
=
1e-1
out_atol
=
out1
.
abs
().
max
().
item
()
*
rtol
dgrad_atol
=
dgrad1
.
abs
().
max
().
item
()
*
rtol
wgrad_atol
=
{}
for
k
in
wgrad1
.
keys
():
wgrad_atol
[
k
]
=
wgrad1
[
k
].
abs
().
max
().
item
()
*
rtol
gr
=
[(
"out"
,
torch
.
allclose
(
out1
,
out2
,
rtol
,
out_atol
,
equal_nan
=
True
))]
gr
=
gr
+
[(
"dgrad"
,
torch
.
allclose
(
dgrad1
,
dgrad2
,
rtol
,
dgrad_atol
,
equal_nan
=
True
))]
gr
=
gr
+
[(
k
+
".wgrad"
,
torch
.
allclose
(
wgrad1
[
k
],
wgrad2
[
k
],
rtol
,
wgrad_atol
[
k
],
equal_nan
=
True
))
for
k
in
wgrad1
.
keys
()]
print
(
gr
)
gr
=
[(
"out"
,(
out1
-
out2
).
norm
(
p
=
2
,
dtype
=
torch
.
float64
).
item
())]
gr
=
gr
+
[(
"dgrad"
,(
dgrad1
-
dgrad2
).
norm
(
p
=
2
,
dtype
=
torch
.
float64
).
item
())]
gr
=
gr
+
[(
k
+
".wgrad"
,(
wgrad1
[
k
]
-
wgrad2
[
k
]).
norm
(
p
=
2
,
dtype
=
torch
.
float64
).
item
())
for
k
in
wgrad1
.
keys
()]
print
(
gr
)
N
,
H
,
W
,
C
=
out1
.
shape
Hs
=
H
//
spatial_group_size
Ht
=
Hs
-
2
print
(
"out1@%d:%d=%s"
%
(
Ht
,
H
,
str
(
out1
[
0
,
Ht
,:
8
,:
5
])))
print
(
"out2@%d:%d=%s"
%
(
Ht
,
H
,
str
(
out2
[
0
,
Ht
,:
8
,:
5
])))
Ht
=
Hs
-
1
print
(
"out1@%d:%d=%s"
%
(
Ht
,
H
,
str
(
out1
[
0
,
Ht
,:
8
,:
5
])))
print
(
"out2@%d:%d=%s"
%
(
Ht
,
H
,
str
(
out2
[
0
,
Ht
,:
8
,:
5
])))
Ht
=
Hs
print
(
"out1@%d:%d=%s"
%
(
Ht
,
H
,
str
(
out1
[
0
,
Ht
,:
8
,:
5
])))
print
(
"out2@%d:%d=%s"
%
(
Ht
,
H
,
str
(
out2
[
0
,
Ht
,:
8
,:
5
])))
Ht
=
Hs
+
1
print
(
"out1@%d:%d=%s"
%
(
Ht
,
H
,
str
(
out1
[
0
,
Ht
,:
8
,:
5
])))
print
(
"out2@%d:%d=%s"
%
(
Ht
,
H
,
str
(
out2
[
0
,
Ht
,:
8
,:
5
])))
N
,
H
,
W
,
C
=
dgrad1
.
shape
Hs
=
H
//
spatial_group_size
Ht
=
Hs
-
2
print
(
"dgrad1@%d:%d=%s"
%
(
Ht
,
H
,
str
(
dgrad1
[
0
,
Ht
,:
8
,:
5
])))
print
(
"dgrad2@%d:%d=%s"
%
(
Ht
,
H
,
str
(
dgrad2
[
0
,
Ht
,:
8
,:
5
])))
Ht
=
Hs
-
1
print
(
"dgrad1@%d:%d=%s"
%
(
Ht
,
H
,
str
(
dgrad1
[
0
,
Ht
,:
8
,:
5
])))
print
(
"dgrad2@%d:%d=%s"
%
(
Ht
,
H
,
str
(
dgrad2
[
0
,
Ht
,:
8
,:
5
])))
Ht
=
Hs
print
(
"dgrad1@%d:%d=%s"
%
(
Ht
,
H
,
str
(
dgrad1
[
0
,
Ht
,:
8
,:
5
])))
print
(
"dgrad2@%d:%d=%s"
%
(
Ht
,
H
,
str
(
dgrad2
[
0
,
Ht
,:
8
,:
5
])))
Ht
=
Hs
+
1
print
(
"dgrad1@%d:%d=%s"
%
(
Ht
,
H
,
str
(
dgrad1
[
0
,
Ht
,:
8
,:
5
])))
print
(
"dgrad2@%d:%d=%s"
%
(
Ht
,
H
,
str
(
dgrad2
[
0
,
Ht
,:
8
,:
5
])))
if
world_size
>
1
:
torch
.
distributed
.
barrier
()
if
__name__
==
"__main__"
:
main
()
apex/contrib/bottleneck/test.py
0 → 100644
View file @
f79993d9
import
torch
from
bottleneck
import
Bottleneck
torch
.
manual_seed
(
23337
)
# use True to print layerwise sum for all outputs in reference code path
DEBUG
=
False
#True
for
stride
,
o_channel
in
[(
1
,
32
),
(
1
,
128
),
(
2
,
32
)]:
print
(
"testing stride =="
,
stride
,
", in_channel == 32 , out_channel =="
,
o_channel
)
a_
=
torch
.
randn
(
17
,
32
,
28
,
28
)
a
=
a_
.
cuda
().
half
().
to
(
memory_format
=
torch
.
channels_last
).
requires_grad_
()
model
=
Bottleneck
(
32
,
8
,
o_channel
,
stride
=
stride
).
cuda
().
half
().
to
(
memory_format
=
torch
.
channels_last
)
# test model
b
=
model
(
a
)
b
.
mean
().
backward
()
d_grad
=
a
.
grad
.
float
()
a
.
grad
=
None
torch
.
cuda
.
synchronize
()
if
DEBUG
:
print
(
"[DEBUG] ref dx :"
,
d_grad
.
sum
().
item
())
# print wgrad. we don't need to reset since later cpp print before accumulation
for
i
,
w
in
enumerate
(
model
.
w_conv
):
print
(
"[DEBUG] ref wgrad{} :"
.
format
(
i
+
1
),
w
.
grad
.
sum
().
item
())
wgrads
=
[]
for
w
in
model
.
w_conv
:
wgrads
.
append
(
w
.
grad
.
float
())
model
.
use_cudnn
=
True
model
.
zero_grad
()
c
=
model
(
a
)
c
.
mean
().
backward
()
torch
.
cuda
.
synchronize
()
print
(
"comparing native and channels_last:"
)
print
(
"max error fprop:"
,
(
b
-
c
).
abs
().
max
().
item
(),
"max elem:"
,
b
.
abs
().
max
().
item
())
print
(
"max error dgrad:"
,
(
d_grad
-
a
.
grad
.
float
()).
abs
().
max
().
item
(),
"max elem:"
,
d_grad
.
abs
().
max
().
item
())
for
i
,
(
w
,
wgrad
)
in
enumerate
(
zip
(
model
.
w_conv
,
wgrads
)):
print
(
"max error wgrad{}:"
.
format
(
i
+
1
),
(
wgrad
-
w
.
grad
.
float
()).
abs
().
max
().
item
(),
"max elem:"
,
wgrad
.
abs
().
max
().
item
())
nhwc_a
=
a_
.
permute
(
0
,
2
,
3
,
1
).
contiguous
().
cuda
().
half
().
requires_grad_
()
nhwc_model
=
Bottleneck
(
32
,
8
,
o_channel
,
stride
=
stride
,
explicit_nhwc
=
True
,
use_cudnn
=
True
).
cuda
().
half
()
for
p
,
q
in
zip
(
model
.
parameters
(),
nhwc_model
.
parameters
()):
# model's storage is already in nhwc, we clone and assign to explicit nhwc model
q
.
data
.
copy_
(
p
.
data
.
permute
(
0
,
2
,
3
,
1
).
contiguous
())
for
p
,
q
in
zip
(
model
.
buffers
(),
nhwc_model
.
buffers
()):
q
.
data
.
copy_
(
p
.
data
)
d
=
nhwc_model
(
nhwc_a
)
d
.
mean
().
backward
()
torch
.
cuda
.
synchronize
()
# reset reference to cudnn channels_last permute
#c_s = c.storage().tolist()
#d_s = d.storage().tolist()
#print(max([x-y for x,y in zip(c_s,d_s)]))
c
=
c
.
contiguous
(
memory_format
=
torch
.
contiguous_format
).
permute
(
0
,
2
,
3
,
1
).
contiguous
()
d_grad
=
a
.
grad
.
float
().
permute
(
0
,
2
,
3
,
1
).
contiguous
()
wgrads
=
[]
for
w
in
model
.
w_conv
:
wgrads
.
append
(
w
.
grad
.
float
().
permute
(
0
,
2
,
3
,
1
).
contiguous
())
torch
.
cuda
.
synchronize
()
print
(
"comparing nhwc and channels_last:"
)
print
(
"max error fprop:"
,
(
d
-
c
).
abs
().
max
().
item
(),
"max elem:"
,
c
.
abs
().
max
().
item
())
print
(
"max error dgrad:"
,
(
d_grad
-
nhwc_a
.
grad
.
float
()).
abs
().
max
().
item
(),
"max elem:"
,
d_grad
.
abs
().
max
().
item
())
for
i
,
(
w
,
wgrad
)
in
enumerate
(
zip
(
nhwc_model
.
w_conv
,
wgrads
)):
print
(
"max error wgrad{}:"
.
format
(
i
+
1
),
(
wgrad
-
w
.
grad
.
float
()).
abs
().
max
().
item
(),
"max elem:"
,
wgrad
.
abs
().
max
().
item
())
apex/contrib/csrc/bottleneck/bottleneck.cpp
0 → 100644
View file @
f79993d9
#include <ATen/ATen.h>
#include <ATen/cudnn/Handle.h> // for getcudnnhandle
#include <torch/extension.h>
#include <torch/torch.h>
#include <vector>
#include <cudnn_frontend.h>
#include <iostream>
#ifdef DEBUG
#define DEBUG_MSG(str) do { std::cout << str << std::endl; } while( false )
#else
#define DEBUG_MSG(str) do { } while ( false )
#endif
#ifdef DEBUG_CUDNN
#define DEBUG_CUDNN_MSG(buf, str) do { buf << str << std::endl; } while( false )
#else
#define DEBUG_CUDNN_MSG(buf, str) do { } while ( false )
#endif
#define checkCudnnErr(...) \
do { \
int err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \
if (err) { \
return; \
} \
} while (0)
int
checkCudnnError
(
cudnnStatus_t
code
,
const
char
*
expr
,
const
char
*
file
,
int
line
)
{
if
(
code
)
{
printf
(
"CUDNN error at %s:%d, code=%d (%s) in '%s'
\n
"
,
file
,
line
,
(
int
)
code
,
cudnnGetErrorString
(
code
),
expr
);
return
1
;
}
return
0
;
}
void
checkError
(
cudaError_t
code
,
char
const
*
func
,
const
char
*
file
,
const
int
line
,
bool
abort
=
true
);
#define checkCUDAError(val) { checkError((val), #val, __FILE__, __LINE__); } // in-line regular function
void
checkError
(
cudaError_t
code
,
char
const
*
func
,
const
char
*
file
,
const
int
line
,
bool
abort
)
{
if
(
code
!=
cudaSuccess
)
{
const
char
*
errorMessage
=
cudaGetErrorString
(
code
);
fprintf
(
stderr
,
"CUDA error returned from
\"
%s
\"
at %s:%d, Error code: %d (%s)
\n
"
,
func
,
file
,
line
,
code
,
errorMessage
);
if
(
abort
){
cudaDeviceReset
();
exit
(
code
);
}
}
}
void
generateStrides
(
const
int64_t
*
dimA
,
int64_t
*
strideA
,
int
nbDims
,
cudnnTensorFormat_t
filterFormat
)
{
// For INT8x4 and INT8x32 we still compute standard strides here to input
// into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref.
if
(
filterFormat
==
CUDNN_TENSOR_NCHW
)
{
strideA
[
nbDims
-
1
]
=
1
;
for
(
int64_t
d
=
nbDims
-
2
;
d
>=
0
;
d
--
)
{
strideA
[
d
]
=
strideA
[
d
+
1
]
*
dimA
[
d
+
1
];
}
}
else
{
// Here we assume that the format is CUDNN_TENSOR_NHWC
strideA
[
1
]
=
1
;
strideA
[
nbDims
-
1
]
=
strideA
[
1
]
*
dimA
[
1
];
for
(
int64_t
d
=
nbDims
-
2
;
d
>=
2
;
d
--
)
{
strideA
[
d
]
=
strideA
[
d
+
1
]
*
dimA
[
d
+
1
];
}
strideA
[
0
]
=
strideA
[
2
]
*
dimA
[
2
];
}
}
int
getFwdConvDilatedFilterDim
(
int
filterDim
,
int
dilation
)
{
return
((
filterDim
-
1
)
*
dilation
)
+
1
;
}
int
getFwdConvPaddedImageDim
(
int
tensorDim
,
int
pad
)
{
return
tensorDim
+
(
2
*
pad
);
}
int
getFwdConvOutputDim
(
int
tensorDim
,
int
pad
,
int
filterDim
,
int
stride
,
int
dilation
)
{
int
p
=
(
getFwdConvPaddedImageDim
(
tensorDim
,
pad
)
-
getFwdConvDilatedFilterDim
(
filterDim
,
dilation
))
/
stride
+
1
;
return
(
p
);
}
enum
{
X_TENSOR
,
Y_TENSOR
,
W_TENSOR
,
Z_TENSOR
,
B_TENSOR
,
AFTERADD_TENSOR
,
AFTERBIAS_TENSOR
,
AFTERCONV_TENSOR
,
OPTIONAL
,
AFTEROPT_TENSOR
,
};
using
common_conv_descriptors
=
std
::
tuple
<
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
ConvDesc
>
;
common_conv_descriptors
create_common_descriptors
(
int64_t
*
x_dim_padded
,
int64_t
*
padA
,
int64_t
*
convstrideA
,
int64_t
*
dilationA
,
int64_t
*
w_dim_padded
,
int64_t
*
y_dim_padded
,
cudnnDataType_t
dataType
,
cudnnConvolutionMode_t
mode
)
{
const
int
convDim
=
2
;
int64_t
strideA_padded
[
4
];
int64_t
outstrideA_padded
[
4
];
int64_t
filterstrideA_padded
[
4
];
generateStrides
(
w_dim_padded
,
filterstrideA_padded
,
4
,
CUDNN_TENSOR_NHWC
);
generateStrides
(
x_dim_padded
,
strideA_padded
,
4
,
CUDNN_TENSOR_NHWC
);
generateStrides
(
y_dim_padded
,
outstrideA_padded
,
4
,
CUDNN_TENSOR_NHWC
);
return
common_conv_descriptors
(
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
x_dim_padded
)
.
setStrides
(
4
,
strideA_padded
)
.
setId
(
'x'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
outstrideA_padded
)
.
setId
(
'y'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
w_dim_padded
)
.
setStrides
(
4
,
filterstrideA_padded
)
.
setId
(
'w'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
ConvDescBuilder
()
.
setDataType
(
CUDNN_DATA_FLOAT
)
.
setMathMode
(
mode
)
.
setNDims
(
convDim
)
.
setStrides
(
convDim
,
convstrideA
)
.
setPrePadding
(
convDim
,
padA
)
.
setPostPadding
(
convDim
,
padA
)
.
setDilation
(
convDim
,
dilationA
)
.
build
());
}
using
common_convbias_descriptors
=
std
::
tuple
<
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
>
;
common_convbias_descriptors
create_conv_bias_add_act_descriptors
(
int64_t
*
x_dim_padded
,
int64_t
*
padA
,
int64_t
*
convstrideA
,
int64_t
*
dilationA
,
int64_t
*
w_dim_padded
,
int64_t
*
y_dim_padded
,
cudnnDataType_t
dataType
)
{
const
int
convDim
=
2
;
int64_t
b_dim_padded
[
4
];
b_dim_padded
[
0
]
=
1
;
b_dim_padded
[
1
]
=
y_dim_padded
[
1
];
b_dim_padded
[
2
]
=
1
;
b_dim_padded
[
3
]
=
1
;
int64_t
x_stride_padded
[
4
];
int64_t
y_stride_padded
[
4
];
int64_t
w_stride_padded
[
4
];
int64_t
b_stride_padded
[
4
];
generateStrides
(
w_dim_padded
,
w_stride_padded
,
4
,
CUDNN_TENSOR_NHWC
);
generateStrides
(
x_dim_padded
,
x_stride_padded
,
4
,
CUDNN_TENSOR_NHWC
);
generateStrides
(
y_dim_padded
,
y_stride_padded
,
4
,
CUDNN_TENSOR_NHWC
);
generateStrides
(
b_dim_padded
,
b_stride_padded
,
4
,
CUDNN_TENSOR_NHWC
);
return
common_convbias_descriptors
(
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
x_dim_padded
)
.
setStrides
(
4
,
x_stride_padded
)
.
setId
(
'x'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setId
(
'y'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
w_dim_padded
)
.
setStrides
(
4
,
w_stride_padded
)
.
setId
(
'w'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
b_dim_padded
)
.
setStrides
(
4
,
b_stride_padded
)
.
setId
(
'z'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
b_dim_padded
)
.
setStrides
(
4
,
b_stride_padded
)
.
setId
(
'b'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setVirtual
()
.
setId
(
'A'
)
// after add
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setVirtual
()
.
setId
(
'B'
)
// after bias
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setId
(
'C'
)
// after conv
.
setAlignment
(
16
)
.
setVirtual
()
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setId
(
'i'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setId
(
'D'
)
// after optional add
.
setAlignment
(
16
)
.
setVirtual
()
.
setDataType
(
dataType
)
.
build
());
}
// tensor descriptors used for dgrad
enum
{
X_OR_DX_TENSOR
,
DY_TENSOR
,
W_OR_DW_TENSOR
,
SCALE_TENSOR
,
RELU_TENSOR
,
AFTER_DCONV_TENSOR
,
AFTER_DRELU_TENSOR
,
};
using
dconv_descriptors
=
std
::
tuple
<
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
>
;
dconv_descriptors
create_dconv_descriptors
(
int64_t
*
x_dim_padded
,
int64_t
*
padA
,
int64_t
*
convstrideA
,
int64_t
*
dilationA
,
int64_t
*
w_dim_padded
,
int64_t
*
y_dim_padded
,
cudnnDataType_t
dataType
)
{
const
int
convDim
=
2
;
int64_t
b_dim_padded
[
4
];
b_dim_padded
[
0
]
=
1
;
b_dim_padded
[
1
]
=
x_dim_padded
[
1
];
b_dim_padded
[
2
]
=
1
;
b_dim_padded
[
3
]
=
1
;
int64_t
x_stride_padded
[
4
];
int64_t
y_stride_padded
[
4
];
int64_t
w_stride_padded
[
4
];
int64_t
b_stride_padded
[
4
];
generateStrides
(
w_dim_padded
,
w_stride_padded
,
4
,
CUDNN_TENSOR_NHWC
);
generateStrides
(
x_dim_padded
,
x_stride_padded
,
4
,
CUDNN_TENSOR_NHWC
);
generateStrides
(
y_dim_padded
,
y_stride_padded
,
4
,
CUDNN_TENSOR_NHWC
);
generateStrides
(
b_dim_padded
,
b_stride_padded
,
4
,
CUDNN_TENSOR_NHWC
);
return
dconv_descriptors
(
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
x_dim_padded
)
.
setStrides
(
4
,
x_stride_padded
)
.
setId
(
'x'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setId
(
'y'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
w_dim_padded
)
.
setStrides
(
4
,
w_stride_padded
)
.
setId
(
'w'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
b_dim_padded
)
.
setStrides
(
4
,
b_stride_padded
)
.
setId
(
's'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
x_dim_padded
)
.
setStrides
(
4
,
x_stride_padded
)
.
setId
(
'r'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
x_dim_padded
)
.
setStrides
(
4
,
x_stride_padded
)
.
setVirtual
()
.
setId
(
'A'
)
// after dconv
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
x_dim_padded
)
.
setStrides
(
4
,
x_stride_padded
)
.
setVirtual
()
.
setId
(
'B'
)
// after drelu
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
());
}
// create a cache for plan
std
::
unordered_map
<
std
::
string
,
cudnn_frontend
::
ExecutionPlan
>
plan_cache
;
// TODO: better name
std
::
string
getConvFusionString
(
int64_t
*
x_dim_padded
,
int64_t
*
padA
,
int64_t
*
convstrideA
,
int64_t
*
dilationA
,
int64_t
*
w_dim_padded
,
cudnnDataType_t
dataType
,
std
::
string
fusion_string
)
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
fusion_string
+=
'X'
;
fusion_string
+=
std
::
to_string
(
x_dim_padded
[
i
]);
}
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
fusion_string
+=
'W'
;
fusion_string
+=
std
::
to_string
(
w_dim_padded
[
i
]);
}
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
fusion_string
+=
'P'
;
fusion_string
+=
std
::
to_string
(
padA
[
i
]);
}
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
fusion_string
+=
'S'
;
fusion_string
+=
std
::
to_string
(
convstrideA
[
i
]);
}
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
fusion_string
+=
'D'
;
fusion_string
+=
std
::
to_string
(
dilationA
[
i
]);
}
fusion_string
+=
'T'
;
fusion_string
+=
std
::
to_string
(
dataType
);
return
fusion_string
;
}
cudnn_frontend
::
ExecutionPlan
&
getOrCreatePlan
(
cudnnHandle_t
handle_
,
std
::
stringstream
&
log_buf
,
cudnn_frontend
::
OperationGraph
&
opGraph
,
std
::
string
cache_string
,
bool
use_heuristic
=
true
){
auto
it
=
plan_cache
.
find
(
cache_string
);
if
(
it
!=
plan_cache
.
end
())
{
DEBUG_CUDNN_MSG
(
log_buf
,
"Found plan in cache"
);
return
it
->
second
;
}
else
{
if
(
use_heuristic
){
// TODO: confirm which mode to use
auto
heuristics
=
cudnn_frontend
::
EngineHeuristicsBuilder
()
.
setOperationGraph
(
opGraph
)
.
setHeurMode
(
CUDNN_HEUR_MODE_INSTANT
)
.
build
();
// try 3 times for now as WAR for no heuristic training
int
max_tries
=
3
,
count
=
0
;
auto
&
engine_configs
=
heuristics
.
getEngineConfig
(
max_tries
);
while
(
true
)
{
try
{
plan_cache
.
emplace
(
cache_string
,
std
::
move
(
cudnn_frontend
::
ExecutionPlanBuilder
()
.
setHandle
(
handle_
)
.
setEngineConfig
(
engine_configs
[
count
],
opGraph
.
getTag
())
.
build
()));
break
;
}
catch
(
cudnn_frontend
::
cudnnException
e
)
{
if
(
++
count
==
max_tries
)
throw
e
;
}
}
}
else
{
DEBUG_CUDNN_MSG
(
log_buf
,
"No plan in cache"
);
// How many engines support this operation graph ?
auto
total_engines
=
opGraph
.
getEngineCount
();
DEBUG_CUDNN_MSG
(
log_buf
,
opGraph
.
describe
()
<<
" has "
<<
total_engines
<<
" engines."
);
// We have to randomly pick one engine from [0, total_engines)
// Selecting "0" by default
auto
engine
=
cudnn_frontend
::
EngineBuilder
().
setGlobalEngineIdx
(
0
).
setOperationGraph
(
opGraph
).
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
engine
.
describe
());
auto
&
knobs
=
engine
.
getSupportedKnobs
();
for
(
auto
it
=
std
::
begin
(
knobs
);
it
!=
std
::
end
(
knobs
);
++
it
)
{
DEBUG_CUDNN_MSG
(
log_buf
,
it
->
describe
());
}
if
(
knobs
.
begin
()
!=
knobs
.
end
())
{
DEBUG_CUDNN_MSG
(
log_buf
,
"Updated knob choice"
);
knobs
.
begin
()
->
setChoice
(
knobs
.
begin
()
->
getMinValue
()
+
1
);
DEBUG_CUDNN_MSG
(
log_buf
,
knobs
.
begin
()
->
describe
());
}
// Createmplacee the requisite engine config
auto
engine_config
=
cudnn_frontend
::
EngineConfigBuilder
().
setEngine
(
engine
).
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
engine_config
.
describe
());
plan_cache
.
emplace
(
cache_string
,
std
::
move
(
cudnn_frontend
::
ExecutionPlanBuilder
().
setHandle
(
handle_
).
setEngineConfig
(
engine_config
).
build
()));
}
return
plan_cache
.
find
(
cache_string
)
->
second
;
}
}
void
run_conv_scale_bias_add_activation
(
int64_t
*
x_dim_padded
,
int64_t
*
pad
,
int64_t
*
convstride
,
int64_t
*
dilation
,
int64_t
*
w_dim_padded
,
int64_t
*
y_dim_padded
,
cudnnDataType_t
dataType
,
at
::
Half
*
devPtrX
,
at
::
Half
*
devPtrW
,
at
::
Half
*
devPtrY
,
at
::
Half
*
devPtrZ
,
at
::
Half
*
devPtrB
,
at
::
Half
*
devPtrI
)
{
cudnnHandle_t
handle_
=
torch
::
native
::
getCudnnHandle
();
std
::
stringstream
log_buf
;
try
{
int
convDim
=
2
;
// Creates the necessary tensor descriptors
common_convbias_descriptors
tensors
=
create_conv_bias_add_act_descriptors
(
x_dim_padded
,
pad
,
convstride
,
dilation
,
w_dim_padded
,
y_dim_padded
,
dataType
);
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
X_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
Y_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
W_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
Z_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
B_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTERADD_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTERBIAS_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTERCONV_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
OPTIONAL
>
(
tensors
).
describe
());
// Define the add operation
auto
scaleDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_MUL
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
scaleDesc
.
describe
());
// Define the bias operation
auto
biasDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_ADD
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
biasDesc
.
describe
());
// optional add
auto
addDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_ADD
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
addDesc
.
describe
());
// Define the activation operation
auto
actDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_RELU_FWD
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
actDesc
.
describe
());
// Define the convolution problem
auto
convDesc
=
cudnn_frontend
::
ConvDescBuilder
()
.
setDataType
(
CUDNN_DATA_FLOAT
)
.
setMathMode
(
CUDNN_CROSS_CORRELATION
)
.
setNDims
(
convDim
)
.
setStrides
(
convDim
,
convstride
)
.
setPrePadding
(
convDim
,
pad
)
.
setPostPadding
(
convDim
,
pad
)
.
setDilation
(
convDim
,
dilation
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
convDesc
.
describe
());
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
// Create a convolution Node
auto
conv_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
)
.
setxDesc
(
std
::
get
<
X_TENSOR
>
(
tensors
))
.
setwDesc
(
std
::
get
<
W_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
AFTERCONV_TENSOR
>
(
tensors
))
.
setcDesc
(
convDesc
)
.
setAlpha
(
alpha
)
.
setBeta
(
beta
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
conv_op
.
describe
());
// Create a Add Node with scaling parameters.
auto
scale_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
conv_op
.
getOutputTensor
())
.
setbDesc
(
std
::
get
<
Z_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
AFTERADD_TENSOR
>
(
tensors
))
.
setpwDesc
(
scaleDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
scale_op
.
describe
());
// Create a Bias Node.
auto
bias_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
scale_op
.
getOutputTensor
())
.
setbDesc
(
std
::
get
<
B_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
AFTERBIAS_TENSOR
>
(
tensors
))
.
setpwDesc
(
biasDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
bias_op
.
describe
());
// Create a optional add Node.
auto
add_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
bias_op
.
getOutputTensor
())
.
setbDesc
(
std
::
get
<
OPTIONAL
>
(
tensors
))
.
setyDesc
(
std
::
get
<
AFTEROPT_TENSOR
>
(
tensors
))
.
setpwDesc
(
addDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
add_op
.
describe
());
// Create an Activation Node.
auto
act_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
devPtrI
?
add_op
.
getOutputTensor
()
:
bias_op
.
getOutputTensor
())
.
setyDesc
(
std
::
get
<
Y_TENSOR
>
(
tensors
))
.
setpwDesc
(
actDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
act_op
.
describe
());
// Create an Operation Graph. In this case it is convolution add bias activation
std
::
array
<
cudnn_frontend
::
Operation
const
*
,
5
>
ops
=
{
&
conv_op
,
&
scale_op
,
&
bias_op
,
devPtrI
?
&
add_op
:
&
act_op
,
&
act_op
};
auto
opGraph
=
cudnn_frontend
::
OperationGraphBuilder
()
.
setHandle
(
handle_
)
.
setOperationGraph
(
devPtrI
?
ops
.
size
()
:
4
,
ops
.
data
())
.
build
();
// Create string encoding for plan caching
auto
cache_string
=
getConvFusionString
(
x_dim_padded
,
pad
,
convstride
,
dilation
,
w_dim_padded
,
dataType
,
opGraph
.
getTag
());
DEBUG_CUDNN_MSG
(
log_buf
,
"[convstring] "
<<
cache_string
);
auto
&
plan
=
getOrCreatePlan
(
handle_
,
log_buf
,
opGraph
,
cache_string
);
DEBUG_CUDNN_MSG
(
log_buf
,
"Plan tag: "
<<
plan
.
getTag
());
auto
workspace_size
=
plan
.
getWorkspaceSize
();
DEBUG_CUDNN_MSG
(
log_buf
,
plan
.
describe
()
<<
" requires workspace "
<<
workspace_size
);
void
*
workspace_ptr
=
nullptr
;
auto
workspace_tensor
=
at
::
empty
({(
workspace_size
+
3
)
/
4
},
at
::
TensorOptions
(
at
::
kCUDA
).
dtype
(
at
::
kFloat
));
if
(
workspace_size
>
0
)
{
workspace_ptr
=
workspace_tensor
.
data_ptr
<
float
>
();
}
void
*
data_ptrs
[]
=
{
devPtrX
,
devPtrY
,
devPtrW
,
devPtrZ
,
devPtrB
,
devPtrI
};
int64_t
uids
[]
=
{
'x'
,
'y'
,
'w'
,
'z'
,
'b'
,
'i'
};
auto
variantPack
=
cudnn_frontend
::
VariantPackBuilder
()
.
setWorkspacePointer
(
workspace_ptr
)
.
setDataPointers
(
devPtrI
?
6
:
5
,
data_ptrs
)
.
setUids
(
devPtrI
?
6
:
5
,
uids
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
"variantPack "
<<
variantPack
.
describe
());
cudnnStatus_t
status
=
cudnnBackendExecute
(
handle_
,
plan
.
get_raw_desc
(),
variantPack
.
get_raw_desc
());
checkCudnnErr
(
status
);
cudnn_frontend
::
throw_if
([
status
]()
{
return
(
status
!=
CUDNN_STATUS_SUCCESS
);
},
"Plan execute error"
);
}
catch
(
cudnn_frontend
::
cudnnException
e
)
{
std
::
cout
<<
log_buf
.
str
()
<<
"[ERROR] Exception "
<<
e
.
what
()
<<
std
::
endl
;
}
}
void
run_conv_scale_bias
(
int64_t
*
x_dim_padded
,
int64_t
*
pad
,
int64_t
*
convstride
,
int64_t
*
dilation
,
int64_t
*
w_dim_padded
,
int64_t
*
y_dim_padded
,
cudnnDataType_t
dataType
,
at
::
Half
*
devPtrX
,
at
::
Half
*
devPtrW
,
at
::
Half
*
devPtrY
,
at
::
Half
*
devPtrZ
,
at
::
Half
*
devPtrB
)
{
cudnnHandle_t
handle_
=
torch
::
native
::
getCudnnHandle
();
std
::
stringstream
log_buf
;
try
{
int
convDim
=
2
;
// Creates the necessary tensor descriptors
common_convbias_descriptors
tensors
=
create_conv_bias_add_act_descriptors
(
x_dim_padded
,
pad
,
convstride
,
dilation
,
w_dim_padded
,
y_dim_padded
,
dataType
);
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
X_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
Y_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
W_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
Z_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
B_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTERADD_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTERBIAS_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTERCONV_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
OPTIONAL
>
(
tensors
).
describe
());
// Define the add operation
auto
scaleDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_MUL
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
scaleDesc
.
describe
());
// Define the bias operation
auto
addDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_ADD
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
addDesc
.
describe
());
// Define the convolution problem
auto
convDesc
=
cudnn_frontend
::
ConvDescBuilder
()
.
setDataType
(
CUDNN_DATA_FLOAT
)
.
setMathMode
(
CUDNN_CROSS_CORRELATION
)
.
setNDims
(
convDim
)
.
setStrides
(
convDim
,
convstride
)
.
setPrePadding
(
convDim
,
pad
)
.
setPostPadding
(
convDim
,
pad
)
.
setDilation
(
convDim
,
dilation
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
convDesc
.
describe
());
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
// Create a convolution Node
auto
conv_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
)
.
setxDesc
(
std
::
get
<
X_TENSOR
>
(
tensors
))
.
setwDesc
(
std
::
get
<
W_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
AFTERCONV_TENSOR
>
(
tensors
))
.
setcDesc
(
convDesc
)
.
setAlpha
(
alpha
)
.
setBeta
(
beta
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
conv_op
.
describe
());
// Create a Add Node with scaling parameters.
auto
scale_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
conv_op
.
getOutputTensor
())
.
setbDesc
(
std
::
get
<
Z_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
AFTERADD_TENSOR
>
(
tensors
))
// TODO: change enum to aftermul
.
setpwDesc
(
scaleDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
scale_op
.
describe
());
// Create a Bias Node.
auto
add_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
scale_op
.
getOutputTensor
())
.
setbDesc
(
std
::
get
<
B_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
Y_TENSOR
>
(
tensors
))
.
setpwDesc
(
addDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
add_op
.
describe
());
// Create an Operation Graph. In this case it is convolution add bias activation
std
::
array
<
cudnn_frontend
::
Operation
const
*
,
3
>
ops
=
{
&
conv_op
,
&
scale_op
,
&
add_op
};
auto
opGraph
=
cudnn_frontend
::
OperationGraphBuilder
()
.
setHandle
(
handle_
)
.
setOperationGraph
(
ops
.
size
(),
ops
.
data
())
.
build
();
// Create string encoding for plan caching
auto
cache_string
=
getConvFusionString
(
x_dim_padded
,
pad
,
convstride
,
dilation
,
w_dim_padded
,
dataType
,
opGraph
.
getTag
());
DEBUG_CUDNN_MSG
(
log_buf
,
"[convstring] "
<<
cache_string
);
auto
&
plan
=
getOrCreatePlan
(
handle_
,
log_buf
,
opGraph
,
cache_string
);
DEBUG_CUDNN_MSG
(
log_buf
,
"Plan tag: "
<<
plan
.
getTag
());
auto
workspace_size
=
plan
.
getWorkspaceSize
();
DEBUG_CUDNN_MSG
(
log_buf
,
plan
.
describe
()
<<
" requires workspace "
<<
workspace_size
);
void
*
workspace_ptr
=
nullptr
;
auto
workspace_tensor
=
at
::
empty
({(
workspace_size
+
3
)
/
4
},
at
::
TensorOptions
(
at
::
kCUDA
).
dtype
(
at
::
kFloat
));
if
(
workspace_size
>
0
)
{
workspace_ptr
=
workspace_tensor
.
data_ptr
<
float
>
();
}
void
*
data_ptrs
[]
=
{
devPtrX
,
devPtrY
,
devPtrW
,
devPtrZ
,
devPtrB
};
int64_t
uids
[]
=
{
'x'
,
'y'
,
'w'
,
'z'
,
'b'
};
auto
variantPack
=
cudnn_frontend
::
VariantPackBuilder
()
.
setWorkspacePointer
(
workspace_ptr
)
.
setDataPointers
(
5
,
data_ptrs
)
.
setUids
(
5
,
uids
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
"variantPack "
<<
variantPack
.
describe
());
cudnnStatus_t
status
=
cudnnBackendExecute
(
handle_
,
plan
.
get_raw_desc
(),
variantPack
.
get_raw_desc
());
checkCudnnErr
(
status
);
cudnn_frontend
::
throw_if
([
status
]()
{
return
(
status
!=
CUDNN_STATUS_SUCCESS
);
},
"Plan execute error"
);
}
catch
(
cudnn_frontend
::
cudnnException
e
)
{
std
::
cout
<<
log_buf
.
str
()
<<
"[ERROR] Exception "
<<
e
.
what
()
<<
std
::
endl
;
}
}
void
run_dconv_drelu_dscale
(
int64_t
*
x_dim_padded
,
int64_t
*
pad
,
int64_t
*
convstride
,
int64_t
*
dilation
,
int64_t
*
w_dim_padded
,
int64_t
*
y_dim_padded
,
cudnnDataType_t
dataType
,
at
::
Half
*
devPtrX
,
at
::
Half
*
devPtrW
,
at
::
Half
*
devPtrY
,
at
::
Half
*
devPtrZ
,
at
::
Half
*
devPtrR
)
{
cudnnHandle_t
handle_
=
torch
::
native
::
getCudnnHandle
();
std
::
stringstream
log_buf
;
try
{
int
convDim
=
2
;
// Creates the necessary tensor descriptors
dconv_descriptors
tensors
=
create_dconv_descriptors
(
x_dim_padded
,
pad
,
convstride
,
dilation
,
w_dim_padded
,
y_dim_padded
,
dataType
);
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
X_OR_DX_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
DY_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
W_OR_DW_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
SCALE_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
RELU_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTER_DCONV_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTER_DRELU_TENSOR
>
(
tensors
).
describe
());
// Define the convolution problem
auto
convDesc
=
cudnn_frontend
::
ConvDescBuilder
()
.
setDataType
(
CUDNN_DATA_FLOAT
)
.
setMathMode
(
CUDNN_CROSS_CORRELATION
)
.
setNDims
(
convDim
)
.
setStrides
(
convDim
,
convstride
)
.
setPrePadding
(
convDim
,
pad
)
.
setPostPadding
(
convDim
,
pad
)
.
setDilation
(
convDim
,
dilation
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
convDesc
.
describe
());
// Define the activation backward operation
auto
actDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_RELU_BWD
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
actDesc
.
describe
());
// Define the scale backward operation
auto
scaleDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_MUL
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
scaleDesc
.
describe
());
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
// Create a convolution Node
auto
conv_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
)
.
setdxDesc
(
std
::
get
<
AFTER_DCONV_TENSOR
>
(
tensors
))
.
setwDesc
(
std
::
get
<
W_OR_DW_TENSOR
>
(
tensors
))
.
setdyDesc
(
std
::
get
<
DY_TENSOR
>
(
tensors
))
.
setcDesc
(
convDesc
)
.
setAlpha
(
alpha
)
.
setBeta
(
beta
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
conv_op
.
describe
());
// TODO: do we need getOutputTensor(), and what it returns in backward case?
// Create an relu backward Node.
auto
act_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setdyDesc
(
std
::
get
<
AFTER_DCONV_TENSOR
>
(
tensors
))
.
setxDesc
(
std
::
get
<
RELU_TENSOR
>
(
tensors
))
.
setdxDesc
(
std
::
get
<
AFTER_DRELU_TENSOR
>
(
tensors
))
.
setpwDesc
(
actDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
act_op
.
describe
());
// Create a Scale Node.
auto
scale_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
std
::
get
<
AFTER_DRELU_TENSOR
>
(
tensors
))
.
setbDesc
(
std
::
get
<
SCALE_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
X_OR_DX_TENSOR
>
(
tensors
))
.
setpwDesc
(
scaleDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
scale_op
.
describe
());
// Create an Operation Graph. In this case it is convolution add bias activation
std
::
array
<
cudnn_frontend
::
Operation
const
*
,
3
>
ops
=
{
&
conv_op
,
&
act_op
,
&
scale_op
};
auto
opGraph
=
cudnn_frontend
::
OperationGraphBuilder
()
.
setHandle
(
handle_
)
.
setOperationGraph
(
ops
.
size
(),
ops
.
data
())
.
build
();
// Create string encoding for plan caching
auto
cache_string
=
getConvFusionString
(
x_dim_padded
,
pad
,
convstride
,
dilation
,
w_dim_padded
,
dataType
,
opGraph
.
getTag
());
DEBUG_CUDNN_MSG
(
log_buf
,
"[convstring] "
<<
cache_string
);
auto
&
plan
=
getOrCreatePlan
(
handle_
,
log_buf
,
opGraph
,
cache_string
);
DEBUG_CUDNN_MSG
(
log_buf
,
"Plan tag: "
<<
plan
.
getTag
());
auto
workspace_size
=
plan
.
getWorkspaceSize
();
DEBUG_CUDNN_MSG
(
log_buf
,
plan
.
describe
()
<<
" requires workspace "
<<
workspace_size
);
void
*
workspace_ptr
=
nullptr
;
auto
workspace_tensor
=
at
::
empty
({(
workspace_size
+
3
)
/
4
},
at
::
TensorOptions
(
at
::
kCUDA
).
dtype
(
at
::
kFloat
));
if
(
workspace_size
>
0
)
{
workspace_ptr
=
workspace_tensor
.
data_ptr
<
float
>
();
}
void
*
data_ptrs
[]
=
{
devPtrX
,
devPtrY
,
devPtrW
,
devPtrZ
,
devPtrR
};
int64_t
uids
[]
=
{
'x'
,
'y'
,
'w'
,
's'
,
'r'
};
auto
variantPack
=
cudnn_frontend
::
VariantPackBuilder
()
.
setWorkspacePointer
(
workspace_ptr
)
.
setDataPointers
(
5
,
data_ptrs
)
.
setUids
(
5
,
uids
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
"variantPack "
<<
variantPack
.
describe
());
cudnnStatus_t
status
=
cudnnBackendExecute
(
handle_
,
plan
.
get_raw_desc
(),
variantPack
.
get_raw_desc
());
checkCudnnErr
(
status
);
cudnn_frontend
::
throw_if
([
status
]()
{
return
(
status
!=
CUDNN_STATUS_SUCCESS
);
},
"Plan execute error"
);
}
catch
(
cudnn_frontend
::
cudnnException
e
)
{
std
::
cout
<<
log_buf
.
str
()
<<
"[ERROR] Exception "
<<
e
.
what
()
<<
std
::
endl
;
}
}
void
run_dconv
(
int64_t
*
x_dim_padded
,
int64_t
*
pad
,
int64_t
*
convstride
,
int64_t
*
dilation
,
int64_t
*
w_dim_padded
,
int64_t
*
y_dim_padded
,
cudnnDataType_t
dataType
,
at
::
Half
*
devPtrX
,
at
::
Half
*
devPtrW
,
at
::
Half
*
devPtrY
,
cudnnBackendDescriptorType_t
mode
)
{
cudnnHandle_t
handle_
=
torch
::
native
::
getCudnnHandle
();
std
::
stringstream
log_buf
;
try
{
int
convDim
=
2
;
// Creates the necessary tensor descriptors
dconv_descriptors
tensors
=
create_dconv_descriptors
(
x_dim_padded
,
pad
,
convstride
,
dilation
,
w_dim_padded
,
y_dim_padded
,
dataType
);
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
X_OR_DX_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
DY_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
W_OR_DW_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
SCALE_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
RELU_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTER_DCONV_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTER_DRELU_TENSOR
>
(
tensors
).
describe
());
// Define the convolution problem
auto
convDesc
=
cudnn_frontend
::
ConvDescBuilder
()
.
setDataType
(
CUDNN_DATA_FLOAT
)
.
setMathMode
(
CUDNN_CROSS_CORRELATION
)
.
setNDims
(
convDim
)
.
setStrides
(
convDim
,
convstride
)
.
setPrePadding
(
convDim
,
pad
)
.
setPostPadding
(
convDim
,
pad
)
.
setDilation
(
convDim
,
dilation
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
convDesc
.
describe
());
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
// Create a convolution Node
// mode should be one of following
// CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
// CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
auto
conv_op_builder
=
cudnn_frontend
::
OperationBuilder
(
mode
);
if
(
mode
==
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
)
{
conv_op_builder
.
setdxDesc
(
std
::
get
<
X_OR_DX_TENSOR
>
(
tensors
))
.
setwDesc
(
std
::
get
<
W_OR_DW_TENSOR
>
(
tensors
))
.
setdyDesc
(
std
::
get
<
DY_TENSOR
>
(
tensors
))
.
setcDesc
(
convDesc
)
.
setAlpha
(
alpha
)
.
setBeta
(
beta
);
}
else
{
conv_op_builder
.
setxDesc
(
std
::
get
<
X_OR_DX_TENSOR
>
(
tensors
))
.
setdwDesc
(
std
::
get
<
W_OR_DW_TENSOR
>
(
tensors
))
.
setdyDesc
(
std
::
get
<
DY_TENSOR
>
(
tensors
))
.
setcDesc
(
convDesc
)
.
setAlpha
(
alpha
)
.
setBeta
(
beta
);
}
auto
conv_op
=
conv_op_builder
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
conv_op
.
describe
());
// Create an Operation Graph. In this case it is convolution add bias activation
std
::
array
<
cudnn_frontend
::
Operation
const
*
,
1
>
ops
=
{
&
conv_op
};
auto
opGraph
=
cudnn_frontend
::
OperationGraphBuilder
()
.
setHandle
(
handle_
)
.
setOperationGraph
(
ops
.
size
(),
ops
.
data
())
.
build
();
// Create string encoding for plan caching
auto
cache_string
=
getConvFusionString
(
x_dim_padded
,
pad
,
convstride
,
dilation
,
w_dim_padded
,
dataType
,
opGraph
.
getTag
());
DEBUG_CUDNN_MSG
(
log_buf
,
"[convstring] "
<<
cache_string
);
auto
&
plan
=
getOrCreatePlan
(
handle_
,
log_buf
,
opGraph
,
cache_string
);
DEBUG_CUDNN_MSG
(
log_buf
,
"Plan tag: "
<<
plan
.
getTag
());
auto
workspace_size
=
plan
.
getWorkspaceSize
();
DEBUG_CUDNN_MSG
(
log_buf
,
plan
.
describe
()
<<
" requires workspace "
<<
workspace_size
);
void
*
workspace_ptr
=
nullptr
;
auto
workspace_tensor
=
at
::
empty
({(
workspace_size
+
3
)
/
4
},
at
::
TensorOptions
(
at
::
kCUDA
).
dtype
(
at
::
kFloat
));
if
(
workspace_size
>
0
)
{
workspace_ptr
=
workspace_tensor
.
data_ptr
<
float
>
();
}
void
*
data_ptrs
[]
=
{
devPtrX
,
devPtrY
,
devPtrW
};
int64_t
uids
[]
=
{
'x'
,
'y'
,
'w'
};
auto
variantPack
=
cudnn_frontend
::
VariantPackBuilder
()
.
setWorkspacePointer
(
workspace_ptr
)
.
setDataPointers
(
3
,
data_ptrs
)
.
setUids
(
3
,
uids
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
"variantPack "
<<
variantPack
.
describe
());
cudnnStatus_t
status
=
cudnnBackendExecute
(
handle_
,
plan
.
get_raw_desc
(),
variantPack
.
get_raw_desc
());
checkCudnnErr
(
status
);
cudnn_frontend
::
throw_if
([
status
]()
{
return
(
status
!=
CUDNN_STATUS_SUCCESS
);
},
"Plan execute error"
);
}
catch
(
cudnn_frontend
::
cudnnException
e
)
{
std
::
cout
<<
log_buf
.
str
()
<<
"[ERROR] Exception "
<<
e
.
what
()
<<
std
::
endl
;
}
}
void
run_dconv_add
(
int64_t
*
x_dim_padded
,
int64_t
*
pad
,
int64_t
*
convstride
,
int64_t
*
dilation
,
int64_t
*
w_dim_padded
,
int64_t
*
y_dim_padded
,
cudnnDataType_t
dataType
,
at
::
Half
*
devPtrX
,
at
::
Half
*
devPtrW
,
at
::
Half
*
devPtrY
,
at
::
Half
*
devPtrR
)
{
cudnnHandle_t
handle_
=
torch
::
native
::
getCudnnHandle
();
std
::
stringstream
log_buf
;
try
{
int
convDim
=
2
;
// Creates the necessary tensor descriptors
dconv_descriptors
tensors
=
create_dconv_descriptors
(
x_dim_padded
,
pad
,
convstride
,
dilation
,
w_dim_padded
,
y_dim_padded
,
dataType
);
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
X_OR_DX_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
DY_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
W_OR_DW_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
SCALE_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
RELU_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTER_DCONV_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTER_DRELU_TENSOR
>
(
tensors
).
describe
());
// Define the convolution problem
auto
convDesc
=
cudnn_frontend
::
ConvDescBuilder
()
.
setDataType
(
CUDNN_DATA_FLOAT
)
.
setMathMode
(
CUDNN_CROSS_CORRELATION
)
.
setNDims
(
convDim
)
.
setStrides
(
convDim
,
convstride
)
.
setPrePadding
(
convDim
,
pad
)
.
setPostPadding
(
convDim
,
pad
)
.
setDilation
(
convDim
,
dilation
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
convDesc
.
describe
());
// Define the add backward operation
auto
addDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_ADD
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
addDesc
.
describe
());
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
// Create a convolution Node
auto
conv_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
)
.
setdxDesc
(
std
::
get
<
AFTER_DCONV_TENSOR
>
(
tensors
))
.
setwDesc
(
std
::
get
<
W_OR_DW_TENSOR
>
(
tensors
))
.
setdyDesc
(
std
::
get
<
DY_TENSOR
>
(
tensors
))
.
setcDesc
(
convDesc
)
.
setAlpha
(
alpha
)
.
setBeta
(
beta
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
conv_op
.
describe
());
// TODO: do we need getOutputTensor(), and what it returns in backward case?
// Create add Node.
auto
add_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
std
::
get
<
AFTER_DCONV_TENSOR
>
(
tensors
))
.
setbDesc
(
std
::
get
<
RELU_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
X_OR_DX_TENSOR
>
(
tensors
))
.
setpwDesc
(
addDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
add_op
.
describe
());
// Create an Operation Graph. In this case it is convolution add bias activation
std
::
array
<
cudnn_frontend
::
Operation
const
*
,
2
>
ops
=
{
&
conv_op
,
&
add_op
};
auto
opGraph
=
cudnn_frontend
::
OperationGraphBuilder
()
.
setHandle
(
handle_
)
.
setOperationGraph
(
ops
.
size
(),
ops
.
data
())
.
build
();
// Create string encoding for plan caching
auto
cache_string
=
getConvFusionString
(
x_dim_padded
,
pad
,
convstride
,
dilation
,
w_dim_padded
,
dataType
,
opGraph
.
getTag
());
DEBUG_CUDNN_MSG
(
log_buf
,
"[convstring] "
<<
cache_string
);
auto
&
plan
=
getOrCreatePlan
(
handle_
,
log_buf
,
opGraph
,
cache_string
);
DEBUG_CUDNN_MSG
(
log_buf
,
"Plan tag: "
<<
plan
.
getTag
());
auto
workspace_size
=
plan
.
getWorkspaceSize
();
DEBUG_CUDNN_MSG
(
log_buf
,
plan
.
describe
()
<<
" requires workspace "
<<
workspace_size
);
void
*
workspace_ptr
=
nullptr
;
auto
workspace_tensor
=
at
::
empty
({(
workspace_size
+
3
)
/
4
},
at
::
TensorOptions
(
at
::
kCUDA
).
dtype
(
at
::
kFloat
));
if
(
workspace_size
>
0
)
{
workspace_ptr
=
workspace_tensor
.
data_ptr
<
float
>
();
}
void
*
data_ptrs
[]
=
{
devPtrX
,
devPtrY
,
devPtrW
,
devPtrR
};
int64_t
uids
[]
=
{
'x'
,
'y'
,
'w'
,
'r'
};
auto
variantPack
=
cudnn_frontend
::
VariantPackBuilder
()
.
setWorkspacePointer
(
workspace_ptr
)
.
setDataPointers
(
4
,
data_ptrs
)
.
setUids
(
4
,
uids
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
"variantPack "
<<
variantPack
.
describe
());
cudnnStatus_t
status
=
cudnnBackendExecute
(
handle_
,
plan
.
get_raw_desc
(),
variantPack
.
get_raw_desc
());
checkCudnnErr
(
status
);
cudnn_frontend
::
throw_if
([
status
]()
{
return
(
status
!=
CUDNN_STATUS_SUCCESS
);
},
"Plan execute error"
);
}
catch
(
cudnn_frontend
::
cudnnException
e
)
{
std
::
cout
<<
log_buf
.
str
()
<<
"[ERROR] Exception "
<<
e
.
what
()
<<
std
::
endl
;
}
}
// inputs contains x,w,z,b,(i)
std
::
vector
<
at
::
Tensor
>
bottleneck_forward
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
)
{
std
::
cout
<<
std
::
fixed
;
// create output vector
std
::
vector
<
at
::
Tensor
>
outputs
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
// setup dimensions
int64_t
dimA
[]
=
{
0
,
0
,
0
,
0
};
int64_t
filterdimA1
[]
=
{
0
,
0
,
0
,
0
};
int64_t
filterdimA2
[]
=
{
0
,
0
,
0
,
0
};
int64_t
filterdimA3
[]
=
{
0
,
0
,
0
,
0
};
int64_t
filterdimA4
[]
=
{
0
,
0
,
0
,
0
};
// All dim calculation after this order of n,c,h,w
int
axis
[]
{
0
,
1
,
2
,
3
};
if
(
explicit_nhwc
)
{
axis
[
0
]
=
0
;
axis
[
1
]
=
3
;
axis
[
2
]
=
1
;
axis
[
3
]
=
2
;
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
dimA
[
dim
]
=
inputs
[
0
].
size
(
axis
[
dim
]);
filterdimA1
[
dim
]
=
inputs
[
1
].
size
(
axis
[
dim
]);
filterdimA2
[
dim
]
=
inputs
[
2
].
size
(
axis
[
dim
]);
filterdimA3
[
dim
]
=
inputs
[
3
].
size
(
axis
[
dim
]);
}
if
(
stride_1X1
!=
1
||
filterdimA3
[
0
]
!=
dimA
[
1
])
{
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
filterdimA4
[
dim
]
=
inputs
[
10
].
size
(
axis
[
dim
]);
}
}
// output dim in n,c,h,w used by backend
int64_t
outdimA1
[]
=
{
0
,
0
,
0
,
0
};
// Computed Below
int64_t
outdimA2
[]
=
{
0
,
0
,
0
,
0
};
// Computed Below
int64_t
outdimA3
[]
=
{
0
,
0
,
0
,
0
};
// Computed Below
// use these fixed value for test run
int64_t
padA
[]
=
{
0
,
0
};
int64_t
padA1
[]
=
{
1
,
1
};
int64_t
dilationA
[]
=
{
1
,
1
};
int64_t
convstrideA
[]
=
{
1
,
1
};
int64_t
convstride1X1
[]
=
{
stride_1X1
,
stride_1X1
};
// compute output from pad/stride/dilation
outdimA1
[
0
]
=
dimA
[
0
];
outdimA1
[
1
]
=
filterdimA1
[
0
];
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA1
[
dim
+
2
]
=
getFwdConvOutputDim
(
dimA
[
dim
+
2
],
padA
[
dim
],
filterdimA1
[
dim
+
2
],
convstride1X1
[
dim
],
dilationA
[
dim
]);
}
outdimA2
[
0
]
=
outdimA1
[
0
];
outdimA2
[
1
]
=
filterdimA2
[
0
];
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA2
[
dim
+
2
]
=
getFwdConvOutputDim
(
outdimA1
[
dim
+
2
],
padA1
[
dim
],
filterdimA2
[
dim
+
2
],
convstrideA
[
dim
],
dilationA
[
dim
]);
}
outdimA3
[
0
]
=
outdimA2
[
0
];
outdimA3
[
1
]
=
filterdimA3
[
0
];
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA3
[
dim
+
2
]
=
getFwdConvOutputDim
(
outdimA2
[
dim
+
2
],
padA
[
dim
],
filterdimA3
[
dim
+
2
],
convstrideA
[
dim
],
dilationA
[
dim
]);
}
// Create output tensor in the correct shape in pytorch's view
int64_t
outdim1
[]
=
{
0
,
0
,
0
,
0
};
int64_t
outdim2
[]
=
{
0
,
0
,
0
,
0
};
int64_t
outdim3
[]
=
{
0
,
0
,
0
,
0
};
if
(
explicit_nhwc
)
{
axis
[
0
]
=
0
;
axis
[
1
]
=
2
;
axis
[
2
]
=
3
;
axis
[
3
]
=
1
;
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
outdim1
[
dim
]
=
outdimA1
[
axis
[
dim
]];
outdim2
[
dim
]
=
outdimA2
[
axis
[
dim
]];
outdim3
[
dim
]
=
outdimA3
[
axis
[
dim
]];
}
// run
at
::
Half
*
x
=
inputs
[
0
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
w
=
inputs
[
1
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
4
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
b
=
inputs
[
7
].
data_ptr
<
at
::
Half
>
();
auto
out1
=
at
::
empty
(
outdim1
,
inputs
[
0
].
type
(),
output_format
);
at
::
Half
*
y1
=
out1
.
data_ptr
<
at
::
Half
>
();
run_conv_scale_bias_add_activation
(
dimA
,
padA
,
convstride1X1
,
dilationA
,
filterdimA1
,
outdimA1
,
CUDNN_DATA_HALF
,
x
,
w
,
y1
,
z
,
b
,
nullptr
);
DEBUG_MSG
(
"[DEBUG] new relu1 : "
<<
out1
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
w
=
inputs
[
2
].
data_ptr
<
at
::
Half
>
();
z
=
inputs
[
5
].
data_ptr
<
at
::
Half
>
();
b
=
inputs
[
8
].
data_ptr
<
at
::
Half
>
();
auto
out2
=
at
::
empty
(
outdim2
,
inputs
[
0
].
type
(),
output_format
);
at
::
Half
*
y2
=
out2
.
data_ptr
<
at
::
Half
>
();
run_conv_scale_bias_add_activation
(
outdimA1
,
padA1
,
convstrideA
,
dilationA
,
filterdimA2
,
outdimA2
,
CUDNN_DATA_HALF
,
y1
,
w
,
y2
,
z
,
b
,
nullptr
);
DEBUG_MSG
(
"[DEBUG] new relu2 : "
<<
out2
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
// create output of conv3
auto
out3
=
at
::
empty
(
outdim3
,
inputs
[
0
].
type
(),
output_format
);
at
::
Half
*
y3
=
out3
.
data_ptr
<
at
::
Half
>
();
// create output of conv4 that may exist
auto
identity
=
at
::
empty_like
(
out3
);
at
::
Half
*
yi
=
identity
.
data_ptr
<
at
::
Half
>
();
if
(
stride_1X1
!=
1
||
filterdimA3
[
0
]
!=
dimA
[
1
]){
w
=
inputs
[
10
].
data_ptr
<
at
::
Half
>
();
z
=
inputs
[
11
].
data_ptr
<
at
::
Half
>
();
b
=
inputs
[
12
].
data_ptr
<
at
::
Half
>
();
run_conv_scale_bias
(
dimA
,
padA
,
convstride1X1
,
dilationA
,
filterdimA4
,
outdimA3
,
CUDNN_DATA_HALF
,
x
,
w
,
yi
,
z
,
b
);
DEBUG_MSG
(
"[DEBUG] new downsample : "
<<
identity
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
}
else
{
yi
=
x
;
}
w
=
inputs
[
3
].
data_ptr
<
at
::
Half
>
();
z
=
inputs
[
6
].
data_ptr
<
at
::
Half
>
();
b
=
inputs
[
9
].
data_ptr
<
at
::
Half
>
();
run_conv_scale_bias_add_activation
(
outdimA2
,
padA
,
convstrideA
,
dilationA
,
filterdimA3
,
outdimA3
,
CUDNN_DATA_HALF
,
y2
,
w
,
y3
,
z
,
b
,
yi
);
DEBUG_MSG
(
"[DEBUG] new relu3 : "
<<
out3
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
outputs
.
push_back
(
out1
);
outputs
.
push_back
(
out2
);
outputs
.
push_back
(
out3
);
return
outputs
;
}
std
::
vector
<
at
::
Tensor
>
bottleneck_backward
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
)
{
bool
requires_grad
=
inputs
[
0
].
requires_grad
();
std
::
cout
<<
std
::
fixed
;
// create output vector
std
::
vector
<
at
::
Tensor
>
outputs
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
// setup dimensions
int64_t
dimA
[]
=
{
0
,
0
,
0
,
0
};
int64_t
filterdimA1
[]
=
{
0
,
0
,
0
,
0
};
int64_t
filterdimA2
[]
=
{
0
,
0
,
0
,
0
};
int64_t
filterdimA3
[]
=
{
0
,
0
,
0
,
0
};
int64_t
filterdimA4
[]
=
{
0
,
0
,
0
,
0
};
// All dim calculation after this order of n,c,h,w
int
axis
[]
{
0
,
1
,
2
,
3
};
if
(
explicit_nhwc
)
{
axis
[
0
]
=
0
;
axis
[
1
]
=
3
;
axis
[
2
]
=
1
;
axis
[
3
]
=
2
;
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
dimA
[
dim
]
=
inputs
[
0
].
size
(
axis
[
dim
]);
filterdimA1
[
dim
]
=
inputs
[
1
].
size
(
axis
[
dim
]);
filterdimA2
[
dim
]
=
inputs
[
2
].
size
(
axis
[
dim
]);
filterdimA3
[
dim
]
=
inputs
[
3
].
size
(
axis
[
dim
]);
}
if
(
stride_1X1
!=
1
||
filterdimA3
[
0
]
!=
dimA
[
1
])
{
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
filterdimA4
[
dim
]
=
inputs
[
14
].
size
(
axis
[
dim
]);
}
}
// output dim in n,c,h,w used by backend
int64_t
outdimA1
[]
=
{
0
,
0
,
0
,
0
};
// Computed Below
int64_t
outdimA2
[]
=
{
0
,
0
,
0
,
0
};
// Computed Below
int64_t
outdimA3
[]
=
{
0
,
0
,
0
,
0
};
// Computed Below
// use these fixed value for test run
int64_t
padA
[]
=
{
0
,
0
};
int64_t
padA1
[]
=
{
1
,
1
};
int64_t
dilationA
[]
=
{
1
,
1
};
int64_t
convstrideA
[]
=
{
1
,
1
};
int64_t
convstride1X1
[]
=
{
stride_1X1
,
stride_1X1
};
// compute output from pad/stride/dilation
outdimA1
[
0
]
=
dimA
[
0
];
outdimA1
[
1
]
=
filterdimA1
[
0
];
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA1
[
dim
+
2
]
=
getFwdConvOutputDim
(
dimA
[
dim
+
2
],
padA
[
dim
],
filterdimA1
[
dim
+
2
],
convstride1X1
[
dim
],
dilationA
[
dim
]);
}
outdimA2
[
0
]
=
outdimA1
[
0
];
outdimA2
[
1
]
=
filterdimA2
[
0
];
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA2
[
dim
+
2
]
=
getFwdConvOutputDim
(
outdimA1
[
dim
+
2
],
padA1
[
dim
],
filterdimA2
[
dim
+
2
],
convstrideA
[
dim
],
dilationA
[
dim
]);
}
outdimA3
[
0
]
=
outdimA2
[
0
];
outdimA3
[
1
]
=
filterdimA3
[
0
];
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA3
[
dim
+
2
]
=
getFwdConvOutputDim
(
outdimA2
[
dim
+
2
],
padA
[
dim
],
filterdimA3
[
dim
+
2
],
convstrideA
[
dim
],
dilationA
[
dim
]);
}
// Create output tensor in the correct shape in pytorch's view
int64_t
outdim1
[]
=
{
0
,
0
,
0
,
0
};
int64_t
outdim2
[]
=
{
0
,
0
,
0
,
0
};
int64_t
outdim3
[]
=
{
0
,
0
,
0
,
0
};
if
(
explicit_nhwc
)
{
axis
[
0
]
=
0
;
axis
[
1
]
=
2
;
axis
[
2
]
=
3
;
axis
[
3
]
=
1
;
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
outdim1
[
dim
]
=
outdimA1
[
axis
[
dim
]];
outdim2
[
dim
]
=
outdimA2
[
axis
[
dim
]];
outdim3
[
dim
]
=
outdimA3
[
axis
[
dim
]];
}
// dconv3+drelu2+dscale2
at
::
Half
*
conv_in
=
inputs
[
13
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
dy3
=
inputs
[
10
].
data_ptr
<
at
::
Half
>
();
DEBUG_MSG
(
"[DEBUG] new dconv3 : "
<<
inputs
[
10
].
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
// wgrad
auto
wgrad3
=
at
::
empty_like
(
inputs
[
3
]);
at
::
Half
*
dw3
=
wgrad3
.
data_ptr
<
at
::
Half
>
();
run_dconv
(
outdimA2
,
padA
,
convstrideA
,
dilationA
,
filterdimA3
,
outdimA3
,
CUDNN_DATA_HALF
,
conv_in
,
dw3
,
dy3
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
);
// dgrad
auto
grad_out2
=
at
::
empty
(
outdim2
,
inputs
[
0
].
type
(),
output_format
);
at
::
Half
*
dy2
=
grad_out2
.
data_ptr
<
at
::
Half
>
();
at
::
Half
*
w
=
inputs
[
3
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
5
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
relu2
=
inputs
[
13
].
data_ptr
<
at
::
Half
>
();
run_dconv_drelu_dscale
(
outdimA2
,
padA
,
convstrideA
,
dilationA
,
filterdimA3
,
outdimA3
,
CUDNN_DATA_HALF
,
dy2
,
w
,
dy3
,
z
,
relu2
);
DEBUG_MSG
(
"[DEBUG] new dconv2 : "
<<
grad_out2
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
// dconv2+drelu1+dscale1
conv_in
=
inputs
[
12
].
data_ptr
<
at
::
Half
>
();
// wgrad
auto
wgrad2
=
at
::
empty_like
(
inputs
[
2
]);
at
::
Half
*
dw2
=
wgrad2
.
data_ptr
<
at
::
Half
>
();
run_dconv
(
outdimA1
,
padA1
,
convstrideA
,
dilationA
,
filterdimA2
,
outdimA2
,
CUDNN_DATA_HALF
,
conv_in
,
dw2
,
dy2
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
);
// dgrad
auto
grad_out1
=
at
::
empty
(
outdim1
,
inputs
[
0
].
type
(),
output_format
);
at
::
Half
*
dy1
=
grad_out1
.
data_ptr
<
at
::
Half
>
();
w
=
inputs
[
2
].
data_ptr
<
at
::
Half
>
();
z
=
inputs
[
4
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
relu1
=
inputs
[
12
].
data_ptr
<
at
::
Half
>
();
// fused dgrad
run_dconv_drelu_dscale
(
outdimA1
,
padA1
,
convstrideA
,
dilationA
,
filterdimA2
,
outdimA2
,
CUDNN_DATA_HALF
,
dy1
,
w
,
dy2
,
z
,
relu1
);
/*
// backward strided conv cannot be fused
// if stride == 1 but channel changes, we can fuse here
if (stride_1X1 != 1){
// dgrad
run_dconv(outdimA1,
padA1,
convstride1X1,
dilationA,
filterdimA2,
outdimA2,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
// mul fused mask
grad_out1.mul_(inputs[15]);
}
else {
at::Half* relu1 = inputs[12].data_ptr<at::Half>();
// fused dgrad
run_dconv_drelu_dscale(outdimA1,
padA1,
convstride1X1,
dilationA,
filterdimA2,
outdimA2,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
z,
relu1);
}
*/
DEBUG_MSG
(
"[DEBUG] new dconv1 : "
<<
grad_out1
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
// create grads of conv4 that may exist
auto
grad_x_conv4
=
at
::
empty_like
(
inputs
[
0
]);
at
::
Half
*
dx_conv4
=
grad_x_conv4
.
data_ptr
<
at
::
Half
>
();
at
::
Tensor
wgrad4
;
// x used for dconv1 and dconv4 wgrad
at
::
Half
*
x
=
inputs
[
0
].
data_ptr
<
at
::
Half
>
();
if
(
stride_1X1
!=
1
||
filterdimA3
[
0
]
!=
dimA
[
1
]){
w
=
inputs
[
14
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
dy_conv4
=
inputs
[
11
].
data_ptr
<
at
::
Half
>
();
if
(
requires_grad
)
{
run_dconv
(
dimA
,
padA
,
convstride1X1
,
dilationA
,
filterdimA4
,
outdimA3
,
CUDNN_DATA_HALF
,
dx_conv4
,
w
,
dy_conv4
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
);
// we don't print here since we can't hook out this grad in pytorch alone to compare, due to addition with dx
// DEBUG_MSG("[DEBUG] new dx_identity : " << grad_x_conv4.to(at::kFloat).sum().item<float>());
}
// wgrad
wgrad4
=
at
::
empty_like
(
inputs
[
14
]);
at
::
Half
*
dw4
=
wgrad4
.
data_ptr
<
at
::
Half
>
();
run_dconv
(
dimA
,
padA
,
convstride1X1
,
dilationA
,
filterdimA4
,
outdimA3
,
CUDNN_DATA_HALF
,
x
,
dw4
,
dy_conv4
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
);
}
else
{
// if there is no downsample, dx_conv4 is fork of drelu3
dx_conv4
=
inputs
[
11
].
data_ptr
<
at
::
Half
>
();
}
// dconv1+add
// wgrad
auto
wgrad1
=
at
::
empty_like
(
inputs
[
1
]);
at
::
Half
*
dw1
=
wgrad1
.
data_ptr
<
at
::
Half
>
();
run_dconv
(
dimA
,
padA
,
convstride1X1
,
dilationA
,
filterdimA1
,
outdimA1
,
CUDNN_DATA_HALF
,
x
,
dw1
,
dy1
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
);
// dgrad
w
=
inputs
[
1
].
data_ptr
<
at
::
Half
>
();
auto
grad_x
=
at
::
empty_like
(
inputs
[
0
]);
at
::
Half
*
dx
=
grad_x
.
data_ptr
<
at
::
Half
>
();
// backward strided conv cannot be fused
// if stride == 1 but channel changes, we can fuse here
if
(
requires_grad
){
if
(
stride_1X1
!=
1
){
run_dconv
(
dimA
,
padA
,
convstride1X1
,
dilationA
,
filterdimA1
,
outdimA1
,
CUDNN_DATA_HALF
,
dx
,
w
,
dy1
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
);
// add 2 together
grad_x
.
add_
(
grad_x_conv4
);
}
else
{
run_dconv_add
(
dimA
,
padA
,
convstride1X1
,
dilationA
,
filterdimA1
,
outdimA1
,
CUDNN_DATA_HALF
,
dx
,
w
,
dy1
,
dx_conv4
);
}
}
DEBUG_MSG
(
"[DEBUG] new dx : "
<<
grad_x
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
DEBUG_MSG
(
"[DEBUG] new wgrad1 : "
<<
wgrad1
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
DEBUG_MSG
(
"[DEBUG] new wgrad2 : "
<<
wgrad2
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
DEBUG_MSG
(
"[DEBUG] new wgrad3 : "
<<
wgrad3
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
outputs
.
push_back
(
grad_x
);
outputs
.
push_back
(
wgrad1
);
outputs
.
push_back
(
wgrad2
);
outputs
.
push_back
(
wgrad3
);
if
(
stride_1X1
!=
1
||
filterdimA3
[
0
]
!=
dimA
[
1
])
{
DEBUG_MSG
(
"[DEBUG] new wgrad4 : "
<<
wgrad4
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
outputs
.
push_back
(
wgrad4
);
}
return
outputs
;
}
namespace
{
struct
bottleneck_forward_status
{
int64_t
dimA
[
4
];
int64_t
filterdimA1
[
4
];
int64_t
filterdimA2
[
4
];
int64_t
filterdimA3
[
4
];
int64_t
filterdimA4
[
4
];
int
axis
[
4
];
int64_t
outdimA0
[
4
];
int64_t
outdimA1
[
4
];
int64_t
outdimA2
[
4
];
int64_t
outdimA3
[
4
];
int64_t
outdimA4
[
4
];
int64_t
padA
[
2
];
int64_t
padA1
[
2
];
int64_t
padA2
[
2
];
// halo padding
int64_t
dilationA
[
2
];
int64_t
convstrideA
[
2
];
int64_t
convstride1X1
[
2
];
int64_t
outdim0
[
4
];
// halo input shape
int64_t
outdim1
[
4
];
int64_t
outdim2
[
4
];
int64_t
outdim3
[
4
];
int64_t
outdim4
[
4
];
// halo output shape
void
init
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
)
{
dimA
[
0
]
=
dimA
[
1
]
=
dimA
[
2
]
=
dimA
[
3
]
=
0
;
filterdimA1
[
0
]
=
filterdimA1
[
1
]
=
filterdimA1
[
2
]
=
filterdimA1
[
3
]
=
0
;
filterdimA2
[
0
]
=
filterdimA2
[
1
]
=
filterdimA2
[
2
]
=
filterdimA2
[
3
]
=
0
;
filterdimA3
[
0
]
=
filterdimA3
[
1
]
=
filterdimA3
[
2
]
=
filterdimA3
[
3
]
=
0
;
filterdimA4
[
0
]
=
filterdimA4
[
1
]
=
filterdimA4
[
2
]
=
filterdimA4
[
3
]
=
0
;
// All dim calculation after this order of n,c,h,w
if
(
explicit_nhwc
)
{
axis
[
0
]
=
0
;
axis
[
1
]
=
3
;
axis
[
2
]
=
1
;
axis
[
3
]
=
2
;
}
else
{
axis
[
0
]
=
0
;
axis
[
1
]
=
1
;
axis
[
2
]
=
2
;
axis
[
3
]
=
3
;
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
dimA
[
dim
]
=
inputs
[
0
].
size
(
axis
[
dim
]);
filterdimA1
[
dim
]
=
inputs
[
1
].
size
(
axis
[
dim
]);
filterdimA2
[
dim
]
=
inputs
[
2
].
size
(
axis
[
dim
]);
filterdimA3
[
dim
]
=
inputs
[
3
].
size
(
axis
[
dim
]);
}
if
(
stride_1X1
!=
1
||
filterdimA3
[
0
]
!=
dimA
[
1
])
{
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
filterdimA4
[
dim
]
=
inputs
[
10
].
size
(
axis
[
dim
]);
}
}
// output dim in n,c,h,w used by backend
outdimA0
[
0
]
=
outdimA0
[
1
]
=
outdimA0
[
2
]
=
outdimA0
[
3
]
=
0
;
outdimA1
[
0
]
=
outdimA1
[
1
]
=
outdimA1
[
2
]
=
outdimA1
[
3
]
=
0
;
outdimA2
[
0
]
=
outdimA2
[
1
]
=
outdimA2
[
2
]
=
outdimA2
[
3
]
=
0
;
outdimA3
[
0
]
=
outdimA3
[
1
]
=
outdimA3
[
2
]
=
outdimA3
[
3
]
=
0
;
outdimA4
[
0
]
=
outdimA4
[
1
]
=
outdimA4
[
2
]
=
outdimA4
[
3
]
=
0
;
// use these fixed value for test run
padA
[
0
]
=
0
;
padA
[
1
]
=
0
;
padA1
[
0
]
=
1
;
padA1
[
1
]
=
1
;
padA2
[
0
]
=
0
;
padA2
[
1
]
=
1
;
dilationA
[
0
]
=
1
;
dilationA
[
1
]
=
1
;
convstrideA
[
0
]
=
1
;
convstrideA
[
1
]
=
1
;
convstride1X1
[
0
]
=
stride_1X1
;
convstride1X1
[
1
]
=
stride_1X1
;
// compute output from pad/stride/dilation
outdimA1
[
0
]
=
dimA
[
0
];
outdimA1
[
1
]
=
filterdimA1
[
0
];
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA1
[
dim
+
2
]
=
getFwdConvOutputDim
(
dimA
[
dim
+
2
],
padA
[
dim
],
filterdimA1
[
dim
+
2
],
convstride1X1
[
dim
],
dilationA
[
dim
]);
}
outdimA2
[
0
]
=
outdimA1
[
0
];
outdimA2
[
1
]
=
filterdimA2
[
0
];
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA2
[
dim
+
2
]
=
getFwdConvOutputDim
(
outdimA1
[
dim
+
2
],
padA1
[
dim
],
filterdimA2
[
dim
+
2
],
convstrideA
[
dim
],
dilationA
[
dim
]);
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
if
(
dim
==
2
)
{
outdimA0
[
dim
]
=
3
;
outdimA4
[
dim
]
=
1
;
}
else
{
outdimA0
[
dim
]
=
outdimA1
[
dim
];
outdimA4
[
dim
]
=
outdimA2
[
dim
];
}
}
outdimA3
[
0
]
=
outdimA2
[
0
];
outdimA3
[
1
]
=
filterdimA3
[
0
];
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA3
[
dim
+
2
]
=
getFwdConvOutputDim
(
outdimA2
[
dim
+
2
],
padA
[
dim
],
filterdimA3
[
dim
+
2
],
convstrideA
[
dim
],
dilationA
[
dim
]);
}
// Create output tensor in the correct shape in pytorch's view
outdim1
[
0
]
=
outdim1
[
1
]
=
outdim1
[
2
]
=
outdim1
[
3
]
=
0
;
outdim2
[
0
]
=
outdim2
[
1
]
=
outdim2
[
2
]
=
outdim2
[
3
]
=
0
;
outdim3
[
0
]
=
outdim3
[
1
]
=
outdim3
[
2
]
=
outdim3
[
3
]
=
0
;
if
(
explicit_nhwc
)
{
axis
[
0
]
=
0
;
axis
[
1
]
=
2
;
axis
[
2
]
=
3
;
axis
[
3
]
=
1
;
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
outdim0
[
dim
]
=
outdimA0
[
axis
[
dim
]];
outdim1
[
dim
]
=
outdimA1
[
axis
[
dim
]];
outdim2
[
dim
]
=
outdimA2
[
axis
[
dim
]];
outdim3
[
dim
]
=
outdimA3
[
axis
[
dim
]];
outdim4
[
dim
]
=
outdimA4
[
axis
[
dim
]];
}
}
};
bottleneck_forward_status
forward_state
;
}
// end of anonymous namespace
std
::
vector
<
at
::
Tensor
>
bottleneck_forward_init
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
)
{
// NB! Bottleneck_forward and bottleneck_backward are NOT thread safe method.
// NB! We use a global object to store state.
forward_state
.
init
(
explicit_nhwc
,
stride_1X1
,
inputs
);
// create output vector
std
::
vector
<
at
::
Tensor
>
outputs
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
//printf("outdim1 = (%d,%d,%d,%d)\n",forward_state.outdim1[0],forward_state.outdim1[1],forward_state.outdim1[2],forward_state.outdim1[3]);
auto
out1
=
at
::
empty
(
forward_state
.
outdim1
,
inputs
[
0
].
type
(),
output_format
);
auto
out2
=
at
::
empty
(
forward_state
.
outdim2
,
inputs
[
0
].
type
(),
output_format
);
auto
out3
=
at
::
empty
(
forward_state
.
outdim3
,
inputs
[
0
].
type
(),
output_format
);
outputs
.
push_back
(
out1
);
outputs
.
push_back
(
out2
);
outputs
.
push_back
(
out3
);
return
outputs
;
}
// inputs contains x,w,z,b,(i)
void
bottleneck_forward_out1
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
)
{
std
::
cout
<<
std
::
fixed
;
// run
at
::
Half
*
x
=
inputs
[
0
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
w
=
inputs
[
1
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
4
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
b
=
inputs
[
7
].
data_ptr
<
at
::
Half
>
();
auto
out1
=
outputs
[
0
];
at
::
Half
*
y1
=
out1
.
data_ptr
<
at
::
Half
>
();
run_conv_scale_bias_add_activation
(
forward_state
.
dimA
,
forward_state
.
padA
,
forward_state
.
convstride1X1
,
forward_state
.
dilationA
,
forward_state
.
filterdimA1
,
forward_state
.
outdimA1
,
CUDNN_DATA_HALF
,
x
,
w
,
y1
,
z
,
b
,
nullptr
);
DEBUG_MSG
(
"[DEBUG] new relu1 : "
<<
out1
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
}
// computes halo (top or bottom) from fat halo input.
// fat halo input is 3 pixels wide in H.
at
::
Tensor
bottleneck_forward_out2_halo
(
bool
explicit_nhwc
,
at
::
Tensor
fat_halo_y1
,
std
::
vector
<
at
::
Tensor
>
inputs
)
{
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
// run
at
::
Half
*
w
=
inputs
[
2
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
5
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
b
=
inputs
[
8
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
y1
=
fat_halo_y1
.
data_ptr
<
at
::
Half
>
();
auto
halo_y2
=
at
::
empty
(
forward_state
.
outdim4
,
inputs
[
0
].
type
(),
output_format
);
at
::
Half
*
y2
=
halo_y2
.
data_ptr
<
at
::
Half
>
();
run_conv_scale_bias_add_activation
(
forward_state
.
outdimA0
,
forward_state
.
padA2
,
forward_state
.
convstrideA
,
forward_state
.
dilationA
,
forward_state
.
filterdimA2
,
forward_state
.
outdimA4
,
CUDNN_DATA_HALF
,
y1
,
w
,
y2
,
z
,
b
,
nullptr
);
return
halo_y2
;
}
void
bottleneck_forward_out2
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
)
{
std
::
cout
<<
std
::
fixed
;
// from _out1 method
at
::
Half
*
x
=
inputs
[
0
].
data_ptr
<
at
::
Half
>
();
auto
out1
=
outputs
[
0
];
at
::
Half
*
y1
=
out1
.
data_ptr
<
at
::
Half
>
();
// run
at
::
Half
*
w
=
inputs
[
2
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
5
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
b
=
inputs
[
8
].
data_ptr
<
at
::
Half
>
();
auto
out2
=
outputs
[
1
];
at
::
Half
*
y2
=
out2
.
data_ptr
<
at
::
Half
>
();
//printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]);
//printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]);
//printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]);
//printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]);
//printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]);
//printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]);
run_conv_scale_bias_add_activation
(
forward_state
.
outdimA1
,
forward_state
.
padA1
,
forward_state
.
convstrideA
,
forward_state
.
dilationA
,
forward_state
.
filterdimA2
,
forward_state
.
outdimA2
,
CUDNN_DATA_HALF
,
y1
,
w
,
y2
,
z
,
b
,
nullptr
);
DEBUG_MSG
(
"[DEBUG] new relu2 : "
<<
out2
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
}
void
bottleneck_forward_rest
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
)
{
std
::
cout
<<
std
::
fixed
;
// from _out1 method
at
::
Half
*
x
=
inputs
[
0
].
data_ptr
<
at
::
Half
>
();
// create output of conv3
auto
out3
=
outputs
[
2
];
at
::
Half
*
y3
=
out3
.
data_ptr
<
at
::
Half
>
();
// create output of conv4 that may exist
auto
identity
=
at
::
empty_like
(
out3
);
at
::
Half
*
yi
=
identity
.
data_ptr
<
at
::
Half
>
();
at
::
Half
*
w
,
*
z
,
*
b
;
if
(
stride_1X1
!=
1
||
forward_state
.
filterdimA3
[
0
]
!=
forward_state
.
dimA
[
1
]){
w
=
inputs
[
10
].
data_ptr
<
at
::
Half
>
();
z
=
inputs
[
11
].
data_ptr
<
at
::
Half
>
();
b
=
inputs
[
12
].
data_ptr
<
at
::
Half
>
();
run_conv_scale_bias
(
forward_state
.
dimA
,
forward_state
.
padA
,
forward_state
.
convstride1X1
,
forward_state
.
dilationA
,
forward_state
.
filterdimA4
,
forward_state
.
outdimA3
,
CUDNN_DATA_HALF
,
x
,
w
,
yi
,
z
,
b
);
DEBUG_MSG
(
"[DEBUG] new downsample : "
<<
identity
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
}
else
{
yi
=
x
;
}
auto
out2
=
outputs
[
1
];
at
::
Half
*
y2
=
out2
.
data_ptr
<
at
::
Half
>
();
w
=
inputs
[
3
].
data_ptr
<
at
::
Half
>
();
z
=
inputs
[
6
].
data_ptr
<
at
::
Half
>
();
b
=
inputs
[
9
].
data_ptr
<
at
::
Half
>
();
run_conv_scale_bias_add_activation
(
forward_state
.
outdimA2
,
forward_state
.
padA
,
forward_state
.
convstrideA
,
forward_state
.
dilationA
,
forward_state
.
filterdimA3
,
forward_state
.
outdimA3
,
CUDNN_DATA_HALF
,
y2
,
w
,
y3
,
z
,
b
,
yi
);
DEBUG_MSG
(
"[DEBUG] new relu3 : "
<<
out3
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
}
namespace
{
struct
bottleneck_backward_state
{
int64_t
dimA
[
4
];
int64_t
filterdimA1
[
4
];
int64_t
filterdimA2
[
4
];
int64_t
filterdimA3
[
4
];
int64_t
filterdimA4
[
4
];
int64_t
filterdimA2hh
[
4
];
// Cin,Cout,1,3
int
axis
[
4
];
int64_t
outdimA1
[
4
];
// grad_out1
int64_t
outdimA2
[
4
];
// grad_out2
int64_t
outdimA3
[
4
];
int64_t
outdimA1h
[
4
];
// output: grad_out1 halo (H=3)
int64_t
outdimA2h
[
4
];
// input : grad_out2 halo cells (H=3)
int64_t
outdimA1hh
[
4
];
// input: grad_out2 halo (H=1)
int64_t
outdimA2hh
[
4
];
// input: out1 halo (H=1)
int64_t
padA
[
2
];
int64_t
padA1
[
2
];
int64_t
padA2
[
2
];
int64_t
dilationA
[
2
];
int64_t
convstrideA
[
2
];
int64_t
convstride1X1
[
2
];
int64_t
filterdim2hh
[
4
];
// Cin,1,3,Cout
int64_t
outdim1
[
4
];
int64_t
outdim2
[
4
];
int64_t
outdim3
[
4
];
int64_t
outdim1h
[
4
];
void
init
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
)
{
// setup dimensions
dimA
[
0
]
=
dimA
[
1
]
=
dimA
[
2
]
=
dimA
[
3
]
=
0
;
filterdimA1
[
0
]
=
filterdimA1
[
1
]
=
filterdimA1
[
2
]
=
filterdimA1
[
3
]
=
0
;
filterdimA2
[
0
]
=
filterdimA2
[
1
]
=
filterdimA2
[
2
]
=
filterdimA2
[
3
]
=
0
;
filterdimA3
[
0
]
=
filterdimA3
[
1
]
=
filterdimA3
[
2
]
=
filterdimA3
[
3
]
=
0
;
filterdimA4
[
0
]
=
filterdimA4
[
1
]
=
filterdimA4
[
2
]
=
filterdimA4
[
3
]
=
0
;
filterdimA2hh
[
0
]
=
filterdimA2hh
[
1
]
=
filterdimA2hh
[
2
]
=
filterdimA2hh
[
3
]
=
0
;
// All dim calculation after this order of n,c,h,w
if
(
explicit_nhwc
)
{
axis
[
0
]
=
0
;
axis
[
1
]
=
3
;
axis
[
2
]
=
1
;
axis
[
3
]
=
2
;
}
else
{
axis
[
0
]
=
0
;
axis
[
1
]
=
1
;
axis
[
2
]
=
2
;
axis
[
3
]
=
3
;
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
dimA
[
dim
]
=
inputs
[
0
].
size
(
axis
[
dim
]);
filterdimA1
[
dim
]
=
inputs
[
1
].
size
(
axis
[
dim
]);
filterdimA2
[
dim
]
=
inputs
[
2
].
size
(
axis
[
dim
]);
filterdimA3
[
dim
]
=
inputs
[
3
].
size
(
axis
[
dim
]);
}
if
(
stride_1X1
!=
1
||
filterdimA3
[
0
]
!=
dimA
[
1
])
{
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
filterdimA4
[
dim
]
=
inputs
[
14
].
size
(
axis
[
dim
]);
}
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
if
(
dim
==
2
)
{
filterdimA2hh
[
dim
]
=
1
;
}
else
{
filterdimA2hh
[
dim
]
=
filterdimA2
[
dim
];
}
}
// output dim in n,c,h,w used by backend
outdimA1
[
0
]
=
outdimA1
[
1
]
=
outdimA1
[
2
]
=
outdimA1
[
3
]
=
0
;
outdimA2
[
0
]
=
outdimA2
[
1
]
=
outdimA2
[
2
]
=
outdimA2
[
3
]
=
0
;
outdimA3
[
0
]
=
outdimA3
[
1
]
=
outdimA3
[
2
]
=
outdimA3
[
3
]
=
0
;
outdimA1h
[
0
]
=
outdimA1h
[
1
]
=
outdimA1h
[
2
]
=
outdimA1h
[
3
]
=
0
;
outdimA2h
[
0
]
=
outdimA2h
[
1
]
=
outdimA2h
[
2
]
=
outdimA2h
[
3
]
=
0
;
outdimA1hh
[
0
]
=
outdimA1hh
[
1
]
=
outdimA1hh
[
2
]
=
outdimA1hh
[
3
]
=
0
;
outdimA2hh
[
0
]
=
outdimA2hh
[
1
]
=
outdimA2hh
[
2
]
=
outdimA2hh
[
3
]
=
0
;
// use these fixed value for test run
padA
[
0
]
=
0
;
padA
[
1
]
=
0
;
padA1
[
0
]
=
1
;
padA1
[
1
]
=
1
;
padA2
[
0
]
=
0
;
padA2
[
1
]
=
1
;
dilationA
[
0
]
=
1
;
dilationA
[
1
]
=
1
;
convstrideA
[
0
]
=
1
;
convstrideA
[
1
]
=
1
;
convstride1X1
[
0
]
=
stride_1X1
;
convstride1X1
[
1
]
=
stride_1X1
;
// compute output from pad/stride/dilation
outdimA1
[
0
]
=
dimA
[
0
];
outdimA1
[
1
]
=
filterdimA1
[
0
];
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA1
[
dim
+
2
]
=
getFwdConvOutputDim
(
dimA
[
dim
+
2
],
padA
[
dim
],
filterdimA1
[
dim
+
2
],
convstride1X1
[
dim
],
dilationA
[
dim
]);
}
outdimA2
[
0
]
=
outdimA1
[
0
];
outdimA2
[
1
]
=
filterdimA2
[
0
];
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA2
[
dim
+
2
]
=
getFwdConvOutputDim
(
outdimA1
[
dim
+
2
],
padA1
[
dim
],
filterdimA2
[
dim
+
2
],
convstrideA
[
dim
],
dilationA
[
dim
]);
}
outdimA3
[
0
]
=
outdimA2
[
0
];
outdimA3
[
1
]
=
filterdimA3
[
0
];
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA3
[
dim
+
2
]
=
getFwdConvOutputDim
(
outdimA2
[
dim
+
2
],
padA
[
dim
],
filterdimA3
[
dim
+
2
],
convstrideA
[
dim
],
dilationA
[
dim
]);
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
if
(
dim
==
2
)
{
outdimA1h
[
dim
]
=
3
;
outdimA2h
[
dim
]
=
3
;
outdimA1hh
[
dim
]
=
1
;
outdimA2hh
[
dim
]
=
1
;
}
else
{
outdimA1h
[
dim
]
=
outdimA1
[
dim
];
outdimA2h
[
dim
]
=
outdimA2
[
dim
];
outdimA1hh
[
dim
]
=
outdimA1
[
dim
];
outdimA2hh
[
dim
]
=
outdimA2
[
dim
];
}
}
// Create output tensor in the correct shape in pytorch's view
outdim1
[
0
]
=
outdim1
[
1
]
=
outdim1
[
2
]
=
outdim1
[
3
]
=
0
;
outdim2
[
0
]
=
outdim2
[
1
]
=
outdim2
[
2
]
=
outdim2
[
3
]
=
0
;
outdim3
[
0
]
=
outdim3
[
1
]
=
outdim3
[
2
]
=
outdim3
[
3
]
=
0
;
outdim1h
[
0
]
=
outdim1h
[
1
]
=
outdim1h
[
2
]
=
outdim1h
[
3
]
=
0
;
filterdim2hh
[
0
]
=
filterdim2hh
[
1
]
=
filterdim2hh
[
2
]
=
filterdim2hh
[
3
]
=
0
;
if
(
explicit_nhwc
)
{
axis
[
0
]
=
0
;
axis
[
1
]
=
2
;
axis
[
2
]
=
3
;
axis
[
3
]
=
1
;
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
outdim1
[
dim
]
=
outdimA1
[
axis
[
dim
]];
outdim2
[
dim
]
=
outdimA2
[
axis
[
dim
]];
outdim3
[
dim
]
=
outdimA3
[
axis
[
dim
]];
outdim1h
[
dim
]
=
outdimA1h
[
axis
[
dim
]];
filterdim2hh
[
dim
]
=
filterdimA2hh
[
axis
[
dim
]];
}
}
};
bottleneck_backward_state
backward_state
;
}
std
::
vector
<
at
::
Tensor
>
bottleneck_backward_init
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
)
{
std
::
cout
<<
std
::
fixed
;
backward_state
.
init
(
explicit_nhwc
,
stride_1X1
,
inputs
);
// create output vector
std
::
vector
<
at
::
Tensor
>
outputs
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
auto
grad_x
=
at
::
empty_like
(
inputs
[
0
]);
auto
wgrad1
=
at
::
empty_like
(
inputs
[
1
]);
auto
wgrad2
=
at
::
empty_like
(
inputs
[
2
]);
auto
wgrad3
=
at
::
empty_like
(
inputs
[
3
]);
outputs
.
push_back
(
grad_x
);
outputs
.
push_back
(
wgrad1
);
outputs
.
push_back
(
wgrad2
);
outputs
.
push_back
(
wgrad3
);
if
(
stride_1X1
!=
1
||
backward_state
.
filterdimA3
[
0
]
!=
backward_state
.
dimA
[
1
])
{
auto
wgrad4
=
at
::
empty_like
(
inputs
[
14
]);
outputs
.
push_back
(
wgrad4
);
}
return
outputs
;
}
at
::
Tensor
bottleneck_backward_grad_out2
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
)
{
bool
requires_grad
=
inputs
[
0
].
requires_grad
();
std
::
cout
<<
std
::
fixed
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
// dconv3+drelu2+dscale2
at
::
Half
*
conv_in
=
inputs
[
13
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
dy3
=
inputs
[
10
].
data_ptr
<
at
::
Half
>
();
DEBUG_MSG
(
"[DEBUG] new dconv3 : "
<<
inputs
[
10
].
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
// wgrad
auto
wgrad3
=
outputs
[
3
];
at
::
Half
*
dw3
=
wgrad3
.
data_ptr
<
at
::
Half
>
();
run_dconv
(
backward_state
.
outdimA2
,
backward_state
.
padA
,
backward_state
.
convstrideA
,
backward_state
.
dilationA
,
backward_state
.
filterdimA3
,
backward_state
.
outdimA3
,
CUDNN_DATA_HALF
,
conv_in
,
dw3
,
dy3
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
);
// dgrad
auto
grad_out2
=
at
::
empty
(
backward_state
.
outdim2
,
inputs
[
0
].
type
(),
output_format
);
at
::
Half
*
dy2
=
grad_out2
.
data_ptr
<
at
::
Half
>
();
at
::
Half
*
w
=
inputs
[
3
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
5
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
relu2
=
inputs
[
13
].
data_ptr
<
at
::
Half
>
();
run_dconv_drelu_dscale
(
backward_state
.
outdimA2
,
backward_state
.
padA
,
backward_state
.
convstrideA
,
backward_state
.
dilationA
,
backward_state
.
filterdimA3
,
backward_state
.
outdimA3
,
CUDNN_DATA_HALF
,
dy2
,
w
,
dy3
,
z
,
relu2
);
// do halo exchange of dy2 here
DEBUG_MSG
(
"[DEBUG] new dconv2 : "
<<
grad_out2
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
return
grad_out2
;
}
at
::
Tensor
bottleneck_backward_grad_out1
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
grad_out2
)
{
bool
requires_grad
=
inputs
[
0
].
requires_grad
();
std
::
cout
<<
std
::
fixed
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
// dgrad
at
::
Half
*
dy2
=
grad_out2
.
data_ptr
<
at
::
Half
>
();
// dgrad
auto
grad_out1
=
at
::
empty
(
backward_state
.
outdim1
,
inputs
[
0
].
type
(),
output_format
);
at
::
Half
*
dy1
=
grad_out1
.
data_ptr
<
at
::
Half
>
();
at
::
Half
*
w
=
inputs
[
2
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
4
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
relu1
=
inputs
[
12
].
data_ptr
<
at
::
Half
>
();
//printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3));
// fused dgrad
run_dconv_drelu_dscale
(
backward_state
.
outdimA1
,
backward_state
.
padA1
,
backward_state
.
convstrideA
,
backward_state
.
dilationA
,
backward_state
.
filterdimA2
,
backward_state
.
outdimA2
,
CUDNN_DATA_HALF
,
dy1
,
w
,
dy2
,
z
,
relu1
);
return
grad_out1
;
}
// perform backward data 3x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,3,W,C] with padding=(1,1) to produce output of shape [N,3,W,C]
at
::
Tensor
bottleneck_backward_grad_out1_halo
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
grad_out2_halo
,
at
::
Tensor
relu1_halo
)
{
bool
requires_grad
=
inputs
[
0
].
requires_grad
();
std
::
cout
<<
std
::
fixed
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
// dgrad
at
::
Half
*
dy2h
=
grad_out2_halo
.
data_ptr
<
at
::
Half
>
();
// dgrad
auto
grad_out1_halo
=
at
::
empty
(
backward_state
.
outdim1h
,
inputs
[
0
].
type
(),
output_format
);
at
::
Half
*
dy1h
=
grad_out1_halo
.
data_ptr
<
at
::
Half
>
();
at
::
Half
*
w
=
inputs
[
2
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
4
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
relu1h
=
relu1_halo
.
data_ptr
<
at
::
Half
>
();
//printf("relu.shape = [%d,%d,%d,%d]\n",relu1_halo.size(0),relu1_halo.size(1),relu1_halo.size(2),relu1_halo.size(3));
// fused dgrad
//printf("backward_state.outdimA1h = {%d,%d,%d,%d}\n",backward_state.outdimA1h[0],backward_state.outdimA1h[1],backward_state.outdimA1h[2],backward_state.outdimA1h[3]);
//printf("backward_state.outdimA2h = {%d,%d,%d,%d}\n",backward_state.outdimA2h[0],backward_state.outdimA2h[1],backward_state.outdimA2h[2],backward_state.outdimA2h[3]);
//printf("backward_state.filterdimA2 = {%d,%d,%d,%d}\n",backward_state.filterdimA2[0],backward_state.filterdimA2[1],backward_state.filterdimA2[2],backward_state.filterdimA2[3]);
run_dconv_drelu_dscale
(
backward_state
.
outdimA1h
,
backward_state
.
padA1
,
backward_state
.
convstrideA
,
backward_state
.
dilationA
,
backward_state
.
filterdimA2
,
backward_state
.
outdimA2h
,
CUDNN_DATA_HALF
,
dy1h
,
w
,
dy2h
,
z
,
relu1h
);
return
grad_out1_halo
;
}
at
::
Tensor
bottleneck_backward_wgrad2
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
grad_out2
)
{
bool
requires_grad
=
inputs
[
0
].
requires_grad
();
std
::
cout
<<
std
::
fixed
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
// dgrad
at
::
Half
*
dy2
=
grad_out2
.
data_ptr
<
at
::
Half
>
();
// dconv2+drelu1+dscale1
at
::
Half
*
conv_in
=
inputs
[
12
].
data_ptr
<
at
::
Half
>
();
// wgrad
auto
wgrad2
=
outputs
[
2
];
at
::
Half
*
dw2
=
wgrad2
.
data_ptr
<
at
::
Half
>
();
//printf("outdimA1 = (%d,%d,%d,%d)\n",backward_state.outdimA1[0],backward_state.outdimA1[1],backward_state.outdimA1[2],backward_state.outdimA1[3]);
run_dconv
(
backward_state
.
outdimA1
,
backward_state
.
padA1
,
backward_state
.
convstrideA
,
backward_state
.
dilationA
,
backward_state
.
filterdimA2
,
backward_state
.
outdimA2
,
CUDNN_DATA_HALF
,
conv_in
,
dw2
,
dy2
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
);
return
wgrad2
;
}
// compute halo cells for input volume of dimension [N,1,W,C] with padding=(0,1) to produce output volume of dimension [N,1,W,C]
// input and grad_out2_halo tensors are all of same shape
// output tensor is of shape [Cin,1,3,Cout] (regular filter dims are [Cin,3,3,Cout]
at
::
Tensor
bottleneck_backward_wgrad2_halo
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
input
,
at
::
Tensor
grad_out2_halo
)
{
bool
requires_grad
=
inputs
[
0
].
requires_grad
();
std
::
cout
<<
std
::
fixed
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
// dgrad
at
::
Half
*
dy2
=
grad_out2_halo
.
data_ptr
<
at
::
Half
>
();
// dconv2+drelu1+dscale1
at
::
Half
*
conv_in
=
input
.
data_ptr
<
at
::
Half
>
();
// wgrad
auto
wgrad2_halo
=
at
::
empty
(
backward_state
.
filterdim2hh
,
input
.
type
(),
output_format
);
at
::
Half
*
dw2
=
wgrad2_halo
.
data_ptr
<
at
::
Half
>
();
//printf("backward_state.outdimA1hh = {%d,%d,%d,%d}\n",backward_state.outdimA1hh[0],backward_state.outdimA1hh[1],backward_state.outdimA1hh[2],backward_state.outdimA1hh[3]);
//printf("backward_state.outdimA2hh = {%d,%d,%d,%d}\n",backward_state.outdimA2hh[0],backward_state.outdimA2hh[1],backward_state.outdimA2hh[2],backward_state.outdimA2hh[3]);
//printf("backward_state.filterdim2hh = {%d,%d,%d,%d}\n",backward_state.filterdim2hh[0],backward_state.filterdim2hh[1],backward_state.filterdim2hh[2],backward_state.filterdim2hh[3]);
//printf("backward_state.filterdimA2hh = {%d,%d,%d,%d}\n",backward_state.filterdimA2hh[0],backward_state.filterdimA2hh[1],backward_state.filterdimA2hh[2],backward_state.filterdimA2hh[3]);
//printf("backward_state.padA2 = {%d,%d}\n",backward_state.padA2[0],backward_state.padA2[1]);
run_dconv
(
backward_state
.
outdimA1hh
,
// N,C,1,W
backward_state
.
padA2
,
// 0, 1
backward_state
.
convstrideA
,
backward_state
.
dilationA
,
backward_state
.
filterdimA2hh
,
// Cin,Cout,1,3
backward_state
.
outdimA2hh
,
// N,C,1,W
CUDNN_DATA_HALF
,
conv_in
,
dw2
,
dy2
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
);
return
wgrad2_halo
;
}
void
bottleneck_backward_rest
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
grad_out2
,
at
::
Tensor
grad_out1
,
at
::
Tensor
wgrad2
)
{
bool
requires_grad
=
inputs
[
0
].
requires_grad
();
std
::
cout
<<
std
::
fixed
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
// dgrad
at
::
Half
*
dy2
=
grad_out2
.
data_ptr
<
at
::
Half
>
();
at
::
Half
*
dy1
=
grad_out1
.
data_ptr
<
at
::
Half
>
();
/*
// backward strided conv cannot be fused
// if stride == 1 but channel changes, we can fuse here
if (stride_1X1 != 1){
// dgrad
run_dconv(outdimA1,
padA1,
convstride1X1,
dilationA,
filterdimA2,
outdimA2,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
// mul fused mask
grad_out1.mul_(inputs[15]);
}
else {
at::Half* relu1 = inputs[12].data_ptr<at::Half>();
// fused dgrad
run_dconv_drelu_dscale(outdimA1,
padA1,
convstride1X1,
dilationA,
filterdimA2,
outdimA2,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
z,
relu1);
}
*/
DEBUG_MSG
(
"[DEBUG] new dconv1 : "
<<
grad_out1
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
// create grads of conv4 that may exist
auto
grad_x_conv4
=
at
::
empty_like
(
inputs
[
0
]);
at
::
Half
*
dx_conv4
=
grad_x_conv4
.
data_ptr
<
at
::
Half
>
();
at
::
Tensor
wgrad4
;
// x used for dconv1 and dconv4 wgrad
at
::
Half
*
x
=
inputs
[
0
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
w
=
NULL
;
if
(
stride_1X1
!=
1
||
backward_state
.
filterdimA3
[
0
]
!=
backward_state
.
dimA
[
1
]){
w
=
inputs
[
14
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
dy_conv4
=
inputs
[
11
].
data_ptr
<
at
::
Half
>
();
if
(
requires_grad
)
{
run_dconv
(
backward_state
.
dimA
,
backward_state
.
padA
,
backward_state
.
convstride1X1
,
backward_state
.
dilationA
,
backward_state
.
filterdimA4
,
backward_state
.
outdimA3
,
CUDNN_DATA_HALF
,
dx_conv4
,
w
,
dy_conv4
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
);
// we don't print here since we can't hook out this grad in pytorch alone to compare, due to addition with dx
// DEBUG_MSG("[DEBUG] new dx_identity : " << grad_x_conv4.to(at::kFloat).sum().item<float>());
}
// wgrad
wgrad4
=
outputs
[
4
];
at
::
Half
*
dw4
=
wgrad4
.
data_ptr
<
at
::
Half
>
();
run_dconv
(
backward_state
.
dimA
,
backward_state
.
padA
,
backward_state
.
convstride1X1
,
backward_state
.
dilationA
,
backward_state
.
filterdimA4
,
backward_state
.
outdimA3
,
CUDNN_DATA_HALF
,
x
,
dw4
,
dy_conv4
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
);
}
else
{
// if there is no downsample, dx_conv4 is fork of drelu3
dx_conv4
=
inputs
[
11
].
data_ptr
<
at
::
Half
>
();
}
// dconv1+add
// wgrad
auto
wgrad1
=
outputs
[
1
];
at
::
Half
*
dw1
=
wgrad1
.
data_ptr
<
at
::
Half
>
();
run_dconv
(
backward_state
.
dimA
,
backward_state
.
padA
,
backward_state
.
convstride1X1
,
backward_state
.
dilationA
,
backward_state
.
filterdimA1
,
backward_state
.
outdimA1
,
CUDNN_DATA_HALF
,
x
,
dw1
,
dy1
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
);
// dgrad
w
=
inputs
[
1
].
data_ptr
<
at
::
Half
>
();
auto
grad_x
=
outputs
[
0
];
at
::
Half
*
dx
=
grad_x
.
data_ptr
<
at
::
Half
>
();
// backward strided conv cannot be fused
// if stride == 1 but channel changes, we can fuse here
if
(
requires_grad
){
if
(
stride_1X1
!=
1
){
run_dconv
(
backward_state
.
dimA
,
backward_state
.
padA
,
backward_state
.
convstride1X1
,
backward_state
.
dilationA
,
backward_state
.
filterdimA1
,
backward_state
.
outdimA1
,
CUDNN_DATA_HALF
,
dx
,
w
,
dy1
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
);
// add 2 together
grad_x
.
add_
(
grad_x_conv4
);
}
else
{
run_dconv_add
(
backward_state
.
dimA
,
backward_state
.
padA
,
backward_state
.
convstride1X1
,
backward_state
.
dilationA
,
backward_state
.
filterdimA1
,
backward_state
.
outdimA1
,
CUDNN_DATA_HALF
,
dx
,
w
,
dy1
,
dx_conv4
);
}
}
DEBUG_MSG
(
"[DEBUG] new dx : "
<<
grad_x
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
DEBUG_MSG
(
"[DEBUG] new wgrad1 : "
<<
wgrad1
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
DEBUG_MSG
(
"[DEBUG] new wgrad2 : "
<<
wgrad2
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
DEBUG_MSG
(
"[DEBUG] new wgrad3 : "
<<
wgrad3
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
if
(
stride_1X1
!=
1
||
backward_state
.
filterdimA3
[
0
]
!=
backward_state
.
dimA
[
1
])
{
DEBUG_MSG
(
"[DEBUG] new wgrad4 : "
<<
wgrad4
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
}
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
bottleneck_forward
,
"Bottleneck block forward"
);
m
.
def
(
"backward"
,
&
bottleneck_backward
,
"Bottleneck block backward"
);
m
.
def
(
"forward_init"
,
&
bottleneck_forward_init
,
"Bottleneck block init"
);
m
.
def
(
"forward_out1"
,
&
bottleneck_forward_out1
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out2"
,
&
bottleneck_forward_out2
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out2_halo"
,
&
bottleneck_forward_out2_halo
,
"Bottleneck block forward"
);
m
.
def
(
"forward_rest"
,
&
bottleneck_forward_rest
,
"Bottleneck block forward"
);
m
.
def
(
"backward_init"
,
&
bottleneck_backward_init
,
"Bottleneck block backward init"
);
m
.
def
(
"backward_grad_out2"
,
&
bottleneck_backward_grad_out2
,
"Bottleneck block backward"
);
m
.
def
(
"backward_grad_out1"
,
&
bottleneck_backward_grad_out1
,
"Bottleneck block backward"
);
m
.
def
(
"backward_grad_out1_halo"
,
&
bottleneck_backward_grad_out1_halo
,
"Bottleneck block backward"
);
m
.
def
(
"backward_wgrad2"
,
&
bottleneck_backward_wgrad2
,
"Bottleneck block backward"
);
m
.
def
(
"backward_wgrad2_halo"
,
&
bottleneck_backward_wgrad2_halo
,
"Bottleneck block backward"
);
m
.
def
(
"backward_rest"
,
&
bottleneck_backward_rest
,
"Bottleneck block backward"
);
}
cudnn-frontend
@
b4e1ad96
Subproject commit b4e1ad9613b89199982c9baf6ee91f6f98f5606d
apex/contrib/csrc/fmha/fmha_api.cpp
0 → 100644
View file @
f79993d9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "fmha.h"
void
set_params
(
Fused_multihead_attention_fprop_params
&
params
,
// sizes
const
size_t
b
,
const
size_t
s
,
const
size_t
h
,
const
size_t
d
,
// device pointers
void
*
qkv_packed_d
,
void
*
cu_seqlens_d
,
void
*
o_packed_d
,
void
*
s_d
,
float
p_dropout
)
{
Data_type
acc_type
=
DATA_TYPE_FP32
;
Data_type
data_type
=
DATA_TYPE_FP16
;
// Reset the parameters
memset
(
&
params
,
0
,
sizeof
(
params
));
// Set the pointers and strides.
params
.
qkv_ptr
=
qkv_packed_d
;
params
.
qkv_stride_in_bytes
=
get_size_in_bytes
(
h
*
3
*
d
,
data_type
);
params
.
o_ptr
=
o_packed_d
;
params
.
o_stride_in_bytes
=
get_size_in_bytes
(
h
*
d
,
data_type
);
params
.
cu_seqlens
=
static_cast
<
int
*>
(
cu_seqlens_d
);
// S = softmax(P)
params
.
s_ptr
=
s_d
;
params
.
s_stride_in_bytes
=
get_size_in_bytes
(
b
*
h
*
s
,
data_type
);
// Set the dimensions.
params
.
b
=
b
;
params
.
h
=
h
;
params
.
s
=
s
;
params
.
d
=
d
;
// Set the different scale values.
const
float
scale_bmm1
=
1.
f
/
sqrtf
(
d
);
constexpr
float
scale_softmax
=
1.
f
;
constexpr
float
scale_bmm2
=
1.
f
;
set_alpha
(
params
.
scale_bmm1
,
scale_bmm1
,
acc_type
);
set_alpha
(
params
.
scale_softmax
,
scale_softmax
,
acc_type
);
set_alpha
(
params
.
scale_bmm2
,
scale_bmm2
,
data_type
);
// Set this to probability of keeping an element to simplify things.
params
.
p_dropout
=
1.
f
-
p_dropout
;
params
.
rp_dropout
=
1.
f
/
params
.
p_dropout
;
TORCH_CHECK
(
p_dropout
<
1.
f
);
set_alpha
(
params
.
scale_dropout
,
params
.
rp_dropout
,
data_type
);
}
std
::
vector
<
at
::
Tensor
>
mha_fwd
(
const
at
::
Tensor
&
qkv
,
// total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const
at
::
Tensor
&
cu_seqlens
,
// b+1
const
float
p_dropout
,
const
int
max_seq_len
,
const
bool
is_training
,
c10
::
optional
<
at
::
Generator
>
gen_
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
TORCH_CHECK
(
dprops
->
major
==
8
&&
dprops
->
minor
==
0
);
int
seq_len
=
512
;
auto
launch
=
&
run_fmha_fp16_512_64_sm80
;
if
(
max_seq_len
<=
128
)
{
seq_len
=
128
;
launch
=
&
run_fmha_fp16_128_64_sm80
;
}
else
if
(
max_seq_len
<=
256
)
{
seq_len
=
256
;
launch
=
&
run_fmha_fp16_256_64_sm80
;
}
else
if
(
max_seq_len
<=
384
)
{
seq_len
=
384
;
launch
=
&
run_fmha_fp16_384_64_sm80
;
}
else
if
(
max_seq_len
<=
512
)
{
seq_len
=
512
;
launch
=
&
run_fmha_fp16_512_64_sm80
;
}
else
{
TORCH_CHECK
(
false
);
}
constexpr
int
warps_m
=
1
;
constexpr
int
warps_n
=
4
;
// this leads to an upper bound
const
int
mmas_m
=
seq_len
/
16
/
warps_m
;
const
int
mmas_n
=
seq_len
/
16
/
warps_n
;
const
int
elts_per_thread
=
8
*
mmas_m
*
mmas_n
;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK
(
qkv
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
cu_seqlens
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
qkv
.
is_cuda
())
TORCH_CHECK
(
cu_seqlens
.
is_cuda
())
TORCH_CHECK
(
qkv
.
is_contiguous
())
TORCH_CHECK
(
cu_seqlens
.
is_contiguous
())
TORCH_CHECK
(
cu_seqlens
.
dim
()
==
1
);
TORCH_CHECK
(
qkv
.
dim
()
==
4
);
const
auto
sizes
=
qkv
.
sizes
();
TORCH_CHECK
(
sizes
[
THREE_DIM
]
==
3
);
const
int
batch_size
=
cu_seqlens
.
numel
()
-
1
;
const
int
total
=
sizes
[
TOTAL_DIM
];
const
int
num_heads
=
sizes
[
H_DIM
];
const
int
head_size
=
sizes
[
D_DIM
];
TORCH_CHECK
(
batch_size
>
0
);
TORCH_CHECK
(
head_size
==
64
);
auto
opts
=
qkv
.
options
();
auto
ctx
=
torch
::
empty
({
total
,
num_heads
,
head_size
},
opts
);
auto
s
=
torch
::
empty
({
batch_size
,
num_heads
,
seq_len
,
seq_len
},
opts
);
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
Fused_multihead_attention_fprop_params
params
;
set_params
(
params
,
batch_size
,
seq_len
,
num_heads
,
head_size
,
qkv
.
data_ptr
(),
cu_seqlens
.
data_ptr
(),
ctx
.
data_ptr
(),
s
.
data_ptr
(),
p_dropout
);
// number of times random will be generated per thread, to offset philox counter in thc random
// state
int64_t
counter_offset
=
elts_per_thread
;
at
::
PhiloxCudaState
rng_engine_inputs
;
if
(
is_training
)
{
// See Note [Acquire lock when using random generators]
std
::
lock_guard
<
std
::
mutex
>
lock
(
gen
->
mutex_
);
params
.
philox_args
=
gen
->
philox_cuda_state
(
counter_offset
);
}
launch
(
params
,
is_training
,
stream
);
return
{
ctx
,
s
};
}
std
::
vector
<
at
::
Tensor
>
mha_bwd
(
const
at
::
Tensor
&
dout
,
// total x num_heads, x head_size
const
at
::
Tensor
&
qkv
,
// total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
at
::
Tensor
&
softmax
,
// b x h x s x s softmax and dmask - will be overwritten with dP
const
at
::
Tensor
&
cu_seqlens
,
// b+1
const
float
p_dropout
,
// probability to drop
const
int
max_seq_len
// max sequence length to choose the kernel
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
TORCH_CHECK
(
dprops
->
major
==
8
&&
dprops
->
minor
==
0
);
int
seq_len
=
512
;
auto
launch
=
&
run_fmha_dgrad_fp16_512_64_sm80
;
if
(
max_seq_len
<=
128
)
{
seq_len
=
128
;
launch
=
&
run_fmha_dgrad_fp16_128_64_sm80
;
}
else
if
(
max_seq_len
<=
256
)
{
seq_len
=
256
;
launch
=
&
run_fmha_dgrad_fp16_256_64_sm80
;
}
else
if
(
max_seq_len
<=
384
)
{
seq_len
=
384
;
launch
=
&
run_fmha_dgrad_fp16_384_64_sm80
;
}
else
if
(
max_seq_len
<=
512
)
{
seq_len
=
512
;
launch
=
&
run_fmha_dgrad_fp16_512_64_sm80
;
}
else
{
TORCH_CHECK
(
false
);
}
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK
(
qkv
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
dout
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
softmax
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
cu_seqlens
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
qkv
.
is_cuda
());
TORCH_CHECK
(
cu_seqlens
.
is_cuda
());
TORCH_CHECK
(
qkv
.
is_contiguous
());
TORCH_CHECK
(
cu_seqlens
.
is_contiguous
());
TORCH_CHECK
(
cu_seqlens
.
dim
()
==
1
);
TORCH_CHECK
(
qkv
.
dim
()
==
4
);
const
auto
sizes
=
qkv
.
sizes
();
TORCH_CHECK
(
sizes
[
THREE_DIM
]
==
3
);
const
int
batch_size
=
cu_seqlens
.
numel
()
-
1
;
const
int
num_heads
=
sizes
[
H_DIM
];
const
int
head_size
=
sizes
[
D_DIM
];
TORCH_CHECK
(
batch_size
>
0
);
TORCH_CHECK
(
head_size
==
64
);
auto
dqkv
=
torch
::
empty_like
(
qkv
);
Fused_multihead_attention_fprop_params
params
;
set_params
(
params
,
batch_size
,
seq_len
,
num_heads
,
head_size
,
qkv
.
data_ptr
(),
cu_seqlens
.
data_ptr
(),
dout
.
data_ptr
(),
// we set o_ptr to dout
softmax
.
data_ptr
(),
// softmax gets overwritten by dP!
p_dropout
);
// we're re-using these scales
Data_type
acc_type
=
DATA_TYPE_FP32
;
set_alpha
(
params
.
scale_bmm1
,
1.
f
,
acc_type
);
set_alpha
(
params
.
scale_softmax
,
1.
f
/
sqrtf
(
head_size
),
acc_type
);
set_alpha
(
params
.
scale_bmm2
,
1.
f
,
DATA_TYPE_FP16
);
params
.
dqkv_ptr
=
dqkv
.
data_ptr
();
launch
(
params
,
stream
);
return
{
dqkv
,
softmax
};
}
std
::
vector
<
at
::
Tensor
>
mha_fwd_nl
(
const
at
::
Tensor
&
qkv
,
// total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const
at
::
Tensor
&
cu_seqlens
,
// b+1
const
float
p_dropout
,
const
int
max_seq_len
,
const
bool
is_training
,
c10
::
optional
<
at
::
Generator
>
gen_
)
{
int
seq_len
=
512
;
auto
launch
=
&
run_fmha_fp16_512_64_sm80_nl
;
TORCH_CHECK
(
max_seq_len
==
seq_len
);
constexpr
int
warps_m
=
1
;
constexpr
int
warps_n
=
4
;
// this leads to an upper bound
const
int
mmas_m
=
seq_len
/
16
/
warps_m
;
const
int
mmas_n
=
seq_len
/
16
/
warps_n
;
// static_assert( mmas_m == 32 );
// static_assert( mmas_n == 4 );
const
int
elts_per_thread
=
8
*
mmas_m
*
mmas_n
;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK
(
qkv
.
is_cuda
())
TORCH_CHECK
(
cu_seqlens
.
is_cuda
())
TORCH_CHECK
(
qkv
.
is_contiguous
())
TORCH_CHECK
(
cu_seqlens
.
is_contiguous
())
TORCH_CHECK
(
cu_seqlens
.
dim
()
==
1
);
TORCH_CHECK
(
qkv
.
dim
()
==
4
);
const
auto
sizes
=
qkv
.
sizes
();
TORCH_CHECK
(
sizes
[
THREE_DIM
]
==
3
);
const
int
batch_size
=
cu_seqlens
.
numel
()
-
1
;
const
int
total
=
sizes
[
TOTAL_DIM
];
const
int
num_heads
=
sizes
[
H_DIM
];
const
int
head_size
=
sizes
[
D_DIM
];
TORCH_CHECK
(
batch_size
>
0
);
TORCH_CHECK
(
head_size
==
64
);
auto
opts
=
qkv
.
options
();
auto
ctx
=
torch
::
empty
({
total
,
num_heads
,
head_size
},
opts
);
auto
s
=
torch
::
empty
({
batch_size
,
num_heads
,
seq_len
,
seq_len
},
opts
);
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
Fused_multihead_attention_fprop_params
params
;
set_params
(
params
,
batch_size
,
seq_len
,
num_heads
,
head_size
,
qkv
.
data_ptr
(),
cu_seqlens
.
data_ptr
(),
ctx
.
data_ptr
(),
s
.
data_ptr
(),
p_dropout
);
// number of times random will be generated per thread, to offset philox counter in thc random
// state
int64_t
counter_offset
=
elts_per_thread
;
at
::
PhiloxCudaState
rng_engine_inputs
;
if
(
is_training
)
{
// See Note [Acquire lock when using random generators]
std
::
lock_guard
<
std
::
mutex
>
lock
(
gen
->
mutex_
);
params
.
philox_args
=
gen
->
philox_cuda_state
(
counter_offset
);
}
int
num_chunks
=
3
;
if
(
batch_size
==
3
)
{
num_chunks
=
2
;
}
launch
(
params
,
is_training
,
num_chunks
,
stream
);
return
{
ctx
,
s
};
}
std
::
vector
<
at
::
Tensor
>
mha_bwd_nl
(
const
at
::
Tensor
&
dout
,
// total x num_heads, x head_size
const
at
::
Tensor
&
qkv
,
// total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
at
::
Tensor
&
softmax
,
// b x h x s x s softmax and dmask - will be overwritten with dP
const
at
::
Tensor
&
cu_seqlens
,
// b+1
const
float
p_dropout
,
// probability to drop
const
int
max_seq_len
// max sequence length to choose the kernel
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK
(
qkv
.
is_cuda
())
TORCH_CHECK
(
cu_seqlens
.
is_cuda
())
TORCH_CHECK
(
qkv
.
is_contiguous
())
TORCH_CHECK
(
cu_seqlens
.
is_contiguous
())
TORCH_CHECK
(
cu_seqlens
.
dim
()
==
1
);
TORCH_CHECK
(
qkv
.
dim
()
==
4
);
const
auto
sizes
=
qkv
.
sizes
();
TORCH_CHECK
(
sizes
[
THREE_DIM
]
==
3
);
const
int
batch_size
=
cu_seqlens
.
numel
()
-
1
;
const
int
total
=
sizes
[
TOTAL_DIM
];
const
int
num_heads
=
sizes
[
H_DIM
];
const
int
head_size
=
sizes
[
D_DIM
];
TORCH_CHECK
(
batch_size
>
0
);
TORCH_CHECK
(
head_size
==
64
);
int
seq_len
=
512
;
auto
launch
=
&
run_fmha_dgrad_fp16_512_64_sm80_nl
;
auto
opts
=
qkv
.
options
();
auto
dqkv
=
torch
::
empty_like
(
qkv
);
int
num_chunks
=
2
;
if
(
batch_size
==
1
)
{
num_chunks
=
4
;
}
else
if
(
batch_size
==
2
)
{
num_chunks
=
3
;
}
auto
dkv
=
torch
::
empty
({
total
,
num_chunks
,
2
,
num_heads
,
head_size
},
opts
);
Fused_multihead_attention_fprop_params
params
;
set_params
(
params
,
batch_size
,
seq_len
,
num_heads
,
head_size
,
qkv
.
data_ptr
(),
cu_seqlens
.
data_ptr
(),
dout
.
data_ptr
(),
// o_ptr = dout
softmax
.
data_ptr
(),
// softmax gets overwritten by dP!
p_dropout
);
params
.
dkv_ptr
=
dkv
.
data_ptr
();
Data_type
acc_type
=
DATA_TYPE_FP32
;
set_alpha
(
params
.
scale_bmm1
,
1.
f
,
acc_type
);
set_alpha
(
params
.
scale_softmax
,
1.
f
/
sqrtf
(
head_size
),
acc_type
);
set_alpha
(
params
.
scale_bmm2
,
1.
f
,
DATA_TYPE_FP16
);
params
.
dqkv_ptr
=
dqkv
.
data_ptr
();
launch
(
params
,
num_chunks
,
stream
);
//SPLIT-K reduction of num_chunks dK, dV parts
// The equivalent of the following Pytorch code:
// using namespace torch::indexing;
// at::Tensor view_out = dqkv.index({Slice(), Slice(1, None, None)});
// torch::sum_out(view_out, dkv, 1);
const
int
hidden_size
=
num_heads
*
head_size
;
fmha_run_noloop_reduce
(
dqkv
.
data_ptr
(),
dkv
.
data_ptr
(),
cu_seqlens
.
data_ptr
<
int
>
(),
hidden_size
,
batch_size
,
total
,
num_chunks
,
stream
);
return
{
dqkv
,
softmax
,
dkv
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
doc
()
=
"Fused Multi-head Self-attention for BERT"
;
m
.
def
(
"fwd"
,
&
mha_fwd
,
"Forward pass"
);
m
.
def
(
"bwd"
,
&
mha_bwd
,
"Backward pass"
);
m
.
def
(
"fwd_nl"
,
&
mha_fwd_nl
,
"Forward pass (small-batch)"
);
m
.
def
(
"bwd_nl"
,
&
mha_bwd_nl
,
"Backward pass (small-batch)"
);
}
apex/contrib/csrc/fmha/src/fmha.h
0 → 100644
View file @
f79993d9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <cuda.h>
#include <vector>
#include <ATen/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <fmha_utils.h>
constexpr
int
TOTAL_DIM
=
0
;
constexpr
int
THREE_DIM
=
1
;
constexpr
int
H_DIM
=
2
;
constexpr
int
D_DIM
=
3
;
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Qkv_params
{
// The QKV matrices.
void
*
qkv_ptr
;
// The stride between rows of the Q, K and V matrices.
size_t
qkv_stride_in_bytes
;
// The number of heads.
int
h
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Fused_multihead_attention_fprop_params
:
public
Qkv_params
{
// The dQKV matrices.
void
*
dqkv_ptr
;
// Temporary for dKV.
void
*
dkv_ptr
;
// The O matrix (output).
void
*
o_ptr
;
// The stride between rows of O.
int64_t
o_stride_in_bytes
;
// The pointer to the S matrix, overwritten by the dP matrix (bwd).
void
*
s_ptr
;
// The stride between rows of the S matrix.
int64_t
s_stride_in_bytes
;
// The dimensions.
int
b
,
s
,
d
;
// The scaling factors for the kernel.
uint32_t
scale_bmm1
,
scale_softmax
,
scale_bmm2
;
// array of length b+1 holding starting offset of each sequence.
int
*
cu_seqlens
;
// The dropout probability (probability of keeping an activation).
float
p_dropout
;
// Scale factor of 1 / (1 - p_dropout).
float
rp_dropout
;
// Scale factor of 1 / (1 - p_dropout), in half2.
uint32_t
scale_dropout
;
// Random state.
at
::
PhiloxCudaState
philox_args
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
void
run_fmha_fp16_128_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
);
void
run_fmha_fp16_256_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
);
void
run_fmha_fp16_384_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
);
void
run_fmha_fp16_512_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
);
void
run_fmha_dgrad_fp16_128_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
);
void
run_fmha_dgrad_fp16_256_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
);
void
run_fmha_dgrad_fp16_384_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
);
void
run_fmha_dgrad_fp16_512_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
);
void
run_fmha_fp16_512_64_sm80_nl
(
const
Fused_multihead_attention_fprop_params
&
params
,
const
bool
is_training
,
const
int
num_chunks
,
cudaStream_t
stream
);
void
run_fmha_dgrad_fp16_512_64_sm80_nl
(
const
Fused_multihead_attention_fprop_params
&
params
,
const
int
num_chunks
,
cudaStream_t
stream
);
void
fmha_run_noloop_reduce
(
void
*
out
,
const
void
*
in
,
const
int
*
cu_seqlens
,
const
int
hidden_size
,
const
int
batch_size
,
const
int
total
,
const
int
num_chunks
,
cudaStream_t
stream
);
apex/contrib/csrc/fmha/src/fmha/gemm.h
0 → 100644
View file @
f79993d9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <fmha/utils.h>
#define FMHA_DIV_UP(m, n) (((m) + (n)-1) / (n))
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Data_type_
,
int
NUM_ELTS_
,
int
BITS_PER_ELT_
,
int
ALIGNMENT_
>
struct
Fragment_base_
{
// The data type.
using
Data_type
=
Data_type_
;
// default input type
using
Input_type_
=
Data_type_
;
// Does it store the array of elements.
enum
{
HAS_ELTS
=
BITS_PER_ELT_
>=
8
};
// The number of elements.
enum
{
NUM_ELTS
=
NUM_ELTS_
};
// The size of element in bits.
enum
{
BITS_PER_ELT
=
BITS_PER_ELT_
};
// The size of byte of a single register.
enum
{
BYTES_PER_REG
=
4
};
// The size in bits.
enum
{
BITS_PER_REG
=
BYTES_PER_REG
*
8
};
// The number of registers needed to store the fragment.
enum
{
NUM_REGS
=
Div_up
<
NUM_ELTS
*
BITS_PER_ELT
,
BITS_PER_REG
>::
VALUE
};
// The size in bytes (as returned by sizeof(Fragment_base<>).
enum
{
SIZE_IN_BYTES
=
NUM_REGS
*
BYTES_PER_REG
};
// The alignment.
enum
{
ALIGNMENT
=
ALIGNMENT_
>
0
?
ALIGNMENT_
:
Min
<
NUM_REGS
*
BYTES_PER_REG
,
16
>::
VALUE
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The type of the elements.
typename
Data_type_
,
// The number of elements.
int
NUM_ELTS_
,
// The alignment if you want to force a value -- use 0 otherwise.
int
ALIGNMENT_
=
0
,
// The base class.
typename
Base_
=
Fragment_base_
<
Data_type_
,
NUM_ELTS_
,
8
*
sizeof
(
Data_type_
),
ALIGNMENT_
>
>
struct
alignas
(
static_cast
<
int
>
(
Base_
::
ALIGNMENT
))
Fragment
:
public
Base_
{
// The size of a load/store.
enum
{
BYTES_PER_LOAD_STORE
=
Base_
::
NUM_REGS
*
sizeof
(
uint32_t
)
};
// Clear the fragment. Using PTX in that code seems to produce better SASS...
inline
__device__
void
clear
()
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Base_
::
NUM_REGS
;
++
ii
)
{
asm
volatile
(
"mov.u32 %0, 0;
\n
"
:
"=r"
(
this
->
reg
(
ii
))
:
);
}
}
// Immutable access to a register.
inline
__device__
const
uint32_t
&
reg
(
int
ii
)
const
{
return
this
->
regs_
[
ii
];
}
// Mutable access to a register.
inline
__device__
uint32_t
&
reg
(
int
ii
)
{
return
this
->
regs_
[
ii
];
}
uint32_t
regs_
[
Base_
::
NUM_REGS
];
// Immutable access to the elements.
inline
__device__
const
Data_type_
&
elt
(
int
ii
)
const
{
return
reinterpret_cast
<
const
Data_type_
*>
(
&
this
->
regs_
[
0
])[
ii
];
}
// Mutable access to the elements.
inline
__device__
Data_type_
&
elt
(
int
ii
)
{
return
reinterpret_cast
<
Data_type_
*>
(
&
this
->
regs_
[
0
])[
ii
];
}
// Immutable access to the elements with a cast.
template
<
typename
Cast_type
>
inline
__device__
const
Cast_type
&
elt_as
(
int
ii
)
const
{
return
reinterpret_cast
<
const
Cast_type
*>
(
&
this
->
regs_
[
0
])[
ii
];
}
// Mutable access to the elements.
template
<
typename
Cast_type
>
inline
__device__
Cast_type
&
elt_as
(
int
ii
)
{
return
reinterpret_cast
<
Cast_type
*>
(
&
this
->
regs_
[
0
])[
ii
];
}
// Add another fragment.
inline
__device__
void
add
(
const
Fragment
&
other
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
NUM_ELTS_
;
++
ii
)
{
this
->
elt
(
ii
)
+=
other
.
elt
(
ii
);
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Layout
>
struct
Fragment_a
:
public
Fragment
<
uint16_t
,
8
>
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Layout
>
struct
Fragment_b
:
public
Fragment
<
uint16_t
,
8
>
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Fragment_accumulator
:
public
Fragment
<
float
,
8
>
{
// The base class.
using
Base
=
Fragment
<
float
,
8
>
;
// Add two fragments.
template
<
typename
Other_fragment_
>
inline
__device__
void
add
(
const
Other_fragment_
&
other
)
{
for
(
int
ii
=
0
;
ii
<
Base
::
NUM_ELTS
;
++
ii
)
{
this
->
elt
(
ii
)
=
this
->
elt
(
ii
)
+
other
.
elt
(
ii
);
}
}
// Do the HMMA.
template
<
typename
Layout_a
,
typename
Layout_b
>
inline
__device__
void
mma
(
const
Fragment_a
<
Layout_a
>
&
a
,
const
Fragment_b
<
Layout_b
>
&
b
)
{
asm
volatile
(
\
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
\n
"
\
" {%0, %1, %2, %3},
\n
"
\
" {%4, %5, %6, %7},
\n
"
\
" {%8, %9},
\n
"
\
" {%0, %1, %2, %3};
\n
"
\
:
"+f"
(
elt
(
0
)),
"+f"
(
elt
(
1
)),
"+f"
(
elt
(
2
)),
"+f"
(
elt
(
3
))
:
"r"
(
a
.
reg
(
0
)),
"r"
(
a
.
reg
(
1
)),
"r"
(
a
.
reg
(
2
)),
"r"
(
a
.
reg
(
3
))
,
"r"
(
b
.
reg
(
0
)),
"r"
(
b
.
reg
(
1
)));
asm
volatile
(
\
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
\n
"
\
" {%0, %1, %2, %3},
\n
"
\
" {%4, %5, %6, %7},
\n
"
\
" {%8, %9},
\n
"
\
" {%0, %1, %2, %3};
\n
"
\
:
"+f"
(
elt
(
4
)),
"+f"
(
elt
(
5
)),
"+f"
(
elt
(
6
)),
"+f"
(
elt
(
7
))
:
"r"
(
a
.
reg
(
0
)),
"r"
(
a
.
reg
(
1
)),
"r"
(
a
.
reg
(
2
)),
"r"
(
a
.
reg
(
3
))
,
"r"
(
b
.
reg
(
2
)),
"r"
(
b
.
reg
(
3
)));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Fragment
,
int
M
,
int
N
>
inline
__device__
void
clear
(
Fragment
(
&
frag
)[
M
][
N
])
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
++
ni
)
{
frag
[
mi
][
ni
].
clear
();
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Accumulator_type
,
int
WARPS_K
>
struct
Clear_accumulator
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
WARPS_K
>
struct
Clear_accumulator
<
float
,
WARPS_K
>
{
template
<
typename
Acc
,
int
M
,
int
N
>
static
inline
__device__
void
apply
(
Acc
(
&
acc
)[
M
][
N
],
bool
=
false
)
{
fmha
::
clear
(
acc
);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Acc
,
typename
A
,
typename
B
,
int
M
,
int
N
>
inline
__device__
void
gemm
(
Acc
(
&
acc
)[
M
][
N
],
const
A
(
&
a
)[
M
],
const
B
(
&
b
)[
N
])
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
++
ni
)
{
acc
[
mi
][
ni
].
mma
(
a
[
mi
],
b
[
ni
]);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The number of rows in the CTA tile.
int
M_
,
// The number of cols in the CTA tile.
int
N_
,
// The number of elements in the the K dimension of the GEMM loop.
int
K_
,
// The number of rows of warps.
int
WARPS_M_
,
// The number of cols of warps.
int
WARPS_N_
,
// The number of warps in the K dimension of the GEMM loop.
int
WARPS_K_
>
struct
Cta_tile_
{
enum
{
M
=
M_
,
N
=
N_
,
K
=
K_
};
// The number of warps.
enum
{
WARPS_M
=
WARPS_M_
,
WARPS_N
=
WARPS_N_
,
WARPS_K
=
WARPS_K_
};
// The number of warps per CTA.
enum
{
WARPS_PER_CTA
=
WARPS_M
*
WARPS_N
*
WARPS_K
};
// The number of threads per warp.
enum
{
THREADS_PER_WARP
=
32
};
// The number of threads per CTA.
enum
{
THREADS_PER_CTA
=
WARPS_PER_CTA
*
THREADS_PER_WARP
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
>
struct
Hmma_tile
{
// The number of elements computed with a single warp-MMA.
enum
{
M_PER_MMA
=
16
,
N_PER_MMA
=
16
,
K_PER_MMA
=
16
};
// The number of elements computed with a single CTA-MMA.
enum
{
M_PER_MMA_PER_CTA
=
M_PER_MMA
*
Cta_tile
::
WARPS_M
,
N_PER_MMA_PER_CTA
=
N_PER_MMA
*
Cta_tile
::
WARPS_N
,
K_PER_MMA_PER_CTA
=
K_PER_MMA
*
Cta_tile
::
WARPS_K
};
// The number of MMAs needed to compute the GEMM.
enum
{
MMAS_M
=
Div_up
<
Cta_tile
::
M
,
M_PER_MMA_PER_CTA
>::
VALUE
,
MMAS_N
=
Div_up
<
Cta_tile
::
N
,
N_PER_MMA_PER_CTA
>::
VALUE
,
MMAS_K
=
Div_up
<
Cta_tile
::
K
,
K_PER_MMA_PER_CTA
>::
VALUE
,
};
// The number of elements computed per warp.
enum
{
M_PER_WARP
=
MMAS_M
*
M_PER_MMA
,
N_PER_WARP
=
MMAS_N
*
N_PER_MMA
,
K_PER_WARP
=
MMAS_K
*
K_PER_MMA
,
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
using
A_type
=
uint16_t
;
using
B_type
=
uint16_t
;
using
C_type
=
uint16_t
;
using
Accumulator_type
=
float
;
using
Epilogue_type
=
float
;
constexpr
int
BITS_PER_ELEMENT_A
=
sizeof
(
A_type
)
*
8
;
constexpr
int
BITS_PER_ELEMENT_B
=
sizeof
(
B_type
)
*
8
;
constexpr
int
BITS_PER_ELEMENT_C
=
sizeof
(
C_type
)
*
8
;
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
M
,
int
N
,
int
K
,
int
WARPS_M
,
int
WARPS_N
,
int
WARPS_K
>
using
Cta_tile_extd
=
Cta_tile_
<
M
,
N
,
K
,
WARPS_M
,
WARPS_N
,
WARPS_K
>
;
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile_
>
using
Cta_tile_with_k_with_padding
=
Cta_tile_extd
<
Cta_tile_
::
M
,
Cta_tile_
::
N
,
Next_power_of_two
<
Cta_tile_
::
K
>::
VALUE
,
Cta_tile_
::
WARPS_M
,
Cta_tile_
::
WARPS_N
,
Cta_tile_
::
WARPS_K
>
;
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha/gmem_tile.h
0 → 100644
View file @
f79993d9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The number of bits per element.
int
BITS_PER_ELEMENT
,
// The number of rows of Q, K or V loaded by this tile.
int
ROWS
,
// The number of columns.
int
COLS
,
// The number of matrics.
int
NUM_MATS
=
3
>
struct
Gmem_tile_qkv
{
// The size of each LDG.
enum
{
BYTES_PER_LDG
=
16
};
// The size of a row in bytes.
enum
{
BYTES_PER_ROW
=
COLS
*
BITS_PER_ELEMENT
/
8
};
// The number of threads to load a "row" of the matrix.
enum
{
THREADS_PER_ROW
=
BYTES_PER_ROW
/
BYTES_PER_LDG
};
// The number of "rows" loaded per LDG.
enum
{
ROWS_PER_LDG
=
Cta_tile
::
THREADS_PER_CTA
/
THREADS_PER_ROW
};
// The number of LDGs needed to load a chunk of the Q matrix.
enum
{
LDGS
=
fmha
::
Div_up
<
ROWS
,
ROWS_PER_LDG
>::
VALUE
};
// Ctor.
template
<
typename
Params
,
typename
BInfo
>
inline
__device__
Gmem_tile_qkv
(
const
Params
&
params
,
int
qkv_offset
,
const
BInfo
&
binfo
,
int
tidx
)
:
params_qkv_stride_in_bytes_
(
params
.
qkv_stride_in_bytes
)
,
actual_seqlen
(
binfo
.
actual_seqlen
)
,
qkv_ptr_
(
reinterpret_cast
<
char
*>
(
params
.
qkv_ptr
))
{
// Compute the position in the sequence (within the CTA for the moment).
int
row
=
tidx
/
THREADS_PER_ROW
;
// Compute the position of the thread in the row.
int
col
=
tidx
%
THREADS_PER_ROW
;
// Store the row as we need it to disable the loads.
row_
=
row
;
// The row offset in the batched GEMM. For each seq element, we store QKV in that order.
int64_t
row_offset
=
(
int64_t
)
row
*
params
.
qkv_stride_in_bytes
;
// Add the block index.
row_offset
+=
(
int64_t
)((
binfo
.
sum_s
*
NUM_MATS
+
qkv_offset
)
*
binfo
.
h
+
binfo
.
bidh
)
*
BYTES_PER_ROW
;
// Assemble the final pointer.
qkv_ptr_
+=
row_offset
+
col
*
BYTES_PER_LDG
;
}
// Store data to shared memory.
template
<
typename
Smem_tile
>
inline
__device__
void
commit
(
Smem_tile
&
smem_tile
)
{
smem_tile
.
store
(
fetch_
);
}
// Load data from memory.
template
<
typename
Smem_tile
>
inline
__device__
void
load
(
Smem_tile
&
smem_tile
)
{
const
void
*
ptrs
[
LDGS
];
uint32_t
preds
[
LDGS
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
LDGS
;
++
ii
)
{
ptrs
[
ii
]
=
qkv_ptr_
+
(
int64_t
)
ii
*
ROWS_PER_LDG
*
params_qkv_stride_in_bytes_
;
preds
[
ii
]
=
((
row_
+
ii
*
ROWS_PER_LDG
)
<
min
(
ROWS
,
actual_seqlen
));
fetch_
[
ii
]
=
make_uint4
(
0
,
0
,
0
,
0
);
}
// not packing predicates removes restrictions (e.g. FP16 384, 4 warps)
Ldg_functor
<
uint4
,
LDGS
>
fct
(
fetch_
,
ptrs
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
LDGS
;
++
ii
)
{
fct
.
load
(
ii
,
preds
[
ii
]);
}
}
// Store data to memory.
inline
__device__
void
store
(
const
uint4
(
&
data
)[
LDGS
])
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
LDGS
;
++
ii
)
{
char
*
ptr
=
qkv_ptr_
+
(
int64_t
)
ii
*
ROWS_PER_LDG
*
params_qkv_stride_in_bytes_
;
if
(
(
row_
+
ii
*
ROWS_PER_LDG
)
<
min
(
ROWS
,
actual_seqlen
)
)
{
fmha
::
stg
(
ptr
,
data
[
ii
]);
}
}
}
// Move the pointer to the next location.
inline
__device__
void
move
()
{
qkv_ptr_
+=
(
int64_t
)
ROWS
*
params_qkv_stride_in_bytes_
;
actual_seqlen
-=
ROWS
;
}
// The stride between rows for the QKV matrice.
int64_t
params_qkv_stride_in_bytes_
;
// The pointer.
char
*
qkv_ptr_
;
// The fetch registers.
uint4
fetch_
[
LDGS
];
// Keep track of the row the thread is processing as we move the tile.
int
row_
;
// The length of the sequence loaded by that memory tile.
int
actual_seqlen
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
>
struct
Gmem_tile_o
{
// The mma tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The size of each element.
enum
{
BYTES_PER_ELEMENT
=
2
};
// The size of a row in bytes.
enum
{
BYTES_PER_ROW
=
Cta_tile
::
N
*
BYTES_PER_ELEMENT
};
// The number of threads to store a "row" of the matrix.
enum
{
THREADS_PER_ROW
=
16
};
// The size of each STG.
enum
{
BYTES_PER_STG
=
BYTES_PER_ROW
/
THREADS_PER_ROW
};
// The number of "rows" stored per iteration of the loop. The output of 1 MMA.
enum
{
ROWS
=
Cta_tile
::
M
};
// The number of "rows" stored per iteration of the loop. The output of 1 MMA.
enum
{
ROWS_PER_LOOP
=
ROWS
<=
64
?
ROWS
:
(
int
)
Mma_tile
::
M_PER_MMA_PER_CTA
};
// The number of outter loop for the stores.
enum
{
LOOPS
=
ROWS
/
ROWS_PER_LOOP
};
// The number of "rows" stored per STG.
enum
{
ROWS_PER_STG
=
Cta_tile
::
THREADS_PER_CTA
/
THREADS_PER_ROW
};
// Do we have to guard against partial writes/reads.
enum
{
HAS_INCOMPLETE_STG
=
Cta_tile
::
M
%
ROWS_PER_STG
!=
0
};
// The number of STGs needed to store a chunk of the Q matrix.
enum
{
STGS_PER_LOOP
=
fmha
::
Div_up
<
ROWS_PER_LOOP
,
ROWS_PER_STG
>::
VALUE
};
// The number of STGs needed to store a chunk of the Q matrix in total.
enum
{
STGS
=
STGS_PER_LOOP
*
LOOPS
};
// Ctor.
template
<
typename
Params
,
typename
BInfo
>
inline
__device__
Gmem_tile_o
(
const
Params
&
params
,
const
BInfo
&
binfo
,
int
tidx
)
:
params_o_stride_in_bytes_
(
params
.
o_stride_in_bytes
)
,
actual_seqlen_
(
binfo
.
actual_seqlen
)
,
o_ptr_
(
reinterpret_cast
<
char
*>
(
params
.
o_ptr
))
{
// Compute the position in the sequence (within the CTA for the moment).
int
row
=
tidx
/
THREADS_PER_ROW
;
// Compute the position of the thread in the row.
int
col
=
tidx
%
THREADS_PER_ROW
;
// Store the row as we need it to disable loads.
row_
=
row
;
// The row offset in the batched GEMM.
int64_t
row_offset
=
(
int64_t
)
row
*
params
.
o_stride_in_bytes
+
binfo
.
bidx
*
BYTES_PER_ROW
;
// Assemble the final pointer.
o_ptr_
+=
row_offset
+
col
*
BYTES_PER_STG
;
// Is that thread active on the last STG?
if
(
HAS_INCOMPLETE_STG
)
{
is_active_for_last_stg_
=
row
+
(
STGS
-
1
)
*
ROWS_PER_STG
<
Cta_tile
::
M
;
}
}
// Store data to global memory.
inline
__device__
void
store
(
const
uint4
(
&
src
)[
STGS_PER_LOOP
],
int
mi
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
STGS_PER_LOOP
;
++
ii
)
{
int
jj
=
mi
*
STGS_PER_LOOP
+
ii
;
if
(
this
->
row_
+
jj
*
ROWS_PER_STG
>=
this
->
actual_seqlen_
)
{
break
;
}
float
x
=
reinterpret_cast
<
const
float
&>
(
src
[
ii
].
x
);
float
y
=
reinterpret_cast
<
const
float
&>
(
src
[
ii
].
y
);
float
z
=
reinterpret_cast
<
const
float
&>
(
src
[
ii
].
z
);
float
w
=
reinterpret_cast
<
const
float
&>
(
src
[
ii
].
w
);
uint2
out
=
float4_to_half4
(
x
,
y
,
z
,
w
);
if
(
!
HAS_INCOMPLETE_STG
||
(
jj
<
STGS
-
1
||
this
->
is_active_for_last_stg_
)
)
{
fmha
::
stg
(
this
->
o_ptr_
+
jj
*
ROWS_PER_STG
*
this
->
params_o_stride_in_bytes_
,
out
);
}
}
}
// Move the pointer to the next location.
inline
__device__
void
move
()
{
row_
+=
ROWS
;
o_ptr_
+=
(
int64_t
)
ROWS
*
params_o_stride_in_bytes_
;
}
// The stride between rows for the QKV matrice.
int64_t
params_o_stride_in_bytes_
;
// The pointer.
char
*
o_ptr_
;
// Is the thread active for the last STG?
int
is_active_for_last_stg_
;
// Keep track of the row to disable loads.
int
row_
;
// The length of the sequence loaded by that memory tile.
int
actual_seqlen_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
,
int
BYTES_PER_ELEMENT
>
struct
Gmem_tile_mma_sd
{
// The mma tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// Each STG stores 8 elements.
enum
{
BYTES_PER_STG
=
BYTES_PER_ELEMENT
*
8
};
// The number of MMAs in the M dimension.
enum
{
MMAS_M
=
Mma_tile
::
MMAS_M
};
// The number of MMAs in the N dimension.
enum
{
MMAS_N
=
Mma_tile
::
MMAS_N
};
// The number of rows computed per MMA per thread block.
enum
{
M_PER_MMA_PER_CTA
=
Mma_tile
::
M_PER_MMA_PER_CTA
};
// The number of cols computed per MMA per thread block.
enum
{
N_PER_MMA_PER_CTA
=
Mma_tile
::
N_PER_MMA_PER_CTA
};
// The number of threads per block.
enum
{
THREADS_PER_CTA
=
Cta_tile
::
THREADS_PER_CTA
};
// The size of each row in bytes. I.e. how many bytes are stored per STG.
enum
{
BYTES_PER_ROW
=
THREADS_PER_CTA
*
BYTES_PER_STG
};
// The fixed sequence length.
enum
{
SEQLEN
=
Cta_tile
::
N
};
// The distance between two blocks (in bytes).
enum
{
BLOCK_STRIDE_BYTES
=
SEQLEN
*
SEQLEN
*
BYTES_PER_ELEMENT
};
// The distance between elements stored per loop (in bytes).
enum
{
LOOP_STRIDE_BYTES
=
MMAS_M
*
MMAS_N
*
BYTES_PER_ROW
};
// The type of elements stored per STG.
using
Type
=
typename
fmha
::
Uint_from_size_in_bytes
<
BYTES_PER_STG
>::
Type
;
// Ctor.
template
<
typename
Params
>
inline
__device__
Gmem_tile_mma_sd
(
void
*
ptr
,
const
Params
&
params
,
const
int
tidx
)
:
ptr_
(
static_cast
<
char
*>
(
ptr
))
{
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The block index.
size_t
bidx
=
bidb
*
params
.
h
+
bidh
;
// Set store location for each thread at the beginning of the loop
ptr_
+=
bidx
*
BLOCK_STRIDE_BYTES
+
tidx
*
BYTES_PER_STG
;
}
// Store to global memory.
inline
__device__
void
store
(
const
Type
&
data
,
const
int
mi
,
const
int
ni
)
{
size_t
offset
=
(
mi
*
MMAS_N
+
ni
)
*
BYTES_PER_ROW
;
fmha
::
stg
(
ptr_
+
offset
,
data
);
}
// Load from global memory.
inline
__device__
void
load
(
Type
&
data
,
const
int
mi
,
const
int
ni
)
{
size_t
offset
=
(
mi
*
MMAS_N
+
ni
)
*
BYTES_PER_ROW
;
fmha
::
ldg
(
data
,
ptr_
+
offset
);
}
// Move to the next tile.
inline
__device__
void
move
()
{
ptr_
+=
LOOP_STRIDE_BYTES
;
}
// The pointer in global memory.
char
*
ptr_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
,
typename
Base
=
Gmem_tile_mma_sd
<
Cta_tile
,
sizeof
(
uint16_t
)>
>
struct
Gmem_tile_mma_s
:
public
Base
{
// The number of mmas in the vertical dimension.
enum
{
M
=
Base
::
MMAS_M
};
// The number of mmas in the horizontal dimension.
enum
{
N
=
Base
::
MMAS_N
};
// The type of the vectors stored by each STG.
using
Type
=
typename
Base
::
Type
;
// Ctor.
template
<
typename
Params
>
inline
__device__
Gmem_tile_mma_s
(
void
*
ptr
,
const
Params
&
params
,
const
int
tidx
)
:
Base
(
ptr
,
params
,
tidx
)
{
}
// Store to global memory.
template
<
typename
Mask
>
inline
__device__
void
store
(
const
float
(
&
softmax
)[
2
*
M
][
4
*
N
],
const
Mask
&
mask
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
float
tmp00
=
softmax
[
2
*
mi
+
0
][
4
*
ni
+
0
];
float
tmp01
=
softmax
[
2
*
mi
+
0
][
4
*
ni
+
1
];
float
tmp02
=
softmax
[
2
*
mi
+
0
][
4
*
ni
+
2
];
float
tmp03
=
softmax
[
2
*
mi
+
0
][
4
*
ni
+
3
];
float
tmp10
=
softmax
[
2
*
mi
+
1
][
4
*
ni
+
0
];
float
tmp11
=
softmax
[
2
*
mi
+
1
][
4
*
ni
+
1
];
float
tmp12
=
softmax
[
2
*
mi
+
1
][
4
*
ni
+
2
];
float
tmp13
=
softmax
[
2
*
mi
+
1
][
4
*
ni
+
3
];
uint4
dst
;
dst
.
x
=
fmha
::
float2_to_half2
(
tmp00
,
tmp01
);
dst
.
y
=
fmha
::
float2_to_half2
(
tmp02
,
tmp03
);
dst
.
z
=
fmha
::
float2_to_half2
(
tmp10
,
tmp11
);
dst
.
w
=
fmha
::
float2_to_half2
(
tmp12
,
tmp13
);
if
(
mask
.
is_valid
(
mi
,
ni
,
0
,
0
)
)
{
Base
::
store
(
dst
,
mi
,
ni
);
}
}
}
}
// Load from global memory.
template
<
typename
Mask
>
inline
__device__
void
load
(
uint4
(
&
regs
)[
M
][
N
],
const
Mask
&
mask
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
regs
[
mi
][
ni
]
=
make_uint4
(
0
,
0
,
0
,
0
);
if
(
mask
.
is_valid
(
mi
,
ni
,
0
,
0
)
)
{
Base
::
load
(
regs
[
mi
][
ni
],
mi
,
ni
);
}
}
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The base class.
typename
Base
=
fmha
::
Gmem_tile_qkv
<
Cta_tile
,
fmha
::
BITS_PER_ELEMENT_A
,
Cta_tile
::
M
,
Cta_tile
::
K
>
>
struct
Gmem_tile_dout
:
public
Base
{
// Ctor.
template
<
typename
Params
,
typename
BInfo
>
inline
__device__
Gmem_tile_dout
(
const
Params
&
params
,
const
BInfo
&
binfo
,
int
tidx
)
:
Base
(
params
,
0
,
binfo
,
tidx
)
{
this
->
qkv_ptr_
=
reinterpret_cast
<
char
*>
(
params
.
o_ptr
);
this
->
params_qkv_stride_in_bytes_
=
params
.
o_stride_in_bytes
;
// needed for move
// Compute the position of the thread in the row.
int
col
=
tidx
%
Base
::
THREADS_PER_ROW
;
// The row offset in the batched GEMM. For each seq element, we store O in that order.
int64_t
row_offset
=
(
int64_t
)
this
->
row_
*
params
.
o_stride_in_bytes
+
binfo
.
bidx
*
Base
::
BYTES_PER_ROW
;
// Assemble the final pointer.
this
->
qkv_ptr_
+=
row_offset
+
col
*
Base
::
BYTES_PER_LDG
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
,
typename
Base
=
fmha
::
Gmem_tile_o
<
Cta_tile
>
>
struct
Gmem_tile_dq
:
public
Base
{
// Ctor.
template
<
typename
Params
,
typename
BInfo
>
inline
__device__
Gmem_tile_dq
(
const
Params
&
params
,
const
BInfo
&
binfo
,
int
tidx
)
:
Base
(
params
,
binfo
,
tidx
)
{
this
->
o_ptr_
=
reinterpret_cast
<
char
*>
(
params
.
dqkv_ptr
);
this
->
params_o_stride_in_bytes_
=
params
.
qkv_stride_in_bytes
;
// needed for move
// Compute the position of the thread in the row.
int
col
=
tidx
%
Base
::
THREADS_PER_ROW
;
// The row offset in the batched GEMM. For each seq element, we store O in that order.
int64_t
row_offset
=
(
int64_t
)
this
->
row_
*
params
.
qkv_stride_in_bytes
+
(
binfo
.
sum_s
*
3
*
binfo
.
h
+
binfo
.
bidh
)
*
Base
::
BYTES_PER_ROW
;
// Assemble the final pointer.
this
->
o_ptr_
+=
row_offset
+
col
*
Base
::
BYTES_PER_STG
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha/kernel_traits.h
0 → 100644
View file @
f79993d9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
S
,
int
D
,
int
STEP
,
int
WARPS_M
,
int
WARPS_N
,
uint32_t
FLAGS
=
0x8u
>
struct
FMHA_kernel_traits
{
// The CTA description for the 1st GEMM.
using
Cta_tile_p
=
fmha
::
Cta_tile_extd
<
STEP
,
S
,
D
,
WARPS_M
,
WARPS_N
,
1
>
;
// The CTA description for the 2nd GEMM.
using
Cta_tile_o
=
fmha
::
Cta_tile_extd
<
STEP
,
D
,
S
,
WARPS_M
,
1
,
WARPS_N
>
;
// Do we use one buffer for K and V.
enum
{
SHARE_SMEM_FOR_K_AND_V
=
(
FLAGS
&
0x8u
)
!=
0u
};
// The global memory tile to load Q.
using
Gmem_tile_q
=
fmha
::
Gmem_tile_qkv
<
Cta_tile_p
,
fmha
::
BITS_PER_ELEMENT_A
,
STEP
,
D
>
;
// The shared memory tile to swizzle Q.
using
Smem_tile_q
=
fmha
::
Smem_tile_a
<
Cta_tile_p
,
fmha
::
Row
,
Gmem_tile_q
::
BYTES_PER_LDG
,
1
>
;
// The global memory tile to load K.
using
Gmem_tile_k
=
fmha
::
Gmem_tile_qkv
<
Cta_tile_p
,
fmha
::
BITS_PER_ELEMENT_B
,
S
,
D
>
;
// The shared memory tile to swizzle K.
using
Smem_tile_k
=
fmha
::
Smem_tile_b
<
Cta_tile_p
,
fmha
::
Col
>
;
// The global memory tile to load V.
using
Gmem_tile_v
=
fmha
::
Gmem_tile_qkv
<
Cta_tile_o
,
fmha
::
BITS_PER_ELEMENT_B
,
S
,
D
>
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
fmha
::
Smem_tile_v
<
Cta_tile_o
>
;
// The global memory tile to store O.
using
Gmem_tile_o
=
fmha
::
Gmem_tile_o
<
Cta_tile_o
>
;
// The shared memory tile for O.
using
Smem_tile_o
=
fmha
::
Smem_tile_o
<
Cta_tile_o
>
;
// The global memory tile to load/store S.
using
Gmem_tile_s
=
fmha
::
Gmem_tile_mma_s
<
Cta_tile_p
>
;
// The shared memory tile to transpose S.
using
Smem_tile_st
=
fmha
::
Smem_tile_mma_transposed
<
Cta_tile_p
>
;
using
Gmem_tile_do
=
fmha
::
Gmem_tile_dout
<
Cta_tile_p
>
;
// Make sure the number of threads match.
static_assert
((
int
)
Gmem_tile_o
::
THREADS_PER_ROW
==
(
int
)
Smem_tile_o
::
THREADS_PER_ROW
,
""
);
// The number of threads.
enum
{
THREADS
=
Cta_tile_p
::
THREADS_PER_CTA
};
// Make sure the number of threads matches both CTAs.
static_assert
((
int
)
THREADS
==
(
int
)
Cta_tile_o
::
THREADS_PER_CTA
,
""
);
// The amount of shared memory needed to load Q and K.
enum
{
BYTES_PER_SMEM_QK
=
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
};
// The extra amount of shared memory needed to load V.
enum
{
BYTES_PER_SMEM_V
=
SHARE_SMEM_FOR_K_AND_V
?
0u
:
Smem_tile_v
::
BYTES_PER_TILE
};
// The amount of shared memory needed for Q, K and V..
enum
{
BYTES_PER_SMEM_QKV
=
BYTES_PER_SMEM_QK
+
BYTES_PER_SMEM_V
};
// The amount of shared memory needed to load Q and store O.
enum
{
BYTES_PER_SMEM_QO
=
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
};
// The amount of shared memory needed for Q, K, V and O.
enum
{
BYTES_PER_SMEM
=
fmha
::
Max
<
BYTES_PER_SMEM_QKV
,
BYTES_PER_SMEM_QO
>::
VALUE
};
// Make sure we have enough shared memory.
static_assert
(
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
<=
BYTES_PER_SMEM
,
""
);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
apex/contrib/csrc/fmha/src/fmha/mask.h
0 → 100644
View file @
f79993d9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
namespace
fmha
{
template
<
typename
Cta_tile
>
struct
Mask
{
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
template
<
typename
Params
,
typename
BInfo
>
__device__
Mask
(
const
Params
&
params
,
const
BInfo
&
blockInfo
,
int
tidx
)
{
actual_seqlen
=
blockInfo
.
actual_seqlen
;
const
int
warp
=
tidx
/
Cta_tile
::
THREADS_PER_WARP
;
const
int
lane
=
tidx
%
Cta_tile
::
THREADS_PER_WARP
;
static_assert
(
Cta_tile
::
WARPS_K
==
1
,
""
);
// find the warp in the Cta tile
const
int
warp_n
=
(
warp
/
Cta_tile
::
WARPS_M
);
const
int
warp_m
=
(
warp
%
Cta_tile
::
WARPS_M
);
// decompose warp into 8x4 tile
const
int
quad
=
lane
/
4
;
const
int
tid
=
(
lane
%
4
)
*
2
;
row
=
warp_m
*
16
+
quad
;
col
=
warp_n
*
16
+
tid
;
}
inline
__device__
bool
is_valid
(
const
int
mi
,
const
int
ni
,
const
int
ii
,
const
int
jj
)
const
{
// ii and jj iterate over the 2x4 fragment
const
bool
col_valid
=
(
ni
*
Mma_tile
::
N_PER_MMA_PER_CTA
+
col
+
(
jj
&
2
)
*
4
+
(
jj
&
1
))
<
actual_seqlen
;
//&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen;
return
col_valid
;
// return row_valid && col_valid;
}
inline
__device__
void
load
(
int
it
)
{
row_offset
=
it
*
Cta_tile
::
M
+
row
;
}
int
row_offset
;
int
row
;
int
col
;
int
actual_seqlen
;
};
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha/smem_tile.h
0 → 100644
View file @
f79993d9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <fmha/utils.h>
#include <fmha/gemm.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The description of the tile computed by this CTA.
typename
Cta_tile
,
// The number of rows in the 2D shared memory buffer.
int
M_
,
// The number of cols.
int
N_
,
// The size in bits of each element.
int
BITS_PER_ELEMENT_
,
// The number of bytes per STS.
int
BYTES_PER_STS_
=
16
,
// The number of buffers. (Used in multistage and double buffer cases.)
int
BUFFERS_PER_TILE_
=
1
,
// Do we enable the fast path for LDS.128 and friends.
int
ENABLE_LDS_FAST_PATH_
=
0
,
// The number of rows that are used for the XOR swizzling to allow fast STS/LDS.
int
ROWS_PER_XOR_PATTERN_
=
8
,
// The number of cols that are used for the XOR swizzling to allow fast STS/LDS.
int
COLS_PER_XOR_PATTERN_
=
1
,
// Use or not predicates
bool
USE_PREDICATES_
=
true
>
struct
Smem_tile_without_skews
{
// The size in bits of each element.
enum
{
BITS_PER_ELEMENT
=
BITS_PER_ELEMENT_
};
// The size in bytes of a single STS.
enum
{
BYTES_PER_STS
=
BYTES_PER_STS_
};
// The number of elements per STS.
enum
{
ELEMENTS_PER_STS
=
BYTES_PER_STS
*
8
/
BITS_PER_ELEMENT
};
// To support arbitrary N, we pad some values to a power-of-2.
enum
{
N_WITH_PADDING
=
Next_power_of_two
<
N_
>::
VALUE
};
// The number of bytes per row without packing of rows.
enum
{
BYTES_PER_ROW_BEFORE_PACKING
=
N_WITH_PADDING
*
BITS_PER_ELEMENT
/
8
};
// The number of bytes per row -- we want at least 128B per row.
enum
{
BYTES_PER_ROW
=
Max
<
BYTES_PER_ROW_BEFORE_PACKING
,
128
>::
VALUE
};
// The number of rows in shared memory (two rows may be packed into a single one).
enum
{
ROWS
=
M_
*
BYTES_PER_ROW_BEFORE_PACKING
/
BYTES_PER_ROW
};
// The number of threads per row.
enum
{
THREADS_PER_ROW_UNBOUNDED
=
BYTES_PER_ROW
/
BYTES_PER_STS
};
// The number of threads per row.
enum
{
THREADS_PER_ROW
=
Min
<
Cta_tile
::
THREADS_PER_CTA
,
THREADS_PER_ROW_UNBOUNDED
>::
VALUE
};
// The number of STS per row.
enum
{
STS_PER_ROW
=
BYTES_PER_ROW
/
THREADS_PER_ROW
/
BYTES_PER_STS
};
// It must be at least one.
static_assert
(
STS_PER_ROW
>=
1
,
""
);
// The number of rows written with a single STS.
enum
{
ROWS_PER_STS
=
Cta_tile
::
THREADS_PER_CTA
/
THREADS_PER_ROW
};
// Make sure we write to at least one row per STS. Thanks Dr. Obvious ;)
static_assert
(
ROWS_PER_STS
>=
1
,
""
);
// The number of STS needed to store all rows.
enum
{
STS_PER_COL
=
Div_up
<
ROWS
,
ROWS_PER_STS
>::
VALUE
};
// The number of STS in total.
enum
{
STS
=
STS_PER_COL
*
STS_PER_ROW
};
// The size of one buffer in bytes in shared memory.
enum
{
BYTES_PER_BUFFER
=
STS
*
BYTES_PER_STS
*
Cta_tile
::
THREADS_PER_CTA
};
// The number of buffers.
enum
{
BUFFERS_PER_TILE
=
BUFFERS_PER_TILE_
};
// The size in bytes of total buffers.
enum
{
BYTES_PER_TILE
=
BYTES_PER_BUFFER
*
BUFFERS_PER_TILE
};
// The boundary for smem_read_offset and smem_write_offset increment.
enum
{
BYTES_PER_TILE_INC_BOUNDARY
=
BYTES_PER_TILE
-
BYTES_PER_BUFFER
};
// Do we enable the LDS.128 fast path?
enum
{
ENABLE_LDS_FAST_PATH
=
ENABLE_LDS_FAST_PATH_
};
static_assert
(
ENABLE_LDS_FAST_PATH
==
0
);
// The number of rows that are used for the XOR swizzling to allow fast STS/LDS.
enum
{
ROWS_PER_XOR_PATTERN
=
ROWS_PER_XOR_PATTERN_
};
// The number of cols that are used for the XOR swizzling to allow fast STS/LDS.
enum
{
COLS_PER_XOR_PATTERN
=
COLS_PER_XOR_PATTERN_
*
16
/
BYTES_PER_STS
};
// Use or not predicates
enum
{
USE_PREDICATES
=
USE_PREDICATES_
};
// The type of elements that are stored in shared memory by each thread.
using
Store_type
=
typename
Uint_from_size_in_bytes
<
BYTES_PER_STS
>::
Type
;
// Ctor.
inline
__device__
Smem_tile_without_skews
(
void
*
smem
,
int
tidx
)
:
smem_
(
__nvvm_get_smem_pointer
(
smem
))
{
// The row written by a thread. See doc/mma_smem_layout.xlsx.
int
smem_write_row
=
tidx
/
THREADS_PER_ROW
;
// The XOR pattern.
int
smem_write_xor
=
smem_write_row
%
ROWS_PER_XOR_PATTERN
*
COLS_PER_XOR_PATTERN
;
// Compute the column and apply the XOR pattern.
int
smem_write_col
=
(
tidx
%
THREADS_PER_ROW
)
^
smem_write_xor
;
// The offset.
this
->
smem_write_offset_
=
smem_write_row
*
BYTES_PER_ROW
+
smem_write_col
*
BYTES_PER_STS
;
// TODO: Why not merge it with the read offset?
this
->
smem_read_buffer_
=
__shfl_sync
(
0xffffffff
,
0
,
0
);
this
->
smem_write_buffer_
=
__shfl_sync
(
0xffffffff
,
0
,
0
);
}
// Compute the store pointers.
template
<
int
N
>
inline
__device__
void
compute_store_pointers
(
uint32_t
(
&
ptrs
)[
N
])
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
N
;
++
ii
)
{
// Decompose the STS into row/col.
int
row
=
ii
/
STS_PER_ROW
;
int
col
=
ii
%
STS_PER_ROW
;
// Assemble the offset.
int
offset
=
smem_write_offset_
+
row
*
ROWS_PER_STS
*
BYTES_PER_ROW
;
// Take the column into account.
if
(
STS_PER_ROW
>
1
)
{
offset
+=
col
*
THREADS_PER_ROW
*
BYTES_PER_STS
;
}
// Apply the XOR pattern if needed.
if
(
ROWS_PER_STS
<
ROWS_PER_XOR_PATTERN
)
{
const
int
m
=
row
*
ROWS_PER_STS
%
ROWS_PER_XOR_PATTERN
;
offset
^=
m
*
COLS_PER_XOR_PATTERN
*
BYTES_PER_STS
;
}
// Assemble the final pointer :)
ptrs
[
ii
]
=
smem_
+
offset
+
smem_write_buffer_
;
}
}
inline
__device__
void
debug_reset
()
{
for
(
int
buffer
=
0
;
buffer
<
BYTES_PER_TILE
;
buffer
+=
BYTES_PER_BUFFER
)
{
for
(
int
row
=
0
;
row
<
ROWS
;
++
row
)
{
for
(
int
col
=
0
;
col
<
BYTES_PER_ROW
;
col
+=
4
)
{
if
(
threadIdx
.
x
==
0
)
{
uint32_t
val
=
0x0
;
sts
(
val
,
smem_
+
row
*
BYTES_PER_ROW
+
col
+
buffer
);
}
}
}
}
}
// Print the content of the tile (only for debug ;)).
inline
__device__
void
debug_print
()
const
{
for
(
int
buffer
=
0
;
buffer
<
BYTES_PER_TILE
;
buffer
+=
BYTES_PER_BUFFER
)
{
for
(
int
row
=
0
;
row
<
ROWS
;
++
row
)
{
for
(
int
col
=
0
;
col
<
BYTES_PER_ROW
;
col
+=
4
)
{
if
(
threadIdx
.
x
==
0
)
{
uint32_t
val
;
lds
(
val
,
smem_
+
row
*
BYTES_PER_ROW
+
col
+
buffer
);
printf
(
"block=(x=%2d, y=%2d, z=%2d) (smem_=%2d, buffer=%2d, row=%2d, byte=%4d)=0x%08x
\n
"
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
smem_
,
buffer
,
row
,
col
,
val
);
}
}
}
}
}
// Move the read offset to next buffer.
inline
__device__
void
move_to_next_read_buffer
()
{
if
(
BUFFERS_PER_TILE
>
1
&&
smem_read_buffer_
>=
BYTES_PER_TILE_INC_BOUNDARY
)
{
this
->
smem_read_buffer_
-=
BYTES_PER_TILE_INC_BOUNDARY
;
}
else
if
(
BUFFERS_PER_TILE
>
1
)
{
this
->
smem_read_buffer_
+=
BYTES_PER_BUFFER
;
}
}
// Move the read offset to next buffer. TODO: Remove this member function!!!
inline
__device__
void
move_next_read_buffer
()
{
this
->
move_to_next_read_buffer
();
}
// Move the read offset to next N buffer (circular-buffer).
inline
__device__
void
move_to_next_read_buffer
(
int
N
)
{
if
(
BUFFERS_PER_TILE
>
1
)
{
this
->
smem_read_buffer_
+=
N
*
BYTES_PER_BUFFER
;
this
->
smem_read_buffer_
-=
smem_read_buffer_
>=
BYTES_PER_TILE
?
BYTES_PER_TILE
:
0
;
}
}
// Move the read offset to next N buffer (circular-buffer). TODO: Remove this member function!!!
inline
__device__
void
move_next_read_buffer
(
int
N
)
{
this
->
move_to_next_read_buffer
(
N
);
}
// Move the write offset to next buffer.
inline
__device__
void
move_to_next_write_buffer
()
{
if
(
BUFFERS_PER_TILE
>
1
&&
smem_write_buffer_
>=
BYTES_PER_TILE_INC_BOUNDARY
)
{
this
->
smem_write_buffer_
-=
BYTES_PER_TILE_INC_BOUNDARY
;
}
else
if
(
BUFFERS_PER_TILE
>
1
)
{
this
->
smem_write_buffer_
+=
BYTES_PER_BUFFER
;
}
}
// Move the write offset to next buffer. TODO: Remove that member function!
inline
__device__
void
move_next_write_buffer
()
{
this
->
move_to_next_write_buffer
();
}
// Move the read offset.
inline
__device__
void
move_read_offset
(
int
delta
)
{
this
->
smem_read_offset_
+=
delta
;
}
// Move the write offset.
inline
__device__
void
move_write_offset
(
int
delta
)
{
this
->
smem_write_offset_
+=
delta
;
}
// Store to the tile in shared memory.
template
<
int
N
>
inline
__device__
void
store
(
const
Store_type
(
&
data
)[
N
],
uint64_t
=
0
)
{
uint32_t
smem_ptrs
[
N
];
this
->
compute_store_pointers
(
smem_ptrs
);
sts
(
smem_ptrs
,
data
);
}
// Store to the tile in shared memory.
template
<
int
N
,
int
M
>
inline
__device__
void
store
(
const
Store_type
(
&
data
)[
N
],
uint32_t
(
&
preds
)[
M
],
uint64_t
=
0
)
{
uint32_t
smem_ptrs
[
N
];
this
->
compute_store_pointers
(
smem_ptrs
);
sts
(
smem_ptrs
,
data
,
preds
);
}
// Store to the tile in shared memory.
template
<
int
N
>
inline
__device__
void
store
(
const
Store_type
(
&
data
)[
N
],
uint32_t
preds
,
uint64_t
=
0
)
{
this
->
store
(
data
,
preds
);
}
// Store to the tile in shared memory.
template
<
int
N
>
inline
__device__
void
store
(
const
void
*
(
&
gmem_ptrs
)[
N
],
uint32_t
preds
,
uint64_t
=
0
)
{
uint32_t
tmp
[
1
]
=
{
preds
};
this
->
store
(
gmem_ptrs
,
tmp
);
}
// The shared memory pointer.
uint32_t
smem_
;
// The read offset. Reserve 4 offsets if needed.
int
smem_read_offset_
;
// The write offset.
int
smem_write_offset_
;
// The buffer base offset for read.
int
smem_read_buffer_
;
// The buffer base offset for write.
int
smem_write_buffer_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The layout of the tile.
typename
Layout
,
// The size of the STS.
int
BYTES_PER_STS
=
16
,
// The number of buffers per tile.
int
BUFFERS_PER_TILE
=
1
,
// Use or not predicates
bool
USE_PREDICATES
=
true
>
struct
Smem_tile_a
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
MMAS_K
,
int
MMAS_K_WITH_PADDING
>
struct
Compute_reset_mask
{
// The potential mask.
enum
{
HALF
=
MMAS_K_WITH_PADDING
/
2
};
// The remainder.
enum
{
MOD
=
MMAS_K
%
HALF
};
// The final value.
enum
{
VALUE
=
(
MMAS_K
==
MOD
?
0
:
HALF
)
|
Compute_reset_mask
<
MOD
,
HALF
>::
VALUE
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
MMAS_K_WITH_PADDING
>
struct
Compute_reset_mask
<
0
,
MMAS_K_WITH_PADDING
>
{
enum
{
VALUE
=
0
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
MMAS_K
>
struct
Compute_reset_mask
<
MMAS_K
,
MMAS_K
>
{
enum
{
VALUE
=
MMAS_K
-
1
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
struct
Rows_per_xor_pattern_a
{
// The size in bits.
enum
{
N_IN_BITS
=
N
*
fmha
::
BITS_PER_ELEMENT_A
};
// The number of rows.
enum
{
VALUE
=
N_IN_BITS
<=
256
?
2
:
(
N_IN_BITS
<=
512
?
4
:
8
)
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
struct
Rows_per_xor_pattern_row_a
:
public
Rows_per_xor_pattern_a
<
N
>
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The size of the STS.
int
BYTES_PER_STS
,
// The number of buffers per tile.
int
BUFFERS_PER_TILE
,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int
ROWS_PER_XOR_PATTERN_
=
Rows_per_xor_pattern_row_a
<
Cta_tile
::
K
>
::
VALUE
>
struct
Smem_tile_row_a
:
public
Smem_tile_without_skews
<
Cta_tile
,
Cta_tile
::
M
,
Cta_tile
::
K
,
fmha
::
BITS_PER_ELEMENT_A
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
,
0
,
ROWS_PER_XOR_PATTERN_
,
1
>
{
// The MMA tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The base class.
using
Base
=
Smem_tile_without_skews
<
Cta_tile
,
Cta_tile
::
M
,
Cta_tile
::
K
,
fmha
::
BITS_PER_ELEMENT_A
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
,
0
,
ROWS_PER_XOR_PATTERN_
,
1
>
;
// The fragment.
using
Fragment
=
Fragment_a
<
Row
>
;
// When we use padding to reach a power of two, special care has to be taken.
using
Cta_tile_with_padding
=
Cta_tile_with_k_with_padding
<
Cta_tile
>
;
// The number of MMAs.
using
Mma_tile_with_padding
=
fmha
::
Hmma_tile
<
Cta_tile_with_padding
>
;
// The size of a single LDS in bytes.
enum
{
BYTES_PER_LDS
=
16
};
// Ctor.
inline
__device__
Smem_tile_row_a
(
void
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
// For documentation on the layout, see doc/mma_smem_layout.xlsx.
// The number of warps.
const
int
WARPS_M
=
Cta_tile
::
WARPS_M
;
const
int
WARPS_N
=
Cta_tile
::
WARPS_N
;
const
int
WARPS_K
=
Cta_tile
::
WARPS_K
;
static_assert
(
WARPS_M
==
1
);
static_assert
(
WARPS_N
==
4
||
WARPS_N
==
8
);
static_assert
(
WARPS_K
==
1
);
static_assert
(
Base
::
ROWS_PER_XOR_PATTERN
==
8
);
// The row and column read by the thread.
int
smem_read_row
=
(
tidx
&
0x0f
);
int
smem_read_col
=
(
tidx
&
0x07
);
smem_read_col
^=
(
tidx
&
0x10
)
/
16
;
// The shared memory offset.
this
->
smem_read_offset_
=
smem_read_row
*
Base
::
BYTES_PER_ROW
+
smem_read_col
*
BYTES_PER_LDS
;
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline
__device__
void
reverse_smem_read_offset
(
int
ki
=
0
)
{
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if
(
Mma_tile_with_padding
::
MMAS_K
>=
2
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
2
;
}
}
// Load from shared memory.
inline
__device__
void
load
(
Fragment
(
&
a
)[
Mma_tile
::
MMAS_M
],
int
ki
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile
::
MMAS_M
;
++
mi
)
{
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
int
offset
=
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
*
Base
::
BYTES_PER_ROW_BEFORE_PACKING
;
// Load using LDSM.M88.4.
uint4
tmp
;
ldsm
(
tmp
,
this
->
smem_
+
this
->
smem_read_offset_
+
this
->
smem_read_buffer_
+
offset
);
// Store the value into the fragment.
a
[
mi
].
reg
(
0
)
=
tmp
.
x
;
a
[
mi
].
reg
(
1
)
=
tmp
.
y
;
a
[
mi
].
reg
(
2
)
=
tmp
.
z
;
a
[
mi
].
reg
(
3
)
=
tmp
.
w
;
}
// Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
static_assert
(
Mma_tile_with_padding
::
MMAS_K
<
64
,
"Not implemented"
);
if
(
Mma_tile_with_padding
::
MMAS_K
>=
32
&&
ki
%
16
==
15
)
{
this
->
smem_read_offset_
^=
31
*
BYTES_PER_LDS
*
2
;
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
16
&&
ki
%
8
==
7
)
{
this
->
smem_read_offset_
^=
15
*
BYTES_PER_LDS
*
2
;
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
8
&&
ki
%
4
==
3
)
{
this
->
smem_read_offset_
^=
7
*
BYTES_PER_LDS
*
2
;
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
4
&&
ki
%
2
==
1
)
{
this
->
smem_read_offset_
^=
3
*
BYTES_PER_LDS
*
2
;
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
2
)
{
this
->
smem_read_offset_
^=
1
*
BYTES_PER_LDS
*
2
;
}
}
// Reset the read offset.
inline
__device__
void
reset_read_offset
()
{
// The number of MMAs in the K dimension.
enum
{
MMAS_K
=
Mma_tile
::
MMAS_K
};
// The number of MMAs in the K dimension when we include padding.
enum
{
MMAS_K_WITH_PADDING
=
Mma_tile_with_padding
::
MMAS_K
};
// Assemble the mask.
enum
{
MASK
=
Compute_reset_mask
<
MMAS_K
,
MMAS_K_WITH_PADDING
>::
VALUE
};
// Reset the read offset.
this
->
smem_read_offset_
^=
MASK
*
BYTES_PER_LDS
*
2
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The size of the STS.
int
BYTES_PER_STS
,
// The number of buffers per tile.
int
BUFFERS_PER_TILE
>
struct
Smem_tile_a
<
Cta_tile
,
Row
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
:
public
Smem_tile_row_a
<
Cta_tile
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
{
// The base class.
using
Base
=
Smem_tile_row_a
<
Cta_tile
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
;
// Ctor.
inline
__device__
Smem_tile_a
(
void
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The layout of the tile.
typename
Layout
,
// The size of the STS.
int
BYTES_PER_STS
=
16
,
// The number of buffers per tile.
int
BUFFERS_PER_TILE
=
1
,
// Use or not predicates
bool
USE_PREDICATES
=
true
>
struct
Smem_tile_b
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
struct
Rows_per_xor_pattern_b
{
// The size in bits.
enum
{
N_IN_BITS
=
N
*
fmha
::
BITS_PER_ELEMENT_B
};
// The number of rows.
enum
{
VALUE
=
N_IN_BITS
<=
256
?
2
:
(
N_IN_BITS
<=
512
?
4
:
8
)
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
struct
Rows_per_xor_pattern_col_b
:
public
Rows_per_xor_pattern_b
<
N
>
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The size of the STS.
int
BYTES_PER_STS
,
// The number of buffers per tile.
int
BUFFERS_PER_TILE
,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int
ROWS_PER_XOR_PATTERN_
=
Rows_per_xor_pattern_col_b
<
Cta_tile
::
K
>
::
VALUE
>
struct
Smem_tile_col_b
:
public
Smem_tile_without_skews
<
Cta_tile
,
Cta_tile
::
N
,
Cta_tile
::
K
,
fmha
::
BITS_PER_ELEMENT_B
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
,
0
,
ROWS_PER_XOR_PATTERN_
,
1
>
{
// The MMA tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The base class.
using
Base
=
Smem_tile_without_skews
<
Cta_tile
,
Cta_tile
::
N
,
Cta_tile
::
K
,
fmha
::
BITS_PER_ELEMENT_B
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
,
0
,
ROWS_PER_XOR_PATTERN_
,
1
>
;
// The fragment.
using
Fragment
=
Fragment_b
<
Col
>
;
// When we use padding to reach a power of two, special care has to be taken.
using
Cta_tile_with_padding
=
Cta_tile_with_k_with_padding
<
Cta_tile
>
;
// The number of MMAs.
using
Mma_tile_with_padding
=
fmha
::
Hmma_tile
<
Cta_tile_with_padding
>
;
// The size of a single LDS in bytes.
enum
{
BYTES_PER_LDS
=
16
};
// The number of STS per thread
enum
{
STS_PER_THREAD_
=
Base
::
ROWS
*
Base
::
THREADS_PER_ROW
/
Cta_tile
::
THREADS_PER_CTA
};
// The number of STS per thread must be at least 1.
enum
{
STS_PER_THREAD
=
Max
<
1
,
STS_PER_THREAD_
>::
VALUE
};
// Ctor.
inline
__device__
Smem_tile_col_b
(
void
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
// For documentation on the layout, see doc/mma_smem_layout.xlsx.
// The number of warps.
const
int
WARPS_M
=
Cta_tile
::
WARPS_M
;
const
int
WARPS_N
=
Cta_tile
::
WARPS_N
;
const
int
WARPS_K
=
Cta_tile
::
WARPS_K
;
static_assert
(
Base
::
ROWS_PER_XOR_PATTERN
==
8
);
static_assert
(
WARPS_M
==
1
);
static_assert
(
WARPS_N
==
4
||
WARPS_N
==
8
);
static_assert
(
WARPS_K
==
1
);
// The masks to select the warps.
const
int
WARP_MASK_N
=
Warp_masks
<
WARPS_M
,
WARPS_N
,
WARPS_K
>::
N
;
// The divisor for the warps.
const
int
WARP_DIV_N
=
WARPS_M
*
1
*
Cta_tile
::
THREADS_PER_WARP
;
// The row and column read by the thread.
int
smem_read_row
=
(
tidx
&
WARP_MASK_N
)
/
WARP_DIV_N
*
Mma_tile
::
N_PER_MMA
+
(
tidx
&
0x07
)
+
(
tidx
&
0x10
)
/
2
;
int
smem_read_col
=
(
tidx
&
0x07
);
smem_read_col
^=
(
tidx
&
0x08
)
/
8
;
// The shared memory offset.
this
->
smem_read_offset_
=
smem_read_row
*
Base
::
BYTES_PER_ROW
+
smem_read_col
*
BYTES_PER_LDS
;
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline
__device__
void
reverse_smem_read_offset
(
int
ki
=
0
)
{
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if
(
Mma_tile_with_padding
::
MMAS_K
>=
2
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
2
;
}
}
// Load from shared memory.
inline
__device__
void
load
(
Fragment
(
&
b
)[
Mma_tile
::
MMAS_N
],
int
ki
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile
::
MMAS_N
;
++
ni
)
{
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
int
offset
=
ni
*
Mma_tile
::
N_PER_MMA_PER_CTA
*
Base
::
BYTES_PER_ROW_BEFORE_PACKING
;
// Load using LDSM.M88.4.
uint4
tmp
;
ldsm
(
tmp
,
this
->
smem_
+
this
->
smem_read_offset_
+
this
->
smem_read_buffer_
+
offset
);
// Store the value into the fragment.
b
[
ni
].
reg
(
0
)
=
tmp
.
x
;
b
[
ni
].
reg
(
1
)
=
tmp
.
y
;
b
[
ni
].
reg
(
2
)
=
tmp
.
z
;
b
[
ni
].
reg
(
3
)
=
tmp
.
w
;
}
// Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
static_assert
(
Mma_tile_with_padding
::
MMAS_K
<
64
,
"Not implemented"
);
if
(
Mma_tile_with_padding
::
MMAS_K
>=
32
&&
ki
%
16
==
15
)
{
this
->
smem_read_offset_
^=
31
*
BYTES_PER_LDS
*
2
;
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
16
&&
ki
%
8
==
7
)
{
this
->
smem_read_offset_
^=
15
*
BYTES_PER_LDS
*
2
;
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
8
&&
ki
%
4
==
3
)
{
this
->
smem_read_offset_
^=
7
*
BYTES_PER_LDS
*
2
;
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
4
&&
ki
%
2
==
1
)
{
this
->
smem_read_offset_
^=
3
*
BYTES_PER_LDS
*
2
;
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
2
)
{
this
->
smem_read_offset_
^=
1
*
BYTES_PER_LDS
*
2
;
}
}
// Reset the read offset.
inline
__device__
void
reset_read_offset
()
{
// The number of MMAs in the K dimension.
enum
{
MMAS_K
=
Mma_tile
::
MMAS_K
};
// The number of MMAs in the K dimension when we include padding.
enum
{
MMAS_K_WITH_PADDING
=
Mma_tile_with_padding
::
MMAS_K
};
// Assemble the mask.
enum
{
MASK
=
Compute_reset_mask
<
MMAS_K
,
MMAS_K_WITH_PADDING
>::
VALUE
};
// Reset the read offset.
this
->
smem_read_offset_
^=
MASK
*
BYTES_PER_LDS
*
2
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The size of the STS.
int
BYTES_PER_STS
,
// The number of buffers per tile.
int
BUFFERS_PER_TILE
>
struct
Smem_tile_b
<
Cta_tile
,
Col
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
:
public
Smem_tile_col_b
<
Cta_tile
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
{
// The base class.
using
Base
=
Smem_tile_col_b
<
Cta_tile
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
;
// Ctor.
inline
__device__
Smem_tile_b
(
void
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
struct
Rows_per_xor_pattern_row_b
:
public
Rows_per_xor_pattern_b
<
N
>
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The size of the STS.
int
BYTES_PER_STS
,
// The number of buffers per tile.
int
BUFFERS_PER_TILE
,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int
ROWS_PER_XOR_PATTERN_
=
Rows_per_xor_pattern_row_b
<
Cta_tile
::
N
>
::
VALUE
,
// How many cols to use for the XOR pattern to avoid bank conflicts?
int
COLS_PER_XOR_PATTERN_
=
1
>
struct
Smem_tile_row_b
:
public
Smem_tile_without_skews
<
Cta_tile
,
Cta_tile
::
K
,
Cta_tile
::
N
,
fmha
::
BITS_PER_ELEMENT_B
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
,
0
,
ROWS_PER_XOR_PATTERN_
,
COLS_PER_XOR_PATTERN_
>
{
// The MMA tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The base class.
using
Base
=
Smem_tile_without_skews
<
Cta_tile
,
Cta_tile
::
K
,
Cta_tile
::
N
,
fmha
::
BITS_PER_ELEMENT_B
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
,
0
,
ROWS_PER_XOR_PATTERN_
,
COLS_PER_XOR_PATTERN_
>
;
// The fragment.
using
Fragment
=
Fragment_b
<
Row
>
;
// Can we use LDSM? No if the data type is 32-bit large.
enum
{
USE_LDSMT
=
fmha
::
BITS_PER_ELEMENT_B
==
16
};
// The size of a single LDS in bytes.
enum
{
BYTES_PER_LDS
=
USE_LDSMT
?
16
:
4
};
// The number of elements per LDS.
enum
{
ELEMENTS_PER_LDS
=
BYTES_PER_LDS
*
8
/
fmha
::
BITS_PER_ELEMENT_B
};
// The number of STS per thread
enum
{
STS_PER_THREAD_
=
Base
::
ROWS
*
Base
::
THREADS_PER_ROW
/
Cta_tile
::
THREADS_PER_CTA
};
// The number of STS per thread must be at least 1.
enum
{
STS_PER_THREAD
=
Max
<
1
,
STS_PER_THREAD_
>::
VALUE
};
// Ctor.
inline
__device__
Smem_tile_row_b
(
void
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
// The number of warps.
const
int
WARPS_M
=
Cta_tile
::
WARPS_M
;
const
int
WARPS_N
=
Cta_tile
::
WARPS_N
;
const
int
WARPS_K
=
Cta_tile
::
WARPS_K
;
static_assert
(
WARPS_K
==
1
);
static_assert
(
WARPS_M
==
4
||
WARPS_M
==
8
);
static_assert
(
WARPS_N
==
1
);
// The masks to select the warps.
const
int
WARP_MASK_N
=
Warp_masks
<
WARPS_M
,
WARPS_N
,
WARPS_K
>::
N
;
const
int
WARP_MASK_K
=
Warp_masks
<
WARPS_M
,
WARPS_N
,
WARPS_K
>::
K
;
// The divisor for the warps.
const
int
WARP_DIV_N
=
WARPS_M
*
1
*
Cta_tile
::
THREADS_PER_WARP
;
const
int
WARP_DIV_K
=
WARPS_M
*
WARPS_N
*
Cta_tile
::
THREADS_PER_WARP
;
// The row/col read by the thread.
int
smem_read_row
,
smem_read_col
;
static_assert
(
USE_LDSMT
);
static_assert
(
Base
::
ROWS_PER_XOR_PATTERN
==
8
);
smem_read_row
=
(
tidx
&
WARP_MASK_K
)
/
WARP_DIV_K
*
Mma_tile
::
MMAS_K
*
16
+
(
tidx
&
0x07
)
+
(
tidx
&
0x08
);
smem_read_col
=
(
tidx
&
0x07
);
smem_read_col
^=
(
tidx
&
WARP_MASK_N
)
/
WARP_DIV_N
*
2
+
(
tidx
&
0x10
)
/
16
;
// The shared memory offset.
this
->
smem_read_offset_
=
smem_read_row
*
Base
::
BYTES_PER_ROW
+
smem_read_col
*
BYTES_PER_LDS
;
// Fill zeroes for group conv
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline
__device__
void
reverse_smem_read_offset
(
int
ki
=
0
)
{
// The size of each element in bits.
const
int
BITS_PER_ELT
=
fmha
::
BITS_PER_ELEMENT_B
;
// The size in bytes of the data needed to compute an MMA per CTA.
const
int
BYTES_PER_MMA_PER_CTA
=
Mma_tile
::
N_PER_MMA_PER_CTA
*
BITS_PER_ELT
/
8
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile
::
MMAS_N
;
++
ni
)
{
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if
(
BYTES_PER_MMA_PER_CTA
>=
128
)
{
// Nothing to do!
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
64
&&
Mma_tile
::
MMAS_N
>
1
)
{
this
->
smem_read_offset_
^=
BYTES_PER_MMA_PER_CTA
;
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
64
)
{
// Nothing to do!
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
32
&&
Mma_tile
::
MMAS_N
==
4
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
(
ni
%
2
==
0
?
2
:
6
);
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
32
&&
Mma_tile
::
MMAS_N
==
2
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
2
;
}
}
// Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)
if
(
BYTES_PER_MMA_PER_CTA
==
64
&&
Mma_tile
::
MMAS_N
>
1
&&
Mma_tile
::
MMAS_N
%
2
==
1
)
{
this
->
smem_read_offset_
^=
BYTES_PER_MMA_PER_CTA
;
}
}
// Load from shared memory.
inline
__device__
void
load
(
Fragment
(
&
b
)[
Mma_tile
::
MMAS_N
],
int
ki
)
{
// The size of each element in bits.
const
int
BITS_PER_ELT
=
fmha
::
BITS_PER_ELEMENT_B
;
// The size in bytes of the data needed to compute an MMA per CTA.
const
int
BYTES_PER_MMA_PER_CTA
=
Mma_tile
::
N_PER_MMA_PER_CTA
*
BITS_PER_ELT
/
8
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile
::
MMAS_N
;
++
ni
)
{
// Prepare the offset.
int
offset
=
ki
*
Base
::
ROWS_PER_XOR_PATTERN
*
2
*
Base
::
BYTES_PER_ROW
;
if
(
BYTES_PER_MMA_PER_CTA
==
32
)
{
offset
+=
this
->
smem_read_offset_
;
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
64
)
{
offset
+=
this
->
smem_read_offset_
+
(
ni
/
2
)
*
BYTES_PER_MMA_PER_CTA
*
2
;
}
else
{
offset
+=
this
->
smem_read_offset_
+
(
ni
)
*
BYTES_PER_MMA_PER_CTA
;
}
// Load the data using LDSM.MT88.2.
uint32_t
ptr
=
this
->
smem_
+
this
->
smem_read_buffer_
+
offset
;
uint4
tmp
;
if
(
USE_LDSMT
)
{
ldsmt
(
tmp
,
ptr
);
}
else
{
lds
(
tmp
.
x
,
(
ptr
)
+
0
*
Base
::
BYTES_PER_ROW
);
lds
(
tmp
.
y
,
(
ptr
)
+
4
*
Base
::
BYTES_PER_ROW
);
lds
(
tmp
.
z
,
(
ptr
^
32
)
+
0
*
Base
::
BYTES_PER_ROW
);
lds
(
tmp
.
w
,
(
ptr
^
32
)
+
4
*
Base
::
BYTES_PER_ROW
);
}
// Store those values in the fragment.
b
[
ni
].
reg
(
0
)
=
tmp
.
x
;
b
[
ni
].
reg
(
1
)
=
tmp
.
y
;
b
[
ni
].
reg
(
2
)
=
tmp
.
z
;
b
[
ni
].
reg
(
3
)
=
tmp
.
w
;
// Move the pointer for the next ni. I expect the compiler to not recompute those.
if
(
BYTES_PER_MMA_PER_CTA
>=
128
)
{
// Nothing to do!
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
64
&&
Mma_tile
::
MMAS_N
>
1
)
{
this
->
smem_read_offset_
^=
BYTES_PER_MMA_PER_CTA
;
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
64
)
{
// Nothing to do!
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
32
&&
Mma_tile
::
MMAS_N
==
4
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
(
ni
%
2
==
0
?
2
:
6
);
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
32
&&
Mma_tile
::
MMAS_N
==
2
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
2
;
}
}
// Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)
if
(
BYTES_PER_MMA_PER_CTA
==
64
&&
Mma_tile
::
MMAS_N
>
1
&&
Mma_tile
::
MMAS_N
%
2
==
1
)
{
this
->
smem_read_offset_
^=
BYTES_PER_MMA_PER_CTA
;
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The size of the STS.
int
BYTES_PER_STS
,
// The number of buffers per tile.
int
BUFFERS_PER_TILE
>
struct
Smem_tile_b
<
Cta_tile
,
Row
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
:
public
Smem_tile_row_b
<
Cta_tile
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
{
// The base class.
using
Base
=
Smem_tile_row_b
<
Cta_tile
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
;
// Ctor.
inline
__device__
Smem_tile_b
(
void
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
>
struct
Smem_tile_v
:
public
fmha
::
Smem_tile_without_skews
<
Cta_tile
,
Cta_tile
::
K
,
Cta_tile
::
N
,
16
,
16
,
1
,
0
,
8
,
1
>
{
// The base class.
using
Base
=
Smem_tile_without_skews
<
Cta_tile
,
Cta_tile
::
K
,
Cta_tile
::
N
,
16
,
16
,
1
,
0
,
8
,
1
>
;
// The MMA tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The fragment.
using
Fragment
=
Fragment_b
<
fmha
::
Col
>
;
// The size of a single LDS in bytes.
enum
{
BYTES_PER_LDS
=
16
};
// Ctor.
inline
__device__
Smem_tile_v
(
void
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
// The row/col read by the thread.
int
read_row
,
read_col
;
static_assert
(
Cta_tile
::
WARPS_M
==
1
&&
Cta_tile
::
WARPS_N
==
1
&&
(
Cta_tile
::
WARPS_K
==
4
||
Cta_tile
::
WARPS_K
==
8
));
read_row
=
(
tidx
&
0xe0
)
/
2
+
(
tidx
&
0x0f
);
read_col
=
(
tidx
&
0x07
);
read_col
^=
(
tidx
&
0x10
)
/
16
;
// The shared memory offset.
this
->
smem_read_offset_
=
read_row
*
Base
::
BYTES_PER_ROW
+
read_col
*
BYTES_PER_LDS
;
}
// Load from shared memory.
inline
__device__
void
load
(
Fragment
(
&
b
)[
Mma_tile
::
MMAS_N
],
int
ki
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile
::
MMAS_N
;
++
ni
)
{
// Jump by 16 * #warps row.
int
row
=
ki
*
16
*
Cta_tile
::
WARPS_K
;
// Load the data using LDSM.MT88.2.
uint4
tmp
;
fmha
::
ldsmt
(
tmp
,
this
->
smem_
+
this
->
smem_read_offset_
+
row
*
Base
::
BYTES_PER_ROW
);
b
[
ni
].
reg
(
0
)
=
tmp
.
x
;
b
[
ni
].
reg
(
1
)
=
tmp
.
y
;
b
[
ni
].
reg
(
2
)
=
tmp
.
z
;
b
[
ni
].
reg
(
3
)
=
tmp
.
w
;
// Move the pointer for the next ni. I expect the compiler to not recompute those.
if
(
Mma_tile
::
MMAS_N
==
4
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
(
ni
%
2
==
0
?
2
:
6
);
}
else
{
assert
(
false
);
// Not implemented!
}
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
>
struct
Smem_tile_o
{
// The MMA tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The accumulators.
using
Accumulator
=
fmha
::
Fragment_accumulator
;
// The accumulators.
using
Data_type
=
typename
Accumulator
::
Data_type
;
// The size of each element.
enum
{
BYTES_PER_ELEMENT
=
sizeof
(
Data_type
)
};
// The size of each STS.
enum
{
BYTES_PER_STS
=
8
};
// The size of each row in shared memory.
enum
{
BYTES_PER_ROW
=
Cta_tile
::
N
*
Cta_tile
::
WARPS_K
*
BYTES_PER_ELEMENT
};
// The size of each LDS.
enum
{
BYTES_PER_LDS
=
16
};
enum
{
THREADS_PER_ROW
=
16
};
// The number of rows.
enum
{
ROWS
=
Cta_tile
::
M
};
// The number of "rows" to process per loop iteration (in the "epilogue").
enum
{
ROWS_PER_LOOP
=
ROWS
<=
64
?
ROWS
:
(
int
)
Mma_tile
::
M_PER_MMA_PER_CTA
};
// The number of outer loops.
enum
{
LOOPS
=
ROWS
/
ROWS_PER_LOOP
};
// Make sure it matches our expectations.
static_assert
(
LOOPS
==
1
||
LOOPS
==
(
int
)
Mma_tile
::
MMAS_M
,
""
);
// The number of rows loaded per LDS.
enum
{
ROWS_PER_LDS
=
Cta_tile
::
THREADS_PER_CTA
/
THREADS_PER_ROW
};
// Do we have to guard against partial writes/reads.
enum
{
HAS_INCOMPLETE_LDS
=
ROWS_PER_LOOP
%
ROWS_PER_LDS
!=
0
};
// The total number of LDS per loop.
enum
{
LDS_PER_LOOP
=
fmha
::
Div_up
<
ROWS_PER_LOOP
,
ROWS_PER_LDS
>::
VALUE
};
// The amount of shared memory.
enum
{
BYTES_PER_TILE
=
ROWS_PER_LOOP
*
BYTES_PER_ROW
};
// The write pointer.
uint32_t
smem_write_
,
smem_read_
;
// Is the thread active for the last LDS of the series?
int
is_active_for_last_lds_
;
static_assert
(
BYTES_PER_ROW
==
64
*
4
*
Cta_tile
::
WARPS_K
);
static_assert
(
LOOPS
==
1
||
LOOPS
==
(
int
)
Mma_tile
::
MMAS_M
,
""
);
// Ctor.
inline
__device__
Smem_tile_o
(
void
*
smem
,
int
tidx
)
{
// Get a 32-bit value for the shared memory address.
uint32_t
smem_
=
__nvvm_get_smem_pointer
(
smem
);
static_assert
(
Cta_tile
::
WARPS_M
==
1
&&
Cta_tile
::
WARPS_N
==
1
&&
(
Cta_tile
::
WARPS_K
==
4
||
Cta_tile
::
WARPS_K
==
8
));
int
write_row
=
(
tidx
&
0x1c
)
/
4
;
int
write_col
=
(
tidx
);
// Assemble the write pointer.
smem_write_
=
smem_
+
write_row
*
BYTES_PER_ROW
+
write_col
*
BYTES_PER_STS
;
// The element read by each thread.
int
read_row
=
tidx
/
THREADS_PER_ROW
;
int
read_col
=
tidx
%
THREADS_PER_ROW
;
// Take the XOR pattern into account for the column.
read_col
^=
2
*
(
read_row
&
0x7
);
// Assemble the read pointer.
this
->
smem_read_
=
smem_
+
read_row
*
BYTES_PER_ROW
+
read_col
*
BYTES_PER_LDS
;
// Is that thread active on the last LDS?
if
(
HAS_INCOMPLETE_LDS
)
{
this
->
is_active_for_last_lds_
=
read_row
+
(
LDS_PER_LOOP
-
1
)
*
ROWS_PER_LDS
<
Cta_tile
::
M
;
}
}
// Load the output fragments.
inline
__device__
void
load
(
uint4
(
&
out
)[
LDS_PER_LOOP
])
const
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
LDS_PER_LOOP
;
++
ii
)
{
// Load the elements before the reduction (split-K).
uint4
tmp
[
Cta_tile
::
WARPS_K
];
#pragma unroll
for
(
int
jj
=
0
;
jj
<
Cta_tile
::
WARPS_K
;
++
jj
)
{
int
imm
=
ii
*
ROWS_PER_LDS
*
BYTES_PER_ROW
+
jj
*
Cta_tile
::
N
*
BYTES_PER_ELEMENT
;
if
(
!
HAS_INCOMPLETE_LDS
||
(
ii
<
LDS_PER_LOOP
-
1
||
this
->
is_active_for_last_lds_
)
)
{
fmha
::
lds
(
tmp
[
jj
],
this
->
smem_read_
+
imm
);
}
}
// Perform the reduction.
out
[
ii
]
=
tmp
[
0
];
#pragma unroll
for
(
int
jj
=
1
;
jj
<
Cta_tile
::
WARPS_K
;
++
jj
)
{
out
[
ii
]
=
fmha
::
fadd4
(
out
[
ii
],
tmp
[
jj
]);
}
}
}
// Store the accumulators.
template
<
int
M
,
int
N
>
inline
__device__
void
store
(
const
Accumulator
(
&
acc
)[
M
][
N
],
int
mi
)
{
enum
{
M_PER_MMA
=
Mma_tile
::
M_PER_MMA_PER_CTA
};
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile
::
MMAS_N
;
++
ni
)
{
// The number of MMAs that are stored per loop iteration.
enum
{
MMAS_M_PER_LOOP
=
Mma_tile
::
MMAS_M
/
LOOPS
};
// Store 1st column of the different MMAs.
#pragma unroll
for
(
int
mj
=
0
;
mj
<
MMAS_M_PER_LOOP
;
++
mj
)
{
// Precompute the immediates to jump between rows.
int
row_0
=
(
mj
*
M_PER_MMA
+
0
)
*
BYTES_PER_ROW
;
int
row_1
=
(
mj
*
M_PER_MMA
+
8
)
*
BYTES_PER_ROW
;
uint2
tmp0
,
tmp1
;
tmp0
.
x
=
acc
[
mi
*
MMAS_M_PER_LOOP
+
mj
][
ni
].
reg
(
0
);
tmp0
.
y
=
acc
[
mi
*
MMAS_M_PER_LOOP
+
mj
][
ni
].
reg
(
1
);
tmp1
.
x
=
acc
[
mi
*
MMAS_M_PER_LOOP
+
mj
][
ni
].
reg
(
2
);
tmp1
.
y
=
acc
[
mi
*
MMAS_M_PER_LOOP
+
mj
][
ni
].
reg
(
3
);
// Store.
fmha
::
sts
(
this
->
smem_write_
+
row_0
,
tmp0
);
fmha
::
sts
(
this
->
smem_write_
+
row_1
,
tmp1
);
}
// Swizzle the write pointer using a XOR of 16B.
this
->
smem_write_
^=
32
;
// Store 2nd column of the different MMAs.
#pragma unroll
for
(
int
mj
=
0
;
mj
<
MMAS_M_PER_LOOP
;
++
mj
)
{
// Precompute the immediates to jump between rows.
int
row_0
=
(
mj
*
M_PER_MMA
+
0
)
*
BYTES_PER_ROW
;
int
row_1
=
(
mj
*
M_PER_MMA
+
8
)
*
BYTES_PER_ROW
;
uint2
tmp0
,
tmp1
;
tmp0
.
x
=
acc
[
mi
*
MMAS_M_PER_LOOP
+
mj
][
ni
].
reg
(
4
);
tmp0
.
y
=
acc
[
mi
*
MMAS_M_PER_LOOP
+
mj
][
ni
].
reg
(
5
);
tmp1
.
x
=
acc
[
mi
*
MMAS_M_PER_LOOP
+
mj
][
ni
].
reg
(
6
);
tmp1
.
y
=
acc
[
mi
*
MMAS_M_PER_LOOP
+
mj
][
ni
].
reg
(
7
);
// Store.
fmha
::
sts
(
this
->
smem_write_
+
row_0
,
tmp0
);
fmha
::
sts
(
this
->
smem_write_
+
row_1
,
tmp1
);
}
// Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B.
this
->
smem_write_
^=
(
ni
&
1
)
?
7
*
32
:
3
*
32
;
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
>
struct
Smem_tile_mma
{
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
using
Fragment
=
fmha
::
Fragment_a
<
fmha
::
Col
>
;
enum
{
COLS
=
Cta_tile
::
N
};
enum
{
BYTES_PER_ELT
=
2
};
enum
{
BYTES_PER_STS
=
4
};
enum
{
BYTES_PER_ROW
=
COLS
*
BYTES_PER_ELT
};
// TODO
enum
{
BYTES_PER_TILE
=
Cta_tile
::
M
*
BYTES_PER_ROW
};
enum
{
WARPS_M
=
Cta_tile
::
WARPS_M
};
enum
{
WARPS_N
=
Cta_tile
::
WARPS_N
};
enum
{
WARPS_K
=
Cta_tile
::
WARPS_K
};
static_assert
(
WARPS_K
==
1
);
inline
__device__
Smem_tile_mma
(
char
*
smem
,
int
tidx
)
{
smem_
=
__nvvm_get_smem_pointer
(
smem
);
int
write_col
,
write_row
;
static_assert
(
WARPS_M
==
1
&&
(
WARPS_N
==
4
||
WARPS_N
==
8
)
||
(
WARPS_M
==
4
||
WARPS_N
==
8
)
||
WARPS_N
==
1
);
if
(
WARPS_M
==
1
&&
(
WARPS_N
==
4
||
WARPS_N
==
8
)
)
{
write_row
=
(
tidx
&
0x1c
)
/
4
;
write_col
=
(
tidx
&
0xe0
)
/
4
+
(
tidx
&
0x03
);
}
else
{
write_row
=
(
tidx
&
0xe0
)
/
2
+
(
tidx
&
0x1c
)
/
4
;
write_col
=
(
tidx
&
0x03
);
}
write_col
^=
(
write_row
&
0x07
)
*
4
;
write_offset_
=
write_row
*
BYTES_PER_ROW
+
write_col
*
BYTES_PER_STS
;
}
template
<
int
M
,
int
N
>
inline
__device__
void
store
(
const
uint4
(
&
regs
)[
M
][
N
])
{
static_assert
(
COLS
==
Cta_tile
::
N
);
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
size_t
offset
=
write_offset_
+
mi
*
WARPS_M
*
16
*
BYTES_PER_ROW
+
ni
*
WARPS_N
*
16
*
BYTES_PER_ELT
;
fmha
::
sts
(
smem_
+
offset
+
0
*
BYTES_PER_ROW
,
regs
[
mi
][
ni
].
x
);
fmha
::
sts
(
smem_
+
offset
+
8
*
BYTES_PER_ROW
,
regs
[
mi
][
ni
].
z
);
offset
^=
4
*
BYTES_PER_STS
;
fmha
::
sts
(
smem_
+
offset
+
0
*
BYTES_PER_ROW
,
regs
[
mi
][
ni
].
y
);
fmha
::
sts
(
smem_
+
offset
+
8
*
BYTES_PER_ROW
,
regs
[
mi
][
ni
].
w
);
}
}
}
uint32_t
smem_
;
uint32_t
write_offset_
;
uint32_t
warp_m
;
uint32_t
warp_n
;
uint32_t
lane
;
};
template
<
typename
Cta_tile
,
typename
Base
=
Smem_tile_mma
<
Cta_tile
>
>
struct
Smem_tile_mma_transposed
:
public
Base
{
enum
{
BYTES_PER_LDS
=
16
};
enum
{
BYTES_PER_ROW
=
Base
::
BYTES_PER_ROW
};
enum
{
BYTES_PER_ELT
=
Base
::
BYTES_PER_ELT
};
enum
{
WARPS_M
=
Base
::
WARPS_M
};
enum
{
WARPS_N
=
Base
::
WARPS_N
};
static_assert
(
WARPS_M
==
1
&&
(
WARPS_N
==
4
||
WARPS_N
==
8
));
using
Fragment
=
typename
Base
::
Fragment
;
inline
__device__
Smem_tile_mma_transposed
(
char
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
static_assert
(
WARPS_M
==
1
&&
(
WARPS_N
==
4
||
WARPS_N
==
8
));
int
read_row
,
read_col
;
read_row
=
(
tidx
&
0x0f
);
read_col
=
(
tidx
&
0xe0
)
/
16
+
(
tidx
&
0x1c
)
/
16
;
read_col
^=
(
read_row
&
0x07
);
read_offset_
=
read_row
*
BYTES_PER_ROW
+
read_col
*
BYTES_PER_LDS
;
}
template
<
int
M
,
int
N
>
inline
__device__
void
load
(
Fragment
(
&
frag
)[
M
][
N
])
{
static_assert
(
Base
::
COLS
==
Cta_tile
::
N
);
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
size_t
offset
=
read_offset_
+
mi
*
WARPS_M
*
16
*
BYTES_PER_ROW
+
ni
*
WARPS_N
*
16
*
BYTES_PER_ELT
;
uint4
dst
;
fmha
::
ldsmt
(
dst
,
this
->
smem_
+
offset
);
frag
[
mi
][
ni
].
reg
(
0
)
=
dst
.
x
;
frag
[
mi
][
ni
].
reg
(
1
)
=
dst
.
z
;
// Fragment A regs col major!
frag
[
mi
][
ni
].
reg
(
2
)
=
dst
.
y
;
frag
[
mi
][
ni
].
reg
(
3
)
=
dst
.
w
;
}
}
}
uint32_t
read_offset_
;
};
template
<
typename
Cta_tile
,
typename
Base
=
Smem_tile_mma
<
Cta_tile
>
>
struct
Smem_tile_mma_epilogue
:
public
Base
{
enum
{
BYTES_PER_LDS
=
16
};
enum
{
BYTES_PER_ROW
=
Base
::
BYTES_PER_ROW
};
enum
{
BYTES_PER_ELT
=
Base
::
BYTES_PER_ELT
};
enum
{
THREADS_PER_ROW
=
BYTES_PER_ROW
/
BYTES_PER_LDS
};
static_assert
(
THREADS_PER_ROW
*
BYTES_PER_LDS
==
BYTES_PER_ROW
);
enum
{
ROWS_PER_LDS
=
Cta_tile
::
THREADS_PER_CTA
/
THREADS_PER_ROW
};
enum
{
NUM_LDS
=
Cta_tile
::
M
/
ROWS_PER_LDS
};
static_assert
(
NUM_LDS
*
ROWS_PER_LDS
==
Cta_tile
::
M
);
enum
{
WARPS_M
=
Base
::
WARPS_M
};
enum
{
WARPS_N
=
Base
::
WARPS_N
};
static_assert
((
WARPS_M
==
4
||
WARPS_N
==
8
)
||
WARPS_N
==
1
);
using
Acc
=
fmha
::
Fragment_accumulator
;
inline
__device__
Smem_tile_mma_epilogue
(
char
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
const
int
read_row
=
tidx
/
THREADS_PER_ROW
;
int
read_col
=
tidx
%
THREADS_PER_ROW
;
read_col
^=
(
read_row
&
0x07
);
read_offset_
=
read_row
*
BYTES_PER_ROW
+
read_col
*
BYTES_PER_LDS
;
}
inline
__device__
void
load
(
uint4
(
&
data
)[
NUM_LDS
])
{
for
(
int
ii
=
0
;
ii
<
NUM_LDS
;
ii
++
)
{
size_t
offset
=
read_offset_
+
ii
*
ROWS_PER_LDS
*
BYTES_PER_ROW
;
fmha
::
lds
(
data
[
ii
],
this
->
smem_
+
offset
);
}
}
template
<
int
M
,
int
N
>
inline
__device__
void
store
(
const
Acc
(
&
acc
)[
M
][
N
]){
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
// 1st row - 4 elements per row.
float
tmp00
=
acc
[
mi
][
ni
].
elt
(
0
);
float
tmp01
=
acc
[
mi
][
ni
].
elt
(
1
);
float
tmp02
=
acc
[
mi
][
ni
].
elt
(
4
);
float
tmp03
=
acc
[
mi
][
ni
].
elt
(
5
);
// 2nd row - 4 elements per row.
float
tmp10
=
acc
[
mi
][
ni
].
elt
(
2
);
float
tmp11
=
acc
[
mi
][
ni
].
elt
(
3
);
float
tmp12
=
acc
[
mi
][
ni
].
elt
(
6
);
float
tmp13
=
acc
[
mi
][
ni
].
elt
(
7
);
uint32_t
x
=
fmha
::
float2_to_half2
(
tmp00
,
tmp01
);
uint32_t
y
=
fmha
::
float2_to_half2
(
tmp02
,
tmp03
);
uint32_t
z
=
fmha
::
float2_to_half2
(
tmp10
,
tmp11
);
uint32_t
w
=
fmha
::
float2_to_half2
(
tmp12
,
tmp13
);
size_t
offset
=
(
this
->
write_offset_
^
(
ni
*
32
))
+
mi
*
WARPS_M
*
16
*
BYTES_PER_ROW
;
fmha
::
sts
(
this
->
smem_
+
offset
+
0
*
BYTES_PER_ROW
,
x
);
fmha
::
sts
(
this
->
smem_
+
offset
+
8
*
BYTES_PER_ROW
,
z
);
offset
^=
4
*
Base
::
BYTES_PER_STS
;
fmha
::
sts
(
this
->
smem_
+
offset
+
0
*
BYTES_PER_ROW
,
y
);
fmha
::
sts
(
this
->
smem_
+
offset
+
8
*
BYTES_PER_ROW
,
w
);
}
}
}
template
<
int
M
,
int
N
>
inline
__device__
void
store
(
const
uint4
(
&
regs
)[
M
][
N
])
{
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
size_t
offset
=
(
this
->
write_offset_
^
(
ni
*
32
))
+
mi
*
WARPS_M
*
16
*
BYTES_PER_ROW
;
fmha
::
sts
(
this
->
smem_
+
offset
+
0
*
BYTES_PER_ROW
,
regs
[
mi
][
ni
].
x
);
fmha
::
sts
(
this
->
smem_
+
offset
+
8
*
BYTES_PER_ROW
,
regs
[
mi
][
ni
].
z
);
offset
^=
4
*
Base
::
BYTES_PER_STS
;
fmha
::
sts
(
this
->
smem_
+
offset
+
0
*
BYTES_PER_ROW
,
regs
[
mi
][
ni
].
y
);
fmha
::
sts
(
this
->
smem_
+
offset
+
8
*
BYTES_PER_ROW
,
regs
[
mi
][
ni
].
w
);
}
}
}
uint32_t
read_offset_
;
};
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha/softmax.h
0 → 100644
View file @
f79993d9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Sum_
{
enum
{
IS_SUM
=
1
};
static
inline
__device__
float
apply
(
float
x
,
float
y
)
{
return
x
+
y
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Max_
{
enum
{
IS_SUM
=
0
};
static
inline
__device__
float
apply
(
float
x
,
float
y
)
{
return
x
>
y
?
x
:
y
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
apply_exp_
(
float
x
,
float
max
)
{
return
__expf
(
x
-
max
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
,
typename
Kernel_traits
>
struct
Softmax_base
{
// The Mma tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The number of MMAs in M/N dimensions.
enum
{
MMAS_M
=
Mma_tile
::
MMAS_M
};
enum
{
MMAS_N
=
Mma_tile
::
MMAS_N
};
// The number of groups of warp such that we have at most 4 warps writing consecutive elements.
enum
{
GROUPS
=
fmha
::
Div_up
<
Cta_tile
::
WARPS_N
,
4
>::
VALUE
};
// The number of elements that we are going to store per row.
enum
{
ELEMENTS_PER_ROW
=
Cta_tile
::
WARPS_N
/
GROUPS
};
// The number of rows.
enum
{
ROWS
=
Cta_tile
::
M
*
GROUPS
};
// The total number of elements.
enum
{
ELEMENTS
=
ROWS
*
ELEMENTS_PER_ROW
};
// Ctor.
template
<
typename
Params
>
inline
__device__
Softmax_base
(
const
Params
&
params
,
void
*
smem
,
int
bidb
,
int
tidx
)
:
// packed_mask_ptr_(reinterpret_cast<const char*>(params.packed_mask_ptr)),
smem_
(
reinterpret_cast
<
float
*>
(
smem
)),
tidx_
(
tidx
)
{
// Move to the 1st mask loaded by the thread+ tidx;
// packed_mask_ptr_ += bidb * params.packed_mask_stride_in_bytes + tidx * sizeof(uint32_t);
// Extract the position in the warp.
int
warp
=
tidx
/
Cta_tile
::
THREADS_PER_WARP
;
int
lane
=
tidx
%
Cta_tile
::
THREADS_PER_WARP
;
// Decompose the warp index into M and N.
int
warp_m
=
warp
%
Cta_tile
::
WARPS_M
;
int
warp_n
=
warp
/
Cta_tile
::
WARPS_M
;
// Decompose the warp-n index into group/position-inside-the-group.
int
warp_g
=
warp_n
/
ELEMENTS_PER_ROW
;
int
warp_i
=
warp_n
%
ELEMENTS_PER_ROW
;
// The location written by the threads.
int
write_row
=
warp_g
*
(
ROWS
/
GROUPS
)
+
warp_m
*
Mma_tile
::
M_PER_MMA
+
lane
/
4
;
int
write_col
=
warp_i
;
// Assemble the write pointer.
smem_write_
=
&
smem_
[
write_row
*
ELEMENTS_PER_ROW
+
write_col
];
// Assemble the read pointer.
smem_read_
=
&
smem_
[
warp_m
*
Mma_tile
::
M_PER_MMA
+
lane
/
4
];
}
template
<
typename
Mask
>
inline
__device__
void
apply_mask
(
const
Mask
&
mask
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
2
;
++
ii
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
++
ni
)
{
#pragma unroll
for
(
int
jj
=
0
;
jj
<
4
;
++
jj
)
{
if
(
!
mask
.
is_valid
(
mi
,
ni
,
ii
,
jj
)
)
{
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
=
-
INFINITY
;
}
}
}
}
}
}
// Apply the exp to all the elements.
inline
__device__
void
apply_exp
(
const
float
(
&
max
)[
MMAS_M
*
2
])
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
*
4
;
++
ni
)
{
elt_
[
mi
][
ni
]
=
apply_exp_
(
elt_
[
mi
][
ni
],
max
[
mi
]);
}
}
}
// Do a CTA-wide reduction.
template
<
typename
Functor
>
inline
__device__
void
reduce_1x4
(
float
(
&
dst
)[
MMAS_M
*
2
])
{
#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
if
(
Functor
::
IS_SUM
)
{
// Apply the summation inside the thread.
float
tmp
[
MMAS_M
*
2
][
2
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
tmp
[
mi
][
0
]
=
0.
f
;
tmp
[
mi
][
1
]
=
0.
f
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
++
ni
)
{
tmp
[
mi
][
0
]
+=
elt_
[
mi
][
4
*
ni
+
0
];
tmp
[
mi
][
0
]
+=
elt_
[
mi
][
4
*
ni
+
1
];
tmp
[
mi
][
1
]
+=
elt_
[
mi
][
4
*
ni
+
2
];
tmp
[
mi
][
1
]
+=
elt_
[
mi
][
4
*
ni
+
3
];
}
dst
[
mi
]
=
tmp
[
mi
][
0
]
+
tmp
[
mi
][
1
];
}
}
else
#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
{
// Apply the functor for each row inside a thread.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
dst
[
mi
]
=
elt_
[
mi
][
0
];
#pragma unroll
for
(
int
ni
=
1
;
ni
<
MMAS_N
*
4
;
++
ni
)
{
dst
[
mi
]
=
Functor
::
apply
(
dst
[
mi
],
elt_
[
mi
][
ni
]);
}
}
}
// Apply the functor for each row inside each group of 4 threads.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
dst
[
mi
]
=
Functor
::
apply
(
dst
[
mi
],
__shfl_xor_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
1
));
__syncwarp
();
dst
[
mi
]
=
Functor
::
apply
(
dst
[
mi
],
__shfl_xor_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
2
));
__syncwarp
();
}
// Store the different values.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
if
(
tidx_
%
4
==
0
)
{
smem_write_
[(
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
0
)
*
ELEMENTS_PER_ROW
]
=
dst
[
2
*
mi
+
0
];
smem_write_
[(
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
8
)
*
ELEMENTS_PER_ROW
]
=
dst
[
2
*
mi
+
1
];
}
}
// Make sure the values are in shared memory.
__syncthreads
();
// Load 8 values (one for each warp). The /8 corresponds to /(4*2) where 4 is from the
// float4.
float4
tmp
[
1
];
if
(
tidx_
<
Cta_tile
::
M
)
{
tmp
[
0
]
=
reinterpret_cast
<
const
float4
*>
(
&
smem_
[
0
*
ELEMENTS
/
2
])[
tidx_
];
}
// Compute the reduction of those 8 values in a binary-tree fashion.
tmp
[
0
].
x
=
Functor
::
apply
(
tmp
[
0
].
x
,
tmp
[
0
].
y
);
tmp
[
0
].
z
=
Functor
::
apply
(
tmp
[
0
].
z
,
tmp
[
0
].
w
);
tmp
[
0
].
x
=
Functor
::
apply
(
tmp
[
0
].
x
,
tmp
[
0
].
z
);
// Make sure we can write to shared memory.
__syncthreads
();
// Store the value back to shared memory.
if
(
tidx_
<
Cta_tile
::
M
)
{
smem_
[
tidx_
]
=
tmp
[
0
].
x
;
}
// Make sure the data is in shared memory.
__syncthreads
();
// Finally read the values.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
dst
[
2
*
mi
+
0
]
=
smem_read_
[
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
0
];
dst
[
2
*
mi
+
1
]
=
smem_read_
[
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
8
];
}
}
// Do a CTA-wide reduction.
template
<
typename
Functor
>
inline
__device__
void
reduce_1x8
(
float
(
&
dst
)[
MMAS_M
*
2
])
{
#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
if
(
Functor
::
IS_SUM
)
{
// Apply the summation inside the thread.
float
tmp
[
MMAS_M
*
2
][
2
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
tmp
[
mi
][
0
]
=
0.
f
;
tmp
[
mi
][
1
]
=
0.
f
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
++
ni
)
{
tmp
[
mi
][
0
]
+=
elt_
[
mi
][
4
*
ni
+
0
];
tmp
[
mi
][
0
]
+=
elt_
[
mi
][
4
*
ni
+
1
];
tmp
[
mi
][
1
]
+=
elt_
[
mi
][
4
*
ni
+
2
];
tmp
[
mi
][
1
]
+=
elt_
[
mi
][
4
*
ni
+
3
];
}
dst
[
mi
]
=
tmp
[
mi
][
0
]
+
tmp
[
mi
][
1
];
}
}
else
#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
{
// Apply the functor for each row inside a thread.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
dst
[
mi
]
=
elt_
[
mi
][
0
];
#pragma unroll
for
(
int
ni
=
1
;
ni
<
MMAS_N
*
4
;
++
ni
)
{
dst
[
mi
]
=
Functor
::
apply
(
dst
[
mi
],
elt_
[
mi
][
ni
]);
}
}
}
// Apply the functor for each row inside each group of 4 threads.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
dst
[
mi
]
=
Functor
::
apply
(
dst
[
mi
],
__shfl_xor_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
1
));
__syncwarp
();
dst
[
mi
]
=
Functor
::
apply
(
dst
[
mi
],
__shfl_xor_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
2
));
__syncwarp
();
}
// Store the different values.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
if
(
tidx_
%
4
==
0
)
{
smem_write_
[(
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
0
)
*
ELEMENTS_PER_ROW
]
=
dst
[
2
*
mi
+
0
];
smem_write_
[(
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
8
)
*
ELEMENTS_PER_ROW
]
=
dst
[
2
*
mi
+
1
];
}
}
// Make sure the values are in shared memory.
__syncthreads
();
// Load 8 values (one for each warp). The /8 corresponds to /(4*2) where 4 is from the
// float4.
float4
tmp
[
2
];
if
(
tidx_
<
Cta_tile
::
M
)
{
tmp
[
0
]
=
reinterpret_cast
<
const
float4
*>
(
&
smem_
[
0
*
ELEMENTS
/
2
])[
tidx_
];
tmp
[
1
]
=
reinterpret_cast
<
const
float4
*>
(
&
smem_
[
1
*
ELEMENTS
/
2
])[
tidx_
];
}
// Compute the reduction of those 8 values in a binary-tree fashion.
tmp
[
0
].
x
=
Functor
::
apply
(
tmp
[
0
].
x
,
tmp
[
0
].
y
);
tmp
[
0
].
z
=
Functor
::
apply
(
tmp
[
0
].
z
,
tmp
[
0
].
w
);
tmp
[
1
].
x
=
Functor
::
apply
(
tmp
[
1
].
x
,
tmp
[
1
].
y
);
tmp
[
1
].
z
=
Functor
::
apply
(
tmp
[
1
].
z
,
tmp
[
1
].
w
);
tmp
[
0
].
x
=
Functor
::
apply
(
tmp
[
0
].
x
,
tmp
[
0
].
z
);
tmp
[
1
].
x
=
Functor
::
apply
(
tmp
[
1
].
x
,
tmp
[
1
].
z
);
tmp
[
0
].
x
=
Functor
::
apply
(
tmp
[
0
].
x
,
tmp
[
1
].
x
);
// Make sure we can write to shared memory.
__syncthreads
();
// Store the value back to shared memory.
if
(
tidx_
<
Cta_tile
::
M
)
{
smem_
[
tidx_
]
=
tmp
[
0
].
x
;
}
// Make sure the data is in shared memory.
__syncthreads
();
// Finally read the values.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
dst
[
2
*
mi
+
0
]
=
smem_read_
[
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
0
];
dst
[
2
*
mi
+
1
]
=
smem_read_
[
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
8
];
}
}
// Do a CTA-wide reduction.
template
<
typename
Functor
>
inline
__device__
void
reduce
(
float
(
&
dst
)[
MMAS_M
*
2
])
{
static_assert
(
Cta_tile
::
WARPS_M
==
1
&&
(
Cta_tile
::
WARPS_N
==
4
||
Cta_tile
::
WARPS_N
==
8
));
if
(
Cta_tile
::
WARPS_M
==
1
&&
Cta_tile
::
WARPS_N
==
4
)
{
reduce_1x4
<
Functor
>
(
dst
);
}
else
if
(
Cta_tile
::
WARPS_M
==
1
&&
Cta_tile
::
WARPS_N
==
8
)
{
reduce_1x8
<
Functor
>
(
dst
);
}
else
{
assert
(
false
);
}
// Make sure we are done reading from shared memory.
__syncthreads
();
}
// Scale all the elements.
inline
__device__
void
scale
(
const
float
(
&
sum
)[
MMAS_M
*
2
])
{
// Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.
float
inv_sum
[
MMAS_M
*
2
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
inv_sum
[
mi
]
=
(
sum
[
mi
]
==
0.
f
||
sum
[
mi
]
!=
sum
[
mi
])
?
1.
f
:
1.
f
/
sum
[
mi
];
}
// Update the values.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
*
4
;
++
ni
)
{
elt_
[
mi
][
ni
]
*=
inv_sum
[
mi
];
}
}
}
// The pointer to the mask.
const
char
*
packed_mask_ptr_
;
// Shared memory for the CTA-wide reduction.
float
*
smem_
,
*
smem_write_
,
*
smem_read_
;
// The current thread index.
int
tidx_
;
// The elements.
float
elt_
[
MMAS_M
*
2
][
MMAS_N
*
4
];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
,
typename
Kernel_traits
>
struct
Softmax
:
public
Softmax_base
<
Cta_tile
,
Kernel_traits
>
{
// The base class.
using
Base
=
Softmax_base
<
Cta_tile
,
Kernel_traits
>
;
// The fragment.
using
Fragment_a
=
fmha
::
Fragment_a
<
fmha
::
Row
>
;
static_assert
(
Fragment_a
::
NUM_REGS
==
4
);
// The MMAs.
enum
{
MMAS_M
=
Base
::
MMAS_M
};
enum
{
MMAS_N
=
Base
::
MMAS_N
};
// The accumulators.
using
Accumulator
=
fmha
::
Fragment_accumulator
;
using
Accumulator_out
=
Fragment
<
uint16_t
,
8
>
;
static_assert
(
Accumulator_out
::
NUM_REGS
==
4
);
static_assert
(
std
::
is_same
<
Accumulator
::
Data_type
,
float
>::
value
);
// Ctor.
template
<
typename
Params
>
inline
__device__
Softmax
(
const
Params
&
params
,
void
*
smem
,
int
bidb
,
int
tidx
)
:
Base
(
params
,
smem
,
bidb
,
tidx
),
params_scale_bmm1_
(
params
.
scale_bmm1
)
{
}
// Store the tile after softmax.
template
<
typename
Gmem_tile
>
inline
__device__
void
store
(
Gmem_tile
&
gmem_tile
)
{
Accumulator_out
acc
[
MMAS_M
][
MMAS_N
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
++
ni
)
{
// The elements.
float
tmp_00
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
0
];
float
tmp_01
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
1
];
float
tmp_02
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
2
];
float
tmp_03
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
3
];
float
tmp_10
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
0
];
float
tmp_11
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
1
];
float
tmp_12
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
2
];
float
tmp_13
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
3
];
// Transform to accumulators.
acc
[
mi
][
ni
].
reg
(
0
)
=
fmha
::
float2_to_half2
(
tmp_00
,
tmp_01
);
acc
[
mi
][
ni
].
reg
(
1
)
=
fmha
::
float2_to_half2
(
tmp_10
,
tmp_11
);
acc
[
mi
][
ni
].
reg
(
2
)
=
fmha
::
float2_to_half2
(
tmp_02
,
tmp_03
);
acc
[
mi
][
ni
].
reg
(
3
)
=
fmha
::
float2_to_half2
(
tmp_12
,
tmp_13
);
}
}
// Delegate to the gmem tile to store.
gmem_tile
.
store
(
acc
);
}
// Pack the data to a fragment for the next GEMM.
template
<
int
K
,
int
M
>
inline
__device__
void
pack
(
Fragment_a
(
&
dst
)[
K
][
M
])
const
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
++
mi
)
{
#pragma unroll
for
(
int
ki
=
0
;
ki
<
K
;
++
ki
)
{
// 1st row - 4 elements per row.
float
tmp_00
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ki
+
0
];
float
tmp_01
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ki
+
1
];
float
tmp_02
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ki
+
2
];
float
tmp_03
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ki
+
3
];
// 2nd row - 4 elements per row.
float
tmp_10
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ki
+
0
];
float
tmp_11
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ki
+
1
];
float
tmp_12
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ki
+
2
];
float
tmp_13
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ki
+
3
];
// Pack to 4 registers.
dst
[
ki
][
mi
].
reg
(
0
)
=
fmha
::
float2_to_half2
(
tmp_00
,
tmp_01
);
dst
[
ki
][
mi
].
reg
(
1
)
=
fmha
::
float2_to_half2
(
tmp_10
,
tmp_11
);
dst
[
ki
][
mi
].
reg
(
2
)
=
fmha
::
float2_to_half2
(
tmp_02
,
tmp_03
);
dst
[
ki
][
mi
].
reg
(
3
)
=
fmha
::
float2_to_half2
(
tmp_12
,
tmp_13
);
}
}
}
// Scale FP32 fragments
inline
__device__
void
unpack
(
const
Accumulator
(
&
acc
)[
MMAS_M
][
MMAS_N
])
{
const
float
scalef
=
reinterpret_cast
<
const
float
&>
(
this
->
params_scale_bmm1_
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
++
ni
)
{
// 1st row - 4 elements per row.
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
0
]
=
acc
[
mi
][
ni
].
elt
(
0
)
*
scalef
;
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
1
]
=
acc
[
mi
][
ni
].
elt
(
1
)
*
scalef
;
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
2
]
=
acc
[
mi
][
ni
].
elt
(
4
)
*
scalef
;
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
3
]
=
acc
[
mi
][
ni
].
elt
(
5
)
*
scalef
;
// 2nd row - 4 elements per row.
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
0
]
=
acc
[
mi
][
ni
].
elt
(
2
)
*
scalef
;
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
1
]
=
acc
[
mi
][
ni
].
elt
(
3
)
*
scalef
;
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
2
]
=
acc
[
mi
][
ni
].
elt
(
6
)
*
scalef
;
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
3
]
=
acc
[
mi
][
ni
].
elt
(
7
)
*
scalef
;
}
}
}
const
uint32_t
params_scale_bmm1_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha/utils.h
0 → 100644
View file @
f79993d9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
extern
"C"
__device__
uint32_t
__nvvm_get_smem_pointer
(
void
*
ptr
);
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Row
{};
struct
Col
{};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
M
,
bool
=
(
M
&
(
M
-
1
))
==
0
>
struct
Next_power_of_two
{
};
template
<
int
M
>
struct
Next_power_of_two
<
M
,
true
>
{
enum
{
VALUE
=
M
};
};
template
<
>
struct
Next_power_of_two
<
3
,
false
>
{
enum
{
VALUE
=
4
};
};
template
<
>
struct
Next_power_of_two
<
5
,
false
>
{
enum
{
VALUE
=
8
};
};
template
<
>
struct
Next_power_of_two
<
6
,
false
>
{
enum
{
VALUE
=
8
};
};
template
<
>
struct
Next_power_of_two
<
7
,
false
>
{
enum
{
VALUE
=
8
};
};
template
<
>
struct
Next_power_of_two
<
9
,
false
>
{
enum
{
VALUE
=
16
};
};
template
<
>
struct
Next_power_of_two
<
10
,
false
>
{
enum
{
VALUE
=
16
};
};
template
<
>
struct
Next_power_of_two
<
11
,
false
>
{
enum
{
VALUE
=
16
};
};
template
<
>
struct
Next_power_of_two
<
12
,
false
>
{
enum
{
VALUE
=
16
};
};
template
<
>
struct
Next_power_of_two
<
13
,
false
>
{
enum
{
VALUE
=
16
};
};
template
<
>
struct
Next_power_of_two
<
14
,
false
>
{
enum
{
VALUE
=
16
};
};
template
<
>
struct
Next_power_of_two
<
15
,
false
>
{
enum
{
VALUE
=
16
};
};
template
<
>
struct
Next_power_of_two
<
24
,
false
>
{
enum
{
VALUE
=
32
};
};
template
<
>
struct
Next_power_of_two
<
48
,
false
>
{
enum
{
VALUE
=
64
};
};
template
<
>
struct
Next_power_of_two
<
80
,
false
>
{
enum
{
VALUE
=
128
};
};
template
<
>
struct
Next_power_of_two
<
96
,
false
>
{
enum
{
VALUE
=
128
};
};
template
<
>
struct
Next_power_of_two
<
112
,
false
>
{
enum
{
VALUE
=
128
};
};
template
<
>
struct
Next_power_of_two
<
144
,
false
>
{
enum
{
VALUE
=
256
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
,
bool
=
(
N
&
(
N
-
1
))
==
0
>
struct
Prev_power_of_two
{
};
template
<
int
N
>
struct
Prev_power_of_two
<
N
,
true
>
{
enum
{
VALUE
=
N
};
};
template
<
>
struct
Prev_power_of_two
<
3
,
false
>
{
enum
{
VALUE
=
2
};
};
template
<
>
struct
Prev_power_of_two
<
5
,
false
>
{
enum
{
VALUE
=
4
};
};
template
<
>
struct
Prev_power_of_two
<
6
,
false
>
{
enum
{
VALUE
=
4
};
};
template
<
>
struct
Prev_power_of_two
<
7
,
false
>
{
enum
{
VALUE
=
4
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
M
,
int
N
>
struct
Div_up
{
enum
{
VALUE
=
(
M
+
N
-
1
)
/
N
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
A
,
int
B
>
struct
Max
{
enum
{
VALUE
=
A
>=
B
?
A
:
B
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
A
,
int
B
,
int
C
>
struct
Max_3
{
enum
{
VALUE
=
Max
<
Max
<
A
,
B
>::
VALUE
,
C
>::
VALUE
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
A
,
int
B
>
struct
Min
{
enum
{
VALUE
=
A
<=
B
?
A
:
B
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
SIZE_IN_BYTES
>
struct
Uint_from_size_in_bytes
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Uint_from_size_in_bytes
<
1
>
{
using
Type
=
uint8_t
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Uint_from_size_in_bytes
<
2
>
{
using
Type
=
uint16_t
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Uint_from_size_in_bytes
<
4
>
{
using
Type
=
uint32_t
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Uint_from_size_in_bytes
<
8
>
{
using
Type
=
uint2
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Uint_from_size_in_bytes
<
16
>
{
using
Type
=
uint4
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
WARPS_M
,
int
WARPS_N
,
int
WARPS_K
>
struct
Warp_masks
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Warp_masks
<
8
,
1
,
1
>
{
enum
{
M
=
0xe0
,
N
=
0x00
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
4
,
2
,
1
>
{
enum
{
M
=
0x60
,
N
=
0x80
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
4
,
1
,
2
>
{
enum
{
M
=
0x60
,
N
=
0x00
,
K
=
0x80
};
};
template
<
>
struct
Warp_masks
<
4
,
1
,
1
>
{
enum
{
M
=
0x60
,
N
=
0x00
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
2
,
4
,
1
>
{
enum
{
M
=
0x20
,
N
=
0xc0
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
2
,
2
,
2
>
{
enum
{
M
=
0x20
,
N
=
0x40
,
K
=
0x80
};
};
template
<
>
struct
Warp_masks
<
2
,
2
,
1
>
{
enum
{
M
=
0x20
,
N
=
0x40
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
2
,
1
,
2
>
{
enum
{
M
=
0x20
,
N
=
0x00
,
K
=
0x40
};
};
template
<
>
struct
Warp_masks
<
2
,
1
,
1
>
{
enum
{
M
=
0x20
,
N
=
0x00
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
1
,
8
,
1
>
{
enum
{
M
=
0x00
,
N
=
0xe0
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
1
,
4
,
2
>
{
enum
{
M
=
0x00
,
N
=
0x60
,
K
=
0x80
};
};
template
<
>
struct
Warp_masks
<
1
,
4
,
1
>
{
enum
{
M
=
0x00
,
N
=
0x60
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
1
,
2
,
2
>
{
enum
{
M
=
0x00
,
N
=
0x20
,
K
=
0x40
};
};
template
<
>
struct
Warp_masks
<
1
,
2
,
1
>
{
enum
{
M
=
0x00
,
N
=
0x20
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
1
,
1
,
4
>
{
enum
{
M
=
0x00
,
N
=
0x00
,
K
=
0x60
};
};
template
<
>
struct
Warp_masks
<
1
,
1
,
2
>
{
enum
{
M
=
0x00
,
N
=
0x00
,
K
=
0x20
};
};
template
<
>
struct
Warp_masks
<
1
,
1
,
1
>
{
enum
{
M
=
0x00
,
N
=
0x00
,
K
=
0x00
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
inline
__device__
__host__
T
div_up
(
T
m
,
T
n
)
{
return
(
m
+
n
-
1
)
/
n
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
int
clz
(
int
x
)
{
for
(
int
i
=
31
;
i
>=
0
;
--
i
)
{
if
(
(
1
<<
i
)
&
x
)
{
return
31
-
i
;
}
}
return
32
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
int
find_log_2
(
int
x
,
bool
round_up
=
false
)
{
int
a
=
31
-
clz
(
x
);
if
(
round_up
)
{
a
+=
(
x
&
(
x
-
1
))
?
1
:
0
;
}
return
a
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
hadd2
(
uint32_t
a
,
uint32_t
b
)
{
uint32_t
c
;
asm
volatile
(
"add.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
c
)
:
"r"
(
a
),
"r"
(
b
));
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
hmin2
(
uint32_t
a
,
uint32_t
b
)
{
uint32_t
c
;
asm
volatile
(
"min.f16x2 %0, %1, %2;"
:
"=r"
(
c
)
:
"r"
(
a
),
"r"
(
b
));
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
hmul2
(
uint32_t
a
,
uint32_t
b
)
{
uint32_t
c
;
asm
volatile
(
"mul.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
c
)
:
"r"
(
a
),
"r"
(
b
));
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint2
hmul4
(
uint2
a
,
uint2
b
)
{
uint2
c
;
c
.
x
=
hmul2
(
a
.
x
,
b
.
x
);
c
.
y
=
hmul2
(
a
.
y
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint4
hmul8
(
uint4
a
,
uint4
b
)
{
uint4
c
;
c
.
x
=
hmul2
(
a
.
x
,
b
.
x
);
c
.
y
=
hmul2
(
a
.
y
,
b
.
y
);
c
.
z
=
hmul2
(
a
.
z
,
b
.
z
);
c
.
w
=
hmul2
(
a
.
w
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint4
hmul8
(
uint32_t
a
,
uint4
b
)
{
uint4
c
;
c
.
x
=
hmul2
(
a
,
b
.
x
);
c
.
y
=
hmul2
(
a
,
b
.
y
);
c
.
z
=
hmul2
(
a
,
b
.
z
);
c
.
w
=
hmul2
(
a
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
hrelu2
(
uint32_t
x
,
uint32_t
lb
=
0
)
{
uint32_t
res
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"max.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
res
)
:
"r"
(
x
),
"r"
(
lb
));
#else
const
uint32_t
zero
=
0u
;
asm
volatile
(
\
"{
\n
"
\
"
\t
.reg .f16x2 sela;
\n
"
\
"
\t
set.gtu.u32.f16x2 sela, %1, %2;
\n
"
\
"
\t
and.b32 %0, sela, %1;
\n
"
"}
\n
"
:
"=r"
(
res
)
:
"r"
(
x
),
"r"
(
zero
));
#endif
return
res
;
}
static
inline
__device__
uint32_t
habs2
(
uint32_t
x
)
{
uint32_t
res
;
asm
volatile
(
"abs.f16x2 %0, %1;
\n
"
:
"=r"
(
res
)
:
"r"
(
x
));
return
res
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
template
<
typename
T
>
static
inline
__device__
T
clamp
(
T
x
,
T
lb
,
T
ub
)
{
return
x
<
lb
?
lb
:
(
x
>
ub
?
ub
:
x
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint16_t
clamp_to_zero
(
uint16_t
x
)
{
uint16_t
mask
;
asm
volatile
(
"set.gtu %0, %1, 0;"
:
"=h"
(
mask
)
:
"h"
(
x
));
return
mask
&
x
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint16_t
float_to_half
(
float
f
)
{
uint16_t
h
;
asm
volatile
(
"cvt.rn.f16.f32 %0, %1;"
:
"=h"
(
h
)
:
"f"
(
f
));
return
h
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
float2_to_half2
(
float
a
,
float
b
)
{
uint32_t
c
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"cvt.rn.f16x2.f32 %0, %1, %2;
\n
"
:
"=r"
(
c
)
:
"f"
(
b
),
"f"
(
a
));
#else
uint16_t
lo
=
float_to_half
(
a
);
uint16_t
hi
=
float_to_half
(
b
);
asm
volatile
(
"mov.b32 %0, {%1, %2};
\n
"
:
"=r"
(
c
)
:
"h"
(
lo
),
"h"
(
hi
));
#endif
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
float_to_half2
(
float
a
)
{
return
float2_to_half2
(
a
,
a
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
float2_to_half2
(
const
float2
&
f
)
{
return
float2_to_half2
(
f
.
x
,
f
.
y
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint2
float4_to_half4
(
float
x
,
float
y
,
float
z
,
float
w
)
{
uint2
d
;
d
.
x
=
float2_to_half2
(
x
,
y
);
d
.
y
=
float2_to_half2
(
z
,
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
hfma2
(
uint32_t
a
,
uint32_t
b
,
uint32_t
c
)
{
uint32_t
d
;
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
d
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
));
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
hfma2_relu
(
uint32_t
a
,
uint32_t
b
,
uint32_t
c
)
{
uint32_t
d
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"fma.rn.f16x2.relu %0, %1, %2, %3;"
:
"=r"
(
d
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
));
#else
d
=
hrelu2
(
hfma2
(
a
,
b
,
c
));
#endif
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
h0_h0
(
uint32_t
x
)
{
uint32_t
y
;
asm
volatile
(
"{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}
\n
"
:
"=r"
(
y
)
:
"r"
(
x
));
return
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
float
h0_to_float
(
uint32_t
h2
)
{
float
f
;
asm
volatile
(
"{
\n
"
\
".reg .f16 lo, hi;
\n
"
\
"mov.b32 {lo, hi}, %1;
\n
"
\
"cvt.f32.f16 %0, lo;
\n
"
\
"}
\n
"
:
"=f"
(
f
)
:
"r"
(
h2
));
return
f
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
h1_h1
(
uint32_t
x
)
{
uint32_t
y
;
asm
volatile
(
"{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}
\n
"
:
"=r"
(
y
)
:
"r"
(
x
));
return
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint16_t
hadd
(
uint16_t
a
,
uint16_t
b
)
{
uint16_t
d
;
asm
volatile
(
"add.f16 %0, %1, %2;"
:
"=h"
(
d
)
:
"h"
(
a
),
"h"
(
b
));
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
hadd
(
uint32_t
a
,
uint32_t
b
)
{
return
hadd2
(
a
,
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint2
hadd4
(
uint2
a
,
uint2
b
)
{
uint2
c
;
c
.
x
=
hadd2
(
a
.
x
,
b
.
x
);
c
.
y
=
hadd2
(
a
.
y
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint2
hadd
(
uint2
a
,
uint2
b
)
{
return
hadd4
(
a
,
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint4
hadd8
(
uint4
a
,
uint4
b
)
{
uint4
c
;
c
.
x
=
hadd2
(
a
.
x
,
b
.
x
);
c
.
y
=
hadd2
(
a
.
y
,
b
.
y
);
c
.
z
=
hadd2
(
a
.
z
,
b
.
z
);
c
.
w
=
hadd2
(
a
.
w
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint4
fadd4
(
uint4
a
,
uint4
b
)
{
float4
c
;
c
.
x
=
reinterpret_cast
<
const
float
&>
(
a
.
x
)
+
reinterpret_cast
<
const
float
&>
(
b
.
x
);
c
.
y
=
reinterpret_cast
<
const
float
&>
(
a
.
y
)
+
reinterpret_cast
<
const
float
&>
(
b
.
y
);
c
.
z
=
reinterpret_cast
<
const
float
&>
(
a
.
z
)
+
reinterpret_cast
<
const
float
&>
(
b
.
z
);
c
.
w
=
reinterpret_cast
<
const
float
&>
(
a
.
w
)
+
reinterpret_cast
<
const
float
&>
(
b
.
w
);
return
reinterpret_cast
<
const
uint4
&>
(
c
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint4
hadd
(
uint4
a
,
uint4
b
)
{
return
hadd8
(
a
,
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
float
half_to_float
(
uint16_t
h
)
{
float
f
;
asm
volatile
(
"cvt.f32.f16 %0, %1;
\n
"
:
"=f"
(
f
)
:
"h"
(
h
));
return
f
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
float2
half2_to_float2
(
uint32_t
x
)
{
uint16_t
lo
,
hi
;
asm
volatile
(
"mov.b32 {%0, %1}, %2;
\n
"
:
"=h"
(
lo
),
"=h"
(
hi
)
:
"r"
(
x
));
return
make_float2
(
half_to_float
(
lo
),
half_to_float
(
hi
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
void
half2_to_float2
(
float
&
x
,
float
&
y
,
uint32_t
h
)
{
float2
tmp
=
half2_to_float2
(
h
);
x
=
tmp
.
x
;
y
=
tmp
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint16_t
hfma
(
uint16_t
a
,
uint16_t
b
,
uint16_t
c
)
{
uint16_t
d
;
asm
volatile
(
"fma.rn.f16 %0, %1, %2, %3;"
:
"=h"
(
d
)
:
"h"
(
a
),
"h"
(
b
),
"h"
(
c
));
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint16_t
hmul
(
uint16_t
a
,
uint16_t
b
)
{
uint16_t
d
;
asm
volatile
(
"mul.f16 %0, %1, %2;"
:
"=h"
(
d
)
:
"h"
(
a
),
"h"
(
b
));
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
float
sigmoid
(
float
x
)
{
return
1.
f
/
(
1.
f
+
expf
(
-
x
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
clear
(
uint16_t
&
dst
)
{
dst
=
uint16_t
(
0
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
clear
(
uint32_t
&
dst
)
{
dst
=
0u
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
clear
(
uint2
&
dst
)
{
dst
=
make_uint2
(
0u
,
0u
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
clear
(
uint4
&
dst
)
{
dst
=
make_uint4
(
0u
,
0u
,
0u
,
0u
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// P R E D I C A T E P A C K I N G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
enum
{
BYTES_PER_REG
=
4
,
PREDS_PER_BYTE
=
4
,
PREDS_PER_REG
=
BYTES_PER_REG
*
PREDS_PER_BYTE
};
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// G E N E R I C P R E D I C A T E D L D G S T S
//
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
,
int
M
,
typename
Functor
>
inline
__device__
void
load_
(
Functor
&
fct
,
const
uint32_t
(
&
preds
)[
M
])
{
// The number of complete bytes (where we use all the predicates in a byte).
enum
{
COMPLETE
=
N
/
PREDS_PER_BYTE
};
// Make sure we did allocate enough predicates.
static_assert
(
Div_up
<
COMPLETE
,
BYTES_PER_REG
>::
VALUE
<=
M
,
""
);
// The remainder.
enum
{
REMAINDER
=
N
-
COMPLETE
*
PREDS_PER_BYTE
};
// Make sure we got the math right and the remainder is between 0 and 3.
static_assert
(
REMAINDER
>=
0
&&
REMAINDER
<=
3
,
""
);
// The mask to extract the predicates.
enum
{
COMPLETE_MASK
=
(
1
<<
PREDS_PER_BYTE
)
-
1
};
// Clear the fetch registers.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
N
;
++
ii
)
{
fct
.
clear
(
ii
);
}
// Run complete steps.
bool
p
[
PREDS_PER_BYTE
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
COMPLETE
;
++
ii
)
{
// The predicate.
uint32_t
reg
=
preds
[
ii
/
BYTES_PER_REG
];
// Extract the predicates.
#pragma unroll
for
(
int
jj
=
0
;
jj
<
PREDS_PER_BYTE
;
++
jj
)
{
uint32_t
mask
=
1u
<<
(
ii
%
BYTES_PER_REG
*
8
+
jj
);
p
[
jj
]
=
(
reg
&
mask
)
!=
0u
;
}
// Issue the loads.
#pragma unroll
for
(
int
jj
=
0
;
jj
<
PREDS_PER_BYTE
;
++
jj
)
{
fct
.
load
(
ii
*
PREDS_PER_BYTE
+
jj
,
p
[
jj
]);
}
}
// Skip the rest of the code if we do not have a remainder.
if
(
REMAINDER
>
0
)
{
// The mask to extract the predicates.
enum
{
REMAINDER_MASK
=
(
1
<<
REMAINDER
)
-
1
};
// The predicate register.
uint32_t
reg
=
preds
[
COMPLETE
/
BYTES_PER_REG
];
// Extract the predicates.
#pragma unroll
for
(
int
jj
=
0
;
jj
<
PREDS_PER_BYTE
;
++
jj
)
{
uint32_t
mask
=
1u
<<
(
COMPLETE
%
BYTES_PER_REG
*
8
+
jj
);
p
[
jj
]
=
(
reg
&
mask
)
!=
0u
;
}
// Issue the loads.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
REMAINDER
;
++
ii
)
{
fct
.
load
(
COMPLETE
*
PREDS_PER_BYTE
+
ii
,
p
[
ii
]);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
M
,
typename
Functor
>
inline
__device__
void
load_
(
Functor
&
fct
,
uint32_t
preds
)
{
uint32_t
tmp
[
1
]
=
{
preds
};
load_
<
M
>
(
fct
,
tmp
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// L D G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldg
(
uint8_t
&
dst
,
const
void
*
ptr
)
{
dst
=
*
reinterpret_cast
<
const
uint8_t
*>
(
ptr
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldg
(
uint16_t
&
dst
,
const
void
*
ptr
)
{
dst
=
*
reinterpret_cast
<
const
uint16_t
*>
(
ptr
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldg
(
uint32_t
&
dst
,
const
void
*
ptr
)
{
dst
=
*
reinterpret_cast
<
const
uint32_t
*>
(
ptr
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldg
(
uint2
&
dst
,
const
void
*
ptr
)
{
dst
=
*
reinterpret_cast
<
const
uint2
*>
(
ptr
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldg
(
uint4
&
dst
,
const
void
*
ptr
)
{
dst
=
*
reinterpret_cast
<
const
uint4
*>
(
ptr
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Data_type
,
int
N
>
struct
Ldg_functor
{
// Ctor.
inline
__device__
Ldg_functor
(
Data_type
(
&
fetch
)[
N
],
const
void
*
(
&
ptrs
)[
N
])
:
fetch_
(
fetch
),
ptrs_
(
ptrs
)
{
}
// Clear the element.
inline
__device__
void
clear
(
int
ii
)
{
fmha
::
clear
(
fetch_
[
ii
]);
}
// Trigger the loads.
inline
__device__
void
load
(
int
ii
,
bool
p
)
{
if
(
p
)
{
ldg
(
fetch_
[
ii
],
ptrs_
[
ii
]);
}
}
// The fetch registers.
Data_type
(
&
fetch_
)[
N
];
// The pointers.
const
void
*
(
&
ptrs_
)[
N
];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Data_type
,
int
N
,
int
M
>
inline
__device__
void
ldg_
(
Data_type
(
&
fetch
)[
N
],
const
void
*
(
&
ptrs
)[
N
],
uint32_t
(
&
preds
)[
M
])
{
Ldg_functor
<
Data_type
,
N
>
fct
(
fetch
,
ptrs
);
load_
<
N
>
(
fct
,
preds
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
,
int
M
>
inline
__device__
void
ldg
(
uint8_t
(
&
fetch
)[
N
],
const
void
*
(
&
ptrs
)[
N
],
uint32_t
(
&
preds
)[
M
])
{
ldg_
<
uint8_t
,
N
>
(
fetch
,
ptrs
,
preds
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
,
int
M
>
inline
__device__
void
ldg
(
uint16_t
(
&
fetch
)[
N
],
const
void
*
(
&
ptrs
)[
N
],
uint32_t
(
&
preds
)[
M
])
{
ldg_
<
uint16_t
,
N
>
(
fetch
,
ptrs
,
preds
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
,
int
M
>
inline
__device__
void
ldg
(
uint32_t
(
&
fetch
)[
N
],
const
void
*
(
&
ptrs
)[
N
],
uint32_t
(
&
preds
)[
M
])
{
ldg_
<
uint32_t
,
N
>
(
fetch
,
ptrs
,
preds
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
,
int
M
>
inline
__device__
void
ldg
(
uint2
(
&
fetch
)[
N
],
const
void
*
(
&
ptrs
)[
N
],
uint32_t
(
&
preds
)[
M
])
{
ldg_
<
uint2
,
N
>
(
fetch
,
ptrs
,
preds
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
,
int
M
>
inline
__device__
void
ldg
(
uint4
(
&
fetch
)[
N
],
const
void
*
(
&
ptrs
)[
N
],
uint32_t
(
&
preds
)[
M
])
{
ldg_
<
uint4
,
N
>
(
fetch
,
ptrs
,
preds
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// L D S
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
lds
(
uint16_t
&
dst
,
uint32_t
ptr
)
{
asm
volatile
(
"ld.shared.b16 %0, [%1];
\n
"
:
"=h"
(
dst
)
:
"r"
(
ptr
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
lds
(
uint32_t
&
dst
,
uint32_t
ptr
)
{
asm
volatile
(
"ld.shared.b32 %0, [%1];
\n
"
:
"=r"
(
dst
)
:
"r"
(
ptr
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
lds
(
uint2
&
dst
,
uint32_t
ptr
)
{
asm
volatile
(
"ld.shared.v2.b32 {%0, %1}, [%2];
\n
"
:
"=r"
(
dst
.
x
),
"=r"
(
dst
.
y
)
:
"r"
(
ptr
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
lds
(
uint4
&
dst
,
uint32_t
ptr
)
{
asm
volatile
(
"ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(
dst
.
x
)
,
"=r"
(
dst
.
y
)
,
"=r"
(
dst
.
z
)
,
"=r"
(
dst
.
w
)
:
"r"
(
ptr
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// L D S M
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldsm
(
uint32_t
&
dst
,
uint32_t
ptr
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];
\n
"
:
"=r"
(
dst
)
:
"r"
(
ptr
));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldsmt
(
uint32_t
&
dst
,
uint32_t
ptr
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];
\n
"
:
"=r"
(
dst
)
:
"r"
(
ptr
));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldsm
(
uint2
&
dst
,
uint32_t
ptr
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];
\n
"
:
"=r"
(
dst
.
x
),
"=r"
(
dst
.
y
)
:
"r"
(
ptr
));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldsmt
(
uint2
&
dst
,
uint32_t
ptr
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];
\n
"
:
"=r"
(
dst
.
x
),
"=r"
(
dst
.
y
)
:
"r"
(
ptr
));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldsm
(
uint4
&
dst
,
uint32_t
ptr
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(
dst
.
x
),
"=r"
(
dst
.
y
),
"=r"
(
dst
.
z
),
"=r"
(
dst
.
w
)
:
"r"
(
ptr
));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldsmt
(
uint4
&
dst
,
uint32_t
ptr
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(
dst
.
x
),
"=r"
(
dst
.
y
),
"=r"
(
dst
.
z
),
"=r"
(
dst
.
w
)
:
"r"
(
ptr
));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// S T G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
stg
(
void
*
ptr
,
uint8_t
val
)
{
*
reinterpret_cast
<
uint8_t
*>
(
ptr
)
=
val
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
stg
(
void
*
ptr
,
uint16_t
val
)
{
*
reinterpret_cast
<
uint16_t
*>
(
ptr
)
=
val
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
stg
(
void
*
ptr
,
uint32_t
val
)
{
*
reinterpret_cast
<
uint32_t
*>
(
ptr
)
=
val
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
stg
(
void
*
ptr
,
uint2
val
)
{
*
reinterpret_cast
<
uint2
*>
(
ptr
)
=
val
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
stg
(
void
*
ptr
,
uint4
val
)
{
*
reinterpret_cast
<
uint4
*>
(
ptr
)
=
val
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// S T S
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
sts
(
uint32_t
ptr
,
uint16_t
val
)
{
asm
volatile
(
"st.shared.b16 [%0], %1;
\n
"
:
:
"r"
(
ptr
),
"h"
(
val
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
sts
(
uint32_t
ptr
,
uint32_t
val
)
{
asm
volatile
(
"st.shared.b32 [%0], %1;
\n
"
:
:
"r"
(
ptr
),
"r"
(
val
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
sts
(
uint32_t
ptr
,
uint2
val
)
{
asm
volatile
(
"st.shared.v2.b32 [%0], {%1, %2};
\n
"
:
:
"r"
(
ptr
)
,
"r"
(
val
.
x
)
,
"r"
(
val
.
y
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
sts
(
uint32_t
ptr
,
uint4
val
)
{
asm
volatile
(
"st.shared.v4.b32 [%0], {%1, %2, %3, %4};
\n
"
:
:
"r"
(
ptr
)
,
"r"
(
val
.
x
)
,
"r"
(
val
.
y
)
,
"r"
(
val
.
z
)
,
"r"
(
val
.
w
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Data_type
,
int
N
>
inline
__device__
void
sts_
(
uint32_t
(
&
ptrs
)[
N
],
const
Data_type
(
&
data
)[
N
])
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
N
;
++
ii
)
{
sts
(
ptrs
[
ii
],
data
[
ii
]);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
inline
__device__
void
sts
(
uint32_t
(
&
ptrs
)[
N
],
const
uint16_t
(
&
data
)[
N
])
{
sts_
<
uint16_t
,
N
>
(
ptrs
,
data
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
inline
__device__
void
sts
(
uint32_t
(
&
ptrs
)[
N
],
const
uint32_t
(
&
data
)[
N
])
{
sts_
<
uint32_t
,
N
>
(
ptrs
,
data
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
inline
__device__
void
sts
(
uint32_t
(
&
ptrs
)[
N
],
const
uint2
(
&
data
)[
N
])
{
sts_
<
uint2
,
N
>
(
ptrs
,
data
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
inline
__device__
void
sts
(
uint32_t
(
&
ptrs
)[
N
],
const
uint4
(
&
data
)[
N
])
{
sts_
<
uint4
,
N
>
(
ptrs
,
data
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu
0 → 100644
View file @
f79993d9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
extern
"C"
__global__
void
fmha_dgrad_fp16_128_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
fmha
::
compute_dq_dk_1xN
<
Kernel_traits
>
(
params
);
}
void
run_fmha_dgrad_fp16_128_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
using
Smem_tile_s
=
fmha
::
Smem_tile_mma_transposed
<
Kernel_traits
::
Cta_tile_p
>
;
constexpr
int
smem_size_s
=
Smem_tile_s
::
BYTES_PER_TILE
;
static_assert
(
smem_size_s
==
16
*
128
*
2
);
static_assert
(
smem_size_o
==
16
*
64
*
4
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
);
constexpr
int
smem_size_dv
=
smem_size_s
+
2
*
smem_size_q
+
smem_size_v
+
smem_size_softmax
;
constexpr
int
smem_size_dq_dk
=
smem_size_s
+
smem_size_o
+
smem_size_q
+
smem_size_v
;
constexpr
int
smem_size
=
std
::
max
(
smem_size_dv
,
smem_size_dq_dk
);
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
fmha_dgrad_fp16_128_64_sm80_kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
dim3
grid
(
params
.
h
,
params
.
b
);
fmha_dgrad_fp16_128_64_sm80_kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
}
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment