Commit 91da9643 authored by limm's avatar limm
Browse files

support v2.1.0

parent 6f674c7e
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings import warnings
from functools import partial
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
import torch import torch
...@@ -14,6 +15,56 @@ from .norm import build_norm_layer ...@@ -14,6 +15,56 @@ from .norm import build_norm_layer
from .padding import build_padding_layer from .padding import build_padding_layer
def efficient_conv_bn_eval_forward(bn: _BatchNorm,
conv: nn.modules.conv._ConvNd,
x: torch.Tensor):
"""
Implementation based on https://arxiv.org/abs/2305.11624
"Tune-Mode ConvBN Blocks For Efficient Transfer Learning"
It leverages the associative law between convolution and affine transform,
i.e., normalize (weight conv feature) = (normalize weight) conv feature.
It works for Eval mode of ConvBN blocks during validation, and can be used
for training as well. It reduces memory and computation cost.
Args:
bn (_BatchNorm): a BatchNorm module.
conv (nn._ConvNd): a conv module
x (torch.Tensor): Input feature map.
"""
# These lines of code are designed to deal with various cases
# like bn without affine transform, and conv without bias
weight_on_the_fly = conv.weight
if conv.bias is not None:
bias_on_the_fly = conv.bias
else:
bias_on_the_fly = torch.zeros_like(bn.running_var)
if bn.weight is not None:
bn_weight = bn.weight
else:
bn_weight = torch.ones_like(bn.running_var)
if bn.bias is not None:
bn_bias = bn.bias
else:
bn_bias = torch.zeros_like(bn.running_var)
# shape of [C_out, 1, 1, 1] in Conv2d
weight_coeff = torch.rsqrt(bn.running_var +
bn.eps).reshape([-1] + [1] *
(len(conv.weight.shape) - 1))
# shape of [C_out, 1, 1, 1] in Conv2d
coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff
# shape of [C_out, C_in, k, k] in Conv2d
weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly
# shape of [C_out] in Conv2d
bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() *\
(bias_on_the_fly - bn.running_mean)
return conv._conv_forward(x, weight_on_the_fly, bias_on_the_fly)
@MODELS.register_module() @MODELS.register_module()
class ConvModule(nn.Module): class ConvModule(nn.Module):
"""A conv block that bundles conv/norm/activation layers. """A conv block that bundles conv/norm/activation layers.
...@@ -65,6 +116,9 @@ class ConvModule(nn.Module): ...@@ -65,6 +116,9 @@ class ConvModule(nn.Module):
sequence of "conv", "norm" and "act". Common examples are sequence of "conv", "norm" and "act". Common examples are
("conv", "norm", "act") and ("act", "conv", "norm"). ("conv", "norm", "act") and ("act", "conv", "norm").
Default: ('conv', 'norm', 'act'). Default: ('conv', 'norm', 'act').
efficient_conv_bn_eval (bool): Whether use efficient conv when the
consecutive bn is in eval mode (either training or testing), as
proposed in https://arxiv.org/abs/2305.11624 . Default: `False`.
""" """
_abbr_ = 'conv_block' _abbr_ = 'conv_block'
...@@ -84,7 +138,8 @@ class ConvModule(nn.Module): ...@@ -84,7 +138,8 @@ class ConvModule(nn.Module):
inplace: bool = True, inplace: bool = True,
with_spectral_norm: bool = False, with_spectral_norm: bool = False,
padding_mode: str = 'zeros', padding_mode: str = 'zeros',
order: tuple = ('conv', 'norm', 'act')): order: tuple = ('conv', 'norm', 'act'),
efficient_conv_bn_eval: bool = False):
super().__init__() super().__init__()
assert conv_cfg is None or isinstance(conv_cfg, dict) assert conv_cfg is None or isinstance(conv_cfg, dict)
assert norm_cfg is None or isinstance(norm_cfg, dict) assert norm_cfg is None or isinstance(norm_cfg, dict)
...@@ -155,6 +210,8 @@ class ConvModule(nn.Module): ...@@ -155,6 +210,8 @@ class ConvModule(nn.Module):
else: else:
self.norm_name = None # type: ignore self.norm_name = None # type: ignore
self.turn_on_efficient_conv_bn_eval(efficient_conv_bn_eval)
# build activation layer # build activation layer
if self.with_activation: if self.with_activation:
act_cfg_ = act_cfg.copy() # type: ignore act_cfg_ = act_cfg.copy() # type: ignore
...@@ -200,13 +257,82 @@ class ConvModule(nn.Module): ...@@ -200,13 +257,82 @@ class ConvModule(nn.Module):
x: torch.Tensor, x: torch.Tensor,
activate: bool = True, activate: bool = True,
norm: bool = True) -> torch.Tensor: norm: bool = True) -> torch.Tensor:
for layer in self.order: layer_index = 0
while layer_index < len(self.order):
layer = self.order[layer_index]
if layer == 'conv': if layer == 'conv':
if self.with_explicit_padding: if self.with_explicit_padding:
x = self.padding_layer(x) x = self.padding_layer(x)
x = self.conv(x) # if the next operation is norm and we have a norm layer in
# eval mode and we have enabled `efficient_conv_bn_eval` for
# the conv operator, then activate the optimized forward and
# skip the next norm operator since it has been fused
if layer_index + 1 < len(self.order) and \
self.order[layer_index + 1] == 'norm' and norm and \
self.with_norm and not self.norm.training and \
self.efficient_conv_bn_eval_forward is not None:
self.conv.forward = partial(
self.efficient_conv_bn_eval_forward, self.norm,
self.conv)
layer_index += 1
x = self.conv(x)
del self.conv.forward
else:
x = self.conv(x)
elif layer == 'norm' and norm and self.with_norm: elif layer == 'norm' and norm and self.with_norm:
x = self.norm(x) x = self.norm(x)
elif layer == 'act' and activate and self.with_activation: elif layer == 'act' and activate and self.with_activation:
x = self.activate(x) x = self.activate(x)
layer_index += 1
return x return x
def turn_on_efficient_conv_bn_eval(self, efficient_conv_bn_eval=True):
# efficient_conv_bn_eval works for conv + bn
# with `track_running_stats` option
if efficient_conv_bn_eval and self.norm \
and isinstance(self.norm, _BatchNorm) \
and self.norm.track_running_stats:
self.efficient_conv_bn_eval_forward = efficient_conv_bn_eval_forward # noqa: E501
else:
self.efficient_conv_bn_eval_forward = None # type: ignore
@staticmethod
def create_from_conv_bn(conv: torch.nn.modules.conv._ConvNd,
bn: torch.nn.modules.batchnorm._BatchNorm,
efficient_conv_bn_eval=True) -> 'ConvModule':
"""Create a ConvModule from a conv and a bn module."""
self = ConvModule.__new__(ConvModule)
super(ConvModule, self).__init__()
self.conv_cfg = None
self.norm_cfg = None
self.act_cfg = None
self.inplace = False
self.with_spectral_norm = False
self.with_explicit_padding = False
self.order = ('conv', 'norm', 'act')
self.with_norm = True
self.with_activation = False
self.with_bias = conv.bias is not None
# build convolution layer
self.conv = conv
# export the attributes of self.conv to a higher level for convenience
self.in_channels = self.conv.in_channels
self.out_channels = self.conv.out_channels
self.kernel_size = self.conv.kernel_size
self.stride = self.conv.stride
self.padding = self.conv.padding
self.dilation = self.conv.dilation
self.transposed = self.conv.transposed
self.output_padding = self.conv.output_padding
self.groups = self.conv.groups
# build normalization layers
self.norm_name, norm = 'bn', bn
self.add_module(self.norm_name, norm)
self.turn_on_efficient_conv_bn_eval(efficient_conv_bn_eval)
return self
...@@ -371,7 +371,7 @@ class GeneralizedAttention(nn.Module): ...@@ -371,7 +371,7 @@ class GeneralizedAttention(nn.Module):
contiguous().\ contiguous().\
view(1, 1, h*w, h_kv*w_kv) view(1, 1, h*w, h_kv*w_kv)
energy = energy.masked_fill_(cur_local_constraint_map, energy = energy.masked_fill_(cur_local_constraint_map.bool(),
float('-inf')) float('-inf'))
attention = F.softmax(energy, 3) attention = F.softmax(energy, 3)
......
...@@ -98,14 +98,17 @@ def build_norm_layer(cfg: Dict, ...@@ -98,14 +98,17 @@ def build_norm_layer(cfg: Dict,
layer_type = cfg_.pop('type') layer_type = cfg_.pop('type')
# Switch registry to the target scope. If `norm_layer` cannot be found if inspect.isclass(layer_type):
# in the registry, fallback to search `norm_layer` in the norm_layer = layer_type
# mmengine.MODELS. else:
with MODELS.switch_scope_and_registry(None) as registry: # Switch registry to the target scope. If `norm_layer` cannot be found
norm_layer = registry.get(layer_type) # in the registry, fallback to search `norm_layer` in the
if norm_layer is None: # mmengine.MODELS.
raise KeyError(f'Cannot find {norm_layer} in registry under scope ' with MODELS.switch_scope_and_registry(None) as registry:
f'name {registry.scope}') norm_layer = registry.get(layer_type)
if norm_layer is None:
raise KeyError(f'Cannot find {norm_layer} in registry under '
f'scope name {registry.scope}')
abbr = infer_abbr(norm_layer) abbr = infer_abbr(norm_layer)
assert isinstance(postfix, (int, str)) assert isinstance(postfix, (int, str))
...@@ -113,7 +116,7 @@ def build_norm_layer(cfg: Dict, ...@@ -113,7 +116,7 @@ def build_norm_layer(cfg: Dict,
requires_grad = cfg_.pop('requires_grad', True) requires_grad = cfg_.pop('requires_grad', True)
cfg_.setdefault('eps', 1e-5) cfg_.setdefault('eps', 1e-5)
if layer_type != 'GN': if norm_layer is not nn.GroupNorm:
layer = norm_layer(num_features, **cfg_) layer = norm_layer(num_features, **cfg_)
if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'): if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'):
layer._specify_ddp_gpu_num(1) layer._specify_ddp_gpu_num(1)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import inspect
from typing import Dict from typing import Dict
import torch.nn as nn import torch.nn as nn
...@@ -27,7 +28,8 @@ def build_padding_layer(cfg: Dict, *args, **kwargs) -> nn.Module: ...@@ -27,7 +28,8 @@ def build_padding_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
cfg_ = cfg.copy() cfg_ = cfg.copy()
padding_type = cfg_.pop('type') padding_type = cfg_.pop('type')
if inspect.isclass(padding_type):
return padding_type(*args, **kwargs, **cfg_)
# Switch registry to the target scope. If `padding_layer` cannot be found # Switch registry to the target scope. If `padding_layer` cannot be found
# in the registry, fallback to search `padding_layer` in the # in the registry, fallback to search `padding_layer` in the
# mmengine.MODELS. # mmengine.MODELS.
......
...@@ -79,15 +79,18 @@ def build_plugin_layer(cfg: Dict, ...@@ -79,15 +79,18 @@ def build_plugin_layer(cfg: Dict,
cfg_ = cfg.copy() cfg_ = cfg.copy()
layer_type = cfg_.pop('type') layer_type = cfg_.pop('type')
if inspect.isclass(layer_type):
# Switch registry to the target scope. If `plugin_layer` cannot be found plugin_layer = layer_type
# in the registry, fallback to search `plugin_layer` in the else:
# mmengine.MODELS. # Switch registry to the target scope. If `plugin_layer` cannot be
with MODELS.switch_scope_and_registry(None) as registry: # found in the registry, fallback to search `plugin_layer` in the
plugin_layer = registry.get(layer_type) # mmengine.MODELS.
if plugin_layer is None: with MODELS.switch_scope_and_registry(None) as registry:
raise KeyError(f'Cannot find {plugin_layer} in registry under scope ' plugin_layer = registry.get(layer_type)
f'name {registry.scope}') if plugin_layer is None:
raise KeyError(
f'Cannot find {plugin_layer} in registry under scope '
f'name {registry.scope}')
abbr = infer_abbr(plugin_layer) abbr = infer_abbr(plugin_layer)
assert isinstance(postfix, (int, str)) assert isinstance(postfix, (int, str))
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import inspect
from typing import Dict from typing import Dict
import torch import torch
...@@ -76,15 +77,18 @@ def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module: ...@@ -76,15 +77,18 @@ def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
layer_type = cfg_.pop('type') layer_type = cfg_.pop('type')
if inspect.isclass(layer_type):
upsample = layer_type
# Switch registry to the target scope. If `upsample` cannot be found # Switch registry to the target scope. If `upsample` cannot be found
# in the registry, fallback to search `upsample` in the # in the registry, fallback to search `upsample` in the
# mmengine.MODELS. # mmengine.MODELS.
with MODELS.switch_scope_and_registry(None) as registry: else:
upsample = registry.get(layer_type) with MODELS.switch_scope_and_registry(None) as registry:
if upsample is None: upsample = registry.get(layer_type)
raise KeyError(f'Cannot find {upsample} in registry under scope ' if upsample is None:
f'name {registry.scope}') raise KeyError(f'Cannot find {upsample} in registry under scope '
if upsample is nn.Upsample: f'name {registry.scope}')
cfg_['mode'] = layer_type if upsample is nn.Upsample:
cfg_['mode'] = layer_type
layer = upsample(*args, **kwargs, **cfg_) layer = upsample(*args, **kwargs, **cfg_)
return layer return layer
...@@ -41,7 +41,7 @@ class NewEmptyTensorOp(torch.autograd.Function): ...@@ -41,7 +41,7 @@ class NewEmptyTensorOp(torch.autograd.Function):
class Conv2d(nn.Conv2d): class Conv2d(nn.Conv2d):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0:
out_shape = [x.shape[0], self.out_channels] out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size, for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
self.padding, self.stride, self.dilation): self.padding, self.stride, self.dilation):
...@@ -62,7 +62,7 @@ class Conv2d(nn.Conv2d): ...@@ -62,7 +62,7 @@ class Conv2d(nn.Conv2d):
class Conv3d(nn.Conv3d): class Conv3d(nn.Conv3d):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0:
out_shape = [x.shape[0], self.out_channels] out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size, for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
self.padding, self.stride, self.dilation): self.padding, self.stride, self.dilation):
...@@ -84,7 +84,7 @@ class Conv3d(nn.Conv3d): ...@@ -84,7 +84,7 @@ class Conv3d(nn.Conv3d):
class ConvTranspose2d(nn.ConvTranspose2d): class ConvTranspose2d(nn.ConvTranspose2d):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0:
out_shape = [x.shape[0], self.out_channels] out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size, for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size,
self.padding, self.stride, self.padding, self.stride,
...@@ -106,7 +106,7 @@ class ConvTranspose2d(nn.ConvTranspose2d): ...@@ -106,7 +106,7 @@ class ConvTranspose2d(nn.ConvTranspose2d):
class ConvTranspose3d(nn.ConvTranspose3d): class ConvTranspose3d(nn.ConvTranspose3d):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0:
out_shape = [x.shape[0], self.out_channels] out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size, for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
self.padding, self.stride, self.padding, self.stride,
...@@ -127,7 +127,7 @@ class MaxPool2d(nn.MaxPool2d): ...@@ -127,7 +127,7 @@ class MaxPool2d(nn.MaxPool2d):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# PyTorch 1.9 does not support empty tensor inference yet # PyTorch 1.9 does not support empty tensor inference yet
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): if obsolete_torch_version(TORCH_VERSION, (1, 9)) and x.numel() == 0:
out_shape = list(x.shape[:2]) out_shape = list(x.shape[:2])
for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size), for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size),
_pair(self.padding), _pair(self.stride), _pair(self.padding), _pair(self.stride),
...@@ -145,7 +145,7 @@ class MaxPool3d(nn.MaxPool3d): ...@@ -145,7 +145,7 @@ class MaxPool3d(nn.MaxPool3d):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# PyTorch 1.9 does not support empty tensor inference yet # PyTorch 1.9 does not support empty tensor inference yet
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): if obsolete_torch_version(TORCH_VERSION, (1, 9)) and x.numel() == 0:
out_shape = list(x.shape[:2]) out_shape = list(x.shape[:2])
for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size), for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size),
_triple(self.padding), _triple(self.padding),
...@@ -164,7 +164,7 @@ class Linear(torch.nn.Linear): ...@@ -164,7 +164,7 @@ class Linear(torch.nn.Linear):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# empty tensor forward of Linear layer is supported in Pytorch 1.6 # empty tensor forward of Linear layer is supported in Pytorch 1.6
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)): if obsolete_torch_version(TORCH_VERSION, (1, 5)) and x.numel() == 0:
out_shape = [x.shape[0], self.out_features] out_shape = [x.shape[0], self.out_features]
empty = NewEmptyTensorOp.apply(x, out_shape) empty = NewEmptyTensorOp.apply(x, out_shape)
if self.training: if self.training:
......
...@@ -16,13 +16,13 @@ except ImportError: ...@@ -16,13 +16,13 @@ except ImportError:
def _scale_size( def _scale_size(
size: Tuple[int, int], size: Tuple[int, int],
scale: Union[float, int, tuple], scale: Union[float, int, Tuple[float, float], Tuple[int, int]],
) -> Tuple[int, int]: ) -> Tuple[int, int]:
"""Rescale a size by a ratio. """Rescale a size by a ratio.
Args: Args:
size (tuple[int]): (w, h). size (tuple[int]): (w, h).
scale (float | tuple(float)): Scaling factor. scale (float | int | tuple(float) | tuple(int)): Scaling factor.
Returns: Returns:
tuple[int]: scaled size. tuple[int]: scaled size.
...@@ -128,7 +128,8 @@ def imresize_to_multiple( ...@@ -128,7 +128,8 @@ def imresize_to_multiple(
img: np.ndarray, img: np.ndarray,
divisor: Union[int, Tuple[int, int]], divisor: Union[int, Tuple[int, int]],
size: Union[int, Tuple[int, int], None] = None, size: Union[int, Tuple[int, int], None] = None,
scale_factor: Union[float, Tuple[float, float], None] = None, scale_factor: Union[float, int, Tuple[float, float], Tuple[int, int],
None] = None,
keep_ratio: bool = False, keep_ratio: bool = False,
return_scale: bool = False, return_scale: bool = False,
interpolation: str = 'bilinear', interpolation: str = 'bilinear',
...@@ -145,9 +146,10 @@ def imresize_to_multiple( ...@@ -145,9 +146,10 @@ def imresize_to_multiple(
divisor. If divisor is a tuple, divisor should be divisor. If divisor is a tuple, divisor should be
(w_divisor, h_divisor). (w_divisor, h_divisor).
size (None | int | tuple[int]): Target size (w, h). Default: None. size (None | int | tuple[int]): Target size (w, h). Default: None.
scale_factor (None | float | tuple[float]): Multiplier for spatial scale_factor (None | float | int | tuple[float] | tuple[int]):
size. Should match input size if it is a tuple and the 2D style is Multiplier for spatial size. Should match input size if it is a
(w_scale_factor, h_scale_factor). Default: None. tuple and the 2D style is (w_scale_factor, h_scale_factor).
Default: None.
keep_ratio (bool): Whether to keep the aspect ratio when resizing the keep_ratio (bool): Whether to keep the aspect ratio when resizing the
image. Default: False. image. Default: False.
return_scale (bool): Whether to return `w_scale` and `h_scale`. return_scale (bool): Whether to return `w_scale` and `h_scale`.
...@@ -215,16 +217,16 @@ def imresize_like( ...@@ -215,16 +217,16 @@ def imresize_like(
def rescale_size(old_size: tuple, def rescale_size(old_size: tuple,
scale: Union[float, int, tuple], scale: Union[float, int, Tuple[int, int]],
return_scale: bool = False) -> tuple: return_scale: bool = False) -> tuple:
"""Calculate the new size to be rescaled to. """Calculate the new size to be rescaled to.
Args: Args:
old_size (tuple[int]): The old size (w, h) of image. old_size (tuple[int]): The old size (w, h) of image.
scale (float | tuple[int]): The scaling factor or maximum size. scale (float | int | tuple[int]): The scaling factor or maximum size.
If it is a float number, then the image will be rescaled by this If it is a float number or an integer, then the image will be
factor, else if it is a tuple of 2 integers, then the image will rescaled by this factor, else if it is a tuple of 2 integers, then
be rescaled as large as possible within the scale. the image will be rescaled as large as possible within the scale.
return_scale (bool): Whether to return the scaling factor besides the return_scale (bool): Whether to return the scaling factor besides the
rescaled image size. rescaled image size.
...@@ -255,7 +257,7 @@ def rescale_size(old_size: tuple, ...@@ -255,7 +257,7 @@ def rescale_size(old_size: tuple,
def imrescale( def imrescale(
img: np.ndarray, img: np.ndarray,
scale: Union[float, Tuple[int, int]], scale: Union[float, int, Tuple[int, int]],
return_scale: bool = False, return_scale: bool = False,
interpolation: str = 'bilinear', interpolation: str = 'bilinear',
backend: Optional[str] = None backend: Optional[str] = None
...@@ -264,10 +266,10 @@ def imrescale( ...@@ -264,10 +266,10 @@ def imrescale(
Args: Args:
img (ndarray): The input image. img (ndarray): The input image.
scale (float | tuple[int]): The scaling factor or maximum size. scale (float | int | tuple[int]): The scaling factor or maximum size.
If it is a float number, then the image will be rescaled by this If it is a float number or an integer, then the image will be
factor, else if it is a tuple of 2 integers, then the image will rescaled by this factor, else if it is a tuple of 2 integers, then
be rescaled as large as possible within the scale. the image will be rescaled as large as possible within the scale.
return_scale (bool): Whether to return the scaling factor besides the return_scale (bool): Whether to return the scaling factor besides the
rescaled image. rescaled image.
interpolation (str): Same as :func:`resize`. interpolation (str): Same as :func:`resize`.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import IS_MLU_AVAILABLE
from .active_rotated_filter import active_rotated_filter from .active_rotated_filter import active_rotated_filter
from .assign_score_withk import assign_score_withk from .assign_score_withk import assign_score_withk
from .ball_query import ball_query from .ball_query import ball_query
...@@ -109,3 +110,9 @@ __all__ = [ ...@@ -109,3 +110,9 @@ __all__ = [
'PrRoIPool', 'prroi_pool', 'bias_act', 'filtered_lrelu', 'conv2d', 'PrRoIPool', 'prroi_pool', 'bias_act', 'filtered_lrelu', 'conv2d',
'conv_transpose2d', 'filter2d', 'upsample2d', 'BezierAlign', 'bezier_align' 'conv_transpose2d', 'filter2d', 'upsample2d', 'BezierAlign', 'bezier_align'
] ]
if IS_MLU_AVAILABLE:
from .deform_conv import DeformConv2dPack_MLU # noqa:F401
from .modulated_deform_conv import \
ModulatedDeformConv2dPack_MLU # noqa:F401
__all__.extend(['ModulatedDeformConv2dPack_MLU', 'DeformConv2dPack_MLU'])
...@@ -116,6 +116,10 @@ def bbox_overlaps(bboxes1: torch.Tensor, ...@@ -116,6 +116,10 @@ def bbox_overlaps(bboxes1: torch.Tensor,
if rows * cols == 0: if rows * cols == 0:
return ious return ious
if bboxes1.device.type == 'cpu' and torch.__version__ == 'parrots':
return _bbox_overlaps_cpu(
bboxes1, bboxes2, mode=mode, aligned=aligned, offset=offset)
ext_module.bbox_overlaps( ext_module.bbox_overlaps(
bboxes1, bboxes2, ious, mode=mode_flag, aligned=aligned, offset=offset) bboxes1, bboxes2, ious, mode=mode_flag, aligned=aligned, offset=offset)
......
...@@ -133,12 +133,20 @@ def box_iou_rotated(bboxes1: torch.Tensor, ...@@ -133,12 +133,20 @@ def box_iou_rotated(bboxes1: torch.Tensor,
if aligned: if aligned:
ious = bboxes1.new_zeros(rows) ious = bboxes1.new_zeros(rows)
else: else:
ious = bboxes1.new_zeros(rows * cols) if bboxes1.device.type == 'mlu':
ious = bboxes1.new_zeros([rows, cols])
else:
ious = bboxes1.new_zeros(rows * cols)
if not clockwise: if not clockwise:
flip_mat = bboxes1.new_ones(bboxes1.shape[-1]) flip_mat = bboxes1.new_ones(bboxes1.shape[-1])
flip_mat[-1] = -1 flip_mat[-1] = -1
bboxes1 = bboxes1 * flip_mat bboxes1 = bboxes1 * flip_mat
bboxes2 = bboxes2 * flip_mat bboxes2 = bboxes2 * flip_mat
if bboxes1.device.type == 'npu':
scale_mat = bboxes1.new_ones(bboxes1.shape[-1])
scale_mat[-1] = 1.0 / 0.01745329252
bboxes1 = bboxes1 * scale_mat
bboxes2 = bboxes2 * scale_mat
bboxes1 = bboxes1.contiguous() bboxes1 = bboxes1.contiguous()
bboxes2 = bboxes2.contiguous() bboxes2 = bboxes2.contiguous()
ext_module.box_iou_rotated( ext_module.box_iou_rotated(
......
...@@ -16,6 +16,7 @@ from typing import Dict, Optional, Tuple, Union ...@@ -16,6 +16,7 @@ from typing import Dict, Optional, Tuple, Union
import torch import torch
from mmengine.utils import digit_version from mmengine.utils import digit_version
from mmengine.utils.dl_utils.parrots_wrapper import is_rocm_pytorch
enabled = True enabled = True
weight_gradients_disabled = False weight_gradients_disabled = False
...@@ -283,28 +284,19 @@ def _conv2d_gradfix( ...@@ -283,28 +284,19 @@ def _conv2d_gradfix(
output_padding=output_padding, output_padding=output_padding,
output_mask=[0, 1, 0])[1] output_mask=[0, 1, 0])[1]
else: else:
is_rocm_pytorch = False if is_rocm_pytorch():
try: name = 'aten::miopen_convolution_transpose_backward_weight'
from torch.utils.cpp_extension import ROCM_HOME if not transpose:
is_rocm_pytorch = True if ((torch.version.hip is not None) and name = 'aten::miopen_convolution_backward_weight'
(ROCM_HOME is not None)) else False
except ImportError:
pass
name=''
flags=[]
if is_rocm_pytorch:
name = ('aten::miopen_convolution_transpose_backward_weight'
if transpose else
'aten::miopen_convolution_backward_weight')
flags = [ flags = [
torch.backends.cudnn.benchmark, torch.backends.cudnn.benchmark,
torch.backends.cudnn.deterministic torch.backends.cudnn.deterministic
] ]
else: else:
# General case => cuDNN. # General case => cuDNN.
name = ('aten::cudnn_convolution_transpose_backward_weight' name = ('aten::cudnn_convolution_transpose_backward_weight'
if transpose else if transpose else
'aten::cudnn_convolution_backward_weight') 'aten::cudnn_convolution_backward_weight')
flags = [ flags = [
torch.backends.cudnn.benchmark, torch.backends.cudnn.benchmark,
torch.backends.cudnn.deterministic, torch.backends.cudnn.deterministic,
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch import torch
from torch import Tensor, nn
from mmengine.utils import digit_version from mmengine.utils import digit_version
from torch import Tensor, nn
_mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3} _mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3}
...@@ -70,7 +71,8 @@ class CornerPool(nn.Module): ...@@ -70,7 +71,8 @@ class CornerPool(nn.Module):
self.mode = mode self.mode = mode
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
if torch.__version__ != 'parrots' and digit_version(torch.__version__) >= digit_version('1.5.0'): if (torch.__version__ != 'parrots' and
digit_version(torch.__version__) >= digit_version('1.5.0')):
dim, flip = self.cummax_dim_flip[self.mode] dim, flip = self.cummax_dim_flip[self.mode]
if flip: if flip:
x = x.flip(dim) x = x.flip(dim)
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#ifndef CARAFE_CUDA_KERNEL_CUH #ifndef CARAFE_CUDA_KERNEL_CUH
#define CARAFE_CUDA_KERNEL_CUH #define CARAFE_CUDA_KERNEL_CUH
#include <ATen/cuda/DeviceUtils.cuh>
#ifdef MMCV_USE_PARROTS #ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp" #include "parrots_cuda_helper.hpp"
#else #else
...@@ -56,7 +58,8 @@ template <> ...@@ -56,7 +58,8 @@ template <>
__device__ __forceinline__ phalf warpReduceSum(phalf val) { __device__ __forceinline__ phalf warpReduceSum(phalf val) {
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2)
#ifdef MMCV_WITH_HIP #ifdef MMCV_WITH_HIP
__PHALF(val) += __shfl_down(val, offset); // Using PyTorch's macro for half support
__PHALF(val) += WARP_SHFL_DOWN(val, offset);
#else #else
__PHALF(val) += __PHALF(val) +=
__shfl_down_sync(FULL_MASK, __PHALF(val).operator __half(), offset); __shfl_down_sync(FULL_MASK, __PHALF(val).operator __half(), offset);
......
/*************************************************************************
* Copyright (C) 2021 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <float.h>
#include "common_mlu_helper.hpp"
#define COORD_NUM 4
__nram__ char nmem_buf[MAX_NRAM_SIZE];
template <typename T>
__mlu_func__ void computeDiv(void *nram_dst, void *nram_src0, void *nram_src1,
void *nram_addition, const int32_t deal_num) {
__bang_active_reciphp((T *)nram_dst, (T *)nram_src1, deal_num);
__bang_mul((T *)nram_dst, (T *)nram_src0, (T *)nram_dst, deal_num);
}
template <>
__mlu_func__ void computeDiv<half>(void *nram_dst, void *nram_src0,
void *nram_src1, void *nram_addition,
const int32_t deal_num) {
__bang_half2float((float *)nram_addition, (half *)nram_src1, deal_num);
__bang_active_reciphp((float *)nram_addition, (float *)nram_addition,
deal_num);
__bang_float2half_rd((half *)nram_src1, (float *)nram_addition, deal_num);
__bang_mul((half *)nram_dst, (half *)nram_src0, (half *)nram_src1, deal_num);
}
template <typename T>
__mlu_func__ void bboxOverlapsWorkflow(
T *vec_b1_x1, T *vec_b1_y1, T *vec_b1_x2, T *vec_b1_y2, T *vec_b2_x1,
T *vec_b2_y1, T *vec_b2_x2, T *vec_b2_y2, T *vec_left, T *vec_right,
T *vec_top, T *vec_bottom, const T *bbox1, const T *bbox2, void *ious,
const int32_t offset, const int32_t mode, const int32_t batches_stride,
const int32_t num_bbox1, const int32_t num_bbox2, const bool aligned) {
int32_t task_batch_stride = (num_bbox1 + taskDim - 1) / taskDim;
int32_t batch_start = taskId * task_batch_stride;
int32_t batch_per_task = batch_start + task_batch_stride < num_bbox1
? task_batch_stride
: num_bbox1 - batch_start;
batch_per_task = batch_per_task > 0 ? batch_per_task : (0);
if (aligned) {
int32_t num_loop_cpy = batch_per_task / batches_stride;
int32_t num_rem_cpy_batches = batch_per_task % batches_stride;
num_loop_cpy = num_rem_cpy_batches > 0 ? num_loop_cpy + 1 : num_loop_cpy;
for (int32_t i = 0; i < num_loop_cpy; i++) {
int32_t index = batch_start + i * batches_stride;
int32_t handle_batches = index + batches_stride > num_bbox1
? num_rem_cpy_batches
: batches_stride;
int32_t b1 = index;
int32_t b2 = index;
int32_t base1 = b1 * COORD_NUM;
__memcpy(vec_b1_x1, &bbox1[base1], sizeof(T), GDRAM2NRAM, sizeof(T),
COORD_NUM * sizeof(T), handle_batches - 1);
__memcpy(vec_b1_y1, &bbox1[base1 + 1], sizeof(T), GDRAM2NRAM, sizeof(T),
COORD_NUM * sizeof(T), handle_batches - 1);
__memcpy(vec_b1_x2, &bbox1[base1 + 2], sizeof(T), GDRAM2NRAM, sizeof(T),
COORD_NUM * sizeof(T), handle_batches - 1);
__memcpy(vec_b1_y2, &bbox1[base1 + 3], sizeof(T), GDRAM2NRAM, sizeof(T),
COORD_NUM * sizeof(T), handle_batches - 1);
int32_t base2 = b2 * COORD_NUM;
__memcpy(vec_b2_x1, &bbox2[base2], sizeof(T), GDRAM2NRAM, sizeof(T),
COORD_NUM * sizeof(T), handle_batches - 1);
__memcpy(vec_b2_y1, &bbox2[base2 + 1], sizeof(T), GDRAM2NRAM, sizeof(T),
COORD_NUM * sizeof(T), handle_batches - 1);
__memcpy(vec_b2_x2, &bbox2[base2 + 2], sizeof(T), GDRAM2NRAM, sizeof(T),
COORD_NUM * sizeof(T), handle_batches - 1);
__memcpy(vec_b2_y2, &bbox2[base2 + 3], sizeof(T), GDRAM2NRAM, sizeof(T),
COORD_NUM * sizeof(T), handle_batches - 1);
// get the width and height
__bang_maxequal(vec_left, vec_b1_x1, vec_b2_x1, batches_stride);
__bang_minequal(vec_right, vec_b1_x2, vec_b2_x2, batches_stride);
__bang_maxequal(vec_top, vec_b1_y1, vec_b2_y1, batches_stride);
__bang_minequal(vec_bottom, vec_b1_y2, vec_b2_y2, batches_stride);
// right - left + offset ---> left
__bang_sub(vec_left, vec_right, vec_left, batches_stride);
__bang_add_scalar(vec_left, vec_left, (T)offset, batches_stride);
// bottom - top + offset ---> right
__bang_sub(vec_right, vec_bottom, vec_top, batches_stride);
__bang_add_scalar(vec_right, vec_right, (T)offset, batches_stride);
// zero vector ---> bottom
__bang_write_value(vec_bottom, batches_stride, 0.f);
// width --> vec_left
__bang_maxequal(vec_left, vec_bottom, vec_left, batches_stride);
T *width = vec_left;
// height --> vec_right
__bang_maxequal(vec_right, vec_bottom, vec_right, batches_stride);
T *height = vec_right;
// get the b1_area
// (b1_x2 - b1_x1 + offset) ---> vec_top
__bang_sub(vec_top, vec_b1_x2, vec_b1_x1, batches_stride);
__bang_add_scalar(vec_top, vec_top, (T)offset, batches_stride);
// (b1_y2 - b1_y1 + offset) ---> vec_bottom
__bang_sub(vec_bottom, vec_b1_y2, vec_b1_y1, batches_stride);
__bang_add_scalar(vec_bottom, vec_bottom, (T)offset, batches_stride);
// b1_area = (b1_x2 - b1_x1 + offset) * (b1_y2 - b1_y1 + offset)
// ---> vec_top;
__bang_mul(vec_top, vec_top, vec_bottom, batches_stride);
T *b1_area = vec_top;
// get the b2_area
// (b2_x2 - b2_x1 + offset) ---> b2_x1
__bang_sub(vec_b2_x1, vec_b2_x2, vec_b2_x1, batches_stride);
__bang_add_scalar(vec_b2_x1, vec_b2_x1, (T)offset, batches_stride);
// (b2_y2 - b2_y1 + offset) ---> b2_y1
__bang_sub(vec_b2_y1, vec_b2_y2, vec_b2_y1, batches_stride);
__bang_add_scalar(vec_b2_y1, vec_b2_y1, (T)offset, batches_stride);
// b2_area = (b2_x2 - b2_x1 + offset) * (b2_y2 - b2_y1 + offset)
// ---> b2_x1;
__bang_mul(vec_b2_x1, vec_b2_x1, vec_b2_y1, batches_stride);
T *b2_area = vec_b2_x1;
// inter_s = width * height
__bang_mul(height, width, height, batches_stride);
T *inter_s = height;
// offset vector ---> vec_b2_y1
__bang_write_value(vec_b2_y1, batches_stride, T(offset));
T *vec_offset = vec_b2_y1;
if (mode == 0) {
__bang_add(b1_area, b1_area, b2_area, batches_stride);
__bang_sub(b1_area, b1_area, inter_s, batches_stride);
__bang_maxequal(b1_area, vec_offset, b1_area, batches_stride);
} else {
__bang_maxequal(b1_area, vec_offset, b1_area, batches_stride);
}
T *base_s = b1_area;
// ious = inter_s / base_s
computeDiv<T>(width, inter_s, base_s, vec_b2_x2, batches_stride);
__memcpy((T *)ious + index, width, handle_batches * sizeof(T),
NRAM2GDRAM);
}
} else {
int32_t num_loop_cpy = num_bbox2 / batches_stride;
int32_t num_rem_cpy_batches = num_bbox2 % batches_stride;
num_loop_cpy = num_rem_cpy_batches > 0 ? num_loop_cpy + 1 : num_loop_cpy;
for (int32_t i = 0; i < batch_per_task; i++) {
int32_t index1 = batch_start + i;
int32_t b1 = index1;
int32_t base1 = b1 * COORD_NUM;
// set bbox1 and bbox2 to nram
__bang_write_value(vec_b1_x1, batches_stride, bbox1[base1]);
__bang_write_value(vec_b1_y1, batches_stride, bbox1[base1 + 1]);
__bang_write_value(vec_b1_x2, batches_stride, bbox1[base1 + 2]);
__bang_write_value(vec_b1_y2, batches_stride, bbox1[base1 + 3]);
for (int32_t j = 0; j < num_loop_cpy; j++) {
int32_t index2 = j * batches_stride;
int32_t handle_batches = index2 + batches_stride > num_bbox2
? num_rem_cpy_batches
: batches_stride;
int32_t b2 = index2;
int32_t base2 = b2 * COORD_NUM;
// copy bbox2 to nram
__memcpy(vec_b2_x1, &bbox2[base2], sizeof(T), GDRAM2NRAM, sizeof(T),
COORD_NUM * sizeof(T), handle_batches - 1);
__memcpy(vec_b2_y1, &bbox2[base2 + 1], sizeof(T), GDRAM2NRAM, sizeof(T),
COORD_NUM * sizeof(T), handle_batches - 1);
__memcpy(vec_b2_x2, &bbox2[base2 + 2], sizeof(T), GDRAM2NRAM, sizeof(T),
COORD_NUM * sizeof(T), handle_batches - 1);
__memcpy(vec_b2_y2, &bbox2[base2 + 3], sizeof(T), GDRAM2NRAM, sizeof(T),
COORD_NUM * sizeof(T), handle_batches - 1);
// get the width and height
__bang_maxequal(vec_left, vec_b1_x1, vec_b2_x1, batches_stride);
__bang_minequal(vec_right, vec_b1_x2, vec_b2_x2, batches_stride);
__bang_maxequal(vec_top, vec_b1_y1, vec_b2_y1, batches_stride);
__bang_minequal(vec_bottom, vec_b1_y2, vec_b2_y2, batches_stride);
// right - left + offset ---> left
__bang_sub(vec_left, vec_right, vec_left, batches_stride);
__bang_add_scalar(vec_left, vec_left, (T)offset, batches_stride);
// bottom - top + offset ---> right
__bang_sub(vec_right, vec_bottom, vec_top, batches_stride);
__bang_add_scalar(vec_right, vec_right, (T)offset, batches_stride);
// zero vector ---> bottom
__bang_write_value(vec_bottom, batches_stride, (T)0);
// width --> vec_left
__bang_maxequal(vec_left, vec_bottom, vec_left, batches_stride);
T *width = vec_left;
// height --> vec_right
__bang_maxequal(vec_right, vec_bottom, vec_right, batches_stride);
T *height = vec_right;
// get the b1_area
// (b1_x2 - b1_x1 + offset) ---> vec_top
__bang_sub(vec_top, vec_b1_x2, vec_b1_x1, batches_stride);
__bang_add_scalar(vec_top, vec_top, (T)offset, batches_stride);
// (b1_y2 - b1_y1 + offset) ---> vec_bottom
__bang_sub(vec_bottom, vec_b1_y2, vec_b1_y1, batches_stride);
__bang_add_scalar(vec_bottom, vec_bottom, (T)offset, batches_stride);
// b1_area = (b1_x2 - b1_x1 + offset) * (b1_y2 - b1_y1 + offset)
// ---> vec_top;
__bang_mul(vec_top, vec_top, vec_bottom, batches_stride);
T *b1_area = vec_top;
// get the b2_area
// (b2_x2 - b2_x1 + offset) ---> b2_x1
__bang_sub(vec_b2_x1, vec_b2_x2, vec_b2_x1, batches_stride);
__bang_add_scalar(vec_b2_x1, vec_b2_x1, (T)offset, batches_stride);
// (b2_y2 - b2_y1 + offset) ---> b2_y1
__bang_sub(vec_b2_y1, vec_b2_y2, vec_b2_y1, batches_stride);
__bang_add_scalar(vec_b2_y1, vec_b2_y1, (T)offset, batches_stride);
// b2_area = (b2_x2 - b2_x1 + offset) * (b2_y2 - b2_y1 + offset)
// ---> b2_x1;
__bang_mul(vec_b2_x1, vec_b2_x1, vec_b2_y1, batches_stride);
T *b2_area = vec_b2_x1;
// inter_s = width * height
__bang_mul(height, width, height, batches_stride);
T *inter_s = height;
// offset vector ---> vec_b2_y1
__bang_write_value(vec_b2_y1, batches_stride, T(offset));
T *vec_offset = vec_b2_y1;
if (mode == 0) {
__bang_add(b1_area, b1_area, b2_area, batches_stride);
__bang_sub(b1_area, b1_area, inter_s, batches_stride);
__bang_maxequal(b1_area, vec_offset, b1_area, batches_stride);
} else {
__bang_maxequal(b1_area, vec_offset, b1_area, batches_stride);
}
T *base_s = b1_area;
// ious = inter_s / base_s
computeDiv<T>(width, inter_s, base_s, vec_b2_x2, batches_stride);
int32_t gdram_offset = index1 * num_bbox2 + index2;
__memcpy((T *)ious + gdram_offset, width, handle_batches * sizeof(T),
NRAM2GDRAM);
}
}
}
}
template <typename T>
__mlu_global__ void MLUUnion1KernelBBoxOverlaps(
const void *bbox1, const void *bbox2, void *ious, const int32_t num_bbox1,
const int32_t num_bbox2, const int32_t mode, const bool aligned,
const int32_t offset) {
/*
* NRAM partition
* |-------------------------------------------------------------|
* | vec_b1_x1 | vec_b1_y1 | vec_b1_x2 | vec_b1_y2 |
* |-------------------------------------------------------------|
* | vec_b2_x1 | vec_b2_y1 | vec_b2_x2 | vec_b2_y2 |
* |-------------------------------------------------------------|
* | vec_left | vec_right | vec_top | vec_bottom |
* |-------------------------------------------------------------|
*
*/
const int32_t align_bytes = PAD_DOWN(MAX_NRAM_SIZE, NFU_ALIGN_SIZE);
const int32_t split_nram_num = 12;
const int32_t nram_stride =
align_bytes / NFU_ALIGN_SIZE / split_nram_num * NFU_ALIGN_SIZE;
void *vec_b1_x1 = nmem_buf;
void *vec_b1_y1 = nmem_buf + nram_stride;
void *vec_b1_x2 = nmem_buf + 2 * nram_stride;
void *vec_b1_y2 = nmem_buf + 3 * nram_stride;
void *vec_b2_x1 = nmem_buf + 4 * nram_stride;
void *vec_b2_y1 = nmem_buf + 5 * nram_stride;
void *vec_b2_x2 = nmem_buf + 6 * nram_stride;
void *vec_b2_y2 = nmem_buf + 7 * nram_stride;
void *vec_left = nmem_buf + 8 * nram_stride;
void *vec_right = nmem_buf + 9 * nram_stride;
void *vec_top = nmem_buf + 10 * nram_stride;
void *vec_bottom = nmem_buf + 11 * nram_stride;
const int32_t vec_length = nram_stride / sizeof(T);
bboxOverlapsWorkflow((T *)vec_b1_x1, (T *)vec_b1_y1, (T *)vec_b1_x2,
(T *)vec_b1_y2, (T *)vec_b2_x1, (T *)vec_b2_y1,
(T *)vec_b2_x2, (T *)vec_b2_y2, (T *)vec_left,
(T *)vec_right, (T *)vec_top, (T *)vec_bottom,
(T *)bbox1, (T *)bbox2, (T *)ious, offset, mode,
vec_length, num_bbox1, num_bbox2, aligned);
}
void KernelBBoxOverlaps(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const cnrtDataType_t d_type,
const void *bbox1, const void *bbox2, void *ious,
const int32_t num_bbox1, const int32_t num_bbox2,
const int32_t mode, const bool aligned,
const int32_t offset) {
if (d_type == CNRT_FLOAT16) {
MLUUnion1KernelBBoxOverlaps<half><<<k_dim, k_type, queue>>>(
bbox1, bbox2, ious, num_bbox1, num_bbox2, mode, aligned, offset);
} else {
MLUUnion1KernelBBoxOverlaps<float><<<k_dim, k_type, queue>>>(
bbox1, bbox2, ious, num_bbox1, num_bbox2, mode, aligned, offset);
}
}
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "carafe_utils.hpp"
#include "common_mlu_helper.hpp"
#define INDEX3(n, h, w, c, strN, strH, strW) \
(strN) * (n) + (strH) * (h) + (strW) * (w) + (c)
#define NRAM_BLOCK PAD_DOWN(MAX_NRAM_SIZE / 5, NRAM_ALIGN_SIZE)
__nram__ char nram_buf[MAX_NRAM_SIZE];
namespace forward {
struct BlockId {
int Ho;
int Wo;
int G;
int Cg;
int Kh;
int Kw;
int Hi;
int Wi;
};
// start indices of block
struct BlockStart {
int Ho;
int Wo;
int G;
int Cg;
int Kh;
int Kw;
int Hi;
int Wi;
int C;
};
struct BlockEnd {
int Ho;
int Wo;
int Kh;
int Kw;
int Hi;
int Wi;
};
struct BlockSize {
int Ho;
int Wo;
int G;
int Cg;
int Kh;
int Kw;
int Hi;
int Wi;
};
template <typename T>
__mlu_func__ void carafeForwardBLOCK(T *input, T *mask,
const CarafeForwardParam param,
const CarafeForwardBlockDim block_dim,
const CarafeForwardGridDim grid_dim,
T *output) {
// data block info
BlockId blkId;
BlockStart blkStart;
BlockEnd blkEnd;
BlockSize blkSize;
// set pointers on NRAM arrays
// input_nram[blkDim_(Hi+Kh)-1, blkDim_(Wi+Kw)-1, blkDim_(G*Cg)]
T *input_nram = (T *)nram_buf;
// mask_nram[blkDim_Ho, blkDim_Wo, blkDim_(G*Kh*Kw)]
T *mask_nram = input_nram + param.input_nram_size;
// output_nram[blkDim_Ho, blkDim_Wo, blkDim_(G*Cg)]
T *output_nram = mask_nram + param.mask_nram_size;
// sum_array[blkDim_(G*Cg)]
T *sum_array = output_nram + param.output_nram_size;
/* ===== loop over N, grid_dim(Ho,Wo,G,Cg)
* iterations are distributed over computing cores
*/
for (int loop_index = taskId; loop_index < param.job_num;
loop_index += taskDim) {
// block idx
blkId.Cg = loop_index;
blkId.G = blkId.Cg / grid_dim.Cg;
blkId.Wo = blkId.G / grid_dim.G;
blkId.Ho = blkId.Wo / grid_dim.Wo;
int sample_idx = blkId.Ho / grid_dim.Ho;
blkId.Cg %= grid_dim.Cg;
blkId.G %= grid_dim.G;
blkId.Wo %= grid_dim.Wo;
blkId.Ho %= grid_dim.Ho;
// block starting indices
blkStart.Ho = blkId.Ho * block_dim.Ho;
blkStart.Wo = blkId.Wo * block_dim.Wo;
blkStart.G = blkId.G * block_dim.G;
blkStart.Cg = blkId.Cg * block_dim.Cg;
blkStart.C = blkStart.G * param.Cg + blkStart.Cg;
// block size
blkSize.Ho = block_dim.Ho;
blkSize.Wo = block_dim.Wo;
blkSize.G = block_dim.G;
blkSize.Cg = block_dim.Cg;
// take care of blocks near the end of each dimension
if (blkId.Ho == (grid_dim.Ho - 1)) {
blkSize.Ho = param.Ho - (grid_dim.Ho - 1) * block_dim.Ho;
}
if (blkId.Wo == (grid_dim.Wo - 1)) {
blkSize.Wo = param.Wo - (grid_dim.Wo - 1) * block_dim.Wo;
}
if (blkId.G == (grid_dim.G - 1)) {
blkSize.G = param.group_size - (grid_dim.G - 1) * block_dim.G;
}
if (blkId.Cg == (grid_dim.Cg - 1)) {
blkSize.Cg = param.Cg - (grid_dim.Cg - 1) * block_dim.Cg;
}
// block end indices
blkEnd.Ho = blkStart.Ho + blkSize.Ho - 1;
blkEnd.Wo = blkStart.Wo + blkSize.Wo - 1;
// set output_nram to zero
__bang_write_value(output_nram, param.output_nram_size, T(0));
// loop blocks of kernel window: grid_dim.(Kh, Kw)
for (blkId.Kh = 0; blkId.Kh < grid_dim.Kh; ++blkId.Kh) {
blkStart.Kh = blkId.Kh * block_dim.Kh;
blkSize.Kh = block_dim.Kh;
if (blkId.Kh == (grid_dim.Kh - 1)) {
blkSize.Kh = param.kernel_size - (grid_dim.Kh - 1) * block_dim.Kh;
}
blkEnd.Kh = blkStart.Kh + blkSize.Kh - 1;
blkStart.Hi = blkStart.Ho / param.scale_factor - param.kernel_size_half +
blkStart.Kh;
blkEnd.Hi =
blkEnd.Ho / param.scale_factor - param.kernel_size_half + blkEnd.Kh;
blkSize.Hi = blkEnd.Hi - blkStart.Hi + 1;
for (blkId.Kw = 0; blkId.Kw < grid_dim.Kw; ++blkId.Kw) {
blkStart.Kw = blkId.Kw * block_dim.Kw;
blkSize.Kw = block_dim.Kw;
if (blkId.Kw == (grid_dim.Kw - 1)) {
blkSize.Kw = param.kernel_size - (grid_dim.Kw - 1) * block_dim.Kw;
}
blkEnd.Kw = blkStart.Kw + blkSize.Kw - 1;
blkStart.Wi = blkStart.Wo / param.scale_factor -
param.kernel_size_half + blkStart.Kw;
blkEnd.Wi =
blkEnd.Wo / param.scale_factor - param.kernel_size_half + blkEnd.Kw;
blkSize.Wi = blkEnd.Wi - blkStart.Wi + 1;
// load input block from gdram2nram
//
// input_nram[ | input[ sample_idx,
// 0:blkSize.Hi-1, | blkStart.Hi + 0:blkSize.Hi-1,
// 0:blkSize.Wi-1, | blkStart.Wi + 0:blkSize.Wi-1,
// 0:blkSize.G-1 | blkStart.G + 0:blkSize.G-1
// 0:blkSize.Cg-1] | blkStart.Cg + 0:blkSize.Cg-1]
//
// To skip out of bound indices:
//
// input_nram[
// hi_start_local:hi_end_local,
// wi_start_local:wi_end_local, ...]
// = input[n,
// hi_start_global:hi_end_global,
// wi_start_global:wi_end_global, ...]
//
int hi_start_local = 0;
int hi_start_global = blkStart.Hi;
if (blkStart.Hi < 0) {
hi_start_local = -blkStart.Hi;
hi_start_global = 0;
}
int wi_start_local = 0;
int wi_start_global = blkStart.Wi;
if (blkStart.Wi < 0) {
wi_start_local = -blkStart.Wi;
wi_start_global = 0;
}
int hi_end_local = blkSize.Hi - 1;
int hi_end_global = blkEnd.Hi;
if (blkEnd.Hi > param.Hi - 1) {
hi_end_global = param.Hi - 1;
hi_end_local -= blkEnd.Hi - hi_end_global;
}
int wi_end_local = blkSize.Wi - 1;
int wi_end_global = blkEnd.Wi;
if (blkEnd.Wi > param.Wi - 1) {
wi_end_global = param.Wi - 1;
wi_end_local -= blkEnd.Wi - wi_end_global;
}
int dst_offset = param.input_nram_stride_h * hi_start_local +
param.input_nram_stride_w * wi_start_local;
T *dst = input_nram + dst_offset;
int src_offset = INDEX3(sample_idx, hi_start_global, wi_start_global,
blkStart.C, param.input_stride_n,
param.input_stride_h, param.input_stride_w);
T *src = input + src_offset;
int input_seg_num_h = hi_end_local - hi_start_local + 1;
int input_seg_num_w = wi_end_local - wi_start_local + 1;
for (int i = 0; i < input_seg_num_h; ++i) {
loadStr3D(dst, src, blkSize.Cg, blkSize.G, input_seg_num_w,
param.input_nram_stride_g, param.input_nram_stride_w,
param.input_stride_g, param.input_stride_w);
dst += param.input_nram_stride_h;
src += param.input_stride_h;
}
/* load mask block from gdram2nram
*
* mask_nram[ | mask[sample_idx,
* 0:blkSize.Ho-1 , | blkStart.Ho + 0:blkSize.Ho-1,
* 0:blkSize.Wo-1, | blkStart.Wo + 0:blkSize.Wo-1,
* 0:blkSize.G-1, | blkStart.G + 0:blkSize.G-1,
* 0:blkSize.Kh-1, | blkStart.Kh + 0:blkSize.Kh-1,
* 0:blkSize.Kw-1] | blkStart.Kw + 0:blkSize.Kw-1]
*/
src_offset = INDEX3(blkStart.Wo, blkStart.G, blkStart.Kh, blkStart.Kw,
param.mask_stride_w, param.mask_stride_g,
param.mask_stride_kh);
src_offset += sample_idx * param.mask_stride_n +
blkStart.Ho * param.mask_stride_h;
for (int ho = 0; ho < blkSize.Ho; ++ho) {
dst = mask_nram + ho * param.mask_nram_stride_h;
src = mask + src_offset + ho * param.mask_stride_h;
for (int wo = 0; wo < blkSize.Wo; ++wo) {
loadStr3D(dst, src, blkSize.Kw, blkSize.Kh, blkSize.G,
param.mask_nram_stride_kh, param.mask_nram_stride_g,
param.mask_stride_kh, param.mask_stride_g);
dst += param.mask_nram_stride_w;
src += param.mask_stride_w;
}
}
// loop each pixel of the output block
for (int ho = 0; ho < blkSize.Ho; ++ho) {
int kernel_hi_start_global = (blkStart.Ho + ho) / param.scale_factor -
param.kernel_size_half + blkStart.Kh;
int kernel_hi_start_local = kernel_hi_start_global - blkStart.Hi;
// int kernel_hi_end_global = kernel_hi_start_global + blkSize.Kh - 1;
// int kernel_hi_end_local = kernel_hi_end_global - blkStart.Hi;
// exclude out of bound indices which should be ignored
int kh_min = hi_start_local - kernel_hi_start_local > 0
? hi_start_local - kernel_hi_start_local
: 0;
int kh_max = hi_end_local - kernel_hi_start_local < blkSize.Kh - 1
? hi_end_local - kernel_hi_start_local
: blkSize.Kh - 1;
for (int wo = 0; wo < blkSize.Wo; ++wo) {
int kernel_wi_start_global =
(blkStart.Wo + wo) / param.scale_factor -
param.kernel_size_half + blkStart.Kw;
int kernel_wi_start_local = kernel_wi_start_global - blkStart.Wi;
// exclude out of bound indices wwich should be ignored
int kw_min = wi_start_local - kernel_wi_start_local > 0
? wi_start_local - kernel_wi_start_local
: 0;
int kw_max = wi_end_local - kernel_wi_start_local < blkSize.Kw - 1
? wi_end_local - kernel_wi_start_local
: blkSize.Kw - 1;
// output_nram[ho, wo, g, c] = sum(mask_nram[ho, wo, g, kh, kw]
// * input_nram[hi+kh, wi+kw, g, c],
// for (kh,kw) in [0:blkSize.Kw-1] x [0:blkSize.Kh-1])
//
// sum(mask_nram[ho, wo, g, kh, kw]
// * input_nram[hi+kh, wi+kw, g, c], (kh,kw))
//
T *mask_array = mask_nram + param.mask_nram_stride_h * ho +
param.mask_nram_stride_w * wo;
for (int kh = kh_min; kh <= kh_max; ++kh) {
for (int kw = kw_min; kw <= kw_max; ++kw) {
T *src =
input_nram +
param.input_nram_stride_h * (kernel_hi_start_local + kh) +
param.input_nram_stride_w * (kernel_wi_start_local + kw);
int mask_index = param.mask_nram_stride_kh * kh + kw;
// mlutiply mask weight with channels for each channel group
T *sum = sum_array;
for (int g = 0; g < blkSize.G; ++g) {
__bang_mul_scalar(sum, src, mask_array[mask_index],
param.block_Cg_NFU);
//
// NOTE: Since block_Cg_NFU >= block_Cg_stride,
// overlapped writing may occur on sum_array.
// So this loop must be executed in order to
// avoid data contamination, as shown below.
//
// |-----block_Cg_NFU---------|
// xxxxxxxxxxxxxxxxxxxxyyyzzzzz------------
// |---block_Cg_stride---|^^^^^will be overwritten
// in the next iteration.
//
// x: actual data used, y: not used, z: overwritten
//
sum += param.input_nram_stride_g;
src += param.input_nram_stride_g;
mask_index += param.mask_nram_stride_g;
} // loop blk_G
// add array[blk_G * blk_C] to output_nram
dst = output_nram + param.output_nram_stride_h * ho +
param.output_nram_stride_w * wo;
__bang_add(dst, dst, sum_array, param.output_nram_stride_w);
} // end loop blk_Kw
} // end loop blk_Kh
} // end loop blk_Wo
} // end loop blk_Ho
} // end loop grid_dim.Kw
} // end loop grid_dim.Kh
/* write output from nram2gdram
*
* output_nram[ | output[sample_idx,
* 0:blkSize.Ho-1, | blkStart.Ho + 0:blkSize.Ho-1,
* 0:blkSize.Wo-1, | blkStart.Wo + 0:blkSize.Wo-1,
* 0:blkSize.G-1, | blkStart.G + 0:blkSize.G-1,
* 0:blkSize.Cg-1] | blkStart.Cg + 0:blkSize.Cg-1]
*/
int dst_offset = INDEX3(sample_idx, blkStart.Ho, blkStart.Wo, blkStart.C,
param.output_stride_n, param.output_stride_h,
param.output_stride_w);
T *dst = output + dst_offset;
T *src = output_nram;
for (int i = 0; i < blkSize.Ho; ++i) {
storeStr3D(dst, src, blkSize.Cg, blkSize.G, blkSize.Wo,
param.output_stride_g, param.output_stride_w,
param.output_nram_stride_g, param.output_nram_stride_w);
dst += param.output_stride_h;
src += param.output_nram_stride_h;
}
} // end loop N, grid_dim.(Hi,Wi,G,Cg)
}
template <typename T>
__mlu_global__ void MLUBLOCKKernelCarafeForward(
const void *input, const void *mask, const CarafeForwardParam param,
const CarafeForwardBlockDim block_dim, const CarafeForwardGridDim grid_dim,
void *output) {
carafeForwardBLOCK((T *)input, (T *)mask, param, block_dim, grid_dim,
(T *)output);
}
} // namespace forward
namespace backward {
template <typename T>
__mlu_func__ void CarafeCompute(T *input, T *mask, T *grad_output,
T *grad_input, T *grad_mask, const int n,
const int hi, const int wi, const int c,
const int k_up, const int group,
const int scale) {
char *input_buff = nram_buf;
char *mask_buff = input_buff + NRAM_BLOCK;
char *grad_input_buff = mask_buff + NRAM_BLOCK;
char *grad_output_buff = grad_input_buff + NRAM_BLOCK;
char *grad_mask_buff = grad_output_buff + NRAM_BLOCK;
int wo = wi * scale;
int ho = hi * scale;
int out_num = n * ho * wo * group;
int group_size = c / group;
int repeat = out_num / taskDim + (int)(taskId < out_num % taskDim);
int num_align = PAD_DOWN(NRAM_BLOCK / sizeof(T), NFU_ALIGN_SIZE / sizeof(T));
int num_per_loop = group_size / num_align;
int rem_for_loop = group_size % num_align;
int rem_for_loop_align = PAD_UP(rem_for_loop, NFU_ALIGN_SIZE / sizeof(T));
for (int k = 0; k < repeat; k++) {
int iter = k * taskDim + taskId;
int group_k = iter % group;
int w_k = (iter / group) % wo;
int h_k = (iter / wo / group) % ho;
int n_k = (iter / ho / wo / group) % n;
int h_i = h_k / scale;
int w_i = w_k / scale;
int start_h = h_i - ((k_up - 1) / 2);
int end_h = h_i + ((k_up - 1) / 2) + 1;
int start_w = w_i - ((k_up - 1) / 2);
int end_w = w_i + ((k_up - 1) / 2) + 1;
T *base_mask = (T *)mask + n_k * ho * wo * group * k_up * k_up +
h_k * wo * group * k_up * k_up + w_k * group * k_up * k_up +
group_k * k_up * k_up;
T *base_grad_mask = (T *)grad_mask + n_k * ho * wo * group * k_up * k_up +
h_k * wo * group * k_up * k_up +
w_k * group * k_up * k_up + group_k * k_up * k_up;
__bang_write_zero((T *)grad_input_buff, NRAM_BLOCK / sizeof(T));
__bang_write_zero((T *)grad_mask_buff, NRAM_BLOCK / sizeof(T));
__bang_write_zero((T *)grad_output_buff, NRAM_BLOCK / sizeof(T));
__memcpy((T *)mask_buff, (T *)base_mask, k_up * k_up * sizeof(T),
GDRAM2NRAM);
for (int i = 0; i < num_per_loop; i++) {
__bang_write_zero((T *)input_buff, NRAM_BLOCK / sizeof(T));
T *base_grad_output = (T *)grad_output + n_k * ho * wo * c +
h_k * wo * c + w_k * c + group_k * group_size +
i * num_align;
__memcpy((T *)grad_output_buff, (T *)base_grad_output,
num_align * sizeof(T), GDRAM2NRAM);
for (int ih = start_h; ih < end_h; ih++) {
for (int iw = start_w; iw < end_w; iw++) {
if (ih < 0 || ih > hi - 1 || iw < 0 || iw > wi - 1) {
continue;
}
int mask_ih = ih - h_i + (k_up - 1) / 2;
int mask_iw = iw - w_i + (k_up - 1) / 2;
int mask_index = mask_ih * k_up + mask_iw;
int input_index = n_k * hi * wi * c + ih * wi * c + iw * c +
group_k * group_size + i * num_align;
T *base_input = (T *)input + input_index;
T *base_grad_input = (T *)grad_input + input_index;
__memcpy((T *)input_buff, (T *)base_input, num_align * sizeof(T),
GDRAM2NRAM);
__bang_mul_scalar((T *)grad_input_buff, (T *)grad_output_buff,
((T *)mask_buff)[mask_index], num_align);
__bang_atomic_add((T *)grad_input_buff, (T *)base_grad_input,
(T *)grad_input_buff, num_align);
__bang_mul((T *)input_buff, (T *)grad_output_buff, (T *)input_buff,
num_align);
__bang_sumpool((T *)input_buff, (T *)input_buff,
NFU_ALIGN_SIZE / sizeof(T),
num_align / (NFU_ALIGN_SIZE / sizeof(T)), 1,
num_align / (NFU_ALIGN_SIZE / sizeof(T)), 1, 1, 1);
__bang_reduce_sum((T *)input_buff, (T *)input_buff,
NFU_ALIGN_SIZE / sizeof(T));
((T *)grad_mask_buff)[mask_index] += ((T *)input_buff)[0];
}
}
}
if (rem_for_loop) {
__bang_write_zero((T *)input_buff, NRAM_BLOCK / sizeof(T));
T *base_grad_output = (T *)grad_output + n_k * ho * wo * c +
h_k * wo * c + w_k * c + group_k * group_size +
num_per_loop * num_align;
__memcpy((T *)grad_output_buff, (T *)base_grad_output,
rem_for_loop * sizeof(T), GDRAM2NRAM);
for (int ih = start_h; ih < end_h; ih++) {
for (int iw = start_w; iw < end_w; iw++) {
if (ih < 0 || ih > hi - 1 || iw < 0 || iw > wi - 1) {
continue;
}
int mask_ih = ih - h_i + (k_up - 1) / 2;
int mask_iw = iw - w_i + (k_up - 1) / 2;
int mask_index = mask_ih * k_up + mask_iw;
int input_index = n_k * hi * wi * c + ih * wi * c + iw * c +
group_k * group_size + num_per_loop * num_align;
T *base_input = (T *)input + input_index;
T *base_grad_input = (T *)grad_input + input_index;
__memcpy((T *)input_buff, (T *)base_input, rem_for_loop * sizeof(T),
GDRAM2NRAM);
__bang_mul_scalar((T *)grad_input_buff, (T *)grad_output_buff,
((T *)mask_buff)[mask_index], rem_for_loop_align);
__bang_atomic_add((T *)grad_input_buff, (T *)base_grad_input,
(T *)grad_input_buff, rem_for_loop);
__bang_mul((T *)input_buff, (T *)grad_output_buff, (T *)input_buff,
rem_for_loop_align);
__bang_sumpool(
(T *)input_buff, (T *)input_buff, NFU_ALIGN_SIZE / sizeof(T),
rem_for_loop_align / (NFU_ALIGN_SIZE / sizeof(T)), 1,
rem_for_loop_align / (NFU_ALIGN_SIZE / sizeof(T)), 1, 1, 1);
__bang_reduce_sum((T *)input_buff, (T *)input_buff,
NFU_ALIGN_SIZE / sizeof(T));
((T *)grad_mask_buff)[mask_index] += ((T *)input_buff)[0];
}
}
}
__memcpy((T *)base_grad_mask, (T *)grad_mask_buff, k_up * k_up * sizeof(T),
NRAM2GDRAM);
}
}
template <typename T>
__mlu_global__ void MLUUnion1KernelCarafeBackward(
const void *input, const void *mask, const void *grad_output,
void *grad_input, void *grad_mask, const int n, const int hi, const int wi,
const int c, const int k_up, const int group, const int scale) {
CarafeCompute((T *)input, (T *)mask, (T *)grad_output, (T *)grad_input,
(T *)grad_mask, n, hi, wi, c, k_up, group, scale);
}
} // namespace backward
void KernelCarafeForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const cnrtDataType_t d_type,
const void *input, const void *mask,
const CarafeForwardParam &param,
const CarafeForwardBlockDim &block_dim,
const CarafeForwardGridDim &grid_dim, void *output) {
if (d_type == CNRT_FLOAT16) {
forward::MLUBLOCKKernelCarafeForward<half><<<k_dim, k_type, queue>>>(
input, mask, param, block_dim, grid_dim, output);
} else {
forward::MLUBLOCKKernelCarafeForward<float><<<k_dim, k_type, queue>>>(
input, mask, param, block_dim, grid_dim, output);
}
}
void KernelCarafeBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, cnrtDataType_t dtype,
const void *input, const void *mask,
const void *grad_output, void *grad_input,
void *grad_mask, const int n, const int hi,
const int wi, const int c, const int k_up,
const int group, const int scale) {
if (dtype == CNRT_FLOAT16) {
backward::MLUUnion1KernelCarafeBackward<half><<<k_dim, k_type, queue>>>(
input, mask, grad_output, grad_input, grad_mask, n, hi, wi, c, k_up,
group, scale);
} else {
backward::MLUUnion1KernelCarafeBackward<float><<<k_dim, k_type, queue>>>(
input, mask, grad_output, grad_input, grad_mask, n, hi, wi, c, k_up,
group, scale);
}
}
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CARAFE_UTILS_HPP_
#define CARAFE_UTILS_HPP_
#define NRAM_ALIGN_SIZE 64
struct CarafeForwardParam {
int N; // batch size
int Hi; // input height
int Wi; // input width
int Ci; // input channels
int Ho; // output height
int Wo; // output width
int Cg; // channels per group
int kernel_size; // kernel_size
int group_size; // group_size
int scale_factor; // scale_factor
int kernel_size_half; // kernel half size (K-1)/2
int kernel_size_sq; // square of kernel size
int dtype_size; // size of tensor data type
// Host arrays' geometry
int input_stride_g;
int input_stride_w;
int input_stride_h;
int input_stride_n;
int input_size;
int mask_stride_kh;
int mask_stride_g;
int mask_stride_w;
int mask_stride_h;
int mask_stride_n;
int mask_size;
int output_stride_g;
int output_stride_w;
int output_stride_h;
int output_stride_n;
int output_size;
// NRAM arrays' geometry
int input_nram_stride_g;
int input_nram_stride_w;
int input_nram_stride_h;
int input_nram_size;
int mask_nram_stride_kh;
int mask_nram_stride_g;
int mask_nram_stride_w;
int mask_nram_stride_h;
int mask_nram_size;
int output_nram_stride_g;
int output_nram_stride_w;
int output_nram_stride_h;
int output_nram_size;
// for address/compute alignment
int align_size_NRAM; // for addressing on NRAM
int align_size_NFU; // for NFU operation length
int block_Cg_NFU; // for bang_mul_const
int job_num; // total job number
};
struct CarafeForwardBlockDim {
int Ho; // block size of output height
int Wo; // block size of output width
int Kh; // block size of kernel height
int Kw; // block size of kernel width
int G; // block size of groups
int Cg; // block size of channels within a group
int Hi; // block size of input height
int Wi; // block size of input width
};
struct CarafeForwardGridDim {
int Ho; // number of blocks of output height
int Wo;
int Kh;
int Kw;
int G;
int Cg;
};
#endif // CARAFE_UTILS_HPP_
...@@ -45,148 +45,6 @@ __mlu_func__ inline scalar_t max(scalar_t a, scalar_t b) { ...@@ -45,148 +45,6 @@ __mlu_func__ inline scalar_t max(scalar_t a, scalar_t b) {
return a > b ? a : b; return a > b ? a : b;
} }
/*!
* @brief loads data from global DRAM to NRAM with 2D pattern.
*
* @param[out] dst
* Pointer to NRAM that stores dst data.
* @param[in] src
* Pointer to global DRAM that stores src data.
* @param[in] size
* The byte size of segment in the lower dimension.
* @param[in] dst_str
* The data stride in bytes between segments in the lower dimension of dst.
* @param[in] src_str
* The data stride in bytes between segments in the lower dimension of src.
* @param[in] seg_num
* The total count of data segments in the lower dimension.
*/
template <typename T>
__mlu_func__ void loadStr2D(T *dst, T *src, const int size, const int dst_str,
const int src_str, const int seg_num) {
if (dst_str == src_str && size == src_str) {
__memcpy(dst, src, src_str * seg_num * sizeof(T), GDRAM2NRAM);
} else if ((size == src_str || src_str <= dst_str) &&
src_str * sizeof(T) <= 512) {
// gather data less than 512Bytes to improve IO efficiency
T *tmp = (T *)dst + (dst_str - src_str) * seg_num;
__memcpy(tmp, src, (src_str * (seg_num - 1) + size) * sizeof(T),
GDRAM2NRAM);
if (dst_str != src_str) {
__memcpy(dst, tmp, size * sizeof(T), NRAM2NRAM, dst_str * sizeof(T),
src_str * sizeof(T), seg_num - 1);
}
} else {
__memcpy(dst, src, size * sizeof(T), GDRAM2NRAM, dst_str * sizeof(T),
src_str * sizeof(T), seg_num - 1);
}
}
/*!
* @brief loads data from global DRAM to NRAM with 3D pattern.
*
* @param[out] dst
* Pointer to NRAM that stores dst data.
* @param[in] src
* Pointer to global DRAM that stores src data.
* @param[in] size
* The byte size of segment in the lowest dimension.
* @param[in] seg_num_in
* The total count of data segments in the lowest dimension.
* @param[in] seg_num_out
* The total count of data segments in the middle dimension.
* @param[in] dst_str_in
* The data stride in bytes between segments in the lowest dimension of dst.
* @param[in] dst_str_out
* The data stride in bytes between segments in the middle dimension of dst.
* @param[in] src_str_in
* The data stride in bytes between segments in the lowest dimension of src.
* @param[in] src_str_out
* The data stride in bytes between segments in the middle dimension of src.
*/
template <typename T>
__mlu_func__ void loadStr3D(T *dst, T *src, const int size,
const int seg_num_in, const int seg_num_out,
const int dst_str_in, const int dst_str_out,
const int src_str_in, const int src_str_out) {
T *tmp_dst = dst;
T *tmp_src = src;
for (int i = 0; i < seg_num_out; ++i) {
loadStr2D(tmp_dst, tmp_src, size, dst_str_in, src_str_in, seg_num_in);
tmp_src += src_str_out;
tmp_dst += dst_str_out;
}
}
/*!
* @brief stores data from NRAM to global DRAM with 2D pattern.
*
* @param[out] dst
* Pointer to global DRAM that stores dst data.
* @param[in] src
* Pointer to NRAM that stores src data.
* @param[in] size
* The byte size of segment in the lower dimension.
* @param[in] dst_str
* The data stride in bytes between segments in the lower dimension of dst.
* @param[in] src_str
* The data stride in bytes between segments in the lower dimension of src.
* @param[in] seg_num
* The total count of data segments in the lower dimension.
*/
template <typename T>
__mlu_func__ void storeStr2D(T *dst, T *src, const int size, const int seg_num,
const int dst_str, const int src_str) {
if ((size == dst_str && dst_str <= src_str) && dst_str * sizeof(T) <= 512) {
// gather data less than 512Bytes to improve IO efficiency
if (dst_str != src_str) {
__memcpy(src, src, size * sizeof(T), NRAM2NRAM, dst_str * sizeof(T),
src_str * sizeof(T), seg_num - 1);
}
__memcpy(dst, src, size * seg_num * sizeof(T), NRAM2GDRAM);
} else {
__memcpy(dst, src, size * sizeof(T), NRAM2GDRAM, dst_str * sizeof(T),
src_str * sizeof(T), seg_num - 1);
}
}
/*!
* @brief stores data from NRAM to global DRAM with 3D pattern.
*
* @param[out] dst
* Pointer to global DRAM that stores dst data.
* @param[in] src
* Pointer to NRAM that stores src data.
* @param[in] size
* The byte size of segment in the lowest dimension.
* @param[in] seg_num_in
* The total count of data segments in the lowest dimension.
* @param[in] seg_num_out
* The total count of data segments in the middle dimension.
* @param[in] dst_str_in
* The data stride in bytes between segments in the lowest dimension of dst.
* @param[in] dst_str_out
* The data stride in bytes between segments in the middle dimension of dst.
* @param[in] src_str_in
* The data stride in bytes between segments in the lowest dimension of src.
* @param[in] src_str_out
* The data stride in bytes between segments in the middle dimension of src.
*/
template <typename T>
__mlu_func__ void storeStr3D(T *dst, T *src, const int size,
const int seg_num_in, const int seg_num_out,
const int dst_str_in, const int dst_str_out,
const int src_str_in, const int src_str_out) {
T *tmp_dst = dst;
T *tmp_src = src;
for (int i = 0; i < seg_num_out; ++i) {
storeStr2D(tmp_dst, tmp_src, size, seg_num_in, dst_str_in, src_str_in);
tmp_src += src_str_out;
tmp_dst += dst_str_out;
}
}
/*! /*!
* @brief Converts int32 to float32 data type. * @brief Converts int32 to float32 data type.
* *
......
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <iostream>
#include "common_mlu_helper.hpp"
#define ROI_OFFSET 5
#define FOURSPLIT 4
#define FIVESPLIT 5
#define NINESPLIT 9
#define THIRTEENSPLIT 13
__nram__ char nram_buffer[MAX_NRAM_SIZE];
template <typename T>
static __mlu_func__ void bilinearInterpolate(const int input_width, T y, T x,
T *w1, T *w2, T *w3, T *w4,
int *x_low, int *x_high,
const int y_low, bool *is_empty) {
if (x < -1.0 || x > input_width) {
*is_empty = true;
return;
}
if (x <= 0) x = 0;
*x_low = int(x);
if (*x_low >= input_width - 1) {
*x_high = *x_low = input_width - 1;
x = T(*x_low);
} else {
*x_high = *x_low + 1;
}
T ly = y - y_low;
T lx = x - *x_low;
T hy = 1.0 - ly;
T hx = 1.0 - lx;
*w1 = hy * hx;
*w2 = hy * lx;
*w3 = ly * hx;
*w4 = ly * lx;
}
template <typename T>
__mlu_func__ void MLUUnion1DeformRoIPoolForward(
const T *input, const T *rois, const T *offset, T *output,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, const T spatial_scale,
const int sampling_ratio, const T gamma) {
for (int bin_index = taskId;
bin_index < num_rois * pooled_width * pooled_height;
bin_index += taskDim) {
int out_batch = bin_index / pooled_width / pooled_height;
int out_height = bin_index / pooled_width % pooled_height;
int out_width = bin_index % pooled_width;
const T *cur_roi = rois + out_batch * ROI_OFFSET;
T *nram_rois = (T *)nram_buffer;
__memcpy((void *)nram_rois, (void *)cur_roi, ROI_OFFSET * sizeof(T),
GDRAM2NRAM);
const int roi_batch = nram_rois[0];
T roi_x_min = nram_rois[1] * spatial_scale - 0.5;
T roi_y_min = nram_rois[2] * spatial_scale - 0.5;
const T roi_x_max = nram_rois[3] * spatial_scale - 0.5;
const T roi_y_max = nram_rois[4] * spatial_scale - 0.5;
const T roi_width = roi_x_max - roi_x_min;
const T roi_height = roi_y_max - roi_y_min;
const T bin_width = roi_width / static_cast<T>(pooled_width);
const T bin_height = roi_height / static_cast<T>(pooled_height);
const T *offset_input = input + roi_batch * height * width * channels;
int roi_bin_grid_height =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceilf(roi_height / pooled_height));
int roi_bin_grid_width =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceilf(roi_width / pooled_width));
if (offset != NULL) {
const T *offset_cur = offset +
out_batch * pooled_width * pooled_height * 2 +
out_height * pooled_width + out_width;
roi_x_min += gamma * roi_width * offset_cur[0];
roi_y_min +=
gamma * roi_height * offset_cur[pooled_width * pooled_height];
}
int type_align = NFU_ALIGN_SIZE / sizeof(T);
int channels_max_num_nram = MAX_NRAM_SIZE / sizeof(T);
int channels_nram_split =
channels_max_num_nram / NINESPLIT / type_align * type_align;
int channel_rem = channels % channels_nram_split;
int channel_loops =
channels / channels_nram_split + (channel_rem != 0 ? 1 : 0);
for (int channel_loop_index = 0; channel_loop_index < channel_loops;
++channel_loop_index) {
int channels_num =
channels_nram_split >= channels ? channels : channels_nram_split;
const int channel_offset = channel_loop_index * channels_num;
if (channel_loop_index + 1 == channel_loops && channel_rem != 0) {
channels_num = channel_rem;
}
int channels_align = CEIL_ALIGN(channels_num, type_align);
int nram_limit = (MAX_NRAM_SIZE / sizeof(T) - channels_align) >> 1;
int c_slice = nram_limit / FOURSPLIT / type_align * type_align;
int c_slice_align = 0;
/* NRAM partition
*
* | | ping | pong |
* |----------|-------------------|-------------------|
* | nram_out | p1 | p2 | p3 | p4 | p1 | p2 | p3 | p4 |
*
*/
T *nram_out = (T *)nram_buffer;
T *nram_ping = nram_out + channels_align;
T *nram_pong = nram_ping + nram_limit;
__bang_write_value((T *)nram_out, channels_align, (T)0);
__bang_write_value((T *)nram_ping, FOURSPLIT * c_slice, (T)0);
__bang_write_value((T *)nram_pong, FOURSPLIT * c_slice, (T)0);
const T num_bins =
static_cast<T>(max(roi_bin_grid_height * roi_bin_grid_width, 1));
const T value_div = 1.0f / num_bins;
bool is_ping_empty = true;
for (int iy = 0; iy < roi_bin_grid_height; ++iy) {
T y = roi_y_min + out_height * bin_height +
static_cast<T>(iy + .5f) * bin_height /
static_cast<T>(roi_bin_grid_height);
if (y < -1.0 || y > height) {
is_ping_empty = true;
continue;
}
if (y <= 0) {
y = 0;
}
int y_low = 0, y_high = 0;
y_low = int(y);
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = T(y_low);
} else {
y_high = y_low + 1;
}
for (int ix = 0; ix < roi_bin_grid_width; ++ix) {
T x = roi_x_min + out_width * bin_width +
static_cast<T>(ix + .5f) * bin_width /
static_cast<T>(roi_bin_grid_width);
const int sample_index = iy * roi_bin_grid_width + ix;
int c_rem = channels_num;
c_slice = nram_limit / FOURSPLIT / type_align * type_align;
c_slice_align = 0;
bool is_empty = false;
T w1, w2, w3, w4;
int x_low = 0, x_high = 0;
bilinearInterpolate(width, y, x, &w1, &w2, &w3, &w4, &x_low, &x_high,
y_low, &is_empty);
if (is_empty) {
is_ping_empty = true;
continue;
}
if (is_ping_empty) {
c_slice = c_slice > c_rem ? c_rem : c_slice;
c_slice_align = CEIL_ALIGN(c_slice, type_align);
__bang_write_value(nram_ping, FOURSPLIT * c_slice_align, (T)0);
__asm__ volatile("sync;");
__memcpy(nram_ping,
offset_input + y_low * width * channels +
x_low * channels + channel_offset,
c_slice * sizeof(T), GDRAM2NRAM);
__memcpy(nram_ping + c_slice_align,
offset_input + y_low * width * channels +
x_high * channels + channel_offset,
c_slice * sizeof(T), GDRAM2NRAM);
__memcpy(nram_ping + 2 * c_slice_align,
offset_input + y_high * width * channels +
x_low * channels + channel_offset,
c_slice * sizeof(T), GDRAM2NRAM);
__memcpy(nram_ping + 3 * c_slice_align,
offset_input + y_high * width * channels +
x_high * channels + channel_offset,
c_slice * sizeof(T), GDRAM2NRAM);
is_ping_empty = false;
}
int c_offset = 0;
int pongc_slice = 0;
int pongc_slice_align = 0;
while (c_rem > 0) {
c_slice = c_slice > c_rem ? c_rem : c_slice;
c_slice_align = CEIL_ALIGN(c_slice, type_align);
if (sample_index + 1 < roi_bin_grid_height * roi_bin_grid_width) {
int iy_tmp = (sample_index + 1) / roi_bin_grid_width;
int ix_tmp = (sample_index + 1) % roi_bin_grid_width;
y = roi_y_min + out_height * bin_height +
static_cast<T>(iy_tmp + .5f) * bin_height /
static_cast<T>(roi_bin_grid_height);
x = roi_x_min + out_width * bin_width +
static_cast<T>(ix_tmp + .5f) * bin_width /
static_cast<T>(roi_bin_grid_width);
if (y < -1.0 || y > height) {
is_empty = true;
} else {
T w1_tmp, w2_tmp, w3_tmp, w4_tmp;
if (y <= 0) {
y = 0;
}
y_low = int(y);
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = T(y_low);
} else {
y_high = y_low + 1;
}
bilinearInterpolate(width, y, x, &w1_tmp, &w2_tmp, &w3_tmp,
&w4_tmp, &x_low, &x_high, y_low, &is_empty);
}
pongc_slice = nram_limit / FOURSPLIT / type_align * type_align;
pongc_slice =
pongc_slice > channels_num ? channels_num : pongc_slice;
pongc_slice_align = CEIL_ALIGN(pongc_slice, type_align);
__bang_write_value(nram_pong, FOURSPLIT * pongc_slice_align,
(T)0);
__asm__ volatile("sync;");
if (!is_empty) {
__memcpy_async(nram_pong,
offset_input + y_low * width * channels +
x_low * channels + channel_offset,
pongc_slice * sizeof(T), GDRAM2NRAM);
__memcpy_async(nram_pong + pongc_slice_align,
offset_input + y_low * width * channels +
x_high * channels + channel_offset,
pongc_slice * sizeof(T), GDRAM2NRAM);
__memcpy_async(nram_pong + 2 * pongc_slice_align,
offset_input + y_high * width * channels +
x_low * channels + channel_offset,
pongc_slice * sizeof(T), GDRAM2NRAM);
__memcpy_async(nram_pong + 3 * pongc_slice_align,
offset_input + y_high * width * channels +
x_high * channels + channel_offset,
pongc_slice * sizeof(T), GDRAM2NRAM);
}
}
__bang_mul_scalar(nram_ping, nram_ping, w1, c_slice_align);
__bang_mul_scalar(nram_ping + c_slice_align,
nram_ping + c_slice_align, w2, c_slice_align);
__bang_add(nram_ping, nram_ping, nram_ping + c_slice_align,
c_slice_align);
__bang_mul_scalar(nram_ping + 2 * c_slice_align,
nram_ping + 2 * c_slice_align, w3, c_slice_align);
__bang_add(nram_ping, nram_ping, nram_ping + 2 * c_slice_align,
c_slice_align);
__bang_mul_scalar(nram_ping + 3 * c_slice_align,
nram_ping + 3 * c_slice_align, w4, c_slice_align);
__bang_add(nram_ping, nram_ping, nram_ping + 3 * c_slice_align,
c_slice_align);
__bang_add(nram_out + c_offset, nram_out + c_offset, nram_ping,
c_slice_align);
T *nram_tmp = nram_ping;
nram_ping = nram_pong;
nram_pong = nram_tmp;
c_rem -= c_slice;
c_offset += c_slice;
__asm__ volatile("sync;");
}
}
}
__bang_mul_scalar(nram_out, nram_out, value_div, channels_align);
__memcpy(output + channels * bin_index + channel_offset, nram_out,
channels_num * sizeof(T), NRAM2GDRAM);
}
}
}
__mlu_global__ void MLUKernelDeformRoIPoolForward(
cnrtDataType_t data_type, const void *input, const void *rois,
const void *offset, void *output, const int channels, const int height,
const int width, const int num_rois, const int pooled_height,
const int pooled_width, const float spatial_scale, const int sampling_ratio,
const float gamma) {
switch (data_type) {
case CNRT_FLOAT16: {
MLUUnion1DeformRoIPoolForward((half *)input, (half *)rois, (half *)offset,
(half *)output, channels, height, width,
num_rois, pooled_height, pooled_width,
static_cast<half>(spatial_scale),
sampling_ratio, static_cast<half>(gamma));
}; break;
case CNRT_FLOAT32: {
MLUUnion1DeformRoIPoolForward(
(float *)input, (float *)rois, (float *)offset, (float *)output,
channels, height, width, num_rois, pooled_height, pooled_width,
static_cast<float>(spatial_scale), sampling_ratio,
static_cast<float>(gamma));
}; break;
default: {
break;
}
}
}
void KernelDeformRoIPoolForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, cnrtDataType_t data_type,
const void *input, const void *rois,
const void *offset, void *output,
const int channels, const int height,
const int width, const int num_rois,
const int pooled_height, const int pooled_width,
const float spatial_scale,
const int sampling_ratio, const float gamma) {
MLUKernelDeformRoIPoolForward<<<k_dim, k_type, queue>>>(
data_type, input, rois, offset, output, channels, height, width, num_rois,
pooled_height, pooled_width, spatial_scale, sampling_ratio, gamma);
}
template <typename T>
__mlu_func__ void MLUUnion1DeformRoIPoolBackward(
const T *grad_output, const T *input, const T *rois, const T *offset,
T *grad_input, T *grad_offset, const int channels, const int height,
const int width, const int num_rois, const int pooled_height,
const int pooled_width, const T spatial_scale, const int sampling_ratio,
const T gamma) {
for (int bin_index = taskId;
bin_index < num_rois * pooled_width * pooled_height;
bin_index += taskDim) {
int out_batch = bin_index / pooled_width / pooled_height;
int out_height = bin_index / pooled_width % pooled_height;
int out_width = bin_index % pooled_width;
const T *cur_roi = rois + out_batch * ROI_OFFSET;
T *nram_rois = (T *)nram_buffer;
__memcpy((void *)nram_rois, (void *)cur_roi, ROI_OFFSET * sizeof(T),
GDRAM2NRAM);
const int roi_batch = nram_rois[0];
T roi_x_min = nram_rois[1] * spatial_scale - 0.5;
T roi_y_min = nram_rois[2] * spatial_scale - 0.5;
const T roi_x_max = nram_rois[3] * spatial_scale - 0.5;
const T roi_y_max = nram_rois[4] * spatial_scale - 0.5;
const T roi_width = roi_x_max - roi_x_min;
const T roi_height = roi_y_max - roi_y_min;
const T bin_width = roi_width / static_cast<T>(pooled_width);
const T bin_height = roi_height / static_cast<T>(pooled_height);
const T *offset_input = input + roi_batch * height * width * channels;
T *offset_grad_input = grad_input + roi_batch * height * width * channels;
int roi_bin_grid_height =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceilf(roi_height / pooled_height));
int roi_bin_grid_width =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceilf(roi_width / pooled_width));
if (offset != NULL) {
const T *offset_cur = offset +
out_batch * pooled_width * pooled_height * 2 +
out_height * pooled_width + out_width;
roi_x_min += gamma * roi_width * offset_cur[0];
roi_y_min +=
gamma * roi_height * offset_cur[pooled_width * pooled_height];
}
/* NRAM partition
*
* If offset != NULL, NRAM partition belows.
* | |
* ping | pong |
* |---------------------------------------------------------------------|-----------|-----------|
* |nram_tmp1|nram_tmp2|nram_tmp3|nram_tmp4|nram_grad_output|nram_sum_tmp|p1|p2|p3|p4|p1|p2|p3|p4|
*
* If offset == NULL, ping and pang will not be needed.
* | |
* |----------------------------------------------------------------------------------|
* | nram_tmp1 | nram_tmp2 | nram_tmp3 | nram_tmp4 | nram_grad_output |
*
*/
int type_align = NFU_ALIGN_SIZE / sizeof(T);
int channels_max_num_nram = MAX_NRAM_SIZE / sizeof(T);
int channels_nram_split =
channels_max_num_nram / FIVESPLIT / type_align * type_align;
int channel_rem = channels % channels_nram_split;
int channel_loops =
channels / channels_nram_split + (channel_rem != 0 ? 1 : 0);
if (offset != NULL) {
channels_nram_split =
channels_max_num_nram / THIRTEENSPLIT / type_align * type_align;
channel_rem = channels % channels_nram_split;
channel_loops =
channels / channels_nram_split + (channel_rem != 0 ? 1 : 0);
}
for (int channel_loop_index = 0; channel_loop_index < channel_loops;
++channel_loop_index) {
int channels_num =
channels_nram_split >= channels ? channels : channels_nram_split;
const int channel_offset = channel_loop_index * channels_num;
if (channel_loop_index + 1 == channel_loops && channel_rem != 0) {
channels_num = channel_rem;
}
int channels_align = CEIL_ALIGN(channels_num, type_align);
const int32_t nram_sum_tmp_channel = NFU_ALIGN_SIZE / sizeof(T);
int nram_limit = (MAX_NRAM_SIZE / sizeof(T) - 5 * channels_align -
nram_sum_tmp_channel) >>
1;
int c_slice = 0;
int c_slice_align = 0;
T *nram_tmp1 = (T *)nram_buffer;
T *nram_tmp2 = (T *)nram_buffer + channels_align;
T *nram_tmp3 = (T *)nram_buffer + 2 * channels_align;
T *nram_tmp4 = (T *)nram_buffer + 3 * channels_align;
T *nram_grad_output = nram_tmp4 + channels_align;
T *nram_sum_tmp = NULL;
T *nram_ping_input = NULL;
T *nram_pong_input = NULL;
__bang_write_value((T *)nram_grad_output, channels_align, (T)0);
__asm__ volatile("sync;");
if (offset != NULL) {
c_slice = nram_limit / FOURSPLIT / type_align * type_align;
nram_sum_tmp = nram_grad_output + channels_align;
nram_ping_input = nram_sum_tmp + nram_sum_tmp_channel;
nram_pong_input = nram_ping_input + FOURSPLIT * c_slice;
__bang_write_value((T *)nram_sum_tmp, nram_sum_tmp_channel, (T)0);
__bang_write_value((T *)nram_ping_input, FOURSPLIT * c_slice, (T)0);
__bang_write_value((T *)nram_pong_input, FOURSPLIT * c_slice, (T)0);
__asm__ volatile("sync;");
}
const T num_bins =
static_cast<T>(max(roi_bin_grid_height * roi_bin_grid_width, 1));
const T value_div = 1.0f / num_bins;
bool is_ping_empty = true;
__memcpy(nram_grad_output,
grad_output + channels * bin_index + channel_offset,
channels_num * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(nram_grad_output, nram_grad_output, value_div,
channels_align);
for (int iy = 0; iy < roi_bin_grid_height; ++iy) {
T y = roi_y_min + out_height * bin_height +
static_cast<T>(iy + .5f) * bin_height /
static_cast<T>(roi_bin_grid_height);
T y_tmp = y;
if (y_tmp < -1.0 || y_tmp > height) {
is_ping_empty = true;
continue;
}
if (y_tmp <= 0) {
y_tmp = 0;
}
int y_low = 0, y_high = 0;
y_low = int(y_tmp);
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y_tmp = T(y_low);
} else {
y_high = y_low + 1;
}
for (int ix = 0; ix < roi_bin_grid_width; ++ix) {
T x = roi_x_min + out_width * bin_width +
static_cast<T>(ix + .5f) * bin_width /
static_cast<T>(roi_bin_grid_width);
const int sample_index = iy * roi_bin_grid_width + ix;
int c_rem = channels_num;
bool is_empty = false;
T w1, w2, w3, w4;
int x_low = 0, x_high = 0;
bilinearInterpolate(width, y_tmp, x, &w1, &w2, &w3, &w4, &x_low,
&x_high, y_low, &is_empty);
if (is_empty) {
is_ping_empty = true;
continue;
}
__bang_mul_scalar((T *)nram_tmp1, (T *)nram_grad_output, w1,
channels_align);
__bang_mul_scalar((T *)nram_tmp2, (T *)nram_grad_output, w2,
channels_align);
__bang_mul_scalar((T *)nram_tmp3, (T *)nram_grad_output, w3,
channels_align);
__bang_mul_scalar((T *)nram_tmp4, (T *)nram_grad_output, w4,
channels_align);
__asm__ volatile("sync;");
__bang_atomic_add(
(T *)nram_tmp1,
(T *)(offset_grad_input + (y_low * width + x_low) * channels +
channel_offset),
(T *)nram_tmp1, channels_num);
__bang_atomic_add(
(T *)nram_tmp2,
(T *)(offset_grad_input + (y_low * width + x_high) * channels +
channel_offset),
(T *)nram_tmp2, channels_num);
__bang_atomic_add(
(T *)nram_tmp3,
(T *)(offset_grad_input + (y_high * width + x_low) * channels +
channel_offset),
(T *)nram_tmp3, channels_num);
__bang_atomic_add(
(T *)nram_tmp4,
(T *)(offset_grad_input + (y_high * width + x_high) * channels +
channel_offset),
(T *)nram_tmp4, channels_num);
if (offset != NULL) {
c_slice = nram_limit / FOURSPLIT / type_align * type_align;
c_slice_align = 0;
if (is_ping_empty) {
c_slice = c_slice > c_rem ? c_rem : c_slice;
c_slice_align = CEIL_ALIGN(c_slice, type_align);
__bang_write_value(nram_ping_input, FOURSPLIT * c_slice_align,
(T)0);
__asm__ volatile("sync;");
const T *src_offset1 = offset_input + y_low * width * channels +
x_low * channels + channel_offset;
const T *src_offset2 = offset_input + y_low * width * channels +
x_high * channels + channel_offset;
const T *src_offset3 = offset_input + y_high * width * channels +
x_low * channels + channel_offset;
const T *src_offset4 = offset_input + y_high * width * channels +
x_high * channels + channel_offset;
__memcpy(nram_ping_input, src_offset1, c_slice * sizeof(T),
GDRAM2NRAM);
__memcpy(nram_ping_input + c_slice_align, src_offset2,
c_slice * sizeof(T), GDRAM2NRAM);
__memcpy(nram_ping_input + 2 * c_slice_align, src_offset3,
c_slice * sizeof(T), GDRAM2NRAM);
__memcpy(nram_ping_input + 3 * c_slice_align, src_offset4,
c_slice * sizeof(T), GDRAM2NRAM);
is_ping_empty = false;
}
int c_offset = 0;
int pongc_slice = 0;
int pongc_slice_align = 0;
while (c_rem > 0) {
c_slice = c_slice > c_rem ? c_rem : c_slice;
c_slice_align = CEIL_ALIGN(c_slice, type_align);
if (sample_index + 1 < roi_bin_grid_height * roi_bin_grid_width) {
int iy_tmp = (sample_index + 1) / roi_bin_grid_width;
int ix_tmp = (sample_index + 1) % roi_bin_grid_width;
T y_tmp = roi_y_min + out_height * bin_height +
static_cast<T>(iy_tmp + .5f) * bin_height /
static_cast<T>(roi_bin_grid_height);
T x_tmp = roi_x_min + out_width * bin_width +
static_cast<T>(ix_tmp + .5f) * bin_width /
static_cast<T>(roi_bin_grid_width);
int x_low_tmp = 0, x_high_tmp = 0, y_low_tmp = 0,
y_high_tmp = 0;
if (y_tmp < -1.0 || y_tmp > height) {
is_empty = true;
} else {
T w1_tmp, w2_tmp, w3_tmp, w4_tmp;
if (y_tmp <= 0) {
y_tmp = 0;
}
y_low_tmp = int(y_tmp);
if (y_low_tmp >= height - 1) {
y_high_tmp = y_low_tmp = height - 1;
y_tmp = T(y_low_tmp);
} else {
y_high_tmp = y_low_tmp + 1;
}
bilinearInterpolate(width, y_tmp, x_tmp, &w1_tmp, &w2_tmp,
&w3_tmp, &w4_tmp, &x_low_tmp, &x_high_tmp,
y_low_tmp, &is_empty);
}
pongc_slice = nram_limit / FOURSPLIT / type_align * type_align;
pongc_slice =
pongc_slice > channels_num ? channels_num : pongc_slice;
pongc_slice_align = CEIL_ALIGN(pongc_slice, type_align);
__bang_write_value(nram_pong_input,
FOURSPLIT * pongc_slice_align, (T)0);
__asm__ volatile("sync;");
if (!is_empty) {
const T *src_offset1 = offset_input +
y_low_tmp * width * channels +
x_low_tmp * channels + channel_offset;
const T *src_offset2 = offset_input +
y_low_tmp * width * channels +
x_high_tmp * channels + channel_offset;
const T *src_offset3 = offset_input +
y_high_tmp * width * channels +
x_low_tmp * channels + channel_offset;
const T *src_offset4 = offset_input +
y_high_tmp * width * channels +
x_high_tmp * channels + channel_offset;
__memcpy_async(nram_pong_input, src_offset1,
pongc_slice * sizeof(T), GDRAM2NRAM);
__memcpy_async(nram_pong_input + pongc_slice_align,
src_offset2, pongc_slice * sizeof(T),
GDRAM2NRAM);
__memcpy_async(nram_pong_input + 2 * pongc_slice_align,
src_offset3, pongc_slice * sizeof(T),
GDRAM2NRAM);
__memcpy_async(nram_pong_input + 3 * pongc_slice_align,
src_offset4, pongc_slice * sizeof(T),
GDRAM2NRAM);
}
}
__bang_mul_scalar(nram_tmp1, nram_ping_input + 3 * c_slice_align,
y - y_low, c_slice_align);
__bang_mul_scalar(nram_tmp2, nram_ping_input + c_slice_align,
y_high - y, c_slice_align);
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
__bang_mul_scalar(nram_tmp2, nram_ping_input + 2 * c_slice_align,
y_low - y, c_slice_align);
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
__bang_mul_scalar(nram_tmp2, nram_ping_input, y - y_high,
c_slice_align);
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
__bang_mul_scalar(nram_tmp1, nram_tmp1, gamma * roi_width,
c_slice_align);
__bang_mul(nram_tmp1, nram_grad_output, nram_tmp1, c_slice_align);
const int32_t kernel_width =
c_slice_align / nram_sum_tmp_channel +
(int32_t)(c_slice_align % nram_sum_tmp_channel > 0);
__bang_sumpool(nram_sum_tmp, nram_tmp1, nram_sum_tmp_channel, 1,
kernel_width, 1, kernel_width, kernel_width, 1);
__bang_reduce_sum(nram_sum_tmp, nram_sum_tmp,
nram_sum_tmp_channel);
__bang_atomic_add(
(T *)nram_sum_tmp,
(T *)(grad_offset +
out_batch * pooled_width * pooled_height * 2 +
out_height * pooled_width + out_width),
(T *)nram_sum_tmp, 1);
__bang_write_value((T *)nram_sum_tmp, nram_sum_tmp_channel, (T)0);
__bang_mul_scalar(nram_tmp1, nram_ping_input + 3 * c_slice_align,
x - x_low, c_slice_align);
__bang_mul_scalar(nram_tmp2, nram_ping_input + 2 * c_slice_align,
x_high - x, c_slice_align);
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
__bang_mul_scalar(nram_tmp2, nram_ping_input + c_slice_align,
x_low - x, c_slice_align);
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
__bang_mul_scalar(nram_tmp2, nram_ping_input, x - x_high,
c_slice_align);
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
__bang_mul_scalar(nram_tmp1, nram_tmp1, gamma * roi_height,
c_slice_align);
__bang_mul(nram_tmp1, nram_grad_output, nram_tmp1, c_slice_align);
__bang_sumpool(nram_sum_tmp, nram_tmp1, nram_sum_tmp_channel, 1,
kernel_width, 1, kernel_width, kernel_width, 1);
__bang_reduce_sum(nram_sum_tmp, nram_sum_tmp,
NFU_ALIGN_SIZE / sizeof(T));
__bang_atomic_add(
(T *)nram_sum_tmp,
(T *)(grad_offset +
out_batch * pooled_width * pooled_height * 2 +
pooled_width * pooled_height +
out_height * pooled_width + out_width),
(T *)nram_sum_tmp, 1);
T *nram_tmp = nram_ping_input;
nram_ping_input = nram_pong_input;
nram_pong_input = nram_tmp;
c_rem -= c_slice;
c_offset += c_slice;
__asm__ volatile("sync;");
}
}
}
}
}
}
}
__mlu_global__ void MLUKernelDeformRoIPoolBackward(
cnrtDataType_t data_type, const void *grad_output, const void *input,
const void *rois, const void *offset, void *grad_input, void *grad_offset,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, const float spatial_scale,
const int sampling_ratio, const float gamma) {
switch (data_type) {
case CNRT_FLOAT16: {
MLUUnion1DeformRoIPoolBackward(
(half *)grad_output, (half *)input, (half *)rois, (half *)offset,
(half *)grad_input, (half *)grad_offset, channels, height, width,
num_rois, pooled_height, pooled_width,
static_cast<half>(spatial_scale), sampling_ratio,
static_cast<half>(gamma));
}; break;
case CNRT_FLOAT32: {
MLUUnion1DeformRoIPoolBackward(
(float *)grad_output, (float *)input, (float *)rois, (float *)offset,
(float *)grad_input, (float *)grad_offset, channels, height, width,
num_rois, pooled_height, pooled_width,
static_cast<float>(spatial_scale), sampling_ratio,
static_cast<float>(gamma));
}; break;
default: {
break;
}
}
}
void KernelDeformRoIPoolBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
cnrtDataType_t data_type, const void *grad_output, const void *input,
const void *rois, const void *offset, void *grad_input, void *grad_offset,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, const float spatial_scale,
const int sampling_ratio, const float gamma) {
MLUKernelDeformRoIPoolBackward<<<k_dim, k_type, queue>>>(
data_type, grad_output, input, rois, offset, grad_input, grad_offset,
channels, height, width, num_rois, pooled_height, pooled_width,
spatial_scale, sampling_ratio, gamma);
}
/*************************************************************************
* Copyright (C) 2021 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <float.h>
#include "common_mlu_helper.hpp"
#define PING 0
#define PONG 1
__nram__ char nram_buffer[MAX_NRAM_SIZE];
namespace forward {
template <typename T>
__mlu_func__ void loadInput(char *nram_input, T *dram_input, const int32_t size,
const int32_t dst_stride = 0,
const int32_t src_stride = 0,
const int32_t count = 1) {
if (dst_stride == src_stride) {
__memcpy_async(nram_input, dram_input, size * count, GDRAM2NRAM);
} else {
__memcpy_async(nram_input, dram_input, size, GDRAM2NRAM, dst_stride,
src_stride, count - 1);
}
}
template <typename T>
__mlu_func__ void loadWeight(char *nram_input, T *dram_input, const int32_t t,
const int32_t c, const int32_t has_weight,
const int32_t partition_nc) {
if (has_weight && partition_nc && t >= 0 && t < c) {
__memcpy_async(nram_input, (T *)dram_input + t, sizeof(T), GDRAM2NRAM);
}
}
template <typename T>
__mlu_func__ void storeOutput(T *dram_output, char *nram_output,
const int32_t size, const int32_t dst_stride = 0,
const int32_t src_stride = 0,
const int32_t count = 1) {
if (dst_stride == src_stride) {
__memcpy_async(dram_output, nram_output, size * count, NRAM2GDRAM);
} else {
__memcpy_async(dram_output, nram_output, size, NRAM2GDRAM, dst_stride,
src_stride, count - 1);
}
}
template <typename T>
__mlu_func__ void compute(T *input, const int32_t *target, const T *weight,
const int32_t has_weight, const int32_t partition_nc,
const int32_t deal_num, const int32_t n_seg,
const int32_t c, const int32_t c_seg,
const int32_t c_start_index, const float alpha,
const float gamma, T *compute_a, T *compute_b,
T *output) {
// set params
const int32_t c_num =
has_weight ? PAD_UP(c_seg, NFU_ALIGN_SIZE / sizeof(T)) : c_seg;
const int32_t c_end_index = c_start_index + c_seg;
const int32_t half_epsilon = 0x0400;
const T epsilon_f =
sizeof(T) == sizeof(float) ? FLT_MIN : *((half *)&half_epsilon);
// 0. alpha_t * p_t^r = alpha * (1 - p) ^ gamma if t == c_i
// = (1 - alpha) * p ^ gamma if t != c_i
__nramset((T *)output, deal_num, (T)(1 - alpha));
__bang_active_sigmoid((T *)compute_b, (T *)input, deal_num);
for (int32_t i = 0; i < n_seg; ++i) {
const int32_t t = *((uint32_t *)target + i);
if (t >= c_start_index && t < c_end_index) {
const uint32_t index = i * c_num + t - c_start_index;
*((T *)input + index) = -1.0 * (*((T *)input + index));
*((T *)compute_b + index) = 1.0 - (*((T *)compute_b + index)) + epsilon_f;
*((T *)output + index) = alpha;
}
}
if (sizeof(T) == sizeof(half)) {
__bang_half2float((float *)compute_a, (half *)compute_b, deal_num);
__bang_active_loghp((float *)compute_a, (float *)compute_a, deal_num);
__bang_mul_const((float *)compute_a, (float *)compute_a, (float)gamma,
deal_num);
__bang_active_exphp((float *)compute_a, (float *)compute_a, deal_num);
__bang_float2half_rd((half *)compute_a, (float *)compute_a, deal_num);
} else {
__bang_active_loghp((T *)compute_a, (T *)compute_b, deal_num);
__bang_mul_const((T *)compute_a, (T *)compute_a, (T)gamma, deal_num);
__bang_active_exphp((T *)compute_a, (T *)compute_a, deal_num);
}
__bang_mul((T *)output, (T *)compute_a, (T *)output, deal_num);
// 1. max = max(0, -x) if t == c_i
// = max(0, x) if t != c_i
__nramset((T *)compute_b, deal_num, (T)0);
__bang_maxequal((T *)compute_b, (T *)compute_b, (T *)input, deal_num);
// 2. -log(p_t) = ln(e^(-max)+ e^(-max-x) + max if t == c_i
// = ln(e^(-max)+ e^(-max+x) + max if t != c_i
__bang_mul_const((T *)compute_a, (T *)compute_b, (T)-1.0, deal_num);
__bang_add((T *)input, (T *)compute_a, (T *)input, deal_num);
__bang_active_exphp((T *)compute_a, (T *)compute_a, deal_num);
__bang_active_exphp((T *)input, (T *)input, deal_num);
__bang_add((T *)compute_a, (T *)compute_a, (T *)input, deal_num);
__bang_active_loghp((T *)compute_a, (T *)compute_a, deal_num);
__bang_add((T *)input, (T *)compute_a, (T *)compute_b, deal_num);
// 3. output = alpha_t * p_t^r * [-log(p_t)]
__bang_mul((T *)output, (T *)output, (T *)input, deal_num);
// 4. with weight
if (has_weight) {
for (int32_t i = 0; i < n_seg; ++i) {
int32_t t = *((int32_t *)target + i);
if (t >= 0 && t < c) {
t = partition_nc ? 0 : t;
__bang_mul_const((T *)output + i * c_num, (T *)output + i * c_num,
*((T *)weight + t), c_num);
}
}
}
}
template <typename T>
__mlu_func__ void startPipeline(
const T *input, const int32_t *target, const T *weight,
char *nram_compute_a, char *nram_compute_b, char *nram_input,
char *nram_target, char *nram_weight, char *nram_output,
const int32_t has_weight, const int32_t partition_nc,
const int32_t pingpong_offset, const int32_t pingpong_weight_offset,
const int32_t c_offset_num, const int32_t n, const int32_t n_seg,
const int32_t c, const int32_t c_seg, const float alpha, const float gamma,
T *output) {
// with offset
input = (T *)((char *)input + c_offset_num * sizeof(T));
output = (T *)((char *)output + c_offset_num * sizeof(T));
const int32_t c_seg_align_num = PAD_UP(c_seg, NFU_ALIGN_SIZE / sizeof(T));
const int32_t c_num = has_weight ? c_seg_align_num : c_seg;
const int32_t deal_num = PAD_UP(n_seg * c_num, NFU_ALIGN_SIZE / sizeof(T));
const int32_t load_size = c_seg * sizeof(T);
const int32_t dram_stride = c * sizeof(T);
const int32_t nram_stride = c_num * sizeof(T);
if (has_weight && !partition_nc) {
loadInput<T>(nram_weight, (T *)weight, load_size, nram_stride, dram_stride,
1);
__asm__ volatile("sync;\n\t");
}
const int32_t repeat = n / n_seg;
const int32_t remain = n % n_seg;
/*
* Pipeline: The pipeline is processed in three stages: Load, Compute, Store.
* The allocated memory space of NRAM is divided into two parts:
* PING and Pong. In a single time slice, PING is used to process
* IO stream and PONG is used for computation. Both of them are
* processed synchronously until finished.
*
* diagram of PINGPONG:
* |------|-----------------------------------------------------------------|
* | | space |
* |------|-----------------------------------------------------------------|
* | time | Ping | Pong | Ping | Pong | Ping | Pong |
* |------|-----------------------------------------------------------------|
* | 0 | L0 | | | | | |
* | 1 | C0 | L1 | | | | |
* | 2 | S0 | C1 | L2 | | | |
* | 3 | | S1 | C2 | L3 | | |
* | 4 | | | S2 | C3 | L4 | |
* | 5 | | | | S3 | C4 | L5 |
* | 6 | | | | | S4 | C5 |
* | 7 | | | | | | S5 |
* |------|-----------------------------------------------------------------|
*/
// diagram of PINGPONG: L0
if (repeat > 0) {
loadInput<T>(nram_input, (T *)input, load_size, nram_stride, dram_stride,
n_seg);
loadInput<int32_t>(nram_target, (int32_t *)target, n_seg * sizeof(int32_t));
loadWeight<T>(nram_weight, (T *)weight, *((int32_t *)target), c, has_weight,
partition_nc);
__asm__ volatile("sync;\n\t");
}
// diagram of PINGPONG: C0 and L1
if (repeat > 1) {
compute((T *)nram_input, (int32_t *)nram_target, (T *)nram_weight,
has_weight, partition_nc, deal_num, n_seg, c, c_seg, c_offset_num,
alpha, gamma, (T *)nram_compute_a, (T *)nram_compute_b,
(T *)nram_output);
loadInput<T>((char *)nram_input + pingpong_offset, (T *)input + c * n_seg,
load_size, nram_stride, dram_stride, n_seg);
loadInput<int32_t>((char *)nram_target + pingpong_offset,
(int32_t *)target + n_seg, n_seg * sizeof(int32_t));
loadWeight<T>((char *)nram_weight + pingpong_weight_offset, (T *)weight,
*((int32_t *)target + n_seg), c, has_weight, partition_nc);
__asm__ volatile("sync;\n\t");
}
for (int32_t i = 0; i < repeat - 2; ++i) {
storeOutput<T>((T *)output + i * c * n_seg,
nram_output + (i % 2) * pingpong_offset, load_size,
dram_stride, nram_stride, n_seg);
loadInput<T>((char *)nram_input + (i % 2) * pingpong_offset,
(T *)(input) + (i + 2) * c * n_seg, load_size, nram_stride,
dram_stride, n_seg);
loadInput<int32_t>((char *)nram_target + (i % 2) * pingpong_offset,
(int32_t *)target + (i + 2) * n_seg,
n_seg * sizeof(int32_t));
loadWeight<T>((char *)nram_weight + (i % 2) * pingpong_weight_offset,
(T *)weight, *((int32_t *)target + (i + 2) * n_seg), c,
has_weight, partition_nc);
compute((T *)(nram_input + ((i + 1) % 2) * pingpong_offset),
(int32_t *)(nram_target + ((i + 1) % 2) * pingpong_offset),
(T *)(nram_weight +
partition_nc * ((i + 1) % 2) * pingpong_weight_offset),
has_weight, partition_nc, deal_num, n_seg, c, c_seg, c_offset_num,
alpha, gamma, (T *)nram_compute_a, (T *)nram_compute_b,
(T *)(nram_output + ((i + 1) % 2) * pingpong_offset));
__asm__ volatile("sync;\n\t");
}
if (repeat > 1) {
storeOutput<T>((T *)output + (repeat - 2) * c * n_seg,
(char *)nram_output + (repeat % 2) * pingpong_offset,
load_size, dram_stride, nram_stride, n_seg);
}
if (remain > 0) {
loadInput<T>((char *)nram_input + (repeat % 2) * pingpong_offset,
(T *)input + repeat * c * n_seg, load_size, nram_stride,
dram_stride, remain);
loadInput<int32_t>((char *)nram_target + (repeat % 2) * pingpong_offset,
(int32_t *)target + repeat * n_seg,
remain * sizeof(int32_t));
loadWeight<T>((char *)nram_weight + (repeat % 2) * pingpong_weight_offset,
(T *)weight, *((int32_t *)target + repeat * n_seg), c,
has_weight, partition_nc);
}
if (repeat > 0) {
compute((T *)(nram_input + ((repeat - 1) % 2) * pingpong_offset),
(int32_t *)(nram_target + ((repeat - 1) % 2) * pingpong_offset),
(T *)(nram_weight +
partition_nc * ((repeat - 1) % 2) * pingpong_weight_offset),
has_weight, partition_nc, deal_num, n_seg, c, c_seg, c_offset_num,
alpha, gamma, (T *)nram_compute_a, (T *)nram_compute_b,
(T *)(nram_output + ((repeat - 1) % 2) * pingpong_offset));
}
__asm__ volatile("sync;\n\t");
if (repeat > 0) {
storeOutput<T>((T *)output + (repeat - 1) * c * n_seg,
(char *)nram_output + ((repeat - 1) % 2) * pingpong_offset,
load_size, dram_stride, nram_stride, n_seg);
}
if (remain > 0) {
int32_t rem_num = PAD_UP(remain * c_num, NFU_ALIGN_SIZE / sizeof(T));
compute((T *)(nram_input + (repeat % 2) * pingpong_offset),
(int32_t *)(nram_target + (repeat % 2) * pingpong_offset),
(T *)(nram_weight +
partition_nc * (repeat % 2) * pingpong_weight_offset),
has_weight, partition_nc, rem_num, remain, c, c_seg, c_offset_num,
alpha, gamma, (T *)nram_compute_a, (T *)nram_compute_b,
(T *)(nram_output + (repeat % 2) * pingpong_offset));
__asm__ volatile("sync;\n\t");
storeOutput<T>((T *)output + repeat * c * n_seg,
(char *)nram_output + (repeat % 2) * pingpong_offset,
load_size, dram_stride, nram_stride, remain);
}
__asm__ volatile("sync;\n\t");
}
template <typename T>
__mlu_func__ void focalLossSigmoidForwardBlock(
const T *input, const int32_t *target, const T *weight, const int32_t n,
const int32_t c, const float alpha, const float gamma, T *output) {
/*
* NRAM partition
* |-----------------------------------------------------------------------|
* | weight |
* |------------------------------- COMPUTE -------------------------------|
* | | |
* | computeA | computeB |
* | | |
* |------------- PING ------------------------------- PONG ---------------|
* | | |
* | input | input |
* | | |
* |-----------------------------------|-----------------------------------|
* | | |
* | output | output |
* | | |
* |-----------------------------------|-----------------------------------|
* | target | target |
* |-----------------------------------|-----------------------------------|
*
* split_pipeline_num is 6: COMPUTE(computeA,computeB), PING(input,output),
* PONG(input,output).
* split_target_num is 2: PING(target), PONG(target).
* weight is not NULL:
* The nram-size of weight is equal to c_align_size when partition input-N.
* The nram-size of weight is equal to NFU_ALIGN_SIZE when partition
* input-NC.
*/
// calculate threshold of c
const int32_t split_pipeline_num = 6;
const int32_t split_target_num = 2;
const int32_t has_weight = weight != NULL;
const int32_t threshold_c =
PAD_DOWN((MAX_NRAM_SIZE - split_target_num * sizeof(int32_t)) /
(split_pipeline_num + has_weight),
NFU_ALIGN_SIZE) /
sizeof(T);
const int32_t c_align = PAD_UP(c, NFU_ALIGN_SIZE / sizeof(T));
const int32_t c_align_size = c_align * sizeof(T);
if (c <= threshold_c) {
// partition inputN
int32_t c_num = c;
int32_t reservered_align_size =
(split_target_num + split_pipeline_num) * NFU_ALIGN_SIZE;
int32_t weight_size = 0;
if (has_weight) {
c_num = c_align;
reservered_align_size = split_target_num * NFU_ALIGN_SIZE;
weight_size = c_align_size;
}
const int32_t remain_size =
MAX_NRAM_SIZE - weight_size - reservered_align_size;
const int32_t n_seg =
remain_size / (split_pipeline_num * c_num * sizeof(T) +
split_target_num * sizeof(int32_t));
const int32_t split_pipeline_size =
PAD_UP(c_num * n_seg * sizeof(T), NFU_ALIGN_SIZE);
const int32_t compute_size = 2 * split_pipeline_size;
const int32_t pingpong_offset = (MAX_NRAM_SIZE - weight_size - compute_size) / 2;
char *nram_weight = (char *)nram_buffer;
char *nram_compute_a = nram_weight + has_weight * c_align_size;
char *nram_compute_b = nram_compute_a + split_pipeline_size;
char *nram_input = nram_compute_b + split_pipeline_size;
char *nram_output = nram_input + split_pipeline_size;
char *nram_target = nram_output + split_pipeline_size;
startPipeline<T>(input, target, weight, nram_compute_a, nram_compute_b,
nram_input, nram_target, nram_weight, nram_output,
has_weight, 0, pingpong_offset, 0, 0, n, n_seg, c, c,
alpha, gamma, output);
} else {
// partition inputNC
const int32_t weight_size = has_weight * NFU_ALIGN_SIZE;
const int32_t remain_size = MAX_NRAM_SIZE - weight_size;
const int32_t split_pipeline_size = PAD_DOWN(
(remain_size - split_target_num * NFU_ALIGN_SIZE) / split_pipeline_num,
NFU_ALIGN_SIZE);
const int32_t c_seg = split_pipeline_size / sizeof(T);
const int32_t n_seg = 1;
const int32_t compute_size = 2 * split_pipeline_size;
const int32_t pingpong_offset = (MAX_NRAM_SIZE - weight_size - compute_size) / 2;
const int32_t pingpong_weight_offset = weight_size / 2;
char *nram_weight = (char *)nram_buffer;
char *nram_compute_a = nram_weight + weight_size;
char *nram_compute_b = nram_compute_a + split_pipeline_size;
char *nram_input = nram_compute_b + split_pipeline_size;
char *nram_output = nram_input + split_pipeline_size;
char *nram_target = nram_output + split_pipeline_size;
const int32_t loop_num = (c + c_seg - 1) / c_seg;
const int32_t partition_nc = 1;
for (int32_t i = 0; i < loop_num; ++i) {
const int32_t c_index = i * c_seg;
const int32_t c_seg_curr = i == (loop_num - 1) ? c - c_index : c_seg;
startPipeline<T>(input, target, weight, nram_compute_a, nram_compute_b,
nram_input, nram_target, nram_weight, nram_output,
has_weight, partition_nc, pingpong_offset,
pingpong_weight_offset, c_index, n, n_seg, c, c_seg_curr,
alpha, gamma, output);
}
}
}
template <typename T>
__mlu_global__ void MLUUnion1KernelFocalLossSigmoidForward(
const void *input, const void *target, const void *weight, const int32_t N,
const int32_t C, const float alpha, const float gamma, void *output) {
const int32_t n_seg = N / taskDim + (taskId == taskDim - 1) * (N % taskDim);
const T *input_offset = (T *)input + N / taskDim * taskId * C;
const int32_t *target_offset = (int32_t *)target + N / taskDim * taskId;
T *output_offset = (T *)output + N / taskDim * taskId * C;
focalLossSigmoidForwardBlock((T *)input_offset, (int32_t *)target_offset,
(T *)weight, n_seg, C, alpha, gamma,
(T *)output_offset);
}
} // namespace forward
namespace backward {
template <typename T>
__mlu_func__ void loadInput(char *nram_input, char *nram_target,
const T *gdram_input, const int32_t *gdram_target,
const int32_t deal_n, const int32_t total_c,
const bool pingping_flag, const bool has_weight,
const int32_t nram_offset,
const int32_t gdram_offset) {
if (pingping_flag == PONG) {
nram_input += nram_offset;
nram_target += nram_offset;
}
__memcpy_async(nram_target, gdram_target + gdram_offset / total_c,
deal_n * sizeof(int32_t), GDRAM2NRAM);
char *nram_input_load = nram_input;
int32_t compute_align_size = 2 * NFU_ALIGN_SIZE;
if (has_weight) {
if (sizeof(T) == sizeof(half)) {
int32_t compute_align_num = compute_align_size / sizeof(float);
int32_t align_c = PAD_UP(total_c, compute_align_num);
int32_t compute_size = deal_n * align_c * sizeof(float);
nram_input_load += compute_size / 2;
}
int32_t align_c = PAD_UP(total_c, NFU_ALIGN_SIZE / sizeof(T));
int32_t total_c_size = total_c * sizeof(T);
int32_t align_c_size = align_c * sizeof(T);
__memcpy_async(nram_input_load, gdram_input + gdram_offset, total_c_size,
GDRAM2NRAM, align_c_size, total_c_size, deal_n - 1);
} else {
if (sizeof(T) == sizeof(half)) {
int32_t compute_size =
PAD_UP(deal_n * total_c * sizeof(float), compute_align_size);
nram_input_load += compute_size / 2;
}
int32_t load_size = deal_n * total_c * sizeof(T);
__memcpy_async(nram_input_load, gdram_input + gdram_offset, load_size,
GDRAM2NRAM);
}
}
template <typename T>
__mlu_func__ void sigmoid(T *dst_data, const T *src_data,
const int32_t elem_count) {
__bang_mul_const(dst_data, (T *)src_data, T(-1), elem_count);
__bang_active_exphp(dst_data, dst_data, elem_count);
__bang_add_const(dst_data, dst_data, T(1), elem_count);
__bang_active_reciphp(dst_data, dst_data, elem_count);
}
template <typename T>
__mlu_func__ void coreCompute(char *nram_input, const T *nram_weight,
const float *nram_flt_min, char *nram_pt,
char *nram_alpha_t, char *nram_temp,
char *nram_target, const float *nram_gamma,
char *nram_output, const float alpha,
const int32_t compute_num, const int32_t deal_n,
const int32_t total_c, const bool pingpong_flag,
const int32_t nram_offset,
const bool has_weight) {
if (pingpong_flag == PONG) {
nram_input += nram_offset;
nram_pt += nram_offset;
nram_alpha_t += nram_offset;
nram_temp += nram_offset;
nram_output += nram_offset;
nram_target += nram_offset;
}
if (sizeof(T) == sizeof(half)) {
const int32_t compute_size = compute_num * sizeof(float);
char *nram_input_load = nram_input + compute_size / 2;
__bang_half2float((float *)nram_input, (half *)nram_input_load,
compute_num);
}
// 0. alpha_t = alpha - 1
__nramset((float *)nram_alpha_t, compute_num, (float)(alpha - 1.0));
// 1. pt = 1 - sigmoid(x)
sigmoid((float *)nram_pt, (float *)nram_input, compute_num);
__bang_mul_const((float *)nram_pt, (float *)nram_pt, (float)(-1),
compute_num);
__bang_add_const((float *)nram_pt, (float *)nram_pt, (float)1, compute_num);
// 2. pt = target[n] == c ? sigmoid(x) : 1 - sigmoid(x)
// alpha_t = target[n] == c ? alpha : alpha - 1
const int32_t nfu_align_num = NFU_ALIGN_SIZE / sizeof(float);
for (int n = 0; n < deal_n; n++) {
const int32_t target_value = ((int32_t *)nram_target)[n];
if (target_value >= total_c || target_value < 0) continue;
int32_t c_offset = 0;
if (has_weight) {
int32_t c_align_num = nfu_align_num;
if (sizeof(T) == sizeof(half)) {
c_align_num += nfu_align_num;
}
c_offset = PAD_UP(total_c, c_align_num);
} else {
c_offset = total_c;
}
int32_t idx = n * c_offset + target_value;
*((float *)nram_pt + idx) = 1.0 - *((float *)nram_pt + idx);
*((float *)nram_alpha_t + idx) = alpha;
}
// 3. temp = -alpha_t * e^(gamma * log(max(1 - pt, FLT_MIN))
__bang_mul_const((float *)nram_temp, (float *)nram_pt, (float)(-1),
compute_num);
__bang_add_const((float *)nram_temp, (float *)nram_temp, (float)(1),
compute_num);
__bang_cycle_maxequal((float *)nram_temp, (float *)nram_temp,
(float *)nram_flt_min, compute_num, nfu_align_num);
__bang_active_loghp((float *)nram_temp, (float *)nram_temp, compute_num);
__bang_cycle_mul((float *)nram_temp, (float *)nram_temp, (float *)nram_gamma,
compute_num, nfu_align_num);
__bang_active_exphp((float *)nram_temp, (float *)nram_temp, compute_num);
__bang_mul((float *)nram_temp, (float *)nram_temp, (float *)nram_alpha_t,
compute_num);
__bang_mul_const((float *)nram_temp, (float *)nram_temp, (float)(-1),
compute_num);
// 4. output = 1 - pt - gamma * pt * log(max(pt, FLT_MIN))
__bang_cycle_maxequal((float *)nram_output, (float *)nram_pt,
(float *)nram_flt_min, compute_num, nfu_align_num);
__bang_active_loghp((float *)nram_output, (float *)nram_output, compute_num);
__bang_mul((float *)nram_output, (float *)nram_output, (float *)nram_pt,
compute_num);
__bang_cycle_mul((float *)nram_output, (float *)nram_output,
(float *)nram_gamma, compute_num, nfu_align_num);
__bang_add((float *)nram_output, (float *)nram_output, (float *)nram_pt,
compute_num);
__bang_mul_const((float *)nram_output, (float *)nram_output, (float)(-1),
compute_num);
__bang_add_const((float *)nram_output, (float *)nram_output, (float)(1),
compute_num);
// 5. output = output * temp
__bang_mul((float *)nram_output, (float *)nram_output, (float *)nram_temp,
compute_num);
if (sizeof(T) == sizeof(half)) {
__bang_float2half_rd((half *)nram_output, (float *)nram_output,
compute_num);
}
if (has_weight) {
// with weight
for (int n = 0; n < deal_n; n++) {
int32_t c_align_num = nfu_align_num;
if (sizeof(T) == sizeof(half)) {
c_align_num += nfu_align_num;
}
int32_t align_c = PAD_UP(total_c, c_align_num);
int32_t target_value = ((int32_t *)nram_target)[n];
T weight_value = nram_weight[target_value];
__bang_mul_const((T *)nram_output + n * align_c,
(T *)nram_output + n * align_c, weight_value, align_c);
}
}
}
template <typename T>
__mlu_func__ void storeOutput(T *gdram_output, const char *nram_output,
const int32_t deal_n, const int32_t total_c,
const bool pingpong_flag, const bool has_weight,
const int32_t nram_offset,
const int32_t gdram_offset) {
if (pingpong_flag == PONG) {
nram_output += nram_offset;
}
const int32_t store_size = deal_n * total_c * sizeof(T);
if (has_weight) {
int32_t align_c = PAD_UP(total_c, NFU_ALIGN_SIZE / sizeof(T));
int32_t total_c_size = total_c * sizeof(T);
int32_t align_c_size = align_c * sizeof(T);
__memcpy_async(gdram_output + gdram_offset, nram_output, total_c_size,
NRAM2GDRAM, total_c_size, align_c_size, deal_n - 1);
} else {
__memcpy_async(gdram_output + gdram_offset, nram_output, store_size,
NRAM2GDRAM);
}
}
template <typename T>
__mlu_func__ void focalLossSigmoidBackwardBlock(
const T *input, const int32_t *target, const T *weight, const float gamma,
const float alpha, const int32_t total_n, const int32_t deal_n,
const int32_t total_c, T *output) {
// params per time slice
int32_t deal_num = deal_n * total_c;
int32_t deal_size = deal_num * sizeof(float);
int32_t compute_num = 0;
int32_t compute_size = 0;
int32_t compute_align_size = NFU_ALIGN_SIZE;
const int32_t nfu_align_num = NFU_ALIGN_SIZE / sizeof(T);
if (sizeof(T) == sizeof(half)) {
compute_align_size += NFU_ALIGN_SIZE;
}
const int32_t compute_align_num = compute_align_size / sizeof(float);
bool has_weight = false;
if (weight != NULL) {
has_weight = true;
int32_t align_c = PAD_UP(total_c, compute_align_num);
compute_num = deal_n * align_c;
compute_size = compute_num * sizeof(float);
} else {
compute_size = PAD_UP(deal_size, compute_align_size);
compute_num = compute_size / sizeof(float);
}
// params per core
int32_t total_num = total_n * total_c;
int32_t num_per_core = PAD_DOWN(total_num / taskDim, deal_num);
int32_t loop_per_core = num_per_core / deal_num;
/* NRAM partition:
*
* |-----------------ping pong--------------------|
* |input | pt | alpha_t | temp | output | target | flt_min | gamma | weight|
*
* split_pipeline_num is 5: input, pt, alpha_t, temp, output.
* nram_reserved_line_num is 2: flt_min, gamma.
*/
const int32_t split_pipeline_num = 5;
const int32_t nram_reserved_line_num = 2;
int32_t target_deal_size = deal_n * sizeof(int32_t);
int32_t target_deal_size_align = PAD_UP(target_deal_size, NFU_ALIGN_SIZE);
// nram PING/PONG offset
int32_t ping_pong_offset =
compute_size * split_pipeline_num + target_deal_size_align;
// gdram addr
int32_t *base_addr_target =
(int32_t *)target + taskId * loop_per_core * deal_n;
T *base_addr_input = (T *)input + taskId * num_per_core;
T *base_addr_output = output + taskId * num_per_core;
// nram addr
char *nram_input = (char *)nram_buffer;
char *nram_pt = nram_input + compute_size;
char *nram_alpha_t = nram_pt + compute_size;
char *nram_temp = nram_alpha_t + compute_size;
char *nram_output = nram_temp + compute_size;
char *nram_target = nram_output + compute_size;
float *nram_flt_min = NULL;
float *nram_gamma = NULL;
T *nram_weight = NULL;
if (!has_weight) {
nram_flt_min = (float *)(nram_buffer + MAX_NRAM_SIZE -
nram_reserved_line_num * NFU_ALIGN_SIZE);
nram_gamma = nram_flt_min + nfu_align_num;
} else {
int32_t weight_space = PAD_UP(total_c * sizeof(T), NFU_ALIGN_SIZE);
nram_flt_min =
(float *)(nram_buffer + MAX_NRAM_SIZE -
nram_reserved_line_num * NFU_ALIGN_SIZE - weight_space);
nram_gamma = nram_flt_min + nfu_align_num;
nram_weight = (T *)(nram_gamma + nfu_align_num);
__memcpy_async(nram_weight, weight, total_c * sizeof(T), GDRAM2NRAM);
}
// nram set gamma and FLT_MIN
__nramset(nram_gamma, nfu_align_num, gamma);
__nramset(nram_flt_min, nfu_align_num, FLT_MIN);
/*
* Pipeline: The pipeline is processed in three stages: Load, Compute, Store.
* The allocated memory space of NRAM is divided into two parts:
* PING and Pong. In a single time slice, PING is used to process
* IO stream and PONG is used for computation. Both of them are
* processed synchronously until finished.
*
* diagram of PINGPONG:
* |------|-----------------------------------------------------------------|
* | | space |
* |------|-----------------------------------------------------------------|
* | time | Ping | Pong | Ping | Pong | Ping | Pong |
* |------|-----------------------------------------------------------------|
* | 0 | L0 | | | | | |
* | 1 | C0 | L1 | | | | |
* | 2 | S0 | C1 | L2 | | | |
* | 3 | | S1 | C2 | L3 | | |
* | 4 | | | S2 | C3 | L4 | |
* | 5 | | | | S3 | C4 | L5 |
* | 6 | | | | | S4 | C5 |
* | 7 | | | | | | S5 |
* |------|-----------------------------------------------------------------|
*/
// diagram of PINGPONG: L0
if (loop_per_core > 0) {
loadInput(nram_input, nram_target, base_addr_input, base_addr_target,
deal_n, total_c, PING, has_weight, ping_pong_offset, 0);
__asm__ volatile("sync;");
}
// diagram of PINGPONG: C0 and L1
if (loop_per_core > 1) {
coreCompute(nram_input, nram_weight, nram_flt_min, nram_pt, nram_alpha_t,
nram_temp, nram_target, nram_gamma, nram_output, alpha,
compute_num, deal_n, total_c, PING, ping_pong_offset,
has_weight);
loadInput(nram_input, nram_target, base_addr_input, base_addr_target,
deal_n, total_c, PONG, has_weight, ping_pong_offset, deal_num);
__asm__ volatile("sync;");
}
for (int i = 0; i < loop_per_core - 2; ++i) {
if (i % 2 == PING) {
storeOutput(base_addr_output, nram_output, deal_n, total_c, PING,
has_weight, ping_pong_offset, i * deal_num);
coreCompute(nram_input, nram_weight, nram_flt_min, nram_pt, nram_alpha_t,
nram_temp, nram_target, nram_gamma, nram_output, alpha,
compute_num, deal_n, total_c, PONG, ping_pong_offset,
has_weight);
loadInput(nram_input, nram_target, base_addr_input, base_addr_target,
deal_n, total_c, PING, has_weight, ping_pong_offset,
(i + 2) * deal_num);
} else {
storeOutput(base_addr_output, nram_output, deal_n, total_c, PONG,
has_weight, ping_pong_offset, i * deal_num);
coreCompute(nram_input, nram_weight, nram_flt_min, nram_pt, nram_alpha_t,
nram_temp, nram_target, nram_gamma, nram_output, alpha,
compute_num, deal_n, total_c, PING, ping_pong_offset,
has_weight);
loadInput(nram_input, nram_target, base_addr_input, base_addr_target,
deal_n, total_c, PONG, has_weight, ping_pong_offset,
(i + 2) * deal_num);
}
__asm__ volatile("sync;");
}
if (loop_per_core > 1) {
if ((loop_per_core - 2) % 2 == PING) {
storeOutput(base_addr_output, nram_output, deal_n, total_c, PING,
has_weight, ping_pong_offset, (loop_per_core - 2) * deal_num);
coreCompute(nram_input, nram_weight, nram_flt_min, nram_pt, nram_alpha_t,
nram_temp, nram_target, nram_gamma, nram_output, alpha,
compute_num, deal_n, total_c, PONG, ping_pong_offset,
has_weight);
} else {
storeOutput(base_addr_output, nram_output, deal_n, total_c, PONG,
has_weight, ping_pong_offset, (loop_per_core - 2) * deal_num);
coreCompute(nram_input, nram_weight, nram_flt_min, nram_pt, nram_alpha_t,
nram_temp, nram_target, nram_gamma, nram_output, alpha,
compute_num, deal_n, total_c, PING, ping_pong_offset,
has_weight);
}
__asm__ volatile("sync;");
}
if (loop_per_core > 0) {
if (loop_per_core == 1) {
coreCompute(nram_input, nram_weight, nram_flt_min, nram_pt, nram_alpha_t,
nram_temp, nram_target, nram_gamma, nram_output, alpha,
compute_num, deal_n, total_c, PING, ping_pong_offset,
has_weight);
__asm__ volatile("sync;");
}
if ((loop_per_core - 1) % 2 == PING) {
storeOutput(base_addr_output, nram_output, deal_n, total_c, PING,
has_weight, ping_pong_offset, (loop_per_core - 1) * deal_num);
} else {
storeOutput(base_addr_output, nram_output, deal_n, total_c, PONG,
has_weight, ping_pong_offset, (loop_per_core - 1) * deal_num);
}
}
// process the remaining data which N remainder per core is less than deal_n
int32_t rem_for_all = total_num - num_per_core * taskDim;
if (rem_for_all == 0) return;
int32_t rem_n_for_all = rem_for_all / total_c;
int32_t rem_n_per_core = (rem_n_for_all + taskDim - 1) / taskDim;
int32_t rem_num_per_core = rem_n_per_core * total_c;
int32_t rem_num_per_core_align = 0;
int32_t rem_core_num = rem_for_all / rem_num_per_core;
int32_t rem_n_for_last = rem_n_for_all % rem_n_per_core;
int32_t rem_num_for_last = rem_n_for_last * total_c;
int32_t rem_num_for_last_align = 0;
if (has_weight) {
int32_t align_c = PAD_UP(total_c, compute_align_num);
rem_num_per_core_align = rem_n_per_core * align_c;
rem_num_for_last_align = rem_n_for_last * align_c;
} else {
rem_num_per_core_align = PAD_UP(rem_num_per_core, compute_align_num);
rem_num_for_last_align = PAD_UP(rem_num_for_last, compute_align_num);
}
int32_t rem_addr_base = num_per_core * taskDim;
int32_t rem_target_addr_base = loop_per_core * deal_n * taskDim;
base_addr_target = (int32_t *)target + rem_target_addr_base;
base_addr_input = (T *)input + rem_addr_base;
base_addr_output = output + rem_addr_base;
if (taskId < rem_core_num) {
loadInput(nram_input, nram_target, base_addr_input, base_addr_target,
rem_n_per_core, total_c, PING, has_weight, ping_pong_offset,
taskId * rem_num_per_core);
__asm__ volatile("sync;");
coreCompute(nram_input, nram_weight, nram_flt_min, nram_pt, nram_alpha_t,
nram_temp, nram_target, nram_gamma, nram_output, alpha,
rem_num_per_core_align, rem_n_per_core, total_c, PING,
ping_pong_offset, has_weight);
__asm__ volatile("sync;");
storeOutput(base_addr_output, nram_output, rem_n_per_core, total_c, PING,
has_weight, ping_pong_offset, taskId * rem_num_per_core);
} else if (taskId == rem_core_num) {
if (rem_num_for_last == 0) return;
loadInput(nram_input, nram_target, base_addr_input, base_addr_target,
rem_n_for_last, total_c, PING, has_weight, ping_pong_offset,
taskId * rem_num_per_core);
__asm__ volatile("sync;");
coreCompute(nram_input, nram_weight, nram_flt_min, nram_pt, nram_alpha_t,
nram_temp, nram_target, nram_gamma, nram_output, alpha,
rem_num_for_last_align, rem_n_for_last, total_c, PING,
ping_pong_offset, has_weight);
__asm__ volatile("sync;");
storeOutput(base_addr_output, nram_output, rem_n_for_last, total_c, PING,
has_weight, ping_pong_offset, taskId * rem_num_per_core);
} else {
return;
}
}
template <typename T>
__mlu_global__ void MLUUnion1KernelFocalLossSigmoidBackward(
const void *input, const void *target, const void *weight,
const float gamma, const float alpha, const int32_t total_n,
const int32_t deal_n, const int32_t total_c, void *output) {
focalLossSigmoidBackwardBlock((T *)input, (int32_t *)target, (T *)weight,
gamma, alpha, total_n, deal_n, total_c,
(T *)output);
}
} // namespace backward
void KernelFocalLossSigmoidForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue,
const cnrtDataType_t d_type,
const void *input, const void *target,
const void *weight, const int32_t N,
const int32_t C, const float alpha,
const float gamma, void *output) {
if (d_type == CNRT_FLOAT16) {
forward::MLUUnion1KernelFocalLossSigmoidForward<
half><<<k_dim, k_type, queue>>>(input, target, weight, N, C, alpha,
gamma, output);
} else {
forward::MLUUnion1KernelFocalLossSigmoidForward<
float><<<k_dim, k_type, queue>>>(input, target, weight, N, C, alpha,
gamma, output);
}
}
void KernelFocalLossSigmoidBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue,
const cnrtDataType_t d_type,
const void *input, const void *target,
const void *weight, const float gamma,
const float alpha, const int32_t dim_n,
const int32_t deal_n, const int32_t dim_c,
void *output) {
if (d_type == CNRT_FLOAT16) {
backward::MLUUnion1KernelFocalLossSigmoidBackward<
half><<<k_dim, k_type, queue>>>(input, target, weight, gamma, alpha,
dim_n, deal_n, dim_c, output);
} else {
backward::MLUUnion1KernelFocalLossSigmoidBackward<
float><<<k_dim, k_type, queue>>>(input, target, weight, gamma, alpha,
dim_n, deal_n, dim_c, output);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment