Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
c2b62b7f
Commit
c2b62b7f
authored
Mar 13, 2025
by
JR_ZZU
🌴
Browse files
delete origin files
parent
2a4864d5
Changes
164
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
8169 deletions
+0
-8169
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu
+0
-84
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu
+0
-137
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h
+0
-531
apex/contrib/csrc/fmha/src/fmha_kernel.h
apex/contrib/csrc/fmha/src/fmha_kernel.h
+0
-179
apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu
apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu
+0
-177
apex/contrib/csrc/fmha/src/fmha_utils.h
apex/contrib/csrc/fmha/src/fmha_utils.h
+0
-92
apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp
apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp
+0
-70
apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu
apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu
+0
-267
apex/contrib/csrc/groupbn/batch_norm.cu
apex/contrib/csrc/groupbn/batch_norm.cu
+0
-342
apex/contrib/csrc/groupbn/batch_norm.h
apex/contrib/csrc/groupbn/batch_norm.h
+0
-901
apex/contrib/csrc/groupbn/batch_norm_add_relu.cu
apex/contrib/csrc/groupbn/batch_norm_add_relu.cu
+0
-353
apex/contrib/csrc/groupbn/batch_norm_add_relu.h
apex/contrib/csrc/groupbn/batch_norm_add_relu.h
+0
-816
apex/contrib/csrc/groupbn/cuda_utils.h
apex/contrib/csrc/groupbn/cuda_utils.h
+0
-28
apex/contrib/csrc/groupbn/dnn.h
apex/contrib/csrc/groupbn/dnn.h
+0
-26
apex/contrib/csrc/groupbn/interface.cpp
apex/contrib/csrc/groupbn/interface.cpp
+0
-175
apex/contrib/csrc/groupbn/ipc.cu
apex/contrib/csrc/groupbn/ipc.cu
+0
-129
apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h
apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h
+0
-3021
apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp
apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp
+0
-139
apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu
apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu
+0
-492
apex/contrib/csrc/layer_norm/ln.h
apex/contrib/csrc/layer_norm/ln.h
+0
-210
No files found.
Too many changes to show.
To preserve performance only
164 of 164+
files are displayed.
Plain diff
Email patch
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu
deleted
100644 → 0
View file @
2a4864d5
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
384
,
64
,
16
,
1
,
4
,
0x18u
>
;
template
<
bool
Is_training
>
__global__
void
fmha_fprop_fp16_384_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
,
const
int
num_full_heads
,
const
int
num_main_groups
,
const
int
main_group_size
,
const
int
main_steps
,
const
int
rest_steps
)
{
fmha
::
device_1xN
<
Kernel_traits
,
Is_training
>
(
params
,
num_full_heads
,
num_main_groups
,
main_group_size
,
main_steps
,
rest_steps
);
}
void
run_fmha_fp16_384_64_sm80
(
Launch_params
<
Fused_multihead_attention_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
auto
kernel
=
launch_params
.
is_training
?
&
fmha_fprop_fp16_384_64_sm80_kernel
<
true
>
:
&
fmha_fprop_fp16_384_64_sm80_kernel
<
false
>
;
constexpr
int
smem_size
=
fmha
::
get_dynamic_smem_size
<
Kernel_traits
>
();
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
const
int
sm_count
=
launch_params
.
props
->
multiProcessorCount
;
int
ctas_per_sm
;
FMHA_CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS
,
smem_size
));
int
total_ctas
=
sm_count
*
ctas_per_sm
;
if
(
configure
)
{
const
int
heads_total
=
launch_params
.
params
.
b
*
launch_params
.
params
.
h
;
std
::
tie
(
launch_params
.
num_full_heads
,
launch_params
.
num_main_groups
,
launch_params
.
heads_last_wave
,
launch_params
.
main_steps
,
launch_params
.
rest_steps
,
launch_params
.
elts_per_thread
)
=
fmha
::
work_dist
<
Kernel_traits
>
(
total_ctas
,
heads_total
);
return
;
}
dim3
grid
(
total_ctas
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
launch_params
.
stream
>>>
(
launch_params
.
params
,
launch_params
.
num_full_heads
,
launch_params
.
num_main_groups
,
launch_params
.
heads_last_wave
,
launch_params
.
main_steps
,
launch_params
.
rest_steps
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu
deleted
100644 → 0
View file @
2a4864d5
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
512
,
64
,
16
,
1
,
8
,
0x00u
>
;
template
<
bool
Is_training
>
__global__
void
fmha_fprop_fp16_512_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
,
const
int
total_heads
)
{
fmha
::
device_1xN
<
Kernel_traits
,
Is_training
>
(
params
,
total_heads
);
}
template
<
bool
Is_training
>
__global__
void
fmha_fprop_fp16_512_64_sm80_kernel_nl
(
Fused_multihead_attention_fprop_params
params
,
const
int
num_full_heads
,
const
int
num_main_groups
,
const
int
main_group_size
,
const
int
main_steps
,
const
int
rest_steps
)
{
fmha
::
device_1xN
<
Kernel_traits
,
Is_training
>
(
params
,
num_full_heads
,
num_main_groups
,
main_group_size
,
main_steps
,
rest_steps
);
}
void
run_fmha_fp16_512_64_sm80_
(
Launch_params
<
Fused_multihead_attention_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
auto
kernel
=
launch_params
.
is_training
?
&
fmha_fprop_fp16_512_64_sm80_kernel
<
true
>
:
&
fmha_fprop_fp16_512_64_sm80_kernel
<
false
>
;
constexpr
int
smem_size
=
fmha
::
get_dynamic_smem_size
<
Kernel_traits
>
();
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
const
int
sm_count
=
launch_params
.
props
->
multiProcessorCount
;
int
ctas_per_sm
;
FMHA_CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS
,
smem_size
));
int
total_ctas
=
sm_count
*
ctas_per_sm
;
const
int
heads_total
=
launch_params
.
params
.
b
*
launch_params
.
params
.
h
;
if
(
configure
)
{
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
typename
Kernel_traits
::
Cta_tile_p
>
;
constexpr
size_t
STEPS
=
Kernel_traits
::
Cta_tile_p
::
N
/
Kernel_traits
::
Cta_tile_p
::
M
;
constexpr
size_t
MMAS_M
=
Mma_tile_p
::
MMAS_M
;
constexpr
size_t
MMAS_N
=
Mma_tile_p
::
MMAS_N
;
size_t
heads_per_cta
=
((
heads_total
+
total_ctas
-
1
)
/
total_ctas
);
size_t
elts_per_head
=
STEPS
*
MMAS_M
*
MMAS_N
*
8
;
launch_params
.
elts_per_thread
=
heads_per_cta
*
elts_per_head
;
return
;
}
dim3
grid
(
total_ctas
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
launch_params
.
stream
>>>
(
launch_params
.
params
,
heads_total
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
void
run_fmha_fp16_512_64_sm80_nl_
(
Launch_params
<
Fused_multihead_attention_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
auto
kernel
=
launch_params
.
is_training
?
&
fmha_fprop_fp16_512_64_sm80_kernel_nl
<
true
>
:
&
fmha_fprop_fp16_512_64_sm80_kernel_nl
<
false
>
;
constexpr
int
smem_size
=
fmha
::
get_dynamic_smem_size
<
Kernel_traits
>
();
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
const
int
sm_count
=
launch_params
.
props
->
multiProcessorCount
;
int
ctas_per_sm
;
FMHA_CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS
,
smem_size
));
int
total_ctas
=
sm_count
*
ctas_per_sm
;
if
(
configure
)
{
const
int
heads_total
=
launch_params
.
params
.
b
*
launch_params
.
params
.
h
;
std
::
tie
(
launch_params
.
num_full_heads
,
launch_params
.
num_main_groups
,
launch_params
.
heads_last_wave
,
launch_params
.
main_steps
,
launch_params
.
rest_steps
,
launch_params
.
elts_per_thread
)
=
fmha
::
work_dist
<
Kernel_traits
>
(
total_ctas
,
heads_total
);
return
;
}
dim3
grid
(
total_ctas
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
launch_params
.
stream
>>>
(
launch_params
.
params
,
launch_params
.
num_full_heads
,
launch_params
.
num_main_groups
,
launch_params
.
heads_last_wave
,
launch_params
.
main_steps
,
launch_params
.
rest_steps
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
void
run_fmha_fp16_512_64_sm80
(
Launch_params
<
Fused_multihead_attention_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
if
(
launch_params
.
is_nl
)
{
run_fmha_fp16_512_64_sm80_nl_
(
launch_params
,
configure
);
}
else
{
run_fmha_fp16_512_64_sm80_
(
launch_params
,
configure
);
}
}
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h
deleted
100644 → 0
View file @
2a4864d5
/***************************************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
>
struct
Gemm_Q_K_base
{
using
Smem_tile_o
=
typename
Kernel_traits
::
Smem_tile_o
;
using
Smem_tile_q
=
typename
Kernel_traits
::
Smem_tile_q
;
using
Smem_tile_k
=
typename
Kernel_traits
::
Smem_tile_k
;
using
Fragment_q
=
typename
Smem_tile_q
::
Fragment
;
using
Fragment_k
=
typename
Smem_tile_k
::
Fragment
;
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
static
constexpr
int
SMEM_BYTES_SOFTMAX
=
Cta_tile_p
::
M
*
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
)
*
2
;
__device__
inline
Gemm_Q_K_base
(
char
*
smem_ptr_q
,
char
*
smem_ptr_k
,
const
int
tidx
)
:
smem_q
(
smem_ptr_q
,
tidx
)
,
smem_k
(
smem_ptr_k
,
tidx
)
{
}
__device__
inline
void
load_q
()
{
smem_q
.
load
(
frag_q
[
0
],
0
);
}
__device__
inline
void
reload_q
()
{
smem_q
.
load
(
frag_q
[
0
],
0
);
}
Fragment_q
frag_q
[
2
][
Mma_tile_p
::
MMAS_M
];
Smem_tile_q
smem_q
;
Smem_tile_k
smem_k
;
};
template
<
typename
Kernel_traits
,
bool
K_in_regs
>
struct
Gemm_Q_K
:
public
Gemm_Q_K_base
<
Kernel_traits
>
{
using
Base
=
Gemm_Q_K_base
<
Kernel_traits
>
;
using
Smem_tile_o
=
typename
Base
::
Smem_tile_o
;
using
Smem_tile_q
=
typename
Base
::
Smem_tile_q
;
using
Smem_tile_k
=
typename
Base
::
Smem_tile_k
;
using
Fragment_k
=
typename
Base
::
Fragment_k
;
using
Mma_tile_p
=
typename
Base
::
Mma_tile_p
;
enum
{
SHARE_SMEM_FOR_K_AND_V
=
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
};
enum
{
SMEM_OFFSET_O
=
Smem_tile_q
::
BYTES_PER_TILE
};
enum
{
SMEM_OFFSET_V
=
Smem_tile_q
::
BYTES_PER_TILE
+
(
SHARE_SMEM_FOR_K_AND_V
?
0
:
Smem_tile_k
::
BYTES_PER_TILE
)
};
// Q | K / V
// | O | SOFTMAX
static
constexpr
int
SMEM_BYTES
=
Smem_tile_q
::
BYTES_PER_TILE
+
std
::
max
((
SHARE_SMEM_FOR_K_AND_V
?
1
:
2
)
*
Smem_tile_k
::
BYTES_PER_TILE
,
Smem_tile_o
::
BYTES_PER_TILE
+
Base
::
SMEM_BYTES_SOFTMAX
);
__device__
inline
Gemm_Q_K
(
char
*
smem_
,
const
int
tidx
)
:
Base
(
smem_
,
smem_
+
Smem_tile_q
::
BYTES_PER_TILE
,
tidx
)
{
}
__device__
inline
void
load_k
(){
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
Base
::
smem_k
.
load
(
frag_k
[
ki
],
ki
);
}
}
template
<
typename
Acc
,
int
M
,
int
N
>
__device__
inline
void
operator
()(
Acc
(
&
acc_p
)[
M
][
N
]){
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
Base
::
smem_q
.
load
(
Base
::
frag_q
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
}
__device__
inline
void
reload_k
(){
// Noop.
}
Fragment_k
frag_k
[
Mma_tile_p
::
MMAS_K
][
Mma_tile_p
::
MMAS_N
];
};
template
<
typename
Kernel_traits
>
struct
Gemm_Q_K
<
Kernel_traits
,
false
>
:
public
Gemm_Q_K_base
<
Kernel_traits
>
{
using
Base
=
Gemm_Q_K_base
<
Kernel_traits
>
;
using
Smem_tile_o
=
typename
Base
::
Smem_tile_o
;
using
Smem_tile_q
=
typename
Base
::
Smem_tile_q
;
using
Smem_tile_k
=
typename
Base
::
Smem_tile_k
;
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_v
;
using
Fragment_k
=
typename
Base
::
Fragment_k
;
using
Mma_tile_p
=
typename
Base
::
Mma_tile_p
;
Fragment_k
frag_k
[
2
][
Mma_tile_p
::
MMAS_N
];
enum
{
SHARE_SMEM_FOR_K_AND_V
=
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
};
enum
{
SMEM_OFFSET_V
=
Smem_tile_q
::
BYTES_PER_TILE
+
(
SHARE_SMEM_FOR_K_AND_V
?
0
:
Smem_tile_k
::
BYTES_PER_TILE
)
};
static_assert
(
Smem_tile_v
::
BYTES_PER_TILE
==
(
int
)
Smem_tile_k
::
BYTES_PER_TILE
);
enum
{
SMEM_OFFSET_O
=
SMEM_OFFSET_V
+
Smem_tile_v
::
BYTES_PER_TILE
};
// Q | K/V + O + SOFTMAX
static
constexpr
int
SMEM_BYTES
=
Smem_tile_q
::
BYTES_PER_TILE
+
(
SHARE_SMEM_FOR_K_AND_V
?
1
:
2
)
*
Smem_tile_k
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
+
Base
::
SMEM_BYTES_SOFTMAX
;
__device__
inline
Gemm_Q_K
(
char
*
smem_
,
const
int
tidx
)
:
Base
(
smem_
,
smem_
+
Smem_tile_q
::
BYTES_PER_TILE
,
tidx
)
{
}
__device__
inline
void
load_k
(){
Base
::
smem_k
.
load
(
frag_k
[
0
],
0
);
}
template
<
typename
Acc
,
int
M
,
int
N
>
__device__
inline
void
operator
()(
Acc
(
&
acc_p
)[
M
][
N
]){
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
Base
::
smem_q
.
load
(
Base
::
frag_q
[
ki
&
1
],
ki
);
Base
::
smem_k
.
load
(
frag_k
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
}
}
__device__
inline
void
reload_k
(){
Base
::
smem_k
.
load
(
frag_k
[
0
],
0
);
}
};
template
<
typename
Kernel_traits
>
constexpr
size_t
get_dynamic_smem_size
(){
return
Gemm_Q_K
<
Kernel_traits
,
Kernel_traits
::
K_IN_REGS
>::
SMEM_BYTES
;
}
template
<
typename
Kernel_traits
,
bool
Is_training
,
typename
Params
,
typename
Prng
>
inline
__device__
void
device_1xN_
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
begin
,
const
int
steps
,
Prng
&
ph
)
{
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
// The description of the CTA tile for the 2nd batched GEMM.
using
Cta_tile_o
=
typename
Kernel_traits
::
Cta_tile_o
;
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
// The MMA tile for the 2nd GEMM.
using
Mma_tile_o
=
fmha
::
Hmma_tile
<
Cta_tile_o
>
;
// The global memory tile to load Q.
using
Gmem_tile_q
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The global memory tile to load K.
using
Gmem_tile_k
=
typename
Kernel_traits
::
Gmem_tile_k
;
// The global memory tile to load V.
using
Gmem_tile_v
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_v
;
// The global memory tile to store O.
using
Gmem_tile_o
=
typename
Kernel_traits
::
Gmem_tile_o
;
// The shared memory tile to swizzle O.
using
Smem_tile_o
=
typename
Kernel_traits
::
Smem_tile_o
;
using
Gmem_tile_s
=
typename
Kernel_traits
::
Gmem_tile_s
;
using
Gemm1
=
Gemm_Q_K
<
Kernel_traits
,
Kernel_traits
::
K_IN_REGS
>
;
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
// The number of threads per row.
enum
{
THREADS_PER_ROW
=
32
};
enum
{
BITS_PER_ELT_S
=
sizeof
(
fmha
::
A_type
)
*
8
};
// Shared memory.
extern
__shared__
char
smem_
[];
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
if
(
binfo
.
stop_early
()
)
return
;
Gemm1
gemm_q_k
(
smem_
,
tidx
);
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
,
binfo
,
tidx
);
// Allocate the global memory tile loader for S.
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
// Wind gmem tiles to the correct position.
for
(
int
it
=
0
;
it
<
begin
;
it
++
)
{
gmem_q
.
move
();
gmem_s
.
move
();
gmem_o
.
move
();
}
fmha
::
Mask
<
Cta_tile_p
>
mask
(
params
,
binfo
,
tidx
);
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
,
2
,
binfo
,
tidx
);
// The base pointer of smem_v;
char
*
smem_v_
=
&
smem_
[
Gemm1
::
SMEM_OFFSET_V
];
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v
smem_v
(
smem_v_
,
tidx
);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o
smem_o
(
&
smem_
[
Gemm1
::
SMEM_OFFSET_O
],
tidx
);
// Trigger the loads for K.
gmem_k
.
load
(
gemm_q_k
.
smem_k
);
// Trigger the loads for Q.
gmem_q
.
load
(
gemm_q_k
.
smem_q
);
// Trigger the loads for V.
gmem_v
.
load
(
smem_v
);
const
uint32_t
scale_bmm1
=
reinterpret_cast
<
const
uint32_t
&>
(
params
.
scale_bmm1
);
#pragma unroll
for
(
int
it
=
0
;
it
<
Gmem_tile_k
::
LDGS
;
it
++
){
gmem_k
.
fetch_
[
it
]
=
fmha
::
hmul8
(
scale_bmm1
,
gmem_k
.
fetch_
[
it
]);
}
// Commit the data for Q and V to shared memory.
gmem_q
.
commit
(
gemm_q_k
.
smem_q
);
gmem_v
.
commit
(
smem_v
);
// Commit the data for K to shared memory.
if
(
!
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
gmem_k
.
commit
(
gemm_q_k
.
smem_k
);
}
__syncthreads
();
// Load the fragments for Q.
gemm_q_k
.
load_q
();
// Load the fragments for V. We keep the data in registers during the entire kernel.
typename
Smem_tile_v
::
Fragment
frag_v
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_N
];
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
smem_v
.
load
(
frag_v
[
ki
],
ki
);
}
// Commit the data for V to shared memory if it has not been done already.
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
// Make sure we are done loading the fragments for K.
__syncthreads
();
// Commit the data to shared memory for V.
gmem_k
.
commit
(
gemm_q_k
.
smem_k
);
// Make sure the data is in shared memory.
__syncthreads
();
}
// Load the fragments for K.
gemm_q_k
.
load_k
();
// Create the object to do the softmax.
Softmax
softmax
(
params
,
&
smem_
[
Gemm1
::
SMEM_OFFSET_O
+
Smem_tile_o
::
BYTES_PER_TILE
],
bidb
,
tidx
);
// Load over the entire sequence length.
for
(
int
l
=
0
;
l
<
steps
;
l
++
)
{
if
(
begin
+
l
*
Cta_tile_p
::
M
>=
binfo
.
actual_seqlen
)
break
;
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_p
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_p
);
// Do this part of P^T = (Q * K^T)^T.
gemm_q_k
(
acc_p
);
// Trigger the load for the next Q values.
if
(
l
<
steps
-
1
)
{
gemm_q_k
.
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
load
(
gemm_q_k
.
smem_q
);
}
// Load the mask for that iteration.
mask
.
load
(
begin
+
l
);
// Convert from the accumulator type to FP32 for Softmax.
softmax
.
unpack_noscale
(
acc_p
);
// Apply the mask.
softmax
.
apply_mask
(
mask
);
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
&&
l
==
0
)
{
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads
();
}
// Compute the max.
float
p_max
[
Mma_tile_p
::
MMAS_M
*
2
];
//softmax.template reduce<fmha::Max_>(p_max);
softmax
.
reduce_max
(
p_max
);
// Compute the exponential value.
softmax
.
apply_exp
(
p_max
);
// Compute the sum.
float
p_sum
[
Mma_tile_p
::
MMAS_M
*
2
];
softmax
.
reduce_sum
(
p_sum
);
// Finalize softmax on the accumulators of P^T.
softmax
.
scale
(
p_sum
);
using
Frag_p
=
fmha
::
Fragment_a
<
fmha
::
Row
>
;
Frag_p
frag_p
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_M
];
if
(
Is_training
)
{
auto
encode_dropout
=
[](
bool
keep
,
float
val
)
{
return
keep
?
val
:
-
val
;
};
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
2
;
ii
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
float4
tmp
=
uniform4
(
ph
());
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from pre-existing zeros
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]
=
encode_dropout
(
tmp
.
x
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]
=
encode_dropout
(
tmp
.
y
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
2
]
=
encode_dropout
(
tmp
.
z
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
2
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]
=
encode_dropout
(
tmp
.
w
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]);
}
}
}
softmax
.
pack
(
frag_p
);
gmem_s
.
store
(
frag_p
,
mask
);
gmem_s
.
move
();
}
else
{
softmax
.
pack
(
frag_p
);
}
// Commit the values for Q into shared memory.
if
(
l
<
steps
-
1
)
{
gmem_q
.
commit
(
gemm_q_k
.
smem_q
);
}
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
ki
++
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_o
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Frag_p
::
NUM_REGS
;
ii
++
)
{
//"Apply" the dropout.
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hmul2
(
frag_p
[
ki
][
mi
].
reg
(
ii
),
params
.
scale_dropout
);
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hrelu2
(
frag_p
[
ki
][
mi
].
reg
(
ii
));
}
}
}
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_o
[
Mma_tile_o
::
MMAS_M
][
Mma_tile_o
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_o
::
WARPS_K
>::
apply
(
acc_o
);
// Do this part of O = P^T * V^T.
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
fmha
::
gemm
(
acc_o
,
frag_p
[
ki
],
frag_v
[
ki
]);
}
// Loop over MMAS_M.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Gmem_tile_o
::
LOOPS
;
++
ii
)
{
// Swizzle the elements and do the final reduction.
smem_o
.
store
(
acc_o
,
ii
);
// Make sure the data is in shared memory.
__syncthreads
();
// Load from shared memory.
uint4
out
[
Gmem_tile_o
::
STGS_PER_LOOP
];
smem_o
.
load
(
out
);
// Make sure the data was read from shared memory.
if
(
ii
<
Gmem_tile_o
::
LOOPS
-
1
)
{
__syncthreads
();
}
// Output the values.
gmem_o
.
store
(
out
,
ii
);
}
// Move to the next part of the output.
gmem_o
.
move
();
gemm_q_k
.
reload_k
();
// Commit the values for Q into shared memory.
if
(
l
<
steps
-
1
)
{
gemm_q_k
.
reload_q
();
}
}
// Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_training
,
typename
Params
>
inline
__device__
void
device_1xN
(
const
Params
&
params
,
const
int
num_full_heads
,
const
int
num_main_groups
,
const
int
main_group_size
,
const
int
main_steps
,
const
int
rest_steps
)
{
constexpr
int
STEPS
=
Kernel_traits
::
Cta_tile_p
::
N
/
Kernel_traits
::
Cta_tile_p
::
M
;
const
int
tidx_global
=
blockIdx
.
x
*
gridDim
.
x
+
threadIdx
.
x
;
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
Philox
ph
(
std
::
get
<
0
>
(
seeds
),
tidx_global
,
std
::
get
<
1
>
(
seeds
));
for
(
int
it
=
0
;
it
<
num_full_heads
;
it
++
)
{
const
int
bidx
=
it
*
gridDim
.
x
+
blockIdx
.
x
;
const
int
bidh
=
bidx
%
params
.
h
;
const
int
bidb
=
bidx
/
params
.
h
;
fmha
::
device_1xN_
<
Kernel_traits
,
Is_training
>
(
params
,
bidb
,
bidh
,
0
,
STEPS
,
ph
);
__syncthreads
();
}
if
(
main_group_size
==
0
)
return
;
const
int
head_offset
=
num_full_heads
*
gridDim
.
x
;
if
(
blockIdx
.
x
<
main_group_size
*
num_main_groups
)
{
// process within heads
const
int
group
=
blockIdx
.
x
%
num_main_groups
;
const
int
bidx
=
blockIdx
.
x
/
num_main_groups
;
const
int
bidh
=
(
head_offset
+
bidx
)
%
params
.
h
;
const
int
bidb
=
(
head_offset
+
bidx
)
/
params
.
h
;
const
int
offset
=
group
*
main_steps
;
fmha
::
device_1xN_
<
Kernel_traits
,
Is_training
>
(
params
,
bidb
,
bidh
,
offset
,
main_steps
,
ph
);
}
else
{
if
(
rest_steps
==
0
)
return
;
// process across heads
const
int
bidx
=
blockIdx
.
x
-
main_group_size
*
num_main_groups
;
const
int
offset
=
num_main_groups
*
main_steps
;
const
int
total_heads
=
params
.
b
*
params
.
h
;
const
int
rest_ctas
=
gridDim
.
x
-
main_group_size
*
num_main_groups
;
for
(
int
it
=
head_offset
+
bidx
;
it
<
total_heads
;
it
+=
rest_ctas
)
{
const
int
bidh
=
it
%
params
.
h
;
const
int
bidb
=
it
/
params
.
h
;
fmha
::
device_1xN_
<
Kernel_traits
,
Is_training
>
(
params
,
bidb
,
bidh
,
offset
,
rest_steps
,
ph
);
__syncthreads
();
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_training
,
typename
Params
>
inline
__device__
void
device_1xN
(
const
Params
&
params
,
const
int
total_heads
)
{
const
int
tidx_global
=
blockIdx
.
x
*
gridDim
.
x
+
threadIdx
.
x
;
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
Philox
ph
(
std
::
get
<
0
>
(
seeds
),
tidx_global
,
std
::
get
<
1
>
(
seeds
));
constexpr
int
STEPS
=
Kernel_traits
::
Cta_tile_p
::
N
/
Kernel_traits
::
Cta_tile_p
::
M
;
for
(
int
bidx
=
blockIdx
.
x
;
bidx
<
total_heads
;
bidx
+=
gridDim
.
x
){
const
int
bidh
=
bidx
%
params
.
h
;
const
int
bidb
=
bidx
/
params
.
h
;
fmha
::
device_1xN_
<
Kernel_traits
,
Is_training
>
(
params
,
bidb
,
bidh
,
0
,
STEPS
,
ph
);
__syncthreads
();
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha_kernel.h
deleted
100644 → 0
View file @
2a4864d5
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <multihead_attn/philox.cuh>
#include <fmha.h>
#include <fmha/utils.h>
#include <fmha/smem_tile.h>
#include <fmha/gmem_tile.h>
#include <fmha/mask.h>
#include <fmha/softmax.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS_PER_CTA
>
struct
BlockInfoPadded
{
template
<
typename
Params
>
__device__
BlockInfoPadded
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
tidx
)
:
bidb
(
bidb
),
bidh
(
bidh
),
h
(
params
.
h
)
{
// The block index.
sum_s
=
params
.
cu_seqlens
[
bidb
];
actual_seqlen
=
params
.
cu_seqlens
[
bidb
+
1
]
-
sum_s
;
bidx
=
sum_s
*
params
.
h
+
bidh
;
tidx_global
=
(
bidb
*
params
.
h
+
bidh
)
*
THREADS_PER_CTA
+
tidx
;
}
__device__
bool
stop_early
()
const
{
return
actual_seqlen
==
0
;
}
int
actual_seqlen
;
int
bidx
;
int
sum_s
;
int
bidh
;
int
bidb
;
int
tidx_global
;
int
h
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
CHUNKS
,
typename
Cta_tile
>
struct
Noloop_traits
{
// Interpretation of Cta_tile dims, i.e. Cta_tile_p:
enum
{
STEP
=
Cta_tile
::
M
};
enum
{
SEQLEN
=
Cta_tile
::
N
};
template
<
typename
Block_info
>
inline
__device__
Noloop_traits
(
const
int
bidc
,
const
Block_info
&
binfo
)
:
bidc_
(
bidc
)
{
const
int
seqlen
=
binfo
.
actual_seqlen
;
const
int
steps
=
(
seqlen
+
STEP
-
1
)
/
STEP
;
const
int
steps_per_chunk
=
(
steps
+
CHUNKS
-
1
)
/
CHUNKS
;
const
int
step_begin
=
bidc_
*
steps_per_chunk
;
const
int
step_end
=
min
(
steps
,
(
bidc_
+
1
)
*
steps_per_chunk
);
const
int
actual_steps
=
max
(
0
,
step_end
-
step_begin
);
loop_offset_
=
step_begin
;
num_steps_
=
actual_steps
;
}
template
<
typename
...
Tiles
>
inline
__device__
void
move_all
(
Tiles
&
...
tiles
)
const
{
using
expand_type
=
int
[];
for
(
int
s
=
0
;
s
<
loop_offset_
;
s
++
)
{
expand_type
{
(
tiles
.
move
(),
0
)...
};
}
}
inline
__device__
int
get_idx_dk
()
const
{
//return bidc_;
return
bidc_
*
2
+
0
;
}
inline
__device__
int
get_idx_dv
()
const
{
//return CHUNKS + bidc_;
return
bidc_
*
2
+
1
;
}
inline
__device__
int
offset_loop_count
(
const
int
l
)
{
// convert loop counter to position in the outer sequence
return
(
loop_offset_
+
l
)
*
STEP
;
}
const
uint32_t
bidc_
;
int
loop_offset_
;
int
num_steps_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
>
std
::
tuple
<
int
,
int
,
int
,
int
,
int
,
int
>
work_dist
(
const
int
total_ctas
,
const
int
heads_total
)
{
constexpr
int
STEPS_PER_HEAD
=
Kernel_traits
::
Cta_tile_p
::
N
/
Kernel_traits
::
Cta_tile_p
::
M
;
const
int
num_full_heads
=
heads_total
/
total_ctas
;
const
int
heads_last_wave
=
heads_total
%
total_ctas
;
int
num_main_groups
=
0
;
int
main_steps
=
0
;
int
rest_steps
=
0
;
if
(
heads_last_wave
>
0
)
{
// Number of CTA groups that process within heads.
num_main_groups
=
total_ctas
/
heads_last_wave
;
// Remaining CTAs that process between heads.
const
int
rest_ctas
=
total_ctas
-
(
heads_last_wave
*
num_main_groups
);
if
(
rest_ctas
==
0
)
{
// We have exactly "num_main_groups" CTAs to process each of the remaining heads.
main_steps
=
(
STEPS_PER_HEAD
+
num_main_groups
-
1
)
/
num_main_groups
;
num_main_groups
=
STEPS_PER_HEAD
/
main_steps
;
// Here: main_step > 0
rest_steps
=
STEPS_PER_HEAD
%
main_steps
;
}
else
{
// Ideal number of steps if we could load-balance as evenly as possible.
const
int
steps_ideal
=
(
heads_last_wave
*
STEPS_PER_HEAD
+
total_ctas
-
1
)
/
total_ctas
;
// Iterations that a "rest" CTA has to do at most.
const
int
max_rest_iters
=
(
heads_last_wave
+
rest_ctas
-
1
)
/
rest_ctas
;
// Find the first step distribution, s.t. the maximum work of the "rest" CTAs is less than the work of the main CTAs.
main_steps
=
steps_ideal
;
rest_steps
=
STEPS_PER_HEAD
-
main_steps
*
num_main_groups
;
for
(
;
main_steps
*
num_main_groups
<
STEPS_PER_HEAD
;
main_steps
++
)
{
rest_steps
=
STEPS_PER_HEAD
-
main_steps
*
num_main_groups
;
const
int
max_rest_total_steps
=
rest_steps
*
max_rest_iters
;
if
(
max_rest_total_steps
<
main_steps
)
break
;
}
rest_steps
=
STEPS_PER_HEAD
-
main_steps
*
num_main_groups
;
}
}
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
const
int
max_steps
=
STEPS_PER_HEAD
*
num_full_heads
+
std
::
max
(
main_steps
,
rest_steps
);
const
int
elts_per_thread_per_step
=
Mma_tile_p
::
MMAS_M
*
Mma_tile_p
::
MMAS_N
*
8
;
const
int
elts_per_thread
=
max_steps
*
elts_per_thread_per_step
;
return
{
num_full_heads
,
num_main_groups
,
heads_last_wave
,
main_steps
,
rest_steps
,
elts_per_thread
};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu
deleted
100644 → 0
View file @
2a4864d5
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
inline
__device__
float4
ldg128
(
const
void
*
ptr
)
{
return
*
static_cast
<
const
float4
*>
(
ptr
);
}
inline
__device__
void
stg128
(
void
*
ptr
,
const
float4
&
data
)
{
*
static_cast
<
float4
*>
(
ptr
)
=
data
;
}
template
<
typename
T
,
int
THREADS
,
int
HIDDEN_SIZE
,
int
CHUNKS
>
__global__
__launch_bounds__
(
THREADS
)
void
fmha_noloop_reduce_kernel
(
void
*
__restrict__
out
,
const
void
*
__restrict__
in
,
const
int
*
__restrict__
cu_seqlens
,
const
int
batch_size
)
{
enum
{
BYTES_PER_LDG
=
16
};
enum
{
NUM_ELTS
=
BYTES_PER_LDG
/
sizeof
(
T
)
};
// One CTA hidden vector for K and V
enum
{
BYTES_PER_ROW
=
HIDDEN_SIZE
*
sizeof
(
T
)
*
2
};
// The stride in bytes in dQKV
enum
{
OUT_STRIDE_BYTES
=
3
*
HIDDEN_SIZE
*
sizeof
(
T
)
};
// The offset in bytes in dQKV to the dKV part for non-interleaved heads
enum
{
OUT_OFFSET_KV_BYTES
=
HIDDEN_SIZE
*
sizeof
(
T
)
};
static_assert
(
BYTES_PER_ROW
==
HIDDEN_SIZE
*
2
*
sizeof
(
T
));
// Size in bytes of the input tile
enum
{
BYTES_PER_TILE
=
CHUNKS
*
BYTES_PER_ROW
};
enum
{
BYTES_PER_CTA
=
THREADS
*
BYTES_PER_LDG
};
enum
{
LDGS
=
BYTES_PER_ROW
/
BYTES_PER_CTA
};
static_assert
(
BYTES_PER_CTA
*
LDGS
==
BYTES_PER_ROW
);
union
Vec_t
{
float4
raw
;
T
elt
[
NUM_ELTS
];
};
// ZERO-OUT invalid positions in dQKV
const
int
total
=
cu_seqlens
[
batch_size
];
if
(
blockIdx
.
x
>=
total
){
enum
{
BYTES_PER_QKV_ROW
=
3
*
HIDDEN_SIZE
*
sizeof
(
T
)
};
enum
{
STGS
=
BYTES_PER_QKV_ROW
/
BYTES_PER_LDG
};
const
float4
zeros
=
make_float4
(
0.
f
,
0.
f
,
0.
f
,
0.
f
);
char
*
base_ptr
=
static_cast
<
char
*>
(
out
)
+
blockIdx
.
x
*
OUT_STRIDE_BYTES
;
for
(
int
tidx
=
threadIdx
.
x
;
tidx
<
STGS
;
tidx
+=
THREADS
){
stg128
(
base_ptr
+
tidx
*
BYTES_PER_LDG
,
zeros
);
}
return
;
}
// SETUP
const
int
offset_in
=
blockIdx
.
x
*
BYTES_PER_TILE
+
threadIdx
.
x
*
BYTES_PER_LDG
;
const
char
*
ptr_in
=
static_cast
<
const
char
*>
(
in
)
+
offset_in
;
const
int
offset_out
=
blockIdx
.
x
*
OUT_STRIDE_BYTES
+
threadIdx
.
x
*
BYTES_PER_LDG
;
char
*
ptr_out
=
static_cast
<
char
*>
(
out
)
+
OUT_OFFSET_KV_BYTES
+
offset_out
;
// LOAD
Vec_t
local_in
[
CHUNKS
][
LDGS
];
#pragma unroll
for
(
int
c
=
0
;
c
<
CHUNKS
;
c
++
)
{
#pragma unroll
for
(
int
l
=
0
;
l
<
LDGS
;
l
++
)
{
int
offset
=
c
*
BYTES_PER_ROW
+
l
*
BYTES_PER_CTA
;
local_in
[
c
][
l
].
raw
=
ldg128
(
ptr_in
+
offset
);
}
}
// UNPACK
float
acc
[
LDGS
][
NUM_ELTS
];
#pragma unroll
for
(
int
l
=
0
;
l
<
LDGS
;
l
++
)
{
#pragma unroll
for
(
int
e
=
0
;
e
<
NUM_ELTS
;
e
++
)
{
acc
[
l
][
e
]
=
float
(
local_in
[
0
][
l
].
elt
[
e
]);
}
}
// COMPUTE
#pragma unroll
for
(
int
c
=
1
;
c
<
CHUNKS
;
c
++
)
{
#pragma unroll
for
(
int
l
=
0
;
l
<
LDGS
;
l
++
)
{
#pragma unroll
for
(
int
e
=
0
;
e
<
NUM_ELTS
;
e
++
)
{
acc
[
l
][
e
]
+=
float
(
local_in
[
c
][
l
].
elt
[
e
]);
}
}
}
// PACK
Vec_t
local_out
[
LDGS
];
#pragma unroll
for
(
int
l
=
0
;
l
<
LDGS
;
l
++
)
{
#pragma unroll
for
(
int
e
=
0
;
e
<
NUM_ELTS
;
e
++
)
{
local_out
[
l
].
elt
[
e
]
=
T
(
acc
[
l
][
e
]);
}
}
// STORE
#pragma unroll
for
(
int
l
=
0
;
l
<
LDGS
;
l
++
)
{
const
int
offset
=
l
*
BYTES_PER_CTA
;
stg128
(
ptr_out
+
offset
,
local_out
[
l
].
raw
);
}
}
void
fmha_run_noloop_reduce
(
void
*
out
,
const
void
*
in
,
const
int
*
cu_seqlens
,
const
int
hidden_size
,
const
int
batch_size
,
const
int
total
,
const
int
num_chunks
,
cudaStream_t
stream
)
{
const
int
blocks
=
total
;
if
(
hidden_size
==
1024
){
constexpr
int
HIDDEN_SIZE
=
1024
;
constexpr
int
THREADS
=
256
;
if
(
num_chunks
==
2
)
{
fmha_noloop_reduce_kernel
<
half
,
THREADS
,
HIDDEN_SIZE
,
2
><<<
blocks
,
THREADS
,
0
,
stream
>>>
(
out
,
in
,
cu_seqlens
,
batch_size
);
}
else
if
(
num_chunks
==
3
)
{
fmha_noloop_reduce_kernel
<
half
,
THREADS
,
HIDDEN_SIZE
,
3
><<<
blocks
,
THREADS
,
0
,
stream
>>>
(
out
,
in
,
cu_seqlens
,
batch_size
);
}
else
{
assert
(
false
&&
"Unsupported num_chunks"
);
}
}
else
{
assert
(
false
&&
"Unsupported hidden_size"
);
}
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
apex/contrib/csrc/fmha/src/fmha_utils.h
deleted
100644 → 0
View file @
2a4864d5
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
#define FMHA_CHECK_CUDA( call ) \
do { \
cudaError_t status_ = call; \
if( status_ != cudaSuccess ) { \
fprintf( stderr, \
"CUDA error (%s:%d): %s\n", \
__FILE__, \
__LINE__, \
cudaGetErrorString( status_ ) ); \
exit( 1 ); \
} \
} while( 0 )
////////////////////////////////////////////////////////////////////////////////////////////////////
enum
Data_type
{
DATA_TYPE_FP16
,
DATA_TYPE_FP32
,
DATA_TYPE_INT32
,
DATA_TYPE_INT8
};
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
void
set_alpha
(
uint32_t
&
alpha
,
float
norm
,
Data_type
dtype
)
{
if
(
dtype
==
DATA_TYPE_FP16
)
{
half
x
=
__float2half_rn
(
norm
);
uint16_t
h
=
reinterpret_cast
<
const
uint16_t
&>
(
x
);
ushort2
h2
=
{
h
,
h
};
alpha
=
reinterpret_cast
<
const
uint32_t
&>
(
h2
);
}
else
if
(
dtype
==
DATA_TYPE_FP32
)
{
alpha
=
reinterpret_cast
<
const
uint32_t
&>
(
norm
);
}
else
if
(
dtype
==
DATA_TYPE_INT32
)
{
int32_t
inorm
=
static_cast
<
int32_t
>
(
norm
);
alpha
=
reinterpret_cast
<
const
uint32_t
&>
(
inorm
);
}
else
{
assert
(
false
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
size_t
get_size_in_bytes
(
size_t
n
,
Data_type
dtype
)
{
switch
(
dtype
)
{
case
DATA_TYPE_FP32
:
return
n
*
4
;
case
DATA_TYPE_FP16
:
return
n
*
2
;
case
DATA_TYPE_INT32
:
return
n
*
4
;
case
DATA_TYPE_INT8
:
return
n
;
default:
assert
(
false
);
return
0
;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp
deleted
100644 → 0
View file @
2a4864d5
#include <torch/torch.h>
#include <vector>
#include <cstdint>
// CUDA forward declarations
std
::
vector
<
at
::
Tensor
>
focal_loss_forward_cuda
(
const
at
::
Tensor
&
cls_output
,
const
at
::
Tensor
&
cls_targets_at_level
,
const
at
::
Tensor
&
num_positives_sum
,
const
int64_t
num_real_classes
,
const
float
alpha
,
const
float
gamma
,
const
float
smoothing_factor
);
at
::
Tensor
focal_loss_backward_cuda
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
partial_grad
,
const
at
::
Tensor
&
num_positives_sum
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std
::
vector
<
at
::
Tensor
>
focal_loss_forward
(
const
at
::
Tensor
&
cls_output
,
const
at
::
Tensor
&
cls_targets_at_level
,
const
at
::
Tensor
&
num_positives_sum
,
const
int64_t
num_real_classes
,
const
float
alpha
,
const
float
gamma
,
const
float
smoothing_factor
)
{
CHECK_INPUT
(
cls_output
);
CHECK_INPUT
(
cls_targets_at_level
);
CHECK_INPUT
(
num_positives_sum
);
return
focal_loss_forward_cuda
(
cls_output
,
cls_targets_at_level
,
num_positives_sum
,
num_real_classes
,
alpha
,
gamma
,
smoothing_factor
);
}
at
::
Tensor
focal_loss_backward
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
partial_grad
,
const
at
::
Tensor
&
num_positives_sum
)
{
CHECK_INPUT
(
grad_output
);
CHECK_INPUT
(
partial_grad
);
return
focal_loss_backward_cuda
(
grad_output
,
partial_grad
,
num_positives_sum
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
focal_loss_forward
,
"Focal loss calculation forward (CUDA)"
);
m
.
def
(
"backward"
,
&
focal_loss_backward
,
"Focal loss calculation backward (CUDA)"
);
}
apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu
deleted
100644 → 0
View file @
2a4864d5
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#define ASSERT_UINT4_ALIGNED(PTR) \
TORCH_INTERNAL_ASSERT(is_aligned<uint4>(PTR), "Tensor " #PTR " is not uint4 aligned")
template
<
class
T
>
bool
is_aligned
(
const
void
*
ptr
)
noexcept
{
auto
iptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
ptr
);
return
!
(
iptr
%
alignof
(
T
));
}
template
<
bool
SMOOTHING
,
int
ILP
,
typename
scalar_t
,
typename
labelscalar_t
,
typename
accscalar_t
,
typename
outscalar_t
>
__global__
void
focal_loss_forward_cuda_kernel
(
outscalar_t
*
loss
,
scalar_t
*
partial_grad
,
const
scalar_t
*
__restrict__
cls_output
,
const
labelscalar_t
*
__restrict__
cls_targets_at_level
,
const
float
*
__restrict__
num_positives_sum
,
const
int64_t
num_examples
,
const
int64_t
num_classes
,
const
int64_t
num_real_classes
,
const
float
alpha
,
const
float
gamma
,
const
float
smoothing_factor
)
{
extern
__shared__
unsigned
char
shm
[];
accscalar_t
*
loss_shm
=
reinterpret_cast
<
accscalar_t
*>
(
shm
);
loss_shm
[
threadIdx
.
x
]
=
0
;
accscalar_t
loss_acc
=
0
;
accscalar_t
one
=
accscalar_t
(
1.0
);
accscalar_t
K
=
accscalar_t
(
2.0
);
accscalar_t
normalizer
=
one
/
static_cast
<
accscalar_t
>
(
num_positives_sum
[
0
]);
accscalar_t
nn_norm
,
np_norm
,
pn_norm
,
pp_norm
;
// *_norm is used for label smoothing only
if
(
SMOOTHING
)
{
nn_norm
=
one
-
smoothing_factor
/
K
;
np_norm
=
smoothing_factor
/
K
;
pn_norm
=
smoothing_factor
-
smoothing_factor
/
K
;
pp_norm
=
one
-
smoothing_factor
+
smoothing_factor
/
K
;
}
uint4
p_vec
,
grad_vec
;
// Accumulate loss on each thread
for
(
int64_t
i
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
)
*
ILP
;
i
<
num_examples
*
num_classes
;
i
+=
gridDim
.
x
*
blockDim
.
x
*
ILP
)
{
int64_t
idy
=
i
/
num_classes
;
labelscalar_t
y
=
cls_targets_at_level
[
idy
];
int64_t
base_yid
=
i
%
num_classes
;
int64_t
pos_idx
=
idy
*
num_classes
+
y
;
p_vec
=
*
(
uint4
*
)
&
cls_output
[
i
];
// Skip ignored matches
if
(
y
==
-
2
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
ILP
;
j
++
)
{
*
((
scalar_t
*
)(
&
grad_vec
)
+
j
)
=
0
;
}
*
(
uint4
*
)
&
partial_grad
[
i
]
=
grad_vec
;
continue
;
}
#pragma unroll
for
(
int
j
=
0
;
j
<
ILP
;
j
++
)
{
// Skip the pad classes
if
(
base_yid
+
j
>=
num_real_classes
)
{
*
((
scalar_t
*
)(
&
grad_vec
)
+
j
)
=
0
;
continue
;
}
accscalar_t
p
=
static_cast
<
accscalar_t
>
(
*
((
scalar_t
*
)(
&
p_vec
)
+
j
));
accscalar_t
exp_np
=
::
exp
(
-
p
);
accscalar_t
exp_pp
=
::
exp
(
p
);
accscalar_t
sigma
=
one
/
(
one
+
exp_np
);
accscalar_t
logee
=
(
p
>=
0
)
?
exp_np
:
exp_pp
;
accscalar_t
addee
=
(
p
>=
0
)
?
0
:
-
p
;
accscalar_t
off_a
=
addee
+
::
log
(
one
+
logee
);
// Negative matches
accscalar_t
base
=
SMOOTHING
?
nn_norm
*
p
:
p
;
accscalar_t
off_b
=
(
SMOOTHING
?
np_norm
:
0
)
-
sigma
;
accscalar_t
coeff_f1
=
one
-
alpha
;
accscalar_t
coeff_f2
=
sigma
;
accscalar_t
coeff_b1
=
gamma
;
accscalar_t
coeff_b2
=
one
-
sigma
;
// Positive matches
if
(
y
>=
0
&&
(
i
+
j
==
pos_idx
))
{
base
=
SMOOTHING
?
pn_norm
*
p
:
0
;
off_b
=
(
SMOOTHING
?
pp_norm
:
one
)
-
sigma
;
coeff_f1
=
alpha
;
coeff_f2
=
one
-
sigma
;
coeff_b1
=
-
gamma
;
coeff_b2
=
sigma
;
}
accscalar_t
coeff_f
=
coeff_f1
*
::
pow
(
coeff_f2
,
gamma
);
accscalar_t
coeff_b
=
coeff_b1
*
coeff_b2
;
accscalar_t
loss_t
=
coeff_f
*
(
base
+
off_a
);
accscalar_t
grad
=
coeff_f
*
(
coeff_b
*
(
base
+
off_a
)
-
off_b
);
// Delay the normalize of partial gradient by num_positives_sum to back
// propagation because scalar_t reduces precision. Focal loss is very
// sensitive to the small gradient. No worry on overflow here since
// gradient has relative smaller range than input.
loss_acc
+=
loss_t
;
*
((
scalar_t
*
)(
&
grad_vec
)
+
j
)
=
static_cast
<
scalar_t
>
(
grad
);
}
// This can't ensure to generate stg.128 and may be two stg.64.
*
(
uint4
*
)
&
partial_grad
[
i
]
=
grad_vec
;
}
loss_shm
[
threadIdx
.
x
]
=
loss_acc
;
// Intra-CTA reduction
__syncthreads
();
for
(
unsigned
int
s
=
blockDim
.
x
/
2
;
s
>
0
;
s
>>=
1
)
{
if
(
threadIdx
.
x
<
s
)
{
loss_shm
[
threadIdx
.
x
]
+=
loss_shm
[
threadIdx
.
x
+
s
];
}
__syncthreads
();
}
// Inter-CTA reduction
if
(
threadIdx
.
x
==
0
)
{
loss_acc
=
loss_shm
[
0
]
*
normalizer
;
atomicAdd
(
loss
,
loss_acc
);
}
}
template
<
int
ILP
,
typename
scalar_t
,
typename
accscalar_t
,
typename
outscalar_t
>
__global__
void
focal_loss_backward_cuda_kernel
(
scalar_t
*
partial_grad
,
const
outscalar_t
*
__restrict__
grad_output
,
const
float
*
__restrict__
num_positives_sum
,
const
uint64_t
numel
)
{
int64_t
idx
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
)
*
ILP
;
accscalar_t
normalizer
=
static_cast
<
accscalar_t
>
(
grad_output
[
0
])
/
static_cast
<
accscalar_t
>
(
num_positives_sum
[
0
]);
// The input is enforced to pad to use vector load, thus there's no need to
// check whether the last element of ILP can out of bound.
if
(
idx
>=
numel
)
return
;
uint4
grad_vec
;
grad_vec
=
*
(
uint4
*
)
&
partial_grad
[
idx
];
#pragma unroll(ILP)
for
(
int
i
=
0
;
i
<
ILP
;
i
++
)
{
auto
grad
=
static_cast
<
accscalar_t
>
(
*
((
scalar_t
*
)(
&
grad_vec
)
+
i
));
grad
*=
normalizer
;
*
((
scalar_t
*
)(
&
grad_vec
)
+
i
)
=
static_cast
<
scalar_t
>
(
grad
);
}
*
(
uint4
*
)
&
partial_grad
[
idx
]
=
grad_vec
;
}
std
::
vector
<
at
::
Tensor
>
focal_loss_forward_cuda
(
const
at
::
Tensor
&
cls_output
,
const
at
::
Tensor
&
cls_targets_at_level
,
const
at
::
Tensor
&
num_positives_sum
,
const
int64_t
num_real_classes
,
const
float
alpha
,
const
float
gamma
,
const
float
smoothing_factor
)
{
// Checks required for correctness
TORCH_INTERNAL_ASSERT
(
cls_output
.
size
(
-
1
)
>=
num_real_classes
,
"Incorrect number of real classes."
);
TORCH_INTERNAL_ASSERT
(
cls_targets_at_level
.
scalar_type
()
==
at
::
kLong
,
"Invalid label type."
);
TORCH_INTERNAL_ASSERT
(
(
num_positives_sum
.
numel
()
==
1
)
&&
(
num_positives_sum
.
scalar_type
()
==
at
::
kFloat
),
"Expect num_positives_sum to be a float32 tensor with only one element."
);
TORCH_INTERNAL_ASSERT
(
cls_output
.
dim
()
==
cls_targets_at_level
.
dim
()
+
1
,
"Mis-matched dimensions between class output and label."
);
for
(
int64_t
i
=
0
;
i
<
cls_targets_at_level
.
dim
();
i
++
)
TORCH_INTERNAL_ASSERT
(
cls_output
.
size
(
i
)
==
cls_targets_at_level
.
size
(
i
),
"Mis-matched shape between class output and label."
);
// Checks required for better performance
const
int
ILP
=
sizeof
(
uint4
)
/
cls_output
.
element_size
();
ASSERT_UINT4_ALIGNED
(
cls_output
.
data_ptr
());
TORCH_INTERNAL_ASSERT
(
cls_output
.
size
(
-
1
)
%
ILP
==
0
,
"Pad number of classes first to take advantage of 128 bit load."
);
TORCH_INTERNAL_ASSERT
(
num_real_classes
>=
ILP
,
"Too few classes."
);
int64_t
num_classes
=
cls_output
.
size
(
-
1
);
int64_t
num_examples
=
cls_output
.
numel
()
/
num_classes
;
at
::
Tensor
loss
=
at
::
zeros
({},
cls_output
.
options
().
dtype
(
at
::
kFloat
));
// Compute the incompelete gradient during fprop since most of the heavy
// functions of bprop are the same as fprop, thus trade memory for compute
// helps with focal loss.
at
::
Tensor
partial_grad
=
at
::
empty_like
(
cls_output
);
// The grid contains 2 CTA per SM, each CTA loop on input with stride till the
// last item.
cudaDeviceProp
props
;
cudaGetDeviceProperties
(
&
props
,
at
::
cuda
::
current_device
());
dim3
block
(
512
);
dim3
grid
(
2
*
props
.
multiProcessorCount
);
// Specialize on label smoothing or not to reduce redundant operations
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
smoothing_factor
==
0.0
f
)
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
cls_output
.
scalar_type
(),
"focal_loss_fprop"
,
[
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
labelscalar_t
=
int64_t
;
using
outscalar_t
=
float
;
const
int
ILP
=
sizeof
(
uint4
)
/
sizeof
(
scalar_t
);
focal_loss_forward_cuda_kernel
<
false
,
ILP
,
scalar_t
,
labelscalar_t
,
accscalar_t
,
outscalar_t
>
<<<
grid
,
block
,
block
.
x
*
sizeof
(
accscalar_t
),
stream
>>>
(
loss
.
data_ptr
<
outscalar_t
>
(),
partial_grad
.
data_ptr
<
scalar_t
>
(),
cls_output
.
data_ptr
<
scalar_t
>
(),
cls_targets_at_level
.
data_ptr
<
labelscalar_t
>
(),
num_positives_sum
.
data_ptr
<
float
>
(),
num_examples
,
num_classes
,
num_real_classes
,
alpha
,
gamma
,
smoothing_factor
);
});
}
else
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
cls_output
.
scalar_type
(),
"focal_loss_fprop"
,
[
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
labelscalar_t
=
int64_t
;
using
outscalar_t
=
float
;
const
int
ILP
=
sizeof
(
uint4
)
/
sizeof
(
scalar_t
);
focal_loss_forward_cuda_kernel
<
true
,
ILP
,
scalar_t
,
labelscalar_t
,
accscalar_t
,
outscalar_t
>
<<<
grid
,
block
,
block
.
x
*
sizeof
(
accscalar_t
),
stream
>>>
(
loss
.
data_ptr
<
outscalar_t
>
(),
partial_grad
.
data_ptr
<
scalar_t
>
(),
cls_output
.
data_ptr
<
scalar_t
>
(),
cls_targets_at_level
.
data_ptr
<
labelscalar_t
>
(),
num_positives_sum
.
data_ptr
<
float
>
(),
num_examples
,
num_classes
,
num_real_classes
,
alpha
,
gamma
,
smoothing_factor
);
});
}
AT_CUDA_CHECK
(
cudaGetLastError
());
return
{
loss
,
partial_grad
};
}
at
::
Tensor
focal_loss_backward_cuda
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
partial_grad
,
const
at
::
Tensor
&
num_positives_sum
)
{
// Each thread process ILP elements
const
int
ILP
=
sizeof
(
uint4
)
/
partial_grad
.
element_size
();
dim3
block
(
512
);
dim3
grid
((
partial_grad
.
numel
()
+
block
.
x
*
ILP
-
1
)
/
(
block
.
x
*
ILP
));
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
partial_grad
.
scalar_type
(),
"focal_loss_bprop"
,
[
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
outscalar_t
=
float
;
const
int
ILP
=
sizeof
(
uint4
)
/
sizeof
(
scalar_t
);
focal_loss_backward_cuda_kernel
<
ILP
,
scalar_t
,
accscalar_t
,
outscalar_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
partial_grad
.
data_ptr
<
scalar_t
>
(),
grad_output
.
data_ptr
<
outscalar_t
>
(),
num_positives_sum
.
data_ptr
<
float
>
(),
partial_grad
.
numel
());
});
AT_CUDA_CHECK
(
cudaGetLastError
());
return
partial_grad
;
}
apex/contrib/csrc/groupbn/batch_norm.cu
deleted
100644 → 0
View file @
2a4864d5
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include "batch_norm.h"
#include <cuda.h>
#include "compat.h"
#define cudaCheckErrors(msg) \
do { \
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
msg, cudaGetErrorString(__err), \
__FILE__, __LINE__); \
fprintf(stderr, "*** FAILED - ABORTING\n"); \
exit(1); \
} \
} while (0)
static
size_t
round_up_to_multiple
(
size_t
x
,
int
multiple
)
{
return
((
x
+
multiple
-
1
)
/
multiple
)
*
multiple
;
}
struct
Workspace
{
Workspace
(
size_t
size
)
:
size
(
size
),
data
(
NULL
)
{
auto
&
allocator
=
*::
c10
::
cuda
::
CUDACachingAllocator
::
get
();
dataPtr
=
allocator
.
allocate
(
size
);
data
=
dataPtr
.
get
();
}
Workspace
(
const
Workspace
&
)
=
delete
;
Workspace
(
Workspace
&&
)
=
default
;
Workspace
&
operator
=
(
Workspace
&&
)
=
default
;
~
Workspace
()
=
default
;
size_t
size
;
void
*
data
;
c10
::
DataPtr
dataPtr
;
};
// Return {y}
at
::
Tensor
nhwc_bn_fwd_train
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
const
bool
fuse_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
auto
memory_format
=
x
.
suggest_memory_format
();
const
bool
check_channels_last
=
x
.
is_contiguous
(
at
::
MemoryFormat
::
ChannelsLast
);
const
int
N
=
x
.
size
(
0
);
const
int
H
=
check_channels_last
?
x
.
size
(
2
)
:
x
.
size
(
1
);
const
int
W
=
check_channels_last
?
x
.
size
(
3
)
:
x
.
size
(
2
);
const
int
C
=
check_channels_last
?
x
.
size
(
1
)
:
x
.
size
(
3
);
// generating new magic number and use that for sync
int
*
magic
=
magic_tensor
.
DATA_PTR
<
int
>
();
*
magic
=
(
*
magic
+
1
)
&
0xff
;
// Allocate output tensor
at
::
Tensor
y
=
check_channels_last
?
at
::
empty
({
N
,
C
,
H
,
W
},
x
.
options
().
memory_format
(
memory_format
))
:
at
::
empty
({
N
,
H
,
W
,
C
},
x
.
options
());
// Create wrapper
NhwcBatchNorm
*
bn
=
new
NhwcBatchNorm
();
bn
->
setInputDescriptor
(
DNN_TENSOR_FORMAT
,
DNN_DATA_HALF
,
N
,
C
,
H
,
W
,
bn_group
);
bn
->
setOutputDescriptor
(
DNN_TENSOR_FORMAT
,
DNN_DATA_HALF
,
N
,
C
,
H
,
W
);
bn
->
setConstants
(
momentum
,
epsilon
);
// set pointers within the wrapper
bn
->
setInputOutputPointers
(
x
.
contiguous
(
memory_format
).
DATA_PTR
<
at
::
Half
>
(),
nullptr
,
y
.
contiguous
(
memory_format
).
DATA_PTR
<
at
::
Half
>
(),
nullptr
);
bn
->
setWeightPointers
({
scale
.
contiguous
().
DATA_PTR
<
float
>
(),
bias
.
contiguous
().
DATA_PTR
<
float
>
()},
{
nullptr
,
nullptr
});
bn
->
setParameterPointers
({
running_mean
.
contiguous
().
DATA_PTR
<
float
>
(),
running_inv_var
.
DATA_PTR
<
float
>
()});
// deal with workspace(s)
auto
workspace_bytes
=
bn
->
numWorkspaceBytes
();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t
total_workspace_bytes
=
0
;
std
::
vector
<
size_t
>
workspace_offsets
;
for
(
auto
index
=
3
;
index
<
workspace_bytes
.
size
();
++
index
)
{
total_workspace_bytes
=
round_up_to_multiple
(
total_workspace_bytes
,
512
);
workspace_offsets
.
push_back
(
total_workspace_bytes
);
auto
alloc_bytes
=
workspace_bytes
[
index
];
total_workspace_bytes
+=
alloc_bytes
;
}
// Allocate the workspace
Workspace
ws
(
total_workspace_bytes
);
std
::
vector
<
void
*>
workspace
;
workspace
.
push_back
(
minibatch_mean
.
contiguous
().
DATA_PTR
<
float
>
());
workspace
.
push_back
(
minibatch_inv_var
.
contiguous
().
DATA_PTR
<
float
>
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
retired_cta_bytes
=
workspace_bytes
[
2
];
void
*
retired_ctas
=
ret_cta
.
contiguous
().
DATA_PTR
<
uint8_t
>
();
assert
(
ret_cta
.
size
(
0
)
>=
retired_cta_bytes
);
workspace
.
push_back
(
retired_ctas
);
for
(
auto
index
=
3
;
index
<
workspace_bytes
.
size
();
++
index
)
{
void
*
ptr
=
reinterpret_cast
<
uint8_t
*>
(
ws
.
data
)
+
workspace_offsets
[
index
-
3
];
workspace
.
push_back
(
ptr
);
}
bn
->
setWorkspacePointers
(
workspace
,
workspace_bytes
);
// Don't fuse in ReLU for now at least
bn
->
fwd
(
stream
,
fuse_relu
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
*
magic
,
occupancy
,
grid_dim_x
,
coop
);
return
y
.
contiguous
(
memory_format
);
}
at
::
Tensor
nhwc_bn_fwd_eval
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
int
bn_group
,
const
float
momentum
,
const
float
epsilon
,
const
bool
fuse_relu
)
{
const
bool
check_channels_last
=
x
.
is_contiguous
(
at
::
MemoryFormat
::
ChannelsLast
);
auto
memory_format
=
x
.
suggest_memory_format
();
const
int
N
=
x
.
size
(
0
);
const
int
H
=
check_channels_last
?
x
.
size
(
2
)
:
x
.
size
(
1
);
const
int
W
=
check_channels_last
?
x
.
size
(
3
)
:
x
.
size
(
2
);
const
int
C
=
check_channels_last
?
x
.
size
(
1
)
:
x
.
size
(
3
);
// Allocate output tensor
at
::
Tensor
y
=
check_channels_last
?
at
::
empty
({
N
,
C
,
H
,
W
},
x
.
options
().
memory_format
(
memory_format
))
:
at
::
empty
({
N
,
H
,
W
,
C
},
x
.
options
());
// Create wrapper
NhwcBatchNorm
*
bn
=
new
NhwcBatchNorm
();
bn
->
setInputDescriptor
(
DNN_TENSOR_FORMAT
,
DNN_DATA_HALF
,
N
,
C
,
H
,
W
,
bn_group
);
bn
->
setOutputDescriptor
(
DNN_TENSOR_FORMAT
,
DNN_DATA_HALF
,
N
,
C
,
H
,
W
);
bn
->
setConstants
(
momentum
,
epsilon
);
// set pointers within the wrapper
bn
->
setInputOutputPointers
(
x
.
contiguous
(
memory_format
).
DATA_PTR
<
at
::
Half
>
(),
nullptr
,
y
.
contiguous
(
memory_format
).
DATA_PTR
<
at
::
Half
>
(),
nullptr
);
bn
->
setWeightPointers
({
scale
.
contiguous
().
DATA_PTR
<
float
>
(),
bias
.
contiguous
().
DATA_PTR
<
float
>
()},
{
nullptr
,
nullptr
});
bn
->
setParameterPointers
({
running_mean
.
contiguous
().
DATA_PTR
<
float
>
(),
running_inv_var
.
contiguous
().
DATA_PTR
<
float
>
()});
// deal with workspace(s)
auto
workspace_bytes
=
bn
->
numWorkspaceBytes
();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t
total_workspace_bytes
=
0
;
std
::
vector
<
size_t
>
workspace_offsets
;
for
(
auto
index
=
3
;
index
<
workspace_bytes
.
size
();
++
index
)
{
total_workspace_bytes
=
round_up_to_multiple
(
total_workspace_bytes
,
512
);
workspace_offsets
.
push_back
(
total_workspace_bytes
);
auto
alloc_bytes
=
workspace_bytes
[
index
];
total_workspace_bytes
+=
alloc_bytes
;
}
// Allocate the workspace
Workspace
ws
(
total_workspace_bytes
);
std
::
vector
<
void
*>
workspace
;
workspace
.
push_back
(
nullptr
);
workspace
.
push_back
(
nullptr
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
retired_cta_bytes
=
workspace_bytes
[
2
];
void
*
retired_ctas
=
ret_cta
.
contiguous
().
DATA_PTR
<
uint8_t
>
();
assert
(
ret_cta
.
size
(
0
)
>=
retired_cta_bytes
);
workspace
.
push_back
(
retired_ctas
);
for
(
auto
index
=
3
;
index
<
workspace_bytes
.
size
();
++
index
)
{
void
*
ptr
=
reinterpret_cast
<
uint8_t
*>
(
ws
.
data
)
+
workspace_offsets
[
index
-
3
];
workspace
.
push_back
(
ptr
);
}
bn
->
setWorkspacePointers
(
workspace
,
workspace_bytes
);
// Don't fuse in ReLU for now at least
bn
->
fwdInference
(
stream
,
fuse_relu
);
return
y
.
contiguous
(
memory_format
);
}
std
::
vector
<
at
::
Tensor
>
nhwc_bn_bwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
dy
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
const
bool
fuse_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
// shape
const
bool
check_channels_last
=
x
.
is_contiguous
(
at
::
MemoryFormat
::
ChannelsLast
);
auto
memory_format
=
x
.
suggest_memory_format
();
const
int
N
=
x
.
size
(
0
);
const
int
H
=
check_channels_last
?
x
.
size
(
2
)
:
x
.
size
(
1
);
const
int
W
=
check_channels_last
?
x
.
size
(
3
)
:
x
.
size
(
2
);
const
int
C
=
check_channels_last
?
x
.
size
(
1
)
:
x
.
size
(
3
);
// generating new magic number and use that for sync
int
*
magic
=
magic_tensor
.
DATA_PTR
<
int
>
();
*
magic
=
(
*
magic
+
1
)
&
0xff
;
// outputs
at
::
Tensor
x_grad
,
scale_grad
,
bias_grad
;
// Allocate outputs
x_grad
=
check_channels_last
?
at
::
empty
({
N
,
C
,
H
,
W
},
dy
.
options
().
memory_format
(
memory_format
))
:
at
::
empty_like
(
x
);
scale_grad
=
at
::
empty_like
(
scale
);
bias_grad
=
at
::
empty_like
(
bias
);
// Create wrapper
NhwcBatchNorm
*
bn
=
new
NhwcBatchNorm
();
bn
->
setInputDescriptor
(
DNN_TENSOR_FORMAT
,
DNN_DATA_HALF
,
N
,
C
,
H
,
W
,
bn_group
);
bn
->
setOutputDescriptor
(
DNN_TENSOR_FORMAT
,
DNN_DATA_HALF
,
N
,
C
,
H
,
W
);
bn
->
setConstants
(
momentum
,
epsilon
);
// set pointers within the wrapper
bn
->
setInputOutputPointers
(
x
.
contiguous
(
memory_format
).
DATA_PTR
<
at
::
Half
>
(),
x_grad
.
contiguous
(
memory_format
).
DATA_PTR
<
at
::
Half
>
(),
nullptr
,
dy
.
contiguous
(
memory_format
).
DATA_PTR
<
at
::
Half
>
());
bn
->
setWeightPointers
({
scale
.
contiguous
().
DATA_PTR
<
float
>
(),
bias
.
contiguous
().
DATA_PTR
<
float
>
()},
{
scale_grad
.
DATA_PTR
<
float
>
(),
bias_grad
.
DATA_PTR
<
float
>
()});
bn
->
setParameterPointers
({
running_mean
.
contiguous
().
DATA_PTR
<
float
>
(),
running_inv_var
.
contiguous
().
DATA_PTR
<
float
>
()});
// deal with workspace(s)
auto
workspace_bytes
=
bn
->
numWorkspaceBytes
();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t
total_workspace_bytes
=
0
;
std
::
vector
<
size_t
>
workspace_offsets
;
for
(
auto
index
=
3
;
index
<
workspace_bytes
.
size
();
++
index
)
{
total_workspace_bytes
=
round_up_to_multiple
(
total_workspace_bytes
,
512
);
workspace_offsets
.
push_back
(
total_workspace_bytes
);
auto
alloc_bytes
=
workspace_bytes
[
index
];
total_workspace_bytes
+=
alloc_bytes
;
}
// Allocate the workspace
Workspace
ws
(
total_workspace_bytes
);
std
::
vector
<
void
*>
workspace
;
workspace
.
push_back
(
minibatch_mean
.
contiguous
().
DATA_PTR
<
float
>
());
workspace
.
push_back
(
minibatch_inv_var
.
contiguous
().
DATA_PTR
<
float
>
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
retired_cta_bytes
=
workspace_bytes
[
2
];
void
*
retired_ctas
=
ret_cta
.
contiguous
().
DATA_PTR
<
uint8_t
>
();
assert
(
ret_cta
.
size
(
0
)
>=
retired_cta_bytes
);
workspace
.
push_back
(
retired_ctas
);
for
(
auto
index
=
3
;
index
<
workspace_bytes
.
size
();
++
index
)
{
void
*
ptr
=
reinterpret_cast
<
uint8_t
*>
(
ws
.
data
)
+
workspace_offsets
[
index
-
3
];
workspace
.
push_back
(
ptr
);
}
bn
->
setWorkspacePointers
(
workspace
,
workspace_bytes
);
bn
->
dgrad
(
stream
,
fuse_relu
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
*
magic
,
occupancy
,
grid_dim_x
,
coop
);
return
std
::
vector
<
at
::
Tensor
>
{
x_grad
.
contiguous
(
memory_format
),
scale_grad
,
bias_grad
};
}
int
nhwc_bn_fwd_occupancy
()
{
int
device_id
=-
1
;
cudaGetDevice
(
&
device_id
);
//max occupancy supported by the code is 2
return
NhwcBatchNorm
::
smem_driven_fwd_occupancy
(
device_id
,
2
);
}
int
nhwc_bn_bwd_occupancy
()
{
int
device_id
=-
1
;
cudaGetDevice
(
&
device_id
);
//max occupancy supported by the code is 2
return
NhwcBatchNorm
::
smem_driven_bwd_occupancy
(
device_id
,
2
);
}
apex/contrib/csrc/groupbn/batch_norm.h
deleted
100644 → 0
View file @
2a4864d5
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file nhwc_batch_norm.h
* \brief CUDA NHWC Batch Normalization code
* \author Shankara Rao Thejaswi Nanditale, Dick Carter, Evgeni Krimer
*/
#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_
#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_
#include "dnn.h"
#include <algorithm>
#include <vector>
#include <string>
#include <iostream>
#include "nhwc_batch_norm_kernel.h"
#include "cuda_utils.h"
#include "c10/macros/Macros.h"
#define VERBOSE_DEFAULT false
class
NhwcBatchNorm
{
public:
NhwcBatchNorm
()
{
name_
=
"nhwc_batchnorm"
;
createTensorDescriptor
(
&
X_tensor_desc_
);
createTensorDescriptor
(
&
Y_tensor_desc_
);
}
~
NhwcBatchNorm
()
{
destroyTensorDescriptor
(
X_tensor_desc_
);
destroyTensorDescriptor
(
Y_tensor_desc_
);
}
void
die
()
{
std
::
cerr
<<
"batchnorm not initialized"
<<
std
::
endl
;
exit
(
-
1
);
}
void
fwd
(
cudaStream_t
stream
,
bool
use_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
void
dgrad
(
cudaStream_t
stream
,
bool
use_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
void
fwdInference
(
cudaStream_t
stream
,
bool
use_relu
);
dim3
calc_fwd_grid
(
int
*
loop
,
const
int
grid_dim_x
);
dim3
calc_bwd_grid
(
int
*
loop
,
const
int
grid_dim_x
);
void
setInputDescriptor
(
const
dnnTensorFormat_t
format
,
const
dnnDataType_t
data_type
,
int
n
,
int
c
,
int
h
,
int
w
,
int
bn_group
)
{
m_
=
n
*
h
*
w
;
int
m_bn_adjusted
=
m_
*
bn_group
;
c_
=
c
;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
svar_inv_count_
=
1.
f
/
m_bn_adjusted
;
// factor to scale sum of squared errors to get running variance. Should be 1/(nhw-1).
int
divisor
=
m_bn_adjusted
-
1
;
// nhw == 1 is unlikely, but by setting the rvar_inv_count_ == 1.f, we avoid running var infs.
rvar_inv_count_
=
divisor
==
0
?
1.
f
:
1.
f
/
divisor
;
setTensorDescriptor
(
X_tensor_desc_
,
format
,
data_type
,
n
,
c
,
h
,
w
);
}
void
setOutputDescriptor
(
const
dnnTensorFormat_t
format
,
const
dnnDataType_t
data_type
,
int
n
,
int
c
,
int
h
,
int
w
)
{
setTensorDescriptor
(
Y_tensor_desc_
,
format
,
data_type
,
n
,
c
,
h
,
w
);
}
const
std
::
vector
<
size_t
>
numWorkspaceBytes
()
const
;
void
setWorkspacePointers
(
const
std
::
vector
<
void
*>&
workspace
,
const
std
::
vector
<
size_t
>&
num_workspace_bytes
);
void
setInputOutputPointers
(
void
*
X
,
void
*
dX
,
void
*
Y
,
void
*
dY
)
{
X_
=
X
;
dX_
=
dX
;
Y_
=
Y
;
dY_
=
dY
;
}
// Sets the pointers for the scale and weight (in that order) data and derivative buffers.
void
setWeightPointers
(
const
std
::
vector
<
void
*>&
weight_pointers
,
const
std
::
vector
<
void
*>&
deriv_pointers
)
{
assert
(
weight_pointers
.
size
()
==
2
);
assert
(
deriv_pointers
.
size
()
==
2
);
scale_
=
static_cast
<
float
*>
(
weight_pointers
[
0
]);
bias_
=
static_cast
<
float
*>
(
weight_pointers
[
1
]);
dscale_
=
static_cast
<
float
*>
(
deriv_pointers
[
0
]);
dbias_
=
static_cast
<
float
*>
(
deriv_pointers
[
1
]);
}
// Sets the pointers for the population mean and variance buffers, in that order.
void
setParameterPointers
(
const
std
::
vector
<
void
*>&
param_pointers
)
{
assert
(
param_pointers
.
size
()
==
2
);
population_mean_
=
static_cast
<
float
*>
(
param_pointers
[
0
]);
population_variance_
=
static_cast
<
float
*>
(
param_pointers
[
1
]);
}
void
setConstants
(
const
double
exp_avg_factor
,
const
double
eps
)
{
exp_avg_factor_
=
exp_avg_factor
;
eps_
=
eps
;
}
void
processCudnnStatus
(
const
dnnStatus_t
&
status
,
const
std
::
string
&
string
=
std
::
string
(),
bool
verbose
=
VERBOSE_DEFAULT
)
{
#ifdef __HIP_PLATFORM_HCC__
if
(
status
!=
DNN_STATUS_SUCCESS
)
LOG
(
FATAL
)
<<
string
<<
" "
<<
miopenGetErrorString
(
status
);
else
if
(
verbose
)
LOG
(
INFO
)
<<
string
<<
" "
<<
miopenGetErrorString
(
status
);
#else
if
(
status
!=
DNN_STATUS_SUCCESS
)
LOG
(
FATAL
)
<<
string
<<
" "
<<
cudnnGetErrorString
(
status
);
else
if
(
verbose
)
LOG
(
INFO
)
<<
string
<<
" "
<<
cudnnGetErrorString
(
status
);
#endif
}
void
checkCudaStatus
(
const
std
::
string
&
string
=
std
::
string
(),
bool
verbose
=
VERBOSE_DEFAULT
)
{
cudaError_t
status
=
cudaGetLastError
();
if
(
status
!=
cudaSuccess
)
LOG
(
FATAL
)
<<
string
<<
" "
<<
cudaGetErrorString
(
status
);
else
if
(
verbose
)
LOG
(
INFO
)
<<
string
<<
" "
<<
cudaGetErrorString
(
status
);
}
size_t
size_retired_ctas
(
int
grid_y
)
const
{
// Note that the value of max_grid_y to handle known GPUs is about 160.
const
int
max_grid_y
=
1024
;
if
(
grid_y
>
max_grid_y
)
LOG
(
INFO
)
<<
"GPU capabilities exceeds assumptions."
;
const
int
retired_cta_bytes
=
max_grid_y
*
2
*
sizeof
(
int
);
// Since the region will be initialized once and used for many kernels,
// the idea is to return an ample size that will cover all uses.
return
retired_cta_bytes
;
}
dnnTensorDescriptor_t
X_tensor_desc_
=
nullptr
;
dnnTensorDescriptor_t
Y_tensor_desc_
=
nullptr
;
void
*
X_
=
nullptr
;
void
*
dX_
=
nullptr
;
void
*
Y_
=
nullptr
;
void
*
dY_
=
nullptr
;
// Learned scale and bias weights.
float
*
scale_
=
nullptr
;
float
*
dscale_
=
nullptr
;
float
*
bias_
=
nullptr
;
float
*
dbias_
=
nullptr
;
// Computed population mean and variance parameters.
float
*
population_mean_
=
nullptr
;
float
*
population_variance_
=
nullptr
;
// Workspace buffers for minibatch mean and variance (computed in fwd, needed by bwd).
float
*
minibatch_mean_
=
nullptr
;
float
*
minibatch_variance_
=
nullptr
;
int
m_
=
0
;
// Number of values per channel that BN is normalizing.
int
c_
=
0
;
// Number of channels over which BN is normalizing.
float
svar_inv_count_
=
0.
f
;
// factor to scale sum of squared errors to get saved variance
float
rvar_inv_count_
=
0.
f
;
// factor to scale sum of squared errors to get running variance
double
exp_avg_factor_
=
0.
;
double
eps_
=
0.
;
std
::
string
name_
;
private:
void
setTensorDescriptor
(
dnnTensorDescriptor_t
descriptor
,
dnnTensorFormat_t
format
,
dnnDataType_t
data_type
,
int
n
,
int
c
,
int
h
,
int
w
)
{
dnnStatus_t
status
=
DNN_STATUS_SUCCESS
;
#ifdef __HIP_PLATFORM_HCC__
status
=
miopenSet4dTensorDescriptor
(
descriptor
,
data_type
,
n
,
c
,
h
,
w
);
#else
status
=
cudnnSetTensor4dDescriptor
(
descriptor
,
format
,
data_type
,
n
,
c
,
h
,
w
);
#endif
processCudnnStatus
(
status
,
"set tensor descriptor"
);
}
void
createTensorDescriptor
(
dnnTensorDescriptor_t
*
descriptor
)
{
dnnStatus_t
status
=
DNN_STATUS_SUCCESS
;
#ifdef __HIP_PLATFORM_HCC__
status
=
miopenCreateTensorDescriptor
(
descriptor
);
#else
status
=
cudnnCreateTensorDescriptor
(
descriptor
);
#endif
processCudnnStatus
(
status
,
"create tensor_descriptor"
);
}
void
destroyTensorDescriptor
(
dnnTensorDescriptor_t
descriptor
)
{
dnnStatus_t
status
=
DNN_STATUS_SUCCESS
;
#ifdef __HIP_PLATFORM_HCC__
status
=
miopenDestroyTensorDescriptor
(
descriptor
);
#else
status
=
cudnnDestroyTensorDescriptor
(
descriptor
);
#endif
processCudnnStatus
(
status
,
"destroy tensor_descriptor"
);
}
protected:
float
*
partial_sums_
=
nullptr
;
int
*
partial_counts_
=
nullptr
;
int
*
retired_ctas_
=
nullptr
;
void
_setFwdParams
(
NhwcBatchNormFwdParams
*
params
)
const
;
void
_setFwdInferenceParams
(
NhwcBatchNormFwdInferenceParams
*
params
)
const
;
void
_setBwdParams
(
NhwcBatchNormBwdParams
*
params
)
const
;
// @todo: ability to configure these?
// Kernel params
static
const
int
USE_ONLINE_APPROACH
=
1
;
static
const
int
THREADS_PER_CTA
=
512
;
static
const
int
THREADS_PER_PIXEL
=
32
;
static
const
int
C_ELEMENTS_PER_CTA
=
128
;
static
const
int
ELEMENTS_PER_LDG
=
C_ELEMENTS_PER_CTA
/
THREADS_PER_PIXEL
;
static
const
int
MAX_SMEM_WITHOUT_OPT_IN
=
48
*
1024
;
typedef
uint16_t
StorageType
;
//typedef float StorageType;
// increasing this to 6 causes spills in fwd kernel!
static
const
int
PIXELS_PER_THREAD_IN_REGISTERS_FWD
=
1
;
static
const
int
PIXELS_PER_THREAD_IN_REGISTERS_BWD
=
1
;
static
const
int
PIXELS_PER_THREAD_IN_SMEM_FWD
=
0
;
static
const
int
PIXELS_PER_THREAD_IN_SMEM_BWD
=
0
;
static
const
int
PIXELS_PER_THREAD_FWD
=
PIXELS_PER_THREAD_IN_REGISTERS_FWD
+
\
PIXELS_PER_THREAD_IN_SMEM_FWD
;
static
const
int
PIXELS_PER_THREAD_BWD
=
PIXELS_PER_THREAD_IN_REGISTERS_BWD
+
\
PIXELS_PER_THREAD_IN_SMEM_BWD
;
static
const
int
PIXELS_PER_THREAD_FWD_INFERENCE
=
4
;
// Derived params
static
const
size_t
SMEM_SIZE_FWD
=
PIXELS_PER_THREAD_IN_SMEM_FWD
*
THREADS_PER_CTA
*
\
ELEMENTS_PER_LDG
*
sizeof
(
StorageType
);
static
const
size_t
SMEM_SIZE_BWD
=
PIXELS_PER_THREAD_IN_SMEM_BWD
*
THREADS_PER_CTA
*
\
ELEMENTS_PER_LDG
*
2
*
sizeof
(
StorageType
);
static
const
int
PIXELS_PER_LDG
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
;
static
const
int
PIXELS_PER_CTA_FWD
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
*
\
PIXELS_PER_THREAD_FWD
;
static
const
int
PIXELS_PER_CTA_BWD
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
*
\
PIXELS_PER_THREAD_BWD
;
static
const
int
PIXELS_PER_CTA_FWD_INFERENCE
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
*
\
PIXELS_PER_THREAD_FWD_INFERENCE
;
// max grid.y in case of group bn is limited by exchange buffer size
static
const
int
MAX_GBN_BLOCK_Y
=
256
;
// Helper function to launch the forward kernel.
// We calculate (based on smem usage) the achievable occupancy and make sure we run a kernel
// version that was compiled with that occupancy in its launch bounds. This way, we avoid
// needless register spills.
void
_fwdKernelLauncher
(
cudaStream_t
stream
,
NhwcBatchNormFwdParams
params
,
dim3
grid_dim
,
int
outer_loops
,
bool
use_relu
,
const
int
occupancy
,
const
bool
coop
)
{
#ifdef __HIP_PLATFORM_HCC__
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto fwd_func = nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
hipFuncSetAttribute((void *) fwd_func, hipFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(¶ms); \
using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
hipLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_FWD, \
stream); \
} else { \
hipLaunchKernel((void *) fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_FWD, \
stream); \
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
#else
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto fwd_func = nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(¶ms); \
using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_FWD, \
stream); \
} else { \
cudaLaunchKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_FWD, \
stream); \
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
#endif
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if
(
outer_loops
==
1
&&
use_relu
)
{
if
(
occupancy
>=
2
)
LAUNCH_FWD_KERNEL
(
1
,
true
,
false
,
2
,
coop
);
else
LAUNCH_FWD_KERNEL
(
1
,
true
,
false
,
1
,
coop
);
}
else
if
(
outer_loops
==
1
&&
!
use_relu
)
{
if
(
occupancy
>=
2
)
LAUNCH_FWD_KERNEL
(
1
,
false
,
false
,
2
,
coop
);
else
LAUNCH_FWD_KERNEL
(
1
,
false
,
false
,
1
,
coop
);
}
else
if
(
use_relu
)
{
if
(
occupancy
>=
2
)
LAUNCH_FWD_KERNEL
(
0
,
true
,
false
,
2
,
coop
);
else
LAUNCH_FWD_KERNEL
(
0
,
true
,
false
,
1
,
coop
);
}
else
{
if
(
occupancy
>=
2
)
LAUNCH_FWD_KERNEL
(
0
,
false
,
false
,
2
,
coop
);
else
LAUNCH_FWD_KERNEL
(
0
,
false
,
false
,
1
,
coop
);
}
#undef LAUNCH_FWD_KERNEL
}
// Helper function to launch the backward kernel.
void
_bwdKernelLauncher
(
cudaStream_t
stream
,
NhwcBatchNormBwdParams
params
,
dim3
grid_dim
,
int
outer_loops
,
bool
use_relu
,
const
int
occupancy
,
const
bool
coop
)
{
#ifdef __HIP_PLATFORM_HCC__
#define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto bwd_func = nhwc_batch_norm_bwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
hipFuncSetAttribute((void *) bwd_func, hipFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " bwd coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(¶ms); \
using BWD_FUNC = decltype(nhwc_batch_norm_bwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
hipLaunchCooperativeKernel<BWD_FUNC>(bwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
hipLaunchKernel((void *) bwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd coop serial kernel"); \
} while (0)
#define LAUNCH_BWD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto bwd_relu_func = nhwc_batch_norm_bwd_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
hipFuncSetAttribute((void *) bwd_relu_func, hipFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " bwd-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(¶ms); \
using BWD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
hipLaunchCooperativeKernel<BWD_RELU_FUNC>(bwd_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
hipLaunchKernel((void *) bwd_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd-relu coop serial kernel"); \
} while (0)
#else
#define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto bwd_func = nhwc_batch_norm_bwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(bwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " bwd coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(¶ms); \
using BWD_FUNC = decltype(nhwc_batch_norm_bwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
cudaLaunchCooperativeKernel<BWD_FUNC>(bwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
cudaLaunchKernel<BWD_FUNC>(bwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd coop serial kernel"); \
} while (0)
#define LAUNCH_BWD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto bwd_relu_func = nhwc_batch_norm_bwd_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(bwd_relu_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " bwd-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(¶ms); \
using BWD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
cudaLaunchCooperativeKernel<BWD_RELU_FUNC>(bwd_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
cudaLaunchKernel<BWD_RELU_FUNC>(bwd_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd-relu coop serial kernel"); \
} while (0)
#endif
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if
(
outer_loops
==
1
&&
use_relu
)
{
if
(
occupancy
>=
2
)
LAUNCH_BWD_RELU_KERNEL
(
1
,
2
,
coop
);
else
LAUNCH_BWD_RELU_KERNEL
(
1
,
1
,
coop
);
}
else
if
(
outer_loops
==
1
&&
!
use_relu
)
{
if
(
occupancy
>=
2
)
LAUNCH_BWD_KERNEL
(
1
,
2
,
coop
);
else
LAUNCH_BWD_KERNEL
(
1
,
1
,
coop
);
}
else
if
(
use_relu
)
{
if
(
occupancy
>=
2
)
LAUNCH_BWD_RELU_KERNEL
(
0
,
2
,
coop
);
else
LAUNCH_BWD_RELU_KERNEL
(
0
,
1
,
coop
);
}
else
{
if
(
occupancy
>=
2
)
LAUNCH_BWD_KERNEL
(
0
,
2
,
coop
);
else
LAUNCH_BWD_KERNEL
(
0
,
1
,
coop
);
}
#undef LAUNCH_BWD_KERNEL
}
public:
// Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.
static
int
smem_driven_fwd_occupancy
(
int
device_id
,
const
int
max_cta_per_sm
)
{
using
namespace
at
::
cuda
::
utils
;
int
fwd_reduction_bytes
=
THREADS_PER_PIXEL
*
(
THREADS_PER_CTA
/
C10_WARP_SIZE
)
*
ELEMENTS_PER_LDG
*
sizeof
(
float
);
int
fwd_smem_bytes
=
SMEM_SIZE_FWD
+
fwd_reduction_bytes
;
int
occupancy
=
MaxSharedMemoryPerMultiprocessor
(
device_id
)
/
fwd_smem_bytes
;
return
std
::
min
(
max_cta_per_sm
,
occupancy
);
}
// Calculate the expected bwd kernel occupancy, as dictated by shared memory usage.
static
int
smem_driven_bwd_occupancy
(
int
device_id
,
const
int
max_cta_per_sm
)
{
using
namespace
at
::
cuda
::
utils
;
int
bwd_reduction_bytes
=
THREADS_PER_PIXEL
*
(
THREADS_PER_CTA
/
C10_WARP_SIZE
)
*
ELEMENTS_PER_LDG
*
sizeof
(
float
);
int
bwd_smem_bytes
=
SMEM_SIZE_BWD
+
bwd_reduction_bytes
;
int
occupancy
=
MaxSharedMemoryPerMultiprocessor
(
device_id
)
/
bwd_smem_bytes
;
return
std
::
min
(
max_cta_per_sm
,
occupancy
);
}
};
const
std
::
vector
<
size_t
>
NhwcBatchNorm
::
numWorkspaceBytes
()
const
{
assert
(
c_
>
0
);
// choose the max memory required between fwd/bwd passes
int
grid_x_fwd
=
div_up
(
m_
,
PIXELS_PER_CTA_FWD
);
int
grid_x_bwd
=
div_up
(
m_
,
PIXELS_PER_CTA_BWD
);
int
grid_x
=
max
(
grid_x_fwd
,
grid_x_bwd
);
int
grid_y
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
const
size_t
num_mean_bytes
=
c_
*
sizeof
(
float
);
const
size_t
num_variance_bytes
=
num_mean_bytes
;
const
size_t
size_sums
=
grid_y
*
grid_x
*
THREADS_PER_PIXEL
*
\
ELEMENTS_PER_LDG
*
2
*
sizeof
(
float
);
const
size_t
size_counts
=
grid_y
*
grid_x
*
sizeof
(
int
);
return
{
num_mean_bytes
,
num_variance_bytes
,
size_retired_ctas
(
grid_y
),
size_sums
,
size_counts
};
}
void
NhwcBatchNorm
::
setWorkspacePointers
(
const
std
::
vector
<
void
*>&
workspace
,
const
std
::
vector
<
size_t
>&
num_workspace_bytes
)
{
assert
(
workspace
.
size
()
==
5
);
assert
(
num_workspace_bytes
.
size
()
==
5
);
minibatch_mean_
=
static_cast
<
float
*>
(
workspace
[
0
]);
minibatch_variance_
=
static_cast
<
float
*>
(
workspace
[
1
]);
retired_ctas_
=
static_cast
<
int
*>
(
workspace
[
2
]);
partial_sums_
=
static_cast
<
float
*>
(
workspace
[
3
]);
partial_counts_
=
static_cast
<
int
*>
(
workspace
[
4
]);
}
void
NhwcBatchNorm
::
_setFwdParams
(
NhwcBatchNormFwdParams
*
params
)
const
{
params
->
gmem_src
=
static_cast
<
uint16_t
*>
(
X_
);
params
->
gmem_dst
=
static_cast
<
uint16_t
*>
(
Y_
);
params
->
gmem_src1
=
nullptr
;
params
->
gmem_bias
=
bias_
;
params
->
gmem_scale
=
scale_
;
params
->
gmem_running_mean
=
population_mean_
;
params
->
gmem_running_var
=
population_variance_
;
params
->
gmem_saved_mean
=
minibatch_mean_
;
params
->
gmem_saved_var
=
minibatch_variance_
;
params
->
gmem_relu_bitmask
=
nullptr
;
params
->
nhw
=
m_
;
params
->
c
=
c_
;
params
->
svar_inv_count
=
svar_inv_count_
;
params
->
rvar_inv_count
=
rvar_inv_count_
;
params
->
gmem_sums
=
partial_sums_
;
params
->
gmem_counts
=
partial_counts_
;
params
->
gmem_retired_ctas
=
retired_ctas_
;
params
->
var_eps
=
eps_
;
params
->
outer_loops
=
0
;
params
->
exp_avg_factor
=
static_cast
<
float
>
(
exp_avg_factor_
);
params
->
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
}
void
NhwcBatchNorm
::
_setFwdInferenceParams
(
NhwcBatchNormFwdInferenceParams
*
params
)
const
{
params
->
gmem_src
=
static_cast
<
uint16_t
*>
(
X_
);
params
->
gmem_dst
=
static_cast
<
uint16_t
*>
(
Y_
);
params
->
gmem_src1
=
nullptr
;
params
->
gmem_bias
=
bias_
;
params
->
gmem_scale
=
scale_
;
params
->
gmem_mean
=
population_mean_
;
params
->
gmem_var
=
population_variance_
;
params
->
nhw
=
m_
;
params
->
c
=
c_
;
params
->
var_eps
=
eps_
;
}
void
NhwcBatchNorm
::
_setBwdParams
(
NhwcBatchNormBwdParams
*
params
)
const
{
params
->
gmem_src
=
static_cast
<
uint16_t
*>
(
X_
);
params
->
gmem_dy
=
static_cast
<
uint16_t
*>
(
dY_
);
params
->
gmem_dst
=
static_cast
<
uint16_t
*>
(
dX_
);
params
->
gmem_dst1
=
nullptr
;
params
->
gmem_relu_bitmask
=
nullptr
;
params
->
gmem_dscale
=
dscale_
;
params
->
gmem_dbias
=
dbias_
;
params
->
gmem_scale
=
scale_
;
params
->
gmem_bias
=
bias_
;
params
->
gmem_saved_mean
=
minibatch_mean_
;
params
->
gmem_saved_var
=
minibatch_variance_
;
params
->
nhw
=
m_
;
params
->
c
=
c_
;
params
->
svar_inv_count
=
svar_inv_count_
;
params
->
gmem_sums
=
partial_sums_
;
params
->
gmem_retired_ctas
=
retired_ctas_
;
params
->
outer_loops
=
0
;
params
->
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
}
void
NhwcBatchNorm
::
fwdInference
(
cudaStream_t
stream
,
bool
use_relu
)
{
bool
ptrs_are_set
=
X_tensor_desc_
!=
nullptr
&&
Y_tensor_desc_
!=
nullptr
&&
scale_
!=
nullptr
&&
bias_
!=
nullptr
// && minibatch_mean_ != nullptr
// && minibatch_variance_ != nullptr
&&
population_mean_
!=
nullptr
&&
population_variance_
!=
nullptr
&&
X_
!=
nullptr
// && dX_ != nullptr
&&
Y_
!=
nullptr
// && dY_ != nullptr
// && dscale_ != nullptr
// && dbias_ != nullptr
&&
partial_sums_
!=
nullptr
&&
partial_counts_
!=
nullptr
;
if
(
!
ptrs_are_set
)
die
();
dim3
grid_dim
;
grid_dim
.
x
=
div_up
(
m_
,
PIXELS_PER_CTA_FWD_INFERENCE
);
grid_dim
.
y
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
// @todo: maybe just move this inside initialize routine?
NhwcBatchNormFwdInferenceParams
params
;
_setFwdInferenceParams
(
&
params
);
if
(
use_relu
)
{
nhwc_batch_norm_fwd_inference
<
StorageType
,
THREADS_PER_CTA
,
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
,
true
,
false
>
<<<
grid_dim
,
THREADS_PER_CTA
,
0
,
stream
>>>
(
params
);
checkCudaStatus
(
name_
+
" fwd_inference-relu kernel"
);
}
else
{
nhwc_batch_norm_fwd_inference
<
StorageType
,
THREADS_PER_CTA
,
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
,
false
,
false
>
<<<
grid_dim
,
THREADS_PER_CTA
,
0
,
stream
>>>
(
params
);
checkCudaStatus
(
name_
+
" fwd_inference kernel"
);
}
}
dim3
NhwcBatchNorm
::
calc_fwd_grid
(
int
*
loop
,
const
int
grid_dim_x
)
{
dim3
grid_dim
;
grid_dim
.
x
=
div_up
(
m_
,
PIXELS_PER_CTA_FWD
);
int
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
unsigned
int
max_grid_x
=
grid_dim_x
;
if
(
grid_dim
.
x
<=
max_grid_x
)
{
*
loop
=
1
;
if
(
max_grid_x
/
grid_dim
.
x
>
1
)
{
grid_dim
.
y
=
std
::
min
(
c_blks
,
static_cast
<
int
>
(
max_grid_x
/
grid_dim
.
x
));
assert
(
grid_dim
.
y
<
MAX_GBN_BLOCK_Y
);
//FIXME: turn into a loop
}
else
{
grid_dim
.
y
=
1
;
}
}
else
{
grid_dim
.
x
=
max_grid_x
;
grid_dim
.
y
=
1
;
int
nhw_in_regs
=
m_
-
PIXELS_PER_THREAD_IN_SMEM_FWD
*
PIXELS_PER_LDG
*
grid_dim
.
x
;
int
pixels_per_iteration
=
PIXELS_PER_THREAD_IN_REGISTERS_FWD
*
PIXELS_PER_LDG
*
grid_dim
.
x
;
*
loop
=
div_up
(
nhw_in_regs
,
pixels_per_iteration
);
}
return
grid_dim
;
}
dim3
NhwcBatchNorm
::
calc_bwd_grid
(
int
*
loop
,
const
int
grid_dim_x
)
{
dim3
grid_dim
;
grid_dim
.
x
=
div_up
(
m_
,
PIXELS_PER_CTA_BWD
);
int
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
unsigned
int
max_grid_x
=
grid_dim_x
;
if
(
grid_dim
.
x
<=
max_grid_x
)
{
*
loop
=
1
;
if
(
max_grid_x
/
grid_dim
.
x
>
1
)
{
grid_dim
.
y
=
std
::
min
(
c_blks
,
static_cast
<
int
>
(
max_grid_x
/
grid_dim
.
x
));
assert
(
grid_dim
.
y
<
MAX_GBN_BLOCK_Y
);
//FIXME: turn into a loop
}
else
{
grid_dim
.
y
=
1
;
}
}
else
{
grid_dim
.
x
=
max_grid_x
;
grid_dim
.
y
=
1
;
int
nhw_in_regs
=
m_
-
PIXELS_PER_THREAD_IN_SMEM_BWD
*
PIXELS_PER_LDG
*
grid_dim
.
x
;
int
pixels_per_iteration
=
PIXELS_PER_THREAD_IN_REGISTERS_BWD
*
PIXELS_PER_LDG
*
grid_dim
.
x
;
*
loop
=
div_up
(
nhw_in_regs
,
pixels_per_iteration
);
}
return
grid_dim
;
}
void
NhwcBatchNorm
::
fwd
(
cudaStream_t
stream
,
bool
use_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
bool
ptrs_are_set
=
X_tensor_desc_
!=
nullptr
&&
Y_tensor_desc_
!=
nullptr
&&
scale_
!=
nullptr
&&
bias_
!=
nullptr
&&
minibatch_mean_
!=
nullptr
&&
minibatch_variance_
!=
nullptr
&&
population_mean_
!=
nullptr
&&
population_variance_
!=
nullptr
&&
X_
!=
nullptr
// && dX_ != nullptr
&&
Y_
!=
nullptr
// && dY_ != nullptr
// && dscale_ != nullptr
// && dbias_ != nullptr
&&
partial_sums_
!=
nullptr
&&
partial_counts_
!=
nullptr
&&
retired_ctas_
!=
nullptr
;
if
(
!
ptrs_are_set
)
die
();
// reset of retired_cta_count no longer needed
NhwcBatchNormFwdParams
params
;
_setFwdParams
(
&
params
);
params
.
my_data
=
my_data
;
params
.
pair_datas
[
0
]
=
pair_data
;
params
.
pair_datas
[
1
]
=
pair_data2
;
params
.
pair_datas
[
2
]
=
pair_data3
;
params
.
magic
=
magic
;
params
.
sync_iters
=
(
bn_group
==
8
)
?
3
:
(
bn_group
>>
1
);
dim3
grid_dim
=
calc_fwd_grid
(
&
params
.
outer_loops
,
grid_dim_x
);
_fwdKernelLauncher
(
stream
,
params
,
grid_dim
,
params
.
outer_loops
,
use_relu
,
occupancy
,
coop
);
}
void
NhwcBatchNorm
::
dgrad
(
cudaStream_t
stream
,
bool
use_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
bool
ptrs_are_set
=
X_tensor_desc_
!=
nullptr
&&
Y_tensor_desc_
!=
nullptr
&&
scale_
!=
nullptr
&&
(
bias_
!=
nullptr
||
!
use_relu
)
&&
minibatch_mean_
!=
nullptr
&&
minibatch_variance_
!=
nullptr
// && population_mean_ != nullptr
// && population_variance_ != nullptr
&&
X_
!=
nullptr
&&
dX_
!=
nullptr
// && Y_ != nullptr
&&
dY_
!=
nullptr
&&
dscale_
!=
nullptr
&&
dbias_
!=
nullptr
;
if
(
!
ptrs_are_set
)
die
();
// reset of retired_cta_count no longer needed
NhwcBatchNormBwdParams
params
;
_setBwdParams
(
&
params
);
params
.
my_data
=
my_data
;
params
.
pair_datas
[
0
]
=
pair_data
;
params
.
pair_datas
[
1
]
=
pair_data2
;
params
.
pair_datas
[
2
]
=
pair_data3
;
params
.
magic
=
magic
;
params
.
sync_iters
=
(
bn_group
==
8
)
?
3
:
(
bn_group
>>
1
);
params
.
wgrad_coeff
=
1.0
/
bn_group
;
dim3
grid_dim
=
calc_bwd_grid
(
&
params
.
outer_loops
,
grid_dim_x
);
_bwdKernelLauncher
(
stream
,
params
,
grid_dim
,
params
.
outer_loops
,
use_relu
,
occupancy
,
coop
);
}
#endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_
apex/contrib/csrc/groupbn/batch_norm_add_relu.cu
deleted
100644 → 0
View file @
2a4864d5
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include "batch_norm_add_relu.h"
#include <cuda.h>
#include "compat.h"
//FIXME move the common stuff to common h file
#define cudaCheckErrors(msg) \
do { \
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
msg, cudaGetErrorString(__err), \
__FILE__, __LINE__); \
fprintf(stderr, "*** FAILED - ABORTING\n"); \
exit(1); \
} \
} while (0)
static
size_t
round_up_to_multiple
(
size_t
x
,
int
multiple
)
{
return
((
x
+
multiple
-
1
)
/
multiple
)
*
multiple
;
}
struct
Workspace
{
Workspace
(
size_t
size
)
:
size
(
size
),
data
(
NULL
)
{
auto
&
allocator
=
*::
c10
::
cuda
::
CUDACachingAllocator
::
get
();
dataPtr
=
allocator
.
allocate
(
size
);
data
=
dataPtr
.
get
();
}
Workspace
(
const
Workspace
&
)
=
delete
;
Workspace
(
Workspace
&&
)
=
default
;
Workspace
&
operator
=
(
Workspace
&&
)
=
default
;
~
Workspace
()
=
default
;
size_t
size
;
void
*
data
;
c10
::
DataPtr
dataPtr
;
};
// Return {y}
at
::
Tensor
nhwc_bn_addrelu_fwd_train
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
z
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
bitmask
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
auto
memory_format
=
x
.
suggest_memory_format
();
const
bool
check_channels_last
=
x
.
is_contiguous
(
at
::
MemoryFormat
::
ChannelsLast
);
const
int
N
=
x
.
size
(
0
);
const
int
H
=
check_channels_last
?
x
.
size
(
2
)
:
x
.
size
(
1
);
const
int
W
=
check_channels_last
?
x
.
size
(
3
)
:
x
.
size
(
2
);
const
int
C
=
check_channels_last
?
x
.
size
(
1
)
:
x
.
size
(
3
);
// generating new magic number and use that for sync
int
*
magic
=
magic_tensor
.
DATA_PTR
<
int
>
();
*
magic
=
(
*
magic
+
1
)
&
0xff
;
// Allocate output tensor
at
::
Tensor
y
=
check_channels_last
?
at
::
empty
({
N
,
C
,
H
,
W
},
x
.
options
().
memory_format
(
memory_format
))
:
at
::
empty
({
N
,
H
,
W
,
C
},
x
.
options
());
// Create wrapper
NhwcBatchNormAddRelu
*
bn
=
new
NhwcBatchNormAddRelu
();
bn
->
setInputDescriptor
(
DNN_TENSOR_FORMAT
,
DNN_DATA_HALF
,
N
,
C
,
H
,
W
,
bn_group
);
bn
->
setOutputDescriptor
(
DNN_TENSOR_FORMAT
,
DNN_DATA_HALF
,
N
,
C
,
H
,
W
);
bn
->
setConstants
(
momentum
,
epsilon
);
// set pointers within the wrapper
bn
->
setInputOutputPointers
(
x
.
contiguous
(
memory_format
).
DATA_PTR
<
at
::
Half
>
(),
nullptr
,
y
.
contiguous
(
memory_format
).
DATA_PTR
<
at
::
Half
>
(),
nullptr
,
z
.
contiguous
(
memory_format
).
DATA_PTR
<
at
::
Half
>
(),
nullptr
);
bn
->
setWeightPointers
({
scale
.
contiguous
().
DATA_PTR
<
float
>
(),
bias
.
contiguous
().
DATA_PTR
<
float
>
()},
{
nullptr
,
nullptr
});
bn
->
setParameterPointers
({
running_mean
.
contiguous
().
DATA_PTR
<
float
>
(),
running_inv_var
.
contiguous
().
DATA_PTR
<
float
>
()});
// deal with workspace(s)
auto
workspace_bytes
=
bn
->
numWorkspaceBytes
();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t
total_workspace_bytes
=
0
;
std
::
vector
<
size_t
>
workspace_offsets
;
for
(
auto
index
=
4
;
index
<
workspace_bytes
.
size
();
++
index
)
{
total_workspace_bytes
=
round_up_to_multiple
(
total_workspace_bytes
,
512
);
workspace_offsets
.
push_back
(
total_workspace_bytes
);
auto
alloc_bytes
=
workspace_bytes
[
index
];
total_workspace_bytes
+=
alloc_bytes
;
}
// Allocate the workspace
Workspace
ws
(
total_workspace_bytes
);
std
::
vector
<
void
*>
workspace
;
workspace
.
push_back
(
minibatch_mean
.
contiguous
().
DATA_PTR
<
float
>
());
workspace
.
push_back
(
minibatch_inv_var
.
contiguous
().
DATA_PTR
<
float
>
());
workspace
.
push_back
(
bitmask
.
contiguous
().
DATA_PTR
<
bitmask_pyt_t
>
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
retired_cta_bytes
=
workspace_bytes
[
3
];
void
*
retired_ctas
=
ret_cta
.
contiguous
().
DATA_PTR
<
uint8_t
>
();
assert
(
ret_cta
.
size
(
0
)
>=
retired_cta_bytes
);
workspace
.
push_back
(
retired_ctas
);
for
(
auto
index
=
4
;
index
<
workspace_bytes
.
size
();
++
index
)
{
void
*
ptr
=
reinterpret_cast
<
uint8_t
*>
(
ws
.
data
)
+
workspace_offsets
[
index
-
4
];
workspace
.
push_back
(
ptr
);
}
bn
->
setWorkspacePointers
(
workspace
,
workspace_bytes
);
// Don't fuse in ReLU for now at least
bn
->
fwd
(
stream
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
*
magic
,
occupancy
,
grid_dim_x
,
coop
);
return
y
.
contiguous
(
memory_format
);
}
at
::
Tensor
nhwc_bn_addrelu_fwd_eval
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
z
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
int
bn_group
,
const
float
momentum
,
const
float
epsilon
)
{
auto
memory_format
=
x
.
suggest_memory_format
();
const
bool
check_channels_last
=
x
.
is_contiguous
(
at
::
MemoryFormat
::
ChannelsLast
);
const
int
N
=
x
.
size
(
0
);
const
int
H
=
check_channels_last
?
x
.
size
(
2
)
:
x
.
size
(
1
);
const
int
W
=
check_channels_last
?
x
.
size
(
3
)
:
x
.
size
(
2
);
const
int
C
=
check_channels_last
?
x
.
size
(
1
)
:
x
.
size
(
3
);
// Allocate output tensor
at
::
Tensor
y
=
check_channels_last
?
at
::
empty
({
N
,
C
,
H
,
W
},
x
.
options
().
memory_format
(
memory_format
))
:
at
::
empty
({
N
,
H
,
W
,
C
},
x
.
options
());
// Create wrapper
NhwcBatchNormAddRelu
*
bn
=
new
NhwcBatchNormAddRelu
();
bn
->
setInputDescriptor
(
DNN_TENSOR_FORMAT
,
DNN_DATA_HALF
,
N
,
C
,
H
,
W
,
bn_group
);
bn
->
setOutputDescriptor
(
DNN_TENSOR_FORMAT
,
DNN_DATA_HALF
,
N
,
C
,
H
,
W
);
bn
->
setConstants
(
momentum
,
epsilon
);
// set pointers within the wrapper
bn
->
setInputOutputPointers
(
x
.
contiguous
(
memory_format
).
DATA_PTR
<
at
::
Half
>
(),
nullptr
,
y
.
contiguous
(
memory_format
).
DATA_PTR
<
at
::
Half
>
(),
nullptr
,
z
.
contiguous
(
memory_format
).
DATA_PTR
<
at
::
Half
>
(),
nullptr
);
bn
->
setWeightPointers
({
scale
.
contiguous
().
DATA_PTR
<
float
>
(),
bias
.
contiguous
().
DATA_PTR
<
float
>
()},
{
nullptr
,
nullptr
});
bn
->
setParameterPointers
({
running_mean
.
contiguous
().
DATA_PTR
<
float
>
(),
running_inv_var
.
contiguous
().
DATA_PTR
<
float
>
()});
// deal with workspace(s)
auto
workspace_bytes
=
bn
->
numWorkspaceBytes
();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t
total_workspace_bytes
=
0
;
std
::
vector
<
size_t
>
workspace_offsets
;
for
(
auto
index
=
4
;
index
<
workspace_bytes
.
size
();
++
index
)
{
total_workspace_bytes
=
round_up_to_multiple
(
total_workspace_bytes
,
512
);
workspace_offsets
.
push_back
(
total_workspace_bytes
);
auto
alloc_bytes
=
workspace_bytes
[
index
];
total_workspace_bytes
+=
alloc_bytes
;
}
// Allocate the workspace
Workspace
ws
(
total_workspace_bytes
);
std
::
vector
<
void
*>
workspace
;
workspace
.
push_back
(
nullptr
);
workspace
.
push_back
(
nullptr
);
workspace
.
push_back
(
nullptr
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
retired_cta_bytes
=
workspace_bytes
[
3
];
void
*
retired_ctas
=
ret_cta
.
contiguous
().
DATA_PTR
<
uint8_t
>
();
assert
(
ret_cta
.
size
(
0
)
>=
retired_cta_bytes
);
workspace
.
push_back
(
retired_ctas
);
for
(
auto
index
=
4
;
index
<
workspace_bytes
.
size
();
++
index
)
{
void
*
ptr
=
reinterpret_cast
<
uint8_t
*>
(
ws
.
data
)
+
workspace_offsets
[
index
-
4
];
workspace
.
push_back
(
ptr
);
}
bn
->
setWorkspacePointers
(
workspace
,
workspace_bytes
);
// Don't fuse in ReLU for now at least
bn
->
fwdInference
(
stream
);
return
y
.
contiguous
(
memory_format
);
}
std
::
vector
<
at
::
Tensor
>
nhwc_bn_addrelu_bwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
dy
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
bitmask
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
// shape
auto
memory_format
=
x
.
suggest_memory_format
();
const
bool
check_channels_last
=
x
.
is_contiguous
(
at
::
MemoryFormat
::
ChannelsLast
);
const
int
N
=
x
.
size
(
0
);
const
int
H
=
check_channels_last
?
x
.
size
(
2
)
:
x
.
size
(
1
);
const
int
W
=
check_channels_last
?
x
.
size
(
3
)
:
x
.
size
(
2
);
const
int
C
=
check_channels_last
?
x
.
size
(
1
)
:
x
.
size
(
3
);
// generating new magic number and use that for sync
int
*
magic
=
magic_tensor
.
DATA_PTR
<
int
>
();
*
magic
=
(
*
magic
+
1
)
&
0xff
;
// outputs
at
::
Tensor
x_grad
,
z_grad
,
scale_grad
,
bias_grad
;
// Allocate outputs
x_grad
=
check_channels_last
?
at
::
empty
({
N
,
C
,
H
,
W
},
dy
.
options
().
memory_format
(
memory_format
))
:
at
::
empty_like
(
x
);
z_grad
=
check_channels_last
?
at
::
empty
({
N
,
C
,
H
,
W
},
dy
.
options
().
memory_format
(
memory_format
))
:
at
::
empty_like
(
x
);
scale_grad
=
at
::
empty_like
(
scale
);
bias_grad
=
at
::
empty_like
(
bias
);
// Create wrapper
NhwcBatchNormAddRelu
*
bn
=
new
NhwcBatchNormAddRelu
();
bn
->
setInputDescriptor
(
DNN_TENSOR_FORMAT
,
DNN_DATA_HALF
,
N
,
C
,
H
,
W
,
bn_group
);
bn
->
setOutputDescriptor
(
DNN_TENSOR_FORMAT
,
DNN_DATA_HALF
,
N
,
C
,
H
,
W
);
bn
->
setConstants
(
momentum
,
epsilon
);
// set pointers within the wrapper
bn
->
setInputOutputPointers
(
x
.
contiguous
(
memory_format
).
DATA_PTR
<
at
::
Half
>
(),
x_grad
.
contiguous
(
memory_format
).
DATA_PTR
<
at
::
Half
>
(),
nullptr
,
dy
.
contiguous
(
memory_format
).
DATA_PTR
<
at
::
Half
>
(),
nullptr
,
z_grad
.
contiguous
(
memory_format
).
DATA_PTR
<
at
::
Half
>
());
bn
->
setWeightPointers
({
scale
.
contiguous
().
DATA_PTR
<
float
>
(),
bias
.
contiguous
().
DATA_PTR
<
float
>
()},
{
scale_grad
.
DATA_PTR
<
float
>
(),
bias_grad
.
DATA_PTR
<
float
>
()});
bn
->
setParameterPointers
({
running_mean
.
contiguous
().
DATA_PTR
<
float
>
(),
running_inv_var
.
contiguous
().
DATA_PTR
<
float
>
()});
// deal with workspace(s)
auto
workspace_bytes
=
bn
->
numWorkspaceBytes
();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t
total_workspace_bytes
=
0
;
std
::
vector
<
size_t
>
workspace_offsets
;
for
(
auto
index
=
4
;
index
<
workspace_bytes
.
size
();
++
index
)
{
total_workspace_bytes
=
round_up_to_multiple
(
total_workspace_bytes
,
512
);
workspace_offsets
.
push_back
(
total_workspace_bytes
);
auto
alloc_bytes
=
workspace_bytes
[
index
];
total_workspace_bytes
+=
alloc_bytes
;
}
// Allocate the workspace
Workspace
ws
(
total_workspace_bytes
);
std
::
vector
<
void
*>
workspace
;
workspace
.
push_back
(
minibatch_mean
.
contiguous
().
DATA_PTR
<
float
>
());
workspace
.
push_back
(
minibatch_inv_var
.
contiguous
().
DATA_PTR
<
float
>
());
workspace
.
push_back
(
bitmask
.
contiguous
().
DATA_PTR
<
bitmask_pyt_t
>
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
retired_cta_bytes
=
workspace_bytes
[
3
];
void
*
retired_ctas
=
ret_cta
.
contiguous
().
DATA_PTR
<
uint8_t
>
();
assert
(
ret_cta
.
size
(
0
)
>=
retired_cta_bytes
);
workspace
.
push_back
(
retired_ctas
);
for
(
auto
index
=
4
;
index
<
workspace_bytes
.
size
();
++
index
)
{
void
*
ptr
=
reinterpret_cast
<
uint8_t
*>
(
ws
.
data
)
+
workspace_offsets
[
index
-
4
];
workspace
.
push_back
(
ptr
);
}
bn
->
setWorkspacePointers
(
workspace
,
workspace_bytes
);
bn
->
dgrad
(
stream
,
my_data
,
pair_data
,
pair_data2
,
pair_data3
,
bn_group
,
*
magic
,
occupancy
,
grid_dim_x
,
coop
);
return
std
::
vector
<
at
::
Tensor
>
{
x_grad
.
contiguous
(
memory_format
),
z_grad
.
contiguous
(
memory_format
),
scale_grad
,
bias_grad
};
}
int
nhwc_bn_addrelu_fwd_occupancy
()
{
int
device_id
=-
1
;
cudaGetDevice
(
&
device_id
);
//max occupancy supported by the code is 2
return
NhwcBatchNormAddRelu
::
smem_driven_fwd_occupancy
(
device_id
,
2
);
}
int
nhwc_bn_addrelu_bwd_occupancy
()
{
int
device_id
=-
1
;
cudaGetDevice
(
&
device_id
);
//max occupancy supported by the code is 2
return
NhwcBatchNormAddRelu
::
smem_driven_bwd_occupancy
(
device_id
,
2
);
}
apex/contrib/csrc/groupbn/batch_norm_add_relu.h
deleted
100644 → 0
View file @
2a4864d5
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file nhwc_batch_norm_add_relu.h
* \brief CUDA NHWC Batch Normalization code with fused addition
* \author Shankara Rao Thejaswi Nanditale, Dick Carter, Maxim Milakov, Evgeni Krimer
*/
#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_
#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_
#include "dnn.h"
#include <algorithm>
#include <vector>
#include <string>
#include <iostream>
#include "nhwc_batch_norm_kernel.h"
#include "cuda_utils.h"
#include "c10/macros/Macros.h"
#ifdef __HIP_PLATFORM_HCC__
using
bitmask_t
=
uint64_t
;
using
bitmask_pyt_t
=
int64_t
;
#else
using
bitmask_t
=
unsigned
int
;
using
bitmask_pyt_t
=
int32_t
;
#endif
#define VERBOSE_DEFAULT false
class
NhwcBatchNormAddRelu
{
public:
NhwcBatchNormAddRelu
()
{
name_
=
"nhwc_batchnormaddrelu"
;
createTensorDescriptor
(
&
X_tensor_desc_
);
createTensorDescriptor
(
&
Y_tensor_desc_
);
}
~
NhwcBatchNormAddRelu
()
{
destroyTensorDescriptor
(
X_tensor_desc_
);
destroyTensorDescriptor
(
Y_tensor_desc_
);
}
void
die
()
{
std
::
cerr
<<
"batchnormaddrelu not initialized"
<<
std
::
endl
;
exit
(
-
1
);
}
void
fwd
(
cudaStream_t
stream
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
void
dgrad
(
cudaStream_t
stream
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
void
fwdInference
(
cudaStream_t
stream
);
dim3
calc_fwd_grid
(
int
*
loop
,
const
int
grid_dim_x
);
dim3
calc_bwd_grid
(
int
*
loop
,
const
int
grid_dim_x
);
void
setInputDescriptor
(
const
dnnTensorFormat_t
format
,
const
dnnDataType_t
data_type
,
int
n
,
int
c
,
int
h
,
int
w
,
int
bn_group
)
{
m_
=
n
*
h
*
w
;
int
m_bn_adjusted
=
m_
*
bn_group
;
c_
=
c
;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
svar_inv_count_
=
1.
f
/
m_bn_adjusted
;
// factor to scale sum of squared errors to get running variance. Should be 1/(nhw-1).
int
divisor
=
m_bn_adjusted
-
1
;
// nhw == 1 is unlikely, but by setting the rvar_inv_count_ == 1.f, we avoid running var infs.
rvar_inv_count_
=
divisor
==
0
?
1.
f
:
1.
f
/
divisor
;
setTensorDescriptor
(
X_tensor_desc_
,
format
,
data_type
,
n
,
c
,
h
,
w
);
}
void
setOutputDescriptor
(
const
dnnTensorFormat_t
format
,
const
dnnDataType_t
data_type
,
int
n
,
int
c
,
int
h
,
int
w
)
{
setTensorDescriptor
(
Y_tensor_desc_
,
format
,
data_type
,
n
,
c
,
h
,
w
);
}
const
std
::
vector
<
size_t
>
numWorkspaceBytes
()
const
;
void
setWorkspacePointers
(
const
std
::
vector
<
void
*>&
workspace
,
const
std
::
vector
<
size_t
>&
num_workspace_bytes
);
void
setInputOutputPointers
(
void
*
X
,
void
*
dX
,
void
*
Y
,
void
*
dY
,
void
*
addend
,
void
*
dAddend
)
{
X_
=
X
;
dX_
=
dX
;
Y_
=
Y
;
dY_
=
dY
;
addend_
=
addend
;
dAddend_
=
dAddend
;
}
// Sets the pointers for the scale and weight (in that order) data and derivative buffers.
void
setWeightPointers
(
const
std
::
vector
<
void
*>&
weight_pointers
,
const
std
::
vector
<
void
*>&
deriv_pointers
)
{
assert
(
weight_pointers
.
size
()
==
2
);
assert
(
deriv_pointers
.
size
()
==
2
);
scale_
=
static_cast
<
float
*>
(
weight_pointers
[
0
]);
bias_
=
static_cast
<
float
*>
(
weight_pointers
[
1
]);
dscale_
=
static_cast
<
float
*>
(
deriv_pointers
[
0
]);
dbias_
=
static_cast
<
float
*>
(
deriv_pointers
[
1
]);
}
// Sets the pointers for the population mean and variance buffers, in that order.
void
setParameterPointers
(
const
std
::
vector
<
void
*>&
param_pointers
)
{
assert
(
param_pointers
.
size
()
==
2
);
population_mean_
=
static_cast
<
float
*>
(
param_pointers
[
0
]);
population_variance_
=
static_cast
<
float
*>
(
param_pointers
[
1
]);
}
void
setConstants
(
const
double
exp_avg_factor
,
const
double
eps
)
{
exp_avg_factor_
=
exp_avg_factor
;
eps_
=
eps
;
}
void
processCudnnStatus
(
const
dnnStatus_t
&
status
,
const
std
::
string
&
string
=
std
::
string
(),
bool
verbose
=
VERBOSE_DEFAULT
)
{
#ifdef __HIP_PLATFORM_HCC__
if
(
status
!=
DNN_STATUS_SUCCESS
)
LOG
(
FATAL
)
<<
string
<<
" "
<<
miopenGetErrorString
(
status
);
else
if
(
verbose
)
LOG
(
INFO
)
<<
string
<<
" "
<<
miopenGetErrorString
(
status
);
#else
if
(
status
!=
DNN_STATUS_SUCCESS
)
LOG
(
FATAL
)
<<
string
<<
" "
<<
cudnnGetErrorString
(
status
);
else
if
(
verbose
)
LOG
(
INFO
)
<<
string
<<
" "
<<
cudnnGetErrorString
(
status
);
#endif
}
void
checkCudaStatus
(
const
std
::
string
&
string
=
std
::
string
(),
bool
verbose
=
VERBOSE_DEFAULT
)
{
cudaError_t
status
=
cudaGetLastError
();
if
(
status
!=
cudaSuccess
)
LOG
(
FATAL
)
<<
string
<<
" "
<<
cudaGetErrorString
(
status
);
else
if
(
verbose
)
LOG
(
INFO
)
<<
string
<<
" "
<<
cudaGetErrorString
(
status
);
}
size_t
size_retired_ctas
(
int
grid_y
)
const
{
// Note that the value of max_grid_y to handle known GPUs is about 160.
const
int
max_grid_y
=
1024
;
if
(
grid_y
>
max_grid_y
)
LOG
(
INFO
)
<<
"GPU capabilities exceeds assumptions."
;
const
int
retired_cta_bytes
=
max_grid_y
*
2
*
sizeof
(
int
);
// Since the region will be initialized once and used for many kernels,
// the idea is to return an ample size that will cover all uses.
return
retired_cta_bytes
;
}
dnnTensorDescriptor_t
X_tensor_desc_
=
nullptr
;
dnnTensorDescriptor_t
Y_tensor_desc_
=
nullptr
;
void
*
X_
=
nullptr
;
void
*
dX_
=
nullptr
;
void
*
Y_
=
nullptr
;
void
*
dY_
=
nullptr
;
void
*
addend_
=
nullptr
;
void
*
dAddend_
=
nullptr
;
// Learned scale and bias weights.
float
*
scale_
=
nullptr
;
float
*
dscale_
=
nullptr
;
float
*
bias_
=
nullptr
;
float
*
dbias_
=
nullptr
;
// Computed population mean and variance parameters.
float
*
population_mean_
=
nullptr
;
float
*
population_variance_
=
nullptr
;
// Workspace buffers for minibatch mean and variance (computed in fwd, needed by bwd).
float
*
minibatch_mean_
=
nullptr
;
float
*
minibatch_variance_
=
nullptr
;
int
m_
=
0
;
// Number of values per channel that BN is normalizing.
int
c_
=
0
;
// Number of channels over which BN is normalizing.
float
svar_inv_count_
=
0.
f
;
// factor to scale sum of squared errors to get saved variance
float
rvar_inv_count_
=
0.
f
;
// factor to scale sum of squared errors to get running variance
double
exp_avg_factor_
=
0.
;
double
eps_
=
0.
;
std
::
string
name_
;
private:
void
setTensorDescriptor
(
dnnTensorDescriptor_t
descriptor
,
dnnTensorFormat_t
format
,
dnnDataType_t
data_type
,
int
n
,
int
c
,
int
h
,
int
w
)
{
dnnStatus_t
status
=
DNN_STATUS_SUCCESS
;
#ifdef __HIP_PLATFORM_HCC__
status
=
miopenSet4dTensorDescriptor
(
descriptor
,
data_type
,
n
,
c
,
h
,
w
);
#else
status
=
cudnnSetTensor4dDescriptor
(
descriptor
,
format
,
data_type
,
n
,
c
,
h
,
w
);
#endif
processCudnnStatus
(
status
,
"set tensor descriptor"
);
}
void
createTensorDescriptor
(
dnnTensorDescriptor_t
*
descriptor
)
{
dnnStatus_t
status
=
DNN_STATUS_SUCCESS
;
#ifdef __HIP_PLATFORM_HCC__
status
=
miopenCreateTensorDescriptor
(
descriptor
);
#else
status
=
cudnnCreateTensorDescriptor
(
descriptor
);
#endif
processCudnnStatus
(
status
,
"create tensor_descriptor"
);
}
void
destroyTensorDescriptor
(
dnnTensorDescriptor_t
descriptor
)
{
dnnStatus_t
status
=
DNN_STATUS_SUCCESS
;
#ifdef __HIP_PLATFORM_HCC__
status
=
miopenDestroyTensorDescriptor
(
descriptor
);
#else
status
=
cudnnDestroyTensorDescriptor
(
descriptor
);
#endif
processCudnnStatus
(
status
,
"destroy tensor_descriptor"
);
}
protected:
float
*
partial_sums_
=
nullptr
;
int
*
partial_counts_
=
nullptr
;
int
*
retired_ctas_
=
nullptr
;
bitmask_t
*
relu_bitmask_
=
nullptr
;
void
_setFwdParams
(
NhwcBatchNormFwdParams
*
params
)
const
;
void
_setFwdInferenceParams
(
NhwcBatchNormFwdInferenceParams
*
params
)
const
;
void
_setBwdParams
(
NhwcBatchNormBwdParams
*
params
)
const
;
// @todo: ability to configure these?
// Kernel params
static
const
int
USE_ONLINE_APPROACH
=
1
;
static
const
int
THREADS_PER_CTA
=
512
;
static
const
int
THREADS_PER_PIXEL
=
32
;
static
const
int
C_ELEMENTS_PER_CTA
=
128
;
static
const
int
ELEMENTS_PER_LDG
=
C_ELEMENTS_PER_CTA
/
THREADS_PER_PIXEL
;
static
const
int
MAX_SMEM_WITHOUT_OPT_IN
=
48
*
1024
;
typedef
uint16_t
StorageType
;
// increasing this to 6 causes spills in fwd kernel!
static
const
int
PIXELS_PER_THREAD_IN_REGISTERS_FWD
=
1
;
static
const
int
PIXELS_PER_THREAD_IN_REGISTERS_BWD
=
1
;
static
const
int
PIXELS_PER_THREAD_IN_SMEM_FWD
=
0
;
static
const
int
PIXELS_PER_THREAD_IN_SMEM_BWD
=
0
;
static
const
int
PIXELS_PER_THREAD_FWD
=
PIXELS_PER_THREAD_IN_REGISTERS_FWD
+
\
PIXELS_PER_THREAD_IN_SMEM_FWD
;
static
const
int
PIXELS_PER_THREAD_BWD
=
PIXELS_PER_THREAD_IN_REGISTERS_BWD
+
\
PIXELS_PER_THREAD_IN_SMEM_BWD
;
static
const
int
PIXELS_PER_THREAD_FWD_INFERENCE
=
4
;
// Derived params
static
const
size_t
SMEM_SIZE_FWD
=
PIXELS_PER_THREAD_IN_SMEM_FWD
*
THREADS_PER_CTA
*
\
ELEMENTS_PER_LDG
*
sizeof
(
StorageType
);
static
const
size_t
SMEM_SIZE_BWD
=
PIXELS_PER_THREAD_IN_SMEM_BWD
*
THREADS_PER_CTA
*
\
ELEMENTS_PER_LDG
*
2
*
sizeof
(
StorageType
);
static
const
int
PIXELS_PER_LDG
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
;
static
const
int
PIXELS_PER_CTA_FWD
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
*
\
PIXELS_PER_THREAD_FWD
;
static
const
int
PIXELS_PER_CTA_BWD
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
*
\
PIXELS_PER_THREAD_BWD
;
static
const
int
PIXELS_PER_CTA_FWD_INFERENCE
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
*
\
PIXELS_PER_THREAD_FWD_INFERENCE
;
// max grid.y in case of group bn is limited by exchange buffer size
static
const
int
MAX_GBN_BLOCK_Y
=
256
;
// Helper function to launch the forward kernel.
// We calculate (based on smem usage) the achievable occupancy and make sure we run a kernel
// version that was compiled with that occupancy in its launch bounds. This way, we avoid
// needless register spills.
void
_fwdKernelLauncher
(
cudaStream_t
stream
,
NhwcBatchNormFwdParams
params
,
dim3
grid_dim
,
int
outer_loops
,
const
int
occupancy
,
const
bool
coop
)
{
#ifdef __HIP_PLATFORM_HCC__
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
"Nhwc batchnormaddrelu kernel smem too big."; \
auto fwd_func = nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
hipFuncSetAttribute((void *) fwd_func, hipFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(¶ms); \
using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
hipLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_FWD, \
stream); \
} else { \
hipLaunchKernel((void *) fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_FWD, \
stream); \
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
#else
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
"Nhwc batchnormaddrelu kernel smem too big."; \
auto fwd_func = nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(¶ms); \
using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_FWD, \
stream); \
} else { \
cudaLaunchKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_FWD, \
stream); \
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
#endif
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if
(
outer_loops
==
1
)
{
if
(
occupancy
>=
2
)
LAUNCH_FWD_KERNEL
(
1
,
false
,
true
,
2
,
coop
);
else
LAUNCH_FWD_KERNEL
(
1
,
false
,
true
,
1
,
coop
);
}
else
{
if
(
occupancy
>=
2
)
LAUNCH_FWD_KERNEL
(
0
,
false
,
true
,
2
,
coop
);
else
LAUNCH_FWD_KERNEL
(
0
,
false
,
true
,
1
,
coop
);
}
#undef LAUNCH_FWD_KERNEL
}
// Helper function to launch the backward kernel.
void
_bwdKernelLauncher
(
cudaStream_t
stream
,
NhwcBatchNormBwdParams
params
,
dim3
grid_dim
,
int
outer_loops
,
const
int
occupancy
,
const
bool
coop
)
{
#ifdef __HIP_PLATFORM_HCC__
#define LAUNCH_BWD_ADD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
"Nhwc batchnormaddrelu kernel smem too big."; \
auto bwd_add_relu_func = nhwc_batch_norm_bwd_add_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
hipFuncSetAttribute((void *) bwd_add_relu_func, \
hipFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + \
" bwd-add-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(¶ms); \
using BWD_ADD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_add_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
hipLaunchCooperativeKernel<BWD_ADD_RELU_FUNC>(bwd_add_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
hipLaunchKernel((void *) bwd_add_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
¶ms_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd-add-relu coop serial kernel"); \
} while (0)
#else
do
{
\
CHECK
(
SMEM_SIZE_BWD
<=
MAX_SMEM_WITHOUT_OPT_IN
)
<<
\
"Nhwc batchnormaddrelu kernel smem too big."
;
\
auto
bwd_add_relu_func
=
nhwc_batch_norm_bwd_add_relu
<
\
StorageType
,
\
THREADS_PER_CTA
,
\
THREADS_PER_PIXEL
,
\
PIXELS_PER_THREAD_IN_REGISTERS_BWD
,
\
PIXELS_PER_THREAD_IN_SMEM_BWD
,
\
ELEMENTS_PER_LDG
,
\
USE_ONLINE_APPROACH
,
\
OUTER_LOOPS
,
\
COMPILED_FOR_OCCUPANCY
>
;
\
if
(
COMPILED_FOR_OCCUPANCY
>
1
)
{
\
cudaFuncSetAttribute
(
bwd_add_relu_func
,
\
cudaFuncAttributePreferredSharedMemoryCarveout
,
100
);
\
checkCudaStatus
(
name_
+
\
" bwd-add-relu coop serial kernel (cudaFuncSetAttribute carveout)"
);
\
}
\
void
*
params_ptr
=
static_cast
<
void
*>
(
&
params
);
\
using
BWD_ADD_RELU_FUNC
=
decltype
(
nhwc_batch_norm_bwd_add_relu
<
\
StorageType
,
\
THREADS_PER_CTA
,
\
THREADS_PER_PIXEL
,
\
PIXELS_PER_THREAD_IN_REGISTERS_BWD
,
\
PIXELS_PER_THREAD_IN_SMEM_BWD
,
\
ELEMENTS_PER_LDG
,
\
USE_ONLINE_APPROACH
,
\
OUTER_LOOPS
,
\
COMPILED_FOR_OCCUPANCY
>
);
\
if
(
COOP
)
{
\
cudaLaunchCooperativeKernel
<
BWD_ADD_RELU_FUNC
>
(
bwd_add_relu_func
,
\
grid_dim
,
\
THREADS_PER_CTA
,
\
&
params_ptr
,
\
SMEM_SIZE_BWD
,
\
stream
);
\
}
else
{
\
cudaLaunchKernel
<
BWD_ADD_RELU_FUNC
>
(
bwd_add_relu_func
,
\
grid_dim
,
\
THREADS_PER_CTA
,
\
&
params_ptr
,
\
SMEM_SIZE_BWD
,
\
stream
);
\
}
\
checkCudaStatus
(
name_
+
" bwd-add-relu coop serial kernel"
);
\
}
while
(
0
)
#endif
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if
(
outer_loops
==
1
)
{
if
(
occupancy
>=
2
)
LAUNCH_BWD_ADD_RELU_KERNEL
(
1
,
2
,
coop
);
else
LAUNCH_BWD_ADD_RELU_KERNEL
(
1
,
1
,
coop
);
}
else
{
if
(
occupancy
>=
2
)
LAUNCH_BWD_ADD_RELU_KERNEL
(
0
,
2
,
coop
);
else
LAUNCH_BWD_ADD_RELU_KERNEL
(
0
,
1
,
coop
);
}
#undef LAUNCH_BWD_KERNEL
}
public:
// Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.
static
int
smem_driven_fwd_occupancy
(
int
device_id
,
const
int
max_cta_per_sm
)
{
using
namespace
at
::
cuda
::
utils
;
int
fwd_reduction_bytes
=
THREADS_PER_PIXEL
*
(
THREADS_PER_CTA
/
C10_WARP_SIZE
)
*
ELEMENTS_PER_LDG
*
sizeof
(
float
);
int
fwd_smem_bytes
=
SMEM_SIZE_FWD
+
fwd_reduction_bytes
;
int
occupancy
=
MaxSharedMemoryPerMultiprocessor
(
device_id
)
/
fwd_smem_bytes
;
return
std
::
min
(
max_cta_per_sm
,
occupancy
);
}
// Calculate the expected bwd kernel occupancy, as dictated by shared memory usage.
static
int
smem_driven_bwd_occupancy
(
int
device_id
,
const
int
max_cta_per_sm
)
{
using
namespace
at
::
cuda
::
utils
;
int
bwd_reduction_bytes
=
THREADS_PER_PIXEL
*
(
THREADS_PER_CTA
/
C10_WARP_SIZE
)
*
ELEMENTS_PER_LDG
*
sizeof
(
float
);
int
bwd_smem_bytes
=
SMEM_SIZE_BWD
+
bwd_reduction_bytes
;
int
occupancy
=
MaxSharedMemoryPerMultiprocessor
(
device_id
)
/
bwd_smem_bytes
;
return
std
::
min
(
max_cta_per_sm
,
occupancy
);
}
};
const
std
::
vector
<
size_t
>
NhwcBatchNormAddRelu
::
numWorkspaceBytes
()
const
{
assert
(
c_
>
0
);
// choose the max memory required between fwd/bwd passes
int
grid_x_fwd
=
div_up
(
m_
,
PIXELS_PER_CTA_FWD
);
int
grid_x_bwd
=
div_up
(
m_
,
PIXELS_PER_CTA_BWD
);
int
grid_x
=
max
(
grid_x_fwd
,
grid_x_bwd
);
int
grid_y
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
const
size_t
num_mean_bytes
=
c_
*
sizeof
(
float
);
const
size_t
num_variance_bytes
=
num_mean_bytes
;
#ifdef __HIP_PLATFORM_HCC__
int
elems_per_group
=
((
m_
+
3
)
&
~
3
)
*
2
;
#else
int
elems_per_group
=
((
m_
+
31
)
&
~
31
)
*
2
;
#endif
int
group_count
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
const
size_t
bitmask_bytes
=
elems_per_group
*
group_count
*
sizeof
(
bitmask_t
);
const
size_t
size_sums
=
grid_y
*
grid_x
*
THREADS_PER_PIXEL
*
\
ELEMENTS_PER_LDG
*
2
*
sizeof
(
float
);
const
size_t
size_counts
=
grid_y
*
grid_x
*
sizeof
(
int
);
return
{
num_mean_bytes
,
num_variance_bytes
,
bitmask_bytes
,
size_retired_ctas
(
grid_y
),
size_sums
,
size_counts
};
}
void
NhwcBatchNormAddRelu
::
setWorkspacePointers
(
const
std
::
vector
<
void
*>&
workspace
,
const
std
::
vector
<
size_t
>&
num_workspace_bytes
)
{
assert
(
workspace
.
size
()
==
6
);
assert
(
num_workspace_bytes
.
size
()
==
6
);
minibatch_mean_
=
static_cast
<
float
*>
(
workspace
[
0
]);
minibatch_variance_
=
static_cast
<
float
*>
(
workspace
[
1
]);
relu_bitmask_
=
static_cast
<
bitmask_t
*>
(
workspace
[
2
]);
retired_ctas_
=
static_cast
<
int
*>
(
workspace
[
3
]);
partial_sums_
=
static_cast
<
float
*>
(
workspace
[
4
]);
partial_counts_
=
static_cast
<
int
*>
(
workspace
[
5
]);
}
void
NhwcBatchNormAddRelu
::
_setFwdParams
(
NhwcBatchNormFwdParams
*
params
)
const
{
params
->
gmem_src
=
static_cast
<
uint16_t
*>
(
X_
);
params
->
gmem_dst
=
static_cast
<
uint16_t
*>
(
Y_
);
params
->
gmem_src1
=
static_cast
<
uint16_t
*>
(
addend_
);
params
->
gmem_bias
=
bias_
;
params
->
gmem_scale
=
scale_
;
params
->
gmem_running_mean
=
population_mean_
;
params
->
gmem_running_var
=
population_variance_
;
params
->
gmem_saved_mean
=
minibatch_mean_
;
params
->
gmem_saved_var
=
minibatch_variance_
;
params
->
gmem_relu_bitmask
=
relu_bitmask_
;
params
->
nhw
=
m_
;
params
->
c
=
c_
;
params
->
svar_inv_count
=
svar_inv_count_
;
params
->
rvar_inv_count
=
rvar_inv_count_
;
params
->
gmem_sums
=
partial_sums_
;
params
->
gmem_counts
=
partial_counts_
;
params
->
gmem_retired_ctas
=
retired_ctas_
;
params
->
var_eps
=
eps_
;
params
->
outer_loops
=
0
;
params
->
exp_avg_factor
=
static_cast
<
float
>
(
exp_avg_factor_
);
params
->
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
}
void
NhwcBatchNormAddRelu
::
_setFwdInferenceParams
(
NhwcBatchNormFwdInferenceParams
*
params
)
const
{
params
->
gmem_src
=
static_cast
<
uint16_t
*>
(
X_
);
params
->
gmem_dst
=
static_cast
<
uint16_t
*>
(
Y_
);
params
->
gmem_src1
=
static_cast
<
uint16_t
*>
(
addend_
);
params
->
gmem_bias
=
bias_
;
params
->
gmem_scale
=
scale_
;
params
->
gmem_mean
=
population_mean_
;
params
->
gmem_var
=
population_variance_
;
params
->
nhw
=
m_
;
params
->
c
=
c_
;
params
->
var_eps
=
eps_
;
}
void
NhwcBatchNormAddRelu
::
_setBwdParams
(
NhwcBatchNormBwdParams
*
params
)
const
{
params
->
gmem_src
=
static_cast
<
uint16_t
*>
(
X_
);
params
->
gmem_dy
=
static_cast
<
uint16_t
*>
(
dY_
);
params
->
gmem_dst
=
static_cast
<
uint16_t
*>
(
dX_
);
params
->
gmem_dst1
=
static_cast
<
uint16_t
*>
(
dAddend_
);
params
->
gmem_relu_bitmask
=
relu_bitmask_
;
params
->
gmem_dscale
=
dscale_
;
params
->
gmem_dbias
=
dbias_
;
params
->
gmem_scale
=
scale_
;
params
->
gmem_bias
=
bias_
;
params
->
gmem_saved_mean
=
minibatch_mean_
;
params
->
gmem_saved_var
=
minibatch_variance_
;
params
->
nhw
=
m_
;
params
->
c
=
c_
;
params
->
svar_inv_count
=
svar_inv_count_
;
params
->
gmem_sums
=
partial_sums_
;
params
->
gmem_retired_ctas
=
retired_ctas_
;
params
->
outer_loops
=
0
;
params
->
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
}
void
NhwcBatchNormAddRelu
::
fwdInference
(
cudaStream_t
stream
)
{
bool
ptrs_are_set
=
X_tensor_desc_
!=
nullptr
&&
Y_tensor_desc_
!=
nullptr
&&
scale_
!=
nullptr
&&
bias_
!=
nullptr
// && minibatch_mean_ != nullptr
// && minibatch_variance_ != nullptr
&&
population_mean_
!=
nullptr
&&
population_variance_
!=
nullptr
&&
X_
!=
nullptr
// && dX_ != nullptr
&&
Y_
!=
nullptr
&&
addend_
!=
nullptr
// && dY_ != nullptr
// && dscale_ != nullptr
// && dbias_ != nullptr
&&
partial_sums_
!=
nullptr
&&
partial_counts_
!=
nullptr
;
if
(
!
ptrs_are_set
)
die
();
dim3
grid_dim
;
grid_dim
.
x
=
div_up
(
m_
,
PIXELS_PER_CTA_FWD_INFERENCE
);
grid_dim
.
y
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
// @todo: maybe just move this inside initialize routine?
NhwcBatchNormFwdInferenceParams
params
;
_setFwdInferenceParams
(
&
params
);
nhwc_batch_norm_fwd_inference
<
StorageType
,
THREADS_PER_CTA
,
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
,
false
,
true
>
<<<
grid_dim
,
THREADS_PER_CTA
,
0
,
stream
>>>
(
params
);
checkCudaStatus
(
name_
+
" fwd_inference-relu kernel"
);
}
dim3
NhwcBatchNormAddRelu
::
calc_fwd_grid
(
int
*
loop
,
const
int
grid_dim_x
)
{
dim3
grid_dim
;
grid_dim
.
x
=
div_up
(
m_
,
PIXELS_PER_CTA_FWD
);
int
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
unsigned
int
max_grid_x
=
grid_dim_x
;
if
(
grid_dim
.
x
<=
max_grid_x
)
{
*
loop
=
1
;
if
(
max_grid_x
/
grid_dim
.
x
>
1
)
{
grid_dim
.
y
=
std
::
min
(
c_blks
,
static_cast
<
int
>
(
max_grid_x
/
grid_dim
.
x
));
assert
(
grid_dim
.
y
<
MAX_GBN_BLOCK_Y
);
//FIXME: turn into a loop
}
else
{
grid_dim
.
y
=
1
;
}
}
else
{
grid_dim
.
x
=
max_grid_x
;
grid_dim
.
y
=
1
;
int
nhw_in_regs
=
m_
-
PIXELS_PER_THREAD_IN_SMEM_FWD
*
PIXELS_PER_LDG
*
grid_dim
.
x
;
int
pixels_per_iteration
=
PIXELS_PER_THREAD_IN_REGISTERS_FWD
*
PIXELS_PER_LDG
*
grid_dim
.
x
;
*
loop
=
div_up
(
nhw_in_regs
,
pixels_per_iteration
);
}
return
grid_dim
;
}
dim3
NhwcBatchNormAddRelu
::
calc_bwd_grid
(
int
*
loop
,
const
int
grid_dim_x
)
{
dim3
grid_dim
;
grid_dim
.
x
=
div_up
(
m_
,
PIXELS_PER_CTA_BWD
);
int
c_blks
=
div_up
(
c_
,
C_ELEMENTS_PER_CTA
);
unsigned
int
max_grid_x
=
grid_dim_x
;
if
(
grid_dim
.
x
<=
max_grid_x
)
{
*
loop
=
1
;
if
(
max_grid_x
/
grid_dim
.
x
>
1
)
{
grid_dim
.
y
=
std
::
min
(
c_blks
,
static_cast
<
int
>
(
max_grid_x
/
grid_dim
.
x
));
assert
(
grid_dim
.
y
<
MAX_GBN_BLOCK_Y
);
//FIXME: turn into a loop
}
else
{
grid_dim
.
y
=
1
;
}
}
else
{
grid_dim
.
x
=
max_grid_x
;
grid_dim
.
y
=
1
;
int
nhw_in_regs
=
m_
-
PIXELS_PER_THREAD_IN_SMEM_BWD
*
PIXELS_PER_LDG
*
grid_dim
.
x
;
int
pixels_per_iteration
=
PIXELS_PER_THREAD_IN_REGISTERS_BWD
*
PIXELS_PER_LDG
*
grid_dim
.
x
;
*
loop
=
div_up
(
nhw_in_regs
,
pixels_per_iteration
);
}
return
grid_dim
;
}
void
NhwcBatchNormAddRelu
::
fwd
(
cudaStream_t
stream
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
bool
ptrs_are_set
=
X_tensor_desc_
!=
nullptr
&&
Y_tensor_desc_
!=
nullptr
&&
scale_
!=
nullptr
&&
bias_
!=
nullptr
&&
minibatch_mean_
!=
nullptr
&&
minibatch_variance_
!=
nullptr
&&
relu_bitmask_
!=
nullptr
&&
population_mean_
!=
nullptr
&&
population_variance_
!=
nullptr
&&
X_
!=
nullptr
// && dX_ != nullptr
&&
Y_
!=
nullptr
&&
addend_
!=
nullptr
// && dY_ != nullptr
// && dscale_ != nullptr
// && dbias_ != nullptr
&&
partial_sums_
!=
nullptr
&&
partial_counts_
!=
nullptr
&&
retired_ctas_
!=
nullptr
;
if
(
!
ptrs_are_set
)
die
();
// reset of retired_cta_count no longer needed
NhwcBatchNormFwdParams
params
;
_setFwdParams
(
&
params
);
params
.
my_data
=
my_data
;
params
.
pair_datas
[
0
]
=
pair_data
;
params
.
pair_datas
[
1
]
=
pair_data2
;
params
.
pair_datas
[
2
]
=
pair_data3
;
params
.
magic
=
magic
;
params
.
sync_iters
=
(
bn_group
==
8
)
?
3
:
(
bn_group
>>
1
);
dim3
grid_dim
=
calc_fwd_grid
(
&
params
.
outer_loops
,
grid_dim_x
);
_fwdKernelLauncher
(
stream
,
params
,
grid_dim
,
params
.
outer_loops
,
occupancy
,
coop
);
}
void
NhwcBatchNormAddRelu
::
dgrad
(
cudaStream_t
stream
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
int
magic
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
)
{
bool
ptrs_are_set
=
X_tensor_desc_
!=
nullptr
&&
Y_tensor_desc_
!=
nullptr
&&
scale_
!=
nullptr
&&
bias_
!=
nullptr
&&
minibatch_mean_
!=
nullptr
&&
minibatch_variance_
!=
nullptr
&&
relu_bitmask_
!=
nullptr
// && population_mean_ != nullptr
// && population_variance_ != nullptr
&&
X_
!=
nullptr
&&
dX_
!=
nullptr
// && Y_ != nullptr
&&
dY_
!=
nullptr
&&
dAddend_
!=
nullptr
&&
dscale_
!=
nullptr
&&
dbias_
!=
nullptr
&&
retired_ctas_
!=
nullptr
;
if
(
!
ptrs_are_set
)
die
();
// reset of retired_cta_count no longer needed
NhwcBatchNormBwdParams
params
;
_setBwdParams
(
&
params
);
params
.
my_data
=
my_data
;
params
.
pair_datas
[
0
]
=
pair_data
;
params
.
pair_datas
[
1
]
=
pair_data2
;
params
.
pair_datas
[
2
]
=
pair_data3
;
params
.
magic
=
magic
;
params
.
sync_iters
=
(
bn_group
==
8
)
?
3
:
(
bn_group
>>
1
);
params
.
wgrad_coeff
=
1.0
/
bn_group
;
dim3
grid_dim
=
calc_bwd_grid
(
&
params
.
outer_loops
,
grid_dim_x
);
_bwdKernelLauncher
(
stream
,
params
,
grid_dim
,
params
.
outer_loops
,
occupancy
,
coop
);
}
#endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_
apex/contrib/csrc/groupbn/cuda_utils.h
deleted
100644 → 0
View file @
2a4864d5
#ifdef __HIP_PLATFORM_HCC__
#include <ATen/hip/HIPContext.h>
#else
#include <ATen/cuda/CUDAContext.h>
#endif
#ifndef CUDA_UTILS_H
#define CUDA_UTILS_H
namespace
at
{
namespace
cuda
{
namespace
utils
{
static
inline
int
MaxSharedMemoryPerMultiprocessor
(
int
device_id
)
{
#ifdef __HIP_PLATFORM_HCC__
return
getDeviceProperties
(
device_id
)
->
maxSharedMemoryPerMultiProcessor
;
#else
return
getDeviceProperties
(
device_id
)
->
sharedMemPerMultiprocessor
;
#endif
}
}
}
}
#endif
apex/contrib/csrc/groupbn/dnn.h
deleted
100644 → 0
View file @
2a4864d5
#ifndef DNN_H
#define DNN_H
#ifdef __HIP_PLATFORM_HCC__
#include <miopen/miopen.h>
#define DNN_STATUS_SUCCESS miopenStatusSuccess
#define DNN_DATA_HALF miopenHalf
#define DNN_TENSOR_FORMAT 0
using
dnnTensorFormat_t
=
int
;
using
dnnDataType_t
=
miopenDataType_t
;
using
dnnStatus_t
=
miopenStatus_t
;
using
dnnTensorDescriptor_t
=
miopenTensorDescriptor_t
;
#else
#include <cudnn.h>
#define DNN_STATUS_SUCCESS CUDNN_STATUS_SUCCESS
#define DNN_DATA_HALF CUDNN_DATA_HALF
#define DNN_TENSOR_FORMAT CUDNN_TENSOR_NHWC
using
dnnTensorFormat_t
=
cudnnTensorFormat_t
;
using
dnnDataType_t
=
cudnnDataType_t
;
using
dnnStatus_t
=
cudnnStatus_t
;
using
dnnTensorDescriptor_t
=
cudnnTensorDescriptor_t
;
#endif
#endif // DNN_H
apex/contrib/csrc/groupbn/interface.cpp
deleted
100644 → 0
View file @
2a4864d5
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/ArrayRef.h>
#include <ATen/ScalarType.h>
#include "ATen/Scalar.h"
#ifndef VERSION_GE_1_1
#include "ATen/Type.h"
#endif
#include "ATen/Tensor.h"
#include "ATen/Storage.h"
#include "ATen/Generator.h"
namespace
py
=
pybind11
;
int64_t
get_buffer_size
(
const
int
bn_sync_steps
);
void
*
get_data_ptr
(
const
at
::
Tensor
&
data
);
void
*
get_remote_data_ptr
(
const
at
::
Tensor
&
handle
,
const
int64_t
offset
);
void
close_remote_data
(
const
at
::
Tensor
&
handle
);
at
::
Tensor
nhwc_bn_fwd_train
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
const
bool
fuse_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
at
::
Tensor
nhwc_bn_fwd_eval
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
int
bn_group
,
const
float
momentum
,
const
float
epsilon
,
const
bool
fuse_relu
);
std
::
vector
<
at
::
Tensor
>
nhwc_bn_bwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
dy
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
const
bool
fuse_relu
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
at
::
Tensor
nhwc_bn_addrelu_fwd_train
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
z
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
bitmask
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
at
::
Tensor
nhwc_bn_addrelu_fwd_eval
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
z
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
ret_cta
,
const
int
bn_group
,
const
float
momentum
,
const
float
epsilon
);
std
::
vector
<
at
::
Tensor
>
nhwc_bn_addrelu_bwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
dy
,
const
at
::
Tensor
&
scale
,
const
at
::
Tensor
&
bias
,
const
at
::
Tensor
&
running_mean
,
const
at
::
Tensor
&
running_inv_var
,
const
at
::
Tensor
&
minibatch_mean
,
const
at
::
Tensor
&
minibatch_inv_var
,
const
at
::
Tensor
&
bitmask
,
const
at
::
Tensor
&
ret_cta
,
const
float
momentum
,
const
float
epsilon
,
void
*
my_data
,
void
*
pair_data
,
void
*
pair_data2
,
void
*
pair_data3
,
const
int
bn_group
,
const
at
::
Tensor
&
magic_tensor
,
const
int
occupancy
,
const
int
grid_dim_x
,
const
bool
coop
);
int
nhwc_bn_fwd_occupancy
();
int
nhwc_bn_bwd_occupancy
();
int
nhwc_bn_addrelu_fwd_occupancy
();
int
nhwc_bn_addrelu_bwd_occupancy
();
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"get_buffer_size"
,
&
get_buffer_size
,
"get_buffer_size"
);
m
.
def
(
"get_data_ptr"
,
&
get_data_ptr
,
"get_data_ptr"
);
m
.
def
(
"get_remote_data_ptr"
,
&
get_remote_data_ptr
,
"get_remote_data_ptr"
);
m
.
def
(
"close_remote_data"
,
&
close_remote_data
,
"close_remote_data"
);
m
.
def
(
"bn_fwd_nhwc"
,
&
nhwc_bn_fwd_train
,
"bn_fwd_nhwc"
);
m
.
def
(
"bn_fwd_eval_nhwc"
,
&
nhwc_bn_fwd_eval
,
"bn_fwd_eval_nhwc"
);
m
.
def
(
"bn_bwd_nhwc"
,
&
nhwc_bn_bwd
,
"bn_bwd_nhwc"
);
m
.
def
(
"bn_fwd_nhwc_occupancy"
,
&
nhwc_bn_fwd_occupancy
,
"bn_fwd_nhwc_occupancy"
);
m
.
def
(
"bn_bwd_nhwc_occupancy"
,
&
nhwc_bn_bwd_occupancy
,
"bn_bwd_nhwc_occupancy"
);
m
.
def
(
"bn_addrelu_fwd_nhwc"
,
&
nhwc_bn_addrelu_fwd_train
,
"bn_addrelu_fwd_nhwc"
);
m
.
def
(
"bn_addrelu_fwd_eval_nhwc"
,
&
nhwc_bn_addrelu_fwd_eval
,
"bn_addrelu_fwd_eval_nhwc"
);
m
.
def
(
"bn_addrelu_bwd_nhwc"
,
&
nhwc_bn_addrelu_bwd
,
"bn_addrelu_bwd_nhwc"
);
m
.
def
(
"bn_addrelu_fwd_nhwc_occupancy"
,
&
nhwc_bn_addrelu_fwd_occupancy
,
"bn_addrelu_fwd_nhwc_occupancy"
);
m
.
def
(
"bn_addrelu_bwd_nhwc_occupancy"
,
&
nhwc_bn_addrelu_bwd_occupancy
,
"bn_addrelu_bwd_nhwc_occupancy"
);
}
apex/contrib/csrc/groupbn/ipc.cu
deleted
100644 → 0
View file @
2a4864d5
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include "compat.h"
#define cudaCheckErrors(msg) \
do { \
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
msg, cudaGetErrorString(__err), \
__FILE__, __LINE__); \
fprintf(stderr, "*** FAILED - ABORTING\n"); \
exit(1); \
} \
} while (0)
template
<
>
struct
std
::
hash
<
cudaIpcMemHandle_t
>
{
size_t
operator
()
(
const
cudaIpcMemHandle_t
&
handle
)
const
{
size_t
hash
=
0
;
uint8_t
*
ptr
=
(
uint8_t
*
)
&
handle
;
assert
(
sizeof
(
uint8_t
)
==
1
);
for
(
int
i
=
0
;
i
<
sizeof
(
cudaIpcMemHandle_t
);
i
++
)
{
hash
+=
*
ptr
;
ptr
++
;
}
return
hash
;
}
};
template
<
>
struct
std
::
equal_to
<
cudaIpcMemHandle_t
>
{
bool
operator
()
(
const
cudaIpcMemHandle_t
&
lhs
,
const
cudaIpcMemHandle_t
&
rhs
)
const
{
return
(
std
::
memcmp
((
void
*
)
&
lhs
,
(
void
*
)
&
rhs
,
sizeof
(
cudaIpcMemHandle_t
))
==
0
);
}
};
namespace
{
namespace
gpuipc
{
//from: src/operator/nn/cudnn/nhwc_batch_norm_kernel.h
// The number of threads per pixel.
const
int
THREADS_PER_PIXEL
=
16
;
// The number of elements per ldg.
const
int
ELEMENTS_PER_LDG
=
4
;
// The number of reducing ops, each uses its own space : mean, var, dscale, dbias
const
int
REDUCE_OPS
=
4
;
// Maximum block.y supported - limited due to buffer allocation
const
int
MAX_BLOCK_Y
=
256
;
const
int
MAX_OFFSET
=
REDUCE_OPS
*
MAX_BLOCK_Y
;
const
int
BYTES_PER_ELEM
=
4
;
// Buffer size per sync step
const
int
SINGLE_SYNC_BUFFER_BYTES
=
MAX_OFFSET
*
THREADS_PER_PIXEL
*
2
*
ELEMENTS_PER_LDG
*
BYTES_PER_ELEM
;
};
class
IpcMemHandleRegistry
{
public:
void
*
getPtr
(
const
cudaIpcMemHandle_t
&
handle
,
int64_t
offset
)
{
if
(
registry_
.
count
(
handle
)
==
0
)
{
registry_
.
insert
(
std
::
make_pair
(
handle
,
RegistryEntry
()));
registry_
[
handle
].
dev_ptr
=
ipcOpenMem
(
handle
);
}
registry_
[
handle
].
ref_count
++
;
return
(((
uint8_t
*
)
registry_
[
handle
].
dev_ptr
)
+
offset
);
}
void
releasePtr
(
const
cudaIpcMemHandle_t
&
handle
)
{
if
(
registry_
.
count
(
handle
)
==
0
)
{
}
if
(
--
registry_
[
handle
].
ref_count
==
0
)
{
ipcCloseMem
(
registry_
[
handle
].
dev_ptr
);
registry_
.
erase
(
handle
);
}
}
struct
RegistryEntry
{
void
*
dev_ptr
;
int
ref_count
;
RegistryEntry
()
:
dev_ptr
(
NULL
)
,
ref_count
(
0
)
{}
};
protected:
std
::
unordered_map
<
cudaIpcMemHandle_t
,
RegistryEntry
>
registry_
;
void
*
ipcOpenMem
(
const
cudaIpcMemHandle_t
&
handle
)
{
void
*
data
;
cudaIpcOpenMemHandle
(
&
data
,
handle
,
cudaIpcMemLazyEnablePeerAccess
);
cudaCheckErrors
(
"ipc init"
);
return
data
;
}
void
ipcCloseMem
(
void
*
dev_ptr
)
{
cudaIpcCloseMemHandle
(
dev_ptr
);
cudaCheckErrors
(
"ipc close"
);
}
};
}
static
IpcMemHandleRegistry
ipc_mem_registry
;
int64_t
get_buffer_size
(
const
int
bn_sync_steps
)
{
return
bn_sync_steps
*
gpuipc
::
SINGLE_SYNC_BUFFER_BYTES
;
}
void
*
get_remote_data_ptr
(
const
at
::
Tensor
&
handle
,
const
int64_t
offset
)
{
cudaIpcMemHandle_t
my_handle
;
memcpy
((
unsigned
char
*
)(
&
my_handle
),
handle
.
DATA_PTR
<
uint8_t
>
(),
sizeof
(
my_handle
));
return
ipc_mem_registry
.
getPtr
(
my_handle
,
offset
);
}
void
close_remote_data
(
const
at
::
Tensor
&
handle
)
{
cudaIpcMemHandle_t
my_handle
;
memcpy
((
unsigned
char
*
)(
&
my_handle
),
handle
.
DATA_PTR
<
uint8_t
>
(),
sizeof
(
my_handle
));
ipc_mem_registry
.
releasePtr
(
my_handle
);
}
void
*
get_data_ptr
(
const
at
::
Tensor
&
data
)
{
return
data
.
DATA_PTR
<
uint8_t
>
();
}
apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h
deleted
100644 → 0
View file @
2a4864d5
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file nhwc_batch_norm_kernel.h
* \brief CUDA NHWC Batch Normalization code
* \author Shankara Rao Thejaswi Nanditale, Dick Carter, Maxim Milakov, Evgeni Krimer
*/
#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_
#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_
#ifdef __HIP_PLATFORM_HCC__
#include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h>
#include <hip/hip_fp16.h>
#endif
#include <stdint.h>
#include <algorithm>
#ifdef __HIP_PLATFORM_HCC__
using
bitmask_t
=
uint64_t
;
#define BITMASK_OFFSET 2
#define ONE_BITMASK 1UL
#else
using
bitmask_t
=
unsigned
int
;
#define BITMASK_OFFSET 2
#define ONE_BITMASK 1U
#endif
#define DEVICE_FUNCTION static inline __device__
// CTA margin used by cooperative launch. Can be overridden by env var NHWC_BATCHNORM_LAUNCH_MARGIN.
#define NHWC_BATCHNORM_LAUNCH_MARGIN_MIN 3
#define NHWC_BATCHNORM_LAUNCH_MARGIN_DEFAULT NHWC_BATCHNORM_LAUNCH_MARGIN_MIN
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
syncwarp
()
{
#ifdef __HIP_PLATFORM_HCC__
__builtin_amdgcn_wave_barrier
();
#else
__syncwarp
();
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
DEVICE_FUNCTION
T
shfl_sync
(
T
var
,
int
src_lane
)
{
#ifdef __HIP_PLATFORM_HCC__
return
__shfl
(
var
,
src_lane
);
#else
return
__shfl_sync
(
0xFFFFFFFFU
,
var
,
src_lane
);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
bitmask_t
ballot
(
int
predicate
)
{
#ifdef __HIP_PLATFORM_HCC__
return
__ballot
(
predicate
);
#else
return
__ballot_sync
(
0xFFFFFFFFU
,
predicate
);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
ELEMENTS_PER_LDG
>
struct
PackedStorage
{
enum
{
PACKED_ELEMENTS_PER_LDG
=
ELEMENTS_PER_LDG
};
typedef
T
Type
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
ELEMENTS_PER_LDG
>
struct
PackedStorage
<
uint16_t
,
ELEMENTS_PER_LDG
>
{
enum
{
PACKED_ELEMENTS_PER_LDG
=
ELEMENTS_PER_LDG
/
2
};
typedef
int
Type
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
from_float
(
int
(
&
dst
)[
N
],
const
float
(
&
src
)[
2
*
N
])
{
// Convert from two f32s to two f16s (mantissa LSB rounds to nearest even)
// (From 64-bit to 32-bit)
half
*
dst_
=
(
half
*
)
dst
;
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
#ifdef __HIP_PLATFORM_HCC__
dst_
[
2
*
i
]
=
__float2half
(
src
[
2
*
i
]);
dst_
[
2
*
i
+
1
]
=
__float2half
(
src
[
2
*
i
+
1
]);
#else
uint16_t
lo
,
hi
;
asm
volatile
(
"cvt.rn.f16.f32 %0, %1;"
:
"=h"
(
lo
)
:
"f"
(
src
[
2
*
i
+
0
]));
asm
volatile
(
"cvt.rn.f16.f32 %0, %1;"
:
"=h"
(
hi
)
:
"f"
(
src
[
2
*
i
+
1
]));
asm
volatile
(
"mov.b32 %0, {%1, %2};"
:
"=r"
(
dst
[
i
])
:
"h"
(
lo
),
"h"
(
hi
));
#endif
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
from_float
(
float
(
&
dst
)[
N
],
const
float
(
&
src
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
dst
[
i
]
=
src
[
i
];
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
to_float
(
float
(
&
dst
)[
2
*
N
],
int
(
&
src
)[
N
])
{
// Convert from two f16s to two f32s (From 32-bit to 64-bit)
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
#ifdef __HIP_PLATFORM_HCC__
half
*
src_
=
(
half
*
)
src
;
dst
[
2
*
i
]
=
__half2float
(
src_
[
2
*
i
]);
dst
[
2
*
i
+
1
]
=
__half2float
(
src_
[
2
*
i
+
1
]);
#else
uint16_t
lo
,
hi
;
asm
volatile
(
"mov.b32 {%0, %1}, %2;"
:
"=h"
(
lo
),
"=h"
(
hi
)
:
"r"
(
src
[
i
]));
asm
volatile
(
"cvt.f32.f16 %0, %1;"
:
"=f"
(
dst
[
2
*
i
+
0
])
:
"h"
(
lo
));
asm
volatile
(
"cvt.f32.f16 %0, %1;"
:
"=f"
(
dst
[
2
*
i
+
1
])
:
"h"
(
hi
));
#endif
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
to_float
(
float
(
&
dst
)[
N
],
float
(
&
src
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
dst
[
i
]
=
src
[
i
];
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
ldg
(
int
(
&
dst
)[
1
],
const
uint16_t
*
gmem
)
{
dst
[
0
]
=
__ldg
((
const
int
*
)
gmem
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
ldg_stream
(
int
(
&
dst
)[
1
],
const
uint16_t
*
gmem
)
{
#ifdef __HIP_PLATFORM_HCC__
dst
[
0
]
=
__ldg
((
const
int
*
)
gmem
);
#else
unsigned
int
tmp
;
asm
volatile
(
"ld.global.cs.nc.s32 %0, [%1];"
:
"=r"
(
tmp
)
:
"l"
((
const
uint
*
)
gmem
));
dst
[
0
]
=
tmp
;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
ldg
(
int
(
&
dst
)[
2
],
const
uint16_t
*
gmem
)
{
int2
tmp
=
__ldg
((
const
int2
*
)
gmem
);
dst
[
0
]
=
tmp
.
x
;
dst
[
1
]
=
tmp
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
ldg_stream
(
int
(
&
dst
)[
2
],
const
uint16_t
*
gmem
)
{
#ifdef __HIP_PLATFORM_HCC__
int2
tmp
=
__ldg
((
const
int2
*
)
gmem
);
dst
[
0
]
=
tmp
.
x
;
dst
[
1
]
=
tmp
.
y
;
#else
int2
tmp
;
asm
volatile
(
"ld.global.cs.nc.v2.s32 {%0,%1}, [%2];"
:
"=r"
(
tmp
.
x
),
"=r"
(
tmp
.
y
)
:
"l"
((
const
int2
*
)
gmem
));
dst
[
0
]
=
tmp
.
x
;
dst
[
1
]
=
tmp
.
y
;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
ldg
(
float
(
&
dst
)[
N
],
const
uint16_t
*
gmem
)
{
int
tmp
[
N
/
2
];
ldg
(
tmp
,
gmem
);
to_float
(
dst
,
tmp
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
ldg_stream
(
float
(
&
dst
)[
N
],
const
uint16_t
*
gmem
)
{
int
tmp
[
N
/
2
];
ldg_stream
(
tmp
,
gmem
);
to_float
(
dst
,
tmp
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
stg
(
uint16_t
*
gmem
,
int
(
&
src
)[
1
])
{
reinterpret_cast
<
int
*>
(
gmem
)[
0
]
=
src
[
0
];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
stg_stream
(
uint16_t
*
gmem
,
int
(
&
src
)[
1
])
{
#ifdef __HIP_PLATFORM_HCC__
reinterpret_cast
<
int
*>
(
gmem
)[
0
]
=
src
[
0
];
#else
unsigned
int
tmp
=
src
[
0
];
asm
volatile
(
"st.global.cs.s32 [%0], %1;"
::
"l"
((
uint
*
)
gmem
)
,
"r"
(
tmp
));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
stg
(
uint16_t
*
gmem
,
int
(
&
src
)[
2
])
{
#ifdef __HIP_PLATFORM_HCC__
half
*
gmem_
=
(
half
*
)
gmem
;
half
*
src_
=
(
half
*
)
src
;
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
gmem_
[
i
]
=
src_
[
i
];
}
#else
reinterpret_cast
<
int2
*>
(
gmem
)[
0
]
=
make_int2
(
src
[
0
],
src
[
1
]);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
stg_stream
(
uint16_t
*
gmem
,
int
(
&
src
)[
2
])
{
#ifdef __HIP_PLATFORM_HCC__
half
*
gmem_
=
(
half
*
)
gmem
;
half
*
src_
=
(
half
*
)
src
;
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
gmem_
[
i
]
=
src_
[
i
];
}
#else
asm
volatile
(
"st.global.cs.v2.s32 [%0], {%1,%2};"
::
"l"
((
uint
*
)
gmem
)
,
"r"
(
src
[
0
]),
"r"
(
src
[
1
]));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
stg
(
uint16_t
*
gmem
,
float
(
&
src
)[
N
])
{
int
tmp
[
N
/
2
];
from_float
(
tmp
,
src
);
stg
(
gmem
,
tmp
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
stg_stream
(
uint16_t
*
gmem
,
float
(
&
src
)[
N
])
{
int
tmp
[
N
/
2
];
from_float
(
tmp
,
src
);
stg_stream
(
gmem
,
tmp
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef __HIP_PLATFORM_HCC__
DEVICE_FUNCTION
void
stg
(
uint16_t
*
gmem
,
float
(
&
src
)[
4
])
{
half
*
gmem_
=
(
half
*
)
gmem
;
gmem_
[
0
]
=
__float2half
(
src
[
0
]);
gmem_
[
1
]
=
__float2half
(
src
[
1
]);
gmem_
[
2
]
=
__float2half
(
src
[
2
]);
gmem_
[
3
]
=
__float2half
(
src
[
3
]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
stg_stream
(
uint16_t
*
gmem
,
float
(
&
src
)[
4
])
{
half
*
gmem_
=
(
half
*
)
gmem
;
gmem_
[
0
]
=
__float2half
(
src
[
0
]);
gmem_
[
1
]
=
__float2half
(
src
[
1
]);
gmem_
[
2
]
=
__float2half
(
src
[
2
]);
gmem_
[
3
]
=
__float2half
(
src
[
3
]);
}
#endif
DEVICE_FUNCTION
void
read_from_gmem
(
float
(
&
dst
)[
2
],
const
float
*
gmem
,
int
idx
)
{
#ifdef __HIP_PLATFORM_HCC__
dst
[
0
]
=
gmem
[
2
*
idx
];
dst
[
1
]
=
gmem
[
2
*
idx
+
1
];
#else
float2
tmp
=
__ldg
(
reinterpret_cast
<
const
float2
*>
(
&
gmem
[
2
*
idx
]));
dst
[
0
]
=
tmp
.
x
;
dst
[
1
]
=
tmp
.
y
;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
read_from_gmem
(
float
(
&
dst
)[
4
],
const
float
*
gmem
,
int
idx
)
{
#ifdef __HIP_PLATFORM_HCC__
dst
[
0
]
=
gmem
[
4
*
idx
];
dst
[
1
]
=
gmem
[
4
*
idx
+
1
];
dst
[
2
]
=
gmem
[
4
*
idx
+
2
];
dst
[
3
]
=
gmem
[
4
*
idx
+
3
];
#else
float4
tmp
=
__ldg
(
reinterpret_cast
<
const
float4
*>
(
&
gmem
[
4
*
idx
]));
dst
[
0
]
=
tmp
.
x
;
dst
[
1
]
=
tmp
.
y
;
dst
[
2
]
=
tmp
.
z
;
dst
[
3
]
=
tmp
.
w
;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
read_from_smem
(
float
(
&
x
)[
2
],
const
float
*
smem
,
int
idx
)
{
#ifdef __HIP_PLATFORM_HCC__
x
[
0
]
=
smem
[
2
*
idx
];
x
[
1
]
=
smem
[
2
*
idx
+
1
];
#else
float2
tmp
=
*
(
const
float2
*
)
&
smem
[
2
*
idx
];
x
[
0
]
=
tmp
.
x
;
x
[
1
]
=
tmp
.
y
;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
read_from_smem
(
int
(
&
x
)[
1
],
const
int
*
smem
,
int
idx
)
{
x
[
0
]
=
smem
[
idx
];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
read_from_smem
(
float
(
&
x
)[
4
],
const
float
*
smem
,
int
idx
)
{
#ifdef __HIP_PLATFORM_HCC__
x
[
0
]
=
smem
[
4
*
idx
];
x
[
1
]
=
smem
[
4
*
idx
+
1
];
x
[
2
]
=
smem
[
4
*
idx
+
2
];
x
[
3
]
=
smem
[
4
*
idx
+
3
];
#else
float4
tmp
=
*
(
const
float4
*
)
&
smem
[
4
*
idx
];
x
[
0
]
=
tmp
.
x
;
x
[
1
]
=
tmp
.
y
;
x
[
2
]
=
tmp
.
z
;
x
[
3
]
=
tmp
.
w
;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
read_from_smem
(
int
(
&
x
)[
2
],
const
int
*
smem
,
int
idx
)
{
#ifdef __HIP_PLATFORM_HCC__
x
[
0
]
=
smem
[
2
*
idx
];
x
[
1
]
=
smem
[
2
*
idx
+
1
];
#else
int2
tmp
=
*
(
const
int2
*
)
&
smem
[
2
*
idx
];
x
[
0
]
=
tmp
.
x
;
x
[
1
]
=
tmp
.
y
;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
write_to_gmem
(
float
*
gmem
,
int
idx
,
const
float
(
&
src
)[
2
])
{
#ifdef __HIP_PLATFORM_HCC__
gmem
[
2
*
idx
]
=
src
[
0
];
gmem
[
2
*
idx
+
1
]
=
src
[
1
];
#else
reinterpret_cast
<
float2
*>
(
&
gmem
[
2
*
idx
])[
0
]
=
make_float2
(
src
[
0
],
src
[
1
]);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
write_to_gmem
(
float
*
gmem
,
int
idx
,
const
float
(
&
src
)[
4
])
{
#ifdef __HIP_PLATFORM_HCC__
gmem
[
4
*
idx
]
=
src
[
0
];
gmem
[
4
*
idx
+
1
]
=
src
[
1
];
gmem
[
4
*
idx
+
2
]
=
src
[
2
];
gmem
[
4
*
idx
+
3
]
=
src
[
3
];
#else
reinterpret_cast
<
float4
*>
(
&
gmem
[
4
*
idx
])[
0
]
=
make_float4
(
src
[
0
],
src
[
1
],
src
[
2
],
src
[
3
]);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
scaled_write_to_gmem
(
float
*
gmem
,
int
idx
,
const
float
(
&
src
)[
4
],
const
float
coeff
)
{
#ifdef __HIP_PLATFORM_HCC__
gmem
[
4
*
idx
]
=
src
[
0
]
*
coeff
;
gmem
[
4
*
idx
+
1
]
=
src
[
1
]
*
coeff
;
gmem
[
4
*
idx
+
2
]
=
src
[
2
]
*
coeff
;
gmem
[
4
*
idx
+
3
]
=
src
[
3
]
*
coeff
;
#else
reinterpret_cast
<
float4
*>
(
&
gmem
[
4
*
idx
])[
0
]
=
make_float4
(
src
[
0
]
*
coeff
,
src
[
1
]
*
coeff
,
src
[
2
]
*
coeff
,
src
[
3
]
*
coeff
);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
write_to_smem
(
float
*
smem
,
int
idx
,
const
float
(
&
x
)[
2
])
{
#ifdef __HIP_PLATFORM_HCC__
smem
[
2
*
idx
]
=
x
[
0
];
smem
[
2
*
idx
+
1
]
=
x
[
1
];
#else
reinterpret_cast
<
float2
*>
(
&
smem
[
2
*
idx
])[
0
]
=
make_float2
(
x
[
0
],
x
[
1
]);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
write_to_smem
(
int
*
smem
,
int
idx
,
const
int
(
&
x
)[
1
])
{
smem
[
idx
]
=
x
[
0
];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
write_to_smem
(
float
*
smem
,
int
idx
,
const
float
(
&
x
)[
4
])
{
#ifdef __HIP_PLATFORM_HCC__
smem
[
4
*
idx
]
=
x
[
0
];
smem
[
4
*
idx
+
1
]
=
x
[
1
];
smem
[
4
*
idx
+
2
]
=
x
[
2
];
smem
[
4
*
idx
+
3
]
=
x
[
3
];
#else
reinterpret_cast
<
float4
*>
(
&
smem
[
4
*
idx
])[
0
]
=
make_float4
(
x
[
0
],
x
[
1
],
x
[
2
],
x
[
3
]);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION
void
write_to_smem
(
int
*
smem
,
int
idx
,
const
int
(
&
x
)[
2
])
{
#ifdef __HIP_PLATFORM_HCC__
smem
[
2
*
idx
]
=
x
[
0
];
smem
[
2
*
idx
+
1
]
=
x
[
1
];
#else
reinterpret_cast
<
int2
*>
(
&
smem
[
2
*
idx
])[
0
]
=
make_int2
(
x
[
0
],
x
[
1
]);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
zero_array
(
int
(
&
dst
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
dst
[
i
]
=
0
;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
zero_array
(
float
(
&
dst
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
dst
[
i
]
=
0.
f
;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
add
(
float
(
&
x
)[
N
],
const
float
(
&
y
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
x
[
i
]
+=
y
[
i
];
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
multiply
(
float
(
&
x
)[
N
],
const
float
(
&
y
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
x
[
i
]
*=
y
[
i
];
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
scale_
(
float
(
&
x
)[
N
],
float
scalar
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
x
[
i
]
*=
scalar
;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
normalize
(
float
(
&
x
)[
N
],
const
float
(
&
bias
)[
N
],
const
float
(
&
scale
)[
N
],
const
float
(
&
m1
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
x
[
i
]
=
bias
[
i
]
+
scale
[
i
]
*
(
x
[
i
]
-
m1
[
i
]);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Storage
>
DEVICE_FUNCTION
Storage
relu
(
Storage
in
)
{
Storage
zero
=
(
Storage
)
0.
f
;
return
(
in
<
zero
)
?
zero
:
in
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
relu_activation
(
float
(
&
x
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
x
[
i
]
=
relu
(
x
[
i
]);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS_PER_CTA
>
DEVICE_FUNCTION
void
parallel_sums_16x2
(
float
*
smem
,
float
(
&
x
)[
4
],
int
nhw
,
void
*
params_my_data
,
void
**
params_pair_datas
,
int
off
,
const
int
magic
,
const
int
sync_iters
)
{
// The size of a warp.
#ifdef __HIP_PLATFORM_HCC__
const
int
THREADS_PER_WARP
=
64
;
#else
const
int
THREADS_PER_WARP
=
32
;
#endif
// The number of warps in a CTA.
const
int
WARPS_PER_CTA
=
THREADS_PER_CTA
/
THREADS_PER_WARP
;
// The number of threads per pixel.
const
int
THREADS_PER_PIXEL
=
16
;
// The number of elements per ldg.
const
int
ELEMENTS_PER_LDG
=
4
;
// The number of reducing ops, each uses its own space : mean, var, dscale, dbias
const
int
REDUCE_OPS
=
4
;
// Maximum block.y supported - limited due to buffer allocation
const
int
MAX_BLOCK_Y
=
256
;
const
int
MAX_OFFSET
=
REDUCE_OPS
*
MAX_BLOCK_Y
;
// The warp decomposition.
const
int
warp_id
=
threadIdx
.
x
/
THREADS_PER_WARP
;
const
int
lane_id
=
threadIdx
.
x
%
THREADS_PER_WARP
;
// total size of data per sync iter
const
int
data_total
=
MAX_OFFSET
*
THREADS_PER_PIXEL
*
ELEMENTS_PER_LDG
*
2
;
#ifdef __HIP_PLATFORM_HCC__
for
(
int
offset
=
THREADS_PER_PIXEL
;
offset
<=
THREADS_PER_WARP
>>
1
;
offset
<<=
1
)
{
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
x
[
i
]
+=
shfl_sync
(
x
[
i
],
offset
+
lane_id
);
}
}
#else
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
x
[
i
]
+=
shfl_sync
(
x
[
i
],
THREADS_PER_PIXEL
+
lane_id
);
}
#endif
// The warp leaders, write to SMEM.
if
(
lane_id
<
THREADS_PER_PIXEL
)
{
write_to_smem
(
smem
,
warp_id
*
THREADS_PER_PIXEL
+
lane_id
,
x
);
}
// The data is in SMEM. Do the final reduction.
__syncthreads
();
// The 1st warp does all the work.
// We do the final reduction each half-warp sequentially reduces the final values.
if
(
warp_id
==
0
)
{
read_from_smem
(
x
,
smem
,
threadIdx
.
x
);
#pragma unroll
for
(
int
offset
=
1
;
offset
<
WARPS_PER_CTA
/
(
THREADS_PER_WARP
/
THREADS_PER_PIXEL
);
++
offset
)
{
float
y
[
ELEMENTS_PER_LDG
];
// Read the mean and variance from the other pixel.
read_from_smem
(
y
,
smem
,
threadIdx
.
x
+
offset
*
THREADS_PER_WARP
);
// Compute the updated sum.
add
(
x
,
y
);
}
#ifdef __HIP_PLATFORM_HCC__
for
(
int
offset
=
THREADS_PER_WARP
>>
1
;
offset
>=
THREADS_PER_PIXEL
;
offset
>>=
1
)
{
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
x
[
i
]
+=
shfl_sync
(
x
[
i
],
offset
+
lane_id
);
}
}
#else
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
x
[
i
]
+=
shfl_sync
(
x
[
i
],
THREADS_PER_PIXEL
+
lane_id
);
}
#endif
// Make sure the data was read from SMEM.
syncwarp
();
// Store the final values.
if
(
threadIdx
.
x
<
THREADS_PER_PIXEL
)
{
// probably could do it earlier, before sync
#ifndef __HIP_PLATFORM_HCC__ // bn_group > 1 is not enabled on HIP
for
(
int
sync_iter
=
0
;
sync_iter
<
sync_iters
;
++
sync_iter
)
{
//float* params_pair_data = (reinterpret_cast<float**>(params_pair_datas))[sync_iter];
void
*
params_pair_data
=
params_pair_datas
[
sync_iter
];
// skip the space consumed by previous sync iterations
const
int
xbuf_offset
=
sync_iter
*
data_total
;
// data starts after flags, but have to skip previous
const
int
data_offset
=
xbuf_offset
+
off
*
ELEMENTS_PER_LDG
*
THREADS_PER_PIXEL
*
2
+
ELEMENTS_PER_LDG
*
threadIdx
.
x
*
2
;
// after sums for this GPU were computed, let CTA0 broadcast the sum to over GPU
if
(
blockIdx
.
x
==
0
)
{
volatile
float
*
write_data
=
&
((
reinterpret_cast
<
float
*>
(
params_pair_data
))[
data_offset
]);
// write the data to memory region to be reflected to other GPU
asm
volatile
(
"st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};"
::
"l"
(
write_data
)
,
"f"
(
x
[
0
]),
"r"
(
magic
),
"f"
(
x
[
2
]),
"r"
(
magic
));
asm
volatile
(
"st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};"
::
"l"
(
write_data
+
4
)
,
"f"
(
x
[
1
]),
"r"
(
magic
),
"f"
(
x
[
3
]),
"r"
(
magic
));
}
// now each CTA (on each GPU) reads the data written by CTA 0 of the other GPU
volatile
float
*
read_data
=
&
((
reinterpret_cast
<
float
*>
(
params_my_data
))[
data_offset
]);
float
other
[
4
];
uint32_t
other_flag_a
,
other_flag_b
;
do
{
asm
volatile
(
"ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];"
:
"=f"
(
other
[
0
]),
"=r"
(
other_flag_a
),
"=f"
(
other
[
2
]),
"=r"
(
other_flag_b
)
:
"l"
(
read_data
));
}
while
((
other_flag_a
!=
magic
)
||
(
other_flag_b
!=
magic
));
do
{
asm
volatile
(
"ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];"
:
"=f"
(
other
[
1
]),
"=r"
(
other_flag_a
),
"=f"
(
other
[
3
]),
"=r"
(
other_flag_b
)
:
"l"
(
read_data
+
4
));
}
while
((
other_flag_a
!=
magic
)
||
(
other_flag_b
!=
magic
));
add
(
x
,
other
);
}
#endif
// finally, after syncing up and accounting for partial sums from
// other GPUs as required, write the result
write_to_smem
(
smem
,
threadIdx
.
x
,
x
);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS_PER_CTA
>
DEVICE_FUNCTION
void
parallel_sums_8x4
(
float
*
smem
,
float
(
&
x
)[
4
],
int
nhw
)
{
// The size of a warp.
#ifdef __HIP_PLATFORM_HCC__
const
int
THREADS_PER_WARP
=
64
;
#else
const
int
THREADS_PER_WARP
=
32
;
#endif
// The number of warps in a CTA.
const
int
WARPS_PER_CTA
=
THREADS_PER_CTA
/
THREADS_PER_WARP
;
// The number of threads per pixel.
const
int
THREADS_PER_PIXEL
=
8
;
// The number of elements per ldg.
const
int
ELEMENTS_PER_LDG
=
4
;
// The warp decomposition.
const
int
warp_id
=
threadIdx
.
x
/
THREADS_PER_WARP
;
const
int
lane_id
=
threadIdx
.
x
%
THREADS_PER_WARP
;
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
x
[
i
]
+=
shfl_sync
(
x
[
i
],
THREADS_PER_PIXEL
+
lane_id
);
x
[
i
]
+=
shfl_sync
(
x
[
i
],
THREADS_PER_PIXEL
*
2
+
lane_id
);
}
// The warp leaders, write to SMEM.
if
(
lane_id
<
THREADS_PER_PIXEL
)
{
write_to_smem
(
smem
,
warp_id
*
THREADS_PER_PIXEL
+
lane_id
,
x
);
}
// The data is in SMEM. Do the final reduction.
__syncthreads
();
// The 1st warp does all the work.
// We do the final reduction each half-warp sequentially reduces the final values.
if
(
warp_id
==
0
)
{
read_from_smem
(
x
,
smem
,
threadIdx
.
x
);
#pragma unroll
for
(
int
offset
=
1
;
offset
<
WARPS_PER_CTA
/
(
THREADS_PER_WARP
/
THREADS_PER_PIXEL
);
++
offset
)
{
float
y
[
ELEMENTS_PER_LDG
];
// Read the mean and variance from the other pixel.
read_from_smem
(
y
,
smem
,
threadIdx
.
x
+
offset
*
THREADS_PER_WARP
);
// Compute the updated sum.
add
(
x
,
y
);
}
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
x
[
i
]
+=
shfl_sync
(
x
[
i
],
THREADS_PER_PIXEL
+
lane_id
);
x
[
i
]
+=
shfl_sync
(
x
[
i
],
THREADS_PER_PIXEL
*
2
+
lane_id
);
}
// Make sure the data was read from SMEM.
syncwarp
();
// Store the final values.
if
(
threadIdx
.
x
<
THREADS_PER_PIXEL
)
{
write_to_smem
(
smem
,
threadIdx
.
x
,
x
);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS_PER_CTA
,
int
THREADS_PER_PIXEL
,
int
ELEMENTS_PER_LDG
>
DEVICE_FUNCTION
void
parallel_sums
(
float
*
smem
,
float
(
&
x
)[
ELEMENTS_PER_LDG
],
int
nhw
)
{
// The size of a warp.
#ifdef __HIP_PLATFORM_HCC__
const
int
THREADS_PER_WARP
=
64
;
#else
const
int
THREADS_PER_WARP
=
32
;
#endif
const
int
WARPS_PER_CTA
=
THREADS_PER_CTA
/
THREADS_PER_WARP
;
// The warp decomposition.
const
int
warp_id
=
threadIdx
.
x
/
THREADS_PER_WARP
;
const
int
lane_id
=
threadIdx
.
x
%
THREADS_PER_WARP
;
// total size of data per sync iter
#ifdef __HIP_PLATFORM_HCC__
for
(
int
offset
=
THREADS_PER_PIXEL
;
offset
<=
THREADS_PER_WARP
>>
1
;
offset
<<=
1
)
{
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
x
[
i
]
+=
shfl_sync
(
x
[
i
],
offset
+
lane_id
);
}
}
#else
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
x
[
i
]
+=
shfl_sync
(
x
[
i
],
THREADS_PER_PIXEL
+
lane_id
);
}
#endif
// The warp leaders, write to SMEM.
if
(
lane_id
<
THREADS_PER_PIXEL
)
{
write_to_smem
(
smem
,
warp_id
*
THREADS_PER_PIXEL
+
lane_id
,
x
);
}
// The data is in SMEM. Do the final reduction.
__syncthreads
();
// The 1st warp does all the work.
// We do the final reduction each half-warp sequentially reduces the final values.
if
(
warp_id
==
0
)
{
read_from_smem
(
x
,
smem
,
threadIdx
.
x
);
#pragma unroll
for
(
int
offset
=
1
;
offset
<
WARPS_PER_CTA
/
(
THREADS_PER_WARP
/
THREADS_PER_PIXEL
);
++
offset
)
{
float
y
[
ELEMENTS_PER_LDG
];
// Read the mean and variance from the other pixel.
read_from_smem
(
y
,
smem
,
threadIdx
.
x
+
offset
*
THREADS_PER_WARP
);
// Compute the updated sum.
add
(
x
,
y
);
}
#ifdef __HIP_PLATFORM_HCC__
for
(
int
offset
=
THREADS_PER_WARP
>>
1
;
offset
>=
THREADS_PER_PIXEL
;
offset
>>=
1
)
{
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
x
[
i
]
+=
shfl_sync
(
x
[
i
],
offset
+
lane_id
);
}
}
#else
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
x
[
i
]
+=
shfl_sync
(
x
[
i
],
THREADS_PER_PIXEL
+
lane_id
);
}
#endif
// Make sure the data was read from SMEM.
syncwarp
();
// Store the final values.
if
(
threadIdx
.
x
<
THREADS_PER_PIXEL
)
{
// probably could do it earlier, before sync
write_to_smem
(
smem
,
threadIdx
.
x
,
x
);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS_PER_PIXEL
,
int
ELEMENTS_PER_LDG
>
struct
ParallelSums
{
template
<
int
THREADS_PER_CTA
>
DEVICE_FUNCTION
void
dispatch
(
float
*
smem
,
float
(
&
x
)[
ELEMENTS_PER_LDG
],
int
nhw
)
{
parallel_sums
<
THREADS_PER_CTA
,
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>
(
smem
,
x
,
nhw
);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/*
template<>
struct ParallelSums<16, 4> {
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void dispatch(float *smem, float (&x)[4], int nhw) {
parallel_sums_16x2<THREADS_PER_CTA>(smem, x, nhw, 0, 0, 0, 0, 0);
}
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void dispatchX(float *smem, float (&x)[4], int nhw, void* params_my_data, void** params_pair_datas, int off, const int magic, const unsigned int& sync_iters) {
parallel_sums_16x2<THREADS_PER_CTA>(smem, x, nhw, params_my_data, params_pair_datas, off, magic, sync_iters);
}
};
template<>
struct ParallelSums<8, 4> {
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void dispatch(float *smem, float (&x)[4], int nhw) {
parallel_sums_8x4<THREADS_PER_CTA>(smem, x, nhw);
}
};
*/
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
int
div_up
(
int
m
,
int
n
)
{
return
(
m
+
n
-
1
)
/
n
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// It is expected that all threads in the CTA enter this function!
DEVICE_FUNCTION
void
inter_block_sync
(
int
*
gmem_retired_ctas
,
int
expected_count
,
bool
master
)
{
// Register the CTA.
if
(
threadIdx
.
x
==
0
)
{
// Issue the membar.
__threadfence
();
// Notify that the CTA is done.
int
val_to_add
=
1
;
if
(
master
)
{
val_to_add
=
-
(
expected_count
-
1
);
}
atomicAdd
(
gmem_retired_ctas
,
val_to_add
);
}
// Are all CTAs done?
if
(
threadIdx
.
x
==
0
)
{
int
retired_ctas
=
-
1
;
do
{
__threadfence
();
#ifdef __HIP_PLATFORM_HCC__
retired_ctas
=
__ldg
((
const
int
*
)
gmem_retired_ctas
);
#else
asm
volatile
(
"ld.global.cg.b32 %0, [%1];"
:
"=r"
(
retired_ctas
)
:
"l"
(
gmem_retired_ctas
));
#endif
}
while
(
retired_ctas
!=
0
);
}
__syncthreads
();
}
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
NhwcBatchNormFwdInferenceParams
{
// The input/output tensors.
uint16_t
*
gmem_src
,
*
gmem_dst
,
*
gmem_src1
;
// the final mean and variance as calculated during the training process
float
*
gmem_mean
,
*
gmem_var
;
// The bias/scale.
float
*
gmem_bias
,
*
gmem_scale
;
// The dimensions.
int
nhw
,
c
;
// epsilon
float
var_eps
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// No DESIRED_OCCUPANCY launch bounds needed, as this is not launched cooperatively
template
<
typename
Storage
,
int
THREADS_PER_CTA
,
int
THREADS_PER_PIXEL
,
int
ELEMENTS_PER_LDG
,
bool
USE_RELU
,
bool
USE_ADD_RELU
>
__global__
__launch_bounds__
(
THREADS_PER_CTA
)
void
nhwc_batch_norm_fwd_inference
(
NhwcBatchNormFwdInferenceParams
params
)
{
// The number of pixels loaded in a single LDG.
const
int
PIXELS_PER_LDG
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
;
// The number of C elements per CTA.
const
int
C_ELEMENTS_PER_CTA
=
THREADS_PER_PIXEL
*
ELEMENTS_PER_LDG
;
// The start position in the NHW dimension where the CTA starts.
const
int
cta_nhw_stride
=
gridDim
.
x
*
PIXELS_PER_LDG
;
// Compute the NHW coordinate of the thread in the CTA.
const
int
thread_in_cta_nhw
=
threadIdx
.
x
/
THREADS_PER_PIXEL
;
// thread's starting point in NHW
const
int
thread_nhw
=
thread_in_cta_nhw
+
blockIdx
.
x
*
PIXELS_PER_LDG
;
// The position in the C dimension where the CTA starts.
const
int
cta_c
=
blockIdx
.
y
*
C_ELEMENTS_PER_CTA
;
// Compute the C coordinate of the thread in the CTA.
const
int
thread_in_cta_c
=
threadIdx
.
x
%
THREADS_PER_PIXEL
;
// Compute the C coordinate of the thread.
const
int
thread_c
=
cta_c
+
thread_in_cta_c
*
ELEMENTS_PER_LDG
;
// Is the thread working on a valid C dimension?
const
int
is_valid_c
=
thread_c
<
params
.
c
;
float
mean
[
ELEMENTS_PER_LDG
],
var
[
ELEMENTS_PER_LDG
];
float
scale
[
ELEMENTS_PER_LDG
],
bias
[
ELEMENTS_PER_LDG
];
zero_array
(
mean
);
zero_array
(
var
);
zero_array
(
scale
);
zero_array
(
bias
);
if
(
is_valid_c
)
{
read_from_gmem
(
var
,
&
params
.
gmem_var
[
cta_c
],
thread_in_cta_c
);
read_from_gmem
(
scale
,
&
params
.
gmem_scale
[
cta_c
],
thread_in_cta_c
);
read_from_gmem
(
mean
,
&
params
.
gmem_mean
[
cta_c
],
thread_in_cta_c
);
read_from_gmem
(
bias
,
&
params
.
gmem_bias
[
cta_c
],
thread_in_cta_c
);
}
// Update the scale with the stddev and eps.
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
scale
[
i
]
*=
rsqrtf
(
var
[
i
]
+
params
.
var_eps
);
}
// The base pointers for reading/writing
uint16_t
*
const
gmem_src
=
&
params
.
gmem_src
[
thread_c
];
uint16_t
*
const
gmem_dst
=
&
params
.
gmem_dst
[
thread_c
];
const
uint16_t
*
gmem_src1
=
nullptr
;
if
(
USE_ADD_RELU
)
{
gmem_src1
=
&
params
.
gmem_src1
[
thread_c
];
}
// apply BN
for
(
int
nhw
=
thread_nhw
;
nhw
<
params
.
nhw
;
nhw
+=
cta_nhw_stride
)
{
float
x_math
[
ELEMENTS_PER_LDG
];
zero_array
(
x_math
);
if
(
is_valid_c
)
{
ldg
(
x_math
,
&
gmem_src
[
nhw
*
params
.
c
]);
}
// Normalize and apply activation function
normalize
(
x_math
,
bias
,
scale
,
mean
);
if
(
USE_ADD_RELU
)
{
float
x1_math
[
ELEMENTS_PER_LDG
];
ldg
(
x1_math
,
&
gmem_src1
[
nhw
*
params
.
c
]);
add
(
x_math
,
x1_math
);
relu_activation
(
x_math
);
}
else
if
(
USE_RELU
)
{
relu_activation
(
x_math
);
}
if
(
is_valid_c
)
{
stg
(
&
gmem_dst
[
nhw
*
params
.
c
],
x_math
);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
NhwcBatchNormFwdParams
{
// The input/output tensors.
uint16_t
*
gmem_src
,
*
gmem_dst
,
*
gmem_src1
;
// The bias/scale.
float
*
gmem_bias
,
*
gmem_scale
;
// running mean/var (refer BN API from cudnn doc)
float
*
gmem_running_mean
,
*
gmem_running_var
;
// saved mean/var (refer BN API from cudnn doc)
float
*
gmem_saved_mean
,
*
gmem_saved_var
;
// ReLU bitmask
bitmask_t
*
gmem_relu_bitmask
;
// The dimensions.
int
nhw
,
c
;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
float
svar_inv_count
;
// factor to scale sum of squared errors to get running variance. Should be 1/nhw or 1/(nhw-1).
float
rvar_inv_count
;
// The buffer to do the reduction for mean, stddev and count.
float
*
gmem_sums
;
// The buffer to count items in the different CTAs.
int
*
gmem_counts
;
// The counters of retired CTAs.
int
*
gmem_retired_ctas
;
// The epsilon to apply to the computation of the variance.
float
var_eps
;
// outer loop count
int
outer_loops
;
// exponential average factor
float
exp_avg_factor
;
// number of CTAs along .x dimension
int
c_blks
;
void
*
my_data
;
void
*
pair_datas
[
4
];
int
magic
;
int
sync_iters
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Storage
,
int
THREADS_PER_CTA
,
int
THREADS_PER_PIXEL
,
int
PIXELS_PER_THREAD_IN_REGISTERS
,
int
PIXELS_PER_THREAD_IN_SMEM
,
int
ELEMENTS_PER_LDG
,
int
USE_ONLINE_APPROACH
,
int
OUTER_LOOPS_
,
bool
USE_RELU
,
bool
USE_ADD_RELU
,
int
DESIRED_OCCUPANCY
>
__global__
__launch_bounds__
(
THREADS_PER_CTA
,
DESIRED_OCCUPANCY
)
void
nhwc_batch_norm_fwd
(
NhwcBatchNormFwdParams
params
)
{
// The number of pixels loaded in a single LDG.
const
int
PIXELS_PER_LDG
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
;
// The number of pixels computed per CTA stored in registers.
const
int
PIXELS_PER_CTA_IN_REGISTERS
=
PIXELS_PER_THREAD_IN_REGISTERS
*
PIXELS_PER_LDG
;
// The number of pixels computed per CTA stored in SMEM.
const
int
PIXELS_PER_CTA_IN_SMEM
=
PIXELS_PER_THREAD_IN_SMEM
*
PIXELS_PER_LDG
;
// The number of C elements per CTA.
const
int
C_ELEMENTS_PER_CTA
=
THREADS_PER_PIXEL
*
ELEMENTS_PER_LDG
;
// Shared memory to do CTA-wide parallel sums.
__shared__
float
smem
[
THREADS_PER_PIXEL
*
(
THREADS_PER_CTA
/
warpSize
)
*
ELEMENTS_PER_LDG
];
// Compute the NHW coordinate of the thread in the CTA.
const
int
thread_in_cta_nhw
=
threadIdx
.
x
/
THREADS_PER_PIXEL
;
// The adapter for the storage.
typedef
PackedStorage
<
Storage
,
ELEMENTS_PER_LDG
>
PackedStorage_
;
// The data type for packed storage in SMEM.
typedef
typename
PackedStorage_
::
Type
PackedStorageType
;
// The number of elements in the packed storage.
const
int
PACKED_ELEMENTS_PER_LDG
=
PackedStorage_
::
PACKED_ELEMENTS_PER_LDG
;
// Registers to keep the data live for the persistent approach.
PackedStorageType
x_storage
[
PIXELS_PER_THREAD_IN_REGISTERS
][
PACKED_ELEMENTS_PER_LDG
];
// Shared memory buffer to store the extra pixels.
extern
__shared__
PackedStorageType
smem_storage_packed
[];
#ifdef __HIP_PLATFORM_HCC__
const
half
zero_h
=
__float2half
(
0.0
F
);
#endif
for
(
int
c_blk_index
=
blockIdx
.
y
;
c_blk_index
<
params
.
c_blks
;
c_blk_index
+=
gridDim
.
y
)
{
// The position in the NHW dimension where the CTA starts.
int
cta_nhw_regs
=
blockIdx
.
x
*
PIXELS_PER_CTA_IN_REGISTERS
;
// The position in the NHW dimension where the CTA starts for the portion in SMEM.
int
cta_nhw_smem
=
blockIdx
.
x
*
PIXELS_PER_CTA_IN_SMEM
;
// The position in the C dimension where the CTA starts.
const
int
cta_c
=
c_blk_index
*
C_ELEMENTS_PER_CTA
;
// Compute the C coordinate of the thread in the CTA.
const
int
thread_in_cta_c
=
threadIdx
.
x
%
THREADS_PER_PIXEL
;
// Compute the C coordinate of the thread.
int
thread_c
=
cta_c
+
thread_in_cta_c
*
ELEMENTS_PER_LDG
;
// Is the thread working on a valid C dimension?
const
int
is_valid_c
=
thread_c
<
params
.
c
;
// Clamp thread_c so that we load from valid locations even if we don't use the value
if
(
!
is_valid_c
)
thread_c
=
params
.
c
-
4
;
// Single pass numerically stable algorithm, see:
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm
//
// n = 0, mean = 0.0, M2 = 0.0
//
// for x in data:
// n += 1
// delta = x - mean
// mean += delta/n
// delta2 = x - mean
// M2 += delta*delta2
//
// if n < 2:
// return float('nan')
// else:
// return M2 / (n - 1)
// Register to store the number of elements read so far.
float
count
=
0.
f
,
mean
[
ELEMENTS_PER_LDG
],
m2
[
ELEMENTS_PER_LDG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
mean
[
i
]
=
0.
f
;
m2
[
i
]
=
0.
f
;
}
// The number of elements loaded by this CTA.
int
cta_count
=
0
;
// The base pointer to load from.
const
uint16_t
*
gmem_src
=
&
params
.
gmem_src
[
thread_c
];
// outer loops
int
OUTER_LOOPS
=
OUTER_LOOPS_
==
1
?
1
:
params
.
outer_loops
;
// Load the batch of elements. Compute the mean/var across those elements.
const
int
pixels_per_iteration
=
PIXELS_PER_CTA_IN_REGISTERS
*
gridDim
.
x
;
if
(
OUTER_LOOPS_
!=
1
)
{
// We cannot load everything to store persistently, so let's makes sure registers and
// smem are fully utilized, offset is evenly divisible by 32
int
offset
=
(
pixels_per_iteration
*
OUTER_LOOPS
+
PIXELS_PER_CTA_IN_SMEM
*
gridDim
.
x
-
params
.
nhw
)
&
~
31
;
cta_nhw_regs
-=
offset
;
cta_nhw_smem
-=
offset
;
}
#pragma unroll 1
for
(
int
loop_i
=
0
;
loop_i
<
OUTER_LOOPS
;
++
loop_i
)
{
// The nhw position.
int
nhw_regs
=
cta_nhw_regs
+
loop_i
*
pixels_per_iteration
;
// Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!
cta_count
+=
max
(
min
(
nhw_regs
+
PIXELS_PER_CTA_IN_REGISTERS
,
params
.
nhw
)
-
max
(
nhw_regs
,
0
),
0
);
// Load the data and compute the local mean/sum and the variance.
if
(
USE_ONLINE_APPROACH
)
{
// Read the elements from memory.
float
is_valid
[
PIXELS_PER_THREAD_IN_REGISTERS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
nhw_regs
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
zero_array
(
x_storage
[
i
]);
is_valid
[
i
]
=
0.
f
;
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
#ifndef __HIP_PLATFORM_HCC__
if
(
loop_i
==
OUTER_LOOPS
-
1
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
}
else
{
#endif
ldg
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
is_valid
[
i
]
=
1.
f
;
}
}
// Do the math.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
// Convert to float.
float
x_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
// Update the count.
count
+=
is_valid
[
i
];
// Invert the count.
float
inv_count
=
is_valid
[
i
]
?
1.
f
/
count
:
0.
f
;
// Update the mean and m2 using deltas.
#pragma unroll
for
(
int
j
=
0
;
j
<
ELEMENTS_PER_LDG
;
++
j
)
{
float
delta0
=
x_math
[
j
]
-
mean
[
j
];
mean
[
j
]
+=
delta0
*
inv_count
;
float
delta1
=
x_math
[
j
]
-
mean
[
j
];
m2
[
j
]
+=
delta0
*
delta1
*
is_valid
[
i
];
}
}
}
else
{
// Read the elements from memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
nhw_regs
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
zero_array
(
x_storage
[
i
]);
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
if
(
loop_i
==
OUTER_LOOPS
-
1
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
}
else
{
ldg
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
}
count
+=
1.
f
;
}
}
// Sum the elements in registers.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
// Convert to float.
float
x_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
// Update the mean and m2 using deltas.
#pragma unroll
for
(
int
j
=
0
;
j
<
ELEMENTS_PER_LDG
;
++
j
)
{
mean
[
j
]
+=
x_math
[
j
];
}
}
// Compute the mean.
float
inv_count
=
1.
f
/
count
;
#pragma unroll
for
(
int
j
=
0
;
j
<
ELEMENTS_PER_LDG
;
++
j
)
{
mean
[
j
]
*=
inv_count
;
}
// Compute the variance.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
// Convert to float.
float
x_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
// Is it a valid pixel?
float
is_valid
=
i
<
static_cast
<
int
>
(
count
)
?
1.
f
:
0.
f
;
// Update the mean and m2 using deltas.
#pragma unroll
for
(
int
j
=
0
;
j
<
ELEMENTS_PER_LDG
;
++
j
)
{
m2
[
j
]
+=
(
x_math
[
j
]
-
mean
[
j
])
*
(
x_math
[
j
]
-
mean
[
j
])
*
is_valid
;
}
}
}
}
// The elements to load and store in SMEM.
int
smem_nhw
=
OUTER_LOOPS
*
pixels_per_iteration
+
cta_nhw_smem
;
// Load elements from SMEM, update the CTA count.
int
pixels_in_smem
=
min
(
smem_nhw
+
PIXELS_PER_CTA_IN_SMEM
,
params
.
nhw
)
-
max
(
smem_nhw
,
0
);
if
(
pixels_in_smem
>
0
)
{
cta_count
+=
pixels_in_smem
;
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_SMEM
;
++
i
)
{
const
int
idx
=
smem_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
float
is_pixel_valid
=
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
?
1.
f
:
0.
f
;
PackedStorageType
x_storage_local
[
PACKED_ELEMENTS_PER_LDG
];
ldg_stream
(
x_storage_local
,
&
gmem_src
[(
is_pixel_valid
?
idx
:
0
)
*
params
.
c
]);
// The offset to store in SMEM.
const
int
offset
=
i
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
// Store in SMEM.
write_to_smem
(
&
smem_storage_packed
[
offset
],
threadIdx
.
x
,
x_storage_local
);
// Update the count.
count
+=
is_pixel_valid
;
// Invert the count.
float
inv_count
=
is_pixel_valid
?
1.
f
/
count
:
0.
f
;
float
x_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage_local
);
// Update the mean and m2 using deltas.
#pragma unroll
for
(
int
j
=
0
;
j
<
ELEMENTS_PER_LDG
;
++
j
)
{
float
delta0
=
x_math
[
j
]
-
mean
[
j
];
mean
[
j
]
+=
delta0
*
inv_count
;
float
delta1
=
x_math
[
j
]
-
mean
[
j
];
m2
[
j
]
+=
delta0
*
delta1
*
is_pixel_valid
;
}
}
}
// We scale the mean by the number of elements. It brings more stability.
float
m1
[
ELEMENTS_PER_LDG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
m1
[
i
]
=
mean
[
i
]
*
count
;
}
// Run the parallel sum accross the CTA to get the local sum.
#ifdef __HIP_PLATFORM_HCC__
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
template
dispatch
<
THREADS_PER_CTA
>(
#else
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
#endif
smem
,
m1
,
thread_in_cta_nhw
);
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
m1
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// Adjust the variance.
float
inv_cta_count
=
1.
f
/
static_cast
<
float
>
(
cta_count
);
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
float
mean_diff
=
m1
[
i
]
*
inv_cta_count
-
mean
[
i
];
m2
[
i
]
=
m2
[
i
]
+
mean_diff
*
mean_diff
*
count
;
}
// Run the parallel sum accross the CTA to get the local adjusted variance.
#ifdef __HIP_PLATFORM_HCC__
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
template
dispatch
<
THREADS_PER_CTA
>(
#else
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
#endif
smem
,
m2
,
thread_in_cta_nhw
);
// The workspace in global memory is distributed across the different CTA.
int
gmem_sums_offset
=
c_blk_index
*
gridDim
.
x
*
C_ELEMENTS_PER_CTA
*
2
;
// Write the data for the CTA to global memory.
float
*
gmem_sums
=
&
params
.
gmem_sums
[
gmem_sums_offset
];
if
(
threadIdx
.
x
<
THREADS_PER_PIXEL
)
{
const
int
idx
=
blockIdx
.
x
*
THREADS_PER_PIXEL
+
threadIdx
.
x
;
write_to_gmem
(
&
gmem_sums
[
0
],
idx
,
m1
);
write_to_gmem
(
&
gmem_sums
[
C_ELEMENTS_PER_CTA
*
gridDim
.
x
],
idx
,
m2
);
}
// The memory location to store the number of pixels per CTA.
int
*
gmem_counts
=
&
params
.
gmem_counts
[
c_blk_index
*
gridDim
.
x
];
if
(
threadIdx
.
x
==
0
)
{
gmem_counts
[
blockIdx
.
x
]
=
cta_count
;
}
// Read the bias and scale.
float
bias
[
ELEMENTS_PER_LDG
],
scale
[
ELEMENTS_PER_LDG
];
if
(
is_valid_c
)
{
read_from_gmem
(
bias
,
&
params
.
gmem_bias
[
cta_c
],
thread_in_cta_c
);
read_from_gmem
(
scale
,
&
params
.
gmem_scale
[
cta_c
],
thread_in_cta_c
);
}
// The counters to count how many CTAs have retired at this point.
// A given cta uses the same counter every other time through the outer loop.
int
*
gmem_retired_ctas
=
&
params
.
gmem_retired_ctas
[
c_blk_index
%
(
2
*
gridDim
.
y
)];
inter_block_sync
(
gmem_retired_ctas
,
gridDim
.
x
,
blockIdx
.
x
==
0
);
// Reset the mean to compute the global mean.
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
m1
[
i
]
=
0.
f
;
}
// Build the global mean.
#pragma unroll 1
for
(
int
idx
=
threadIdx
.
x
;
idx
<
THREADS_PER_PIXEL
*
gridDim
.
x
;
idx
+=
THREADS_PER_CTA
)
{
float
tmp
[
ELEMENTS_PER_LDG
];
read_from_gmem
(
tmp
,
gmem_sums
,
idx
);
add
(
m1
,
tmp
);
}
#ifndef __HIP_PLATFORM_HCC__
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
m1
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_datas
,
4
*
c_blk_index
+
3
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
#endif
#ifdef __HIP_PLATFORM_HCC__
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
template
dispatch
<
THREADS_PER_CTA
>(
#else
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
#endif
smem
,
m1
,
thread_in_cta_nhw
);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
m1
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// Normalize the mean.
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
m1
[
i
]
=
m1
[
i
]
*
params
.
svar_inv_count
;
}
// Reset the variance.
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
m2
[
i
]
=
0.
f
;
}
// for add+relu fusion
const
uint16_t
*
gmem_src1
=
nullptr
;
if
(
USE_ADD_RELU
)
{
gmem_src1
=
&
params
.
gmem_src1
[
thread_c
];
}
// Build the global variance.
#pragma unroll 1
for
(
int
idx
=
threadIdx
.
x
;
idx
<
THREADS_PER_PIXEL
*
gridDim
.
x
;
idx
+=
THREADS_PER_CTA
)
{
// Read the means computed by different CTAs (again). Reuse tmp if we have 1 iteration.
float
tmp_mean
[
ELEMENTS_PER_LDG
],
tmp_var
[
ELEMENTS_PER_LDG
];
read_from_gmem
(
tmp_mean
,
&
gmem_sums
[
0
],
idx
);
read_from_gmem
(
tmp_var
,
&
gmem_sums
[
C_ELEMENTS_PER_CTA
*
gridDim
.
x
],
idx
);
// Read the number of pixels visited by a given CTA.
cta_count
=
__ldg
(
&
gmem_counts
[
idx
/
THREADS_PER_PIXEL
]);
// Compute the diff to update the variance.
float
mean_diff
[
ELEMENTS_PER_LDG
],
inv_cta_count
=
1.
f
/
static_cast
<
float
>
(
cta_count
);
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
mean_diff
[
i
]
=
m1
[
i
]
-
tmp_mean
[
i
]
*
inv_cta_count
;
}
// Update the variance.
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
m2
[
i
]
+=
tmp_var
[
i
]
+
mean_diff
[
i
]
*
mean_diff
[
i
]
*
static_cast
<
float
>
(
cta_count
);
}
}
#ifndef __HIP_PLATFORM_HCC__
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
m2
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_datas
,
4
*
c_blk_index
+
2
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
#endif
#ifdef __HIP_PLATFORM_HCC__
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
template
dispatch
<
THREADS_PER_CTA
>(
#else
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
#endif
smem
,
m2
,
thread_in_cta_nhw
);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
__syncthreads
();
read_from_smem
(
m2
,
smem
,
thread_in_cta_c
);
// Finalize the stddev.
// becasue saved var and running var may have different denominator, we don't do it here
// scale_(m2, inv_count);
// store the saved mean/var
float
svarinv
[
ELEMENTS_PER_LDG
];
bool
is_valid_for_saving
=
is_valid_c
&&
blockIdx
.
x
==
0
&&
thread_in_cta_nhw
==
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
svarinv
[
i
]
=
rsqrtf
(
m2
[
i
]
*
params
.
svar_inv_count
+
params
.
var_eps
);
}
if
(
is_valid_for_saving
)
{
write_to_gmem
(
params
.
gmem_saved_mean
,
thread_c
/
ELEMENTS_PER_LDG
,
m1
);
write_to_gmem
(
params
.
gmem_saved_var
,
thread_c
/
ELEMENTS_PER_LDG
,
svarinv
);
}
// store the running mean/var
float
rmean
[
ELEMENTS_PER_LDG
],
rvar
[
ELEMENTS_PER_LDG
];
zero_array
(
rmean
);
zero_array
(
rvar
);
if
(
params
.
exp_avg_factor
!=
1.
f
&&
is_valid_for_saving
)
{
read_from_gmem
(
rmean
,
params
.
gmem_running_mean
,
thread_c
/
ELEMENTS_PER_LDG
);
read_from_gmem
(
rvar
,
params
.
gmem_running_var
,
thread_c
/
ELEMENTS_PER_LDG
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
rmean
[
i
]
=
(
1.
f
-
params
.
exp_avg_factor
)
*
rmean
[
i
]
+
\
params
.
exp_avg_factor
*
m1
[
i
];
rvar
[
i
]
=
(
1.
f
-
params
.
exp_avg_factor
)
*
rvar
[
i
]
+
\
params
.
exp_avg_factor
*
(
m2
[
i
]
*
params
.
rvar_inv_count
);
}
if
(
is_valid_for_saving
)
{
write_to_gmem
(
params
.
gmem_running_mean
,
thread_c
/
ELEMENTS_PER_LDG
,
rmean
);
write_to_gmem
(
params
.
gmem_running_var
,
thread_c
/
ELEMENTS_PER_LDG
,
rvar
);
}
// Update the scale with the stddev and eps.
multiply
(
scale
,
svarinv
);
// The base pointer to write to.
uint16_t
*
const
gmem_dst
=
&
params
.
gmem_dst
[
thread_c
];
bitmask_t
*
const
gmem_relu_bitmask
=
params
.
gmem_relu_bitmask
+
#ifdef __HIP_PLATFORM_HCC__
((
params
.
nhw
+
3
)
&
~
3
)
*
2
*
c_blk_index
;
#else
((
params
.
nhw
+
31
)
&
~
31
)
*
2
*
c_blk_index
;
#endif
// Store the elements in registers.
#pragma unroll 1
for
(
int
loop_i
=
OUTER_LOOPS
-
1
;
loop_i
>=
0
;
--
loop_i
)
{
// The value for nhw.
int
out_nhw
=
cta_nhw_regs
+
loop_i
*
pixels_per_iteration
;
// Normalize the elements and write to memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
out_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
const
bool
is_valid_nhw
=
static_cast
<
unsigned
int
>
(
idx
)
<
static_cast
<
unsigned
int
>
(
params
.
nhw
);
const
bool
is_valid
=
is_valid_nhw
&&
is_valid_c
;
// Convert to float.
float
x_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
// Normalize and apply activation function
normalize
(
x_math
,
bias
,
scale
,
m1
);
if
(
USE_ADD_RELU
)
{
float
x1_math
[
ELEMENTS_PER_LDG
];
ldg_stream
(
x1_math
,
&
gmem_src1
[(
is_valid
?
idx
:
0
)
*
params
.
c
]);
add
(
x_math
,
x1_math
);
bitmask_t
relu_mask
;
#ifdef __HIP_PLATFORM_HCC__
int
lane_id
=
threadIdx
.
x
&
63
;
#else
int
lane_id
=
threadIdx
.
x
&
31
;
#endif
#pragma unroll
for
(
int
j
=
0
;
j
<
ELEMENTS_PER_LDG
;
++
j
)
{
#ifdef __HIP_PLATFORM_HCC__
bool
rectified
=
__hle
(
__float2half
(
x_math
[
j
]),
zero_h
);
#else
bool
rectified
=
x_math
[
j
]
<
0
;
#endif
bitmask_t
local_relu_mask
=
ballot
(
rectified
);
if
(
lane_id
==
j
)
{
// Thread 0 remembers the relu_mask from the first time through this
// loop, Thread 1 the next, Thread 2 the next, and Thread 3 the last.
relu_mask
=
local_relu_mask
;
}
if
(
rectified
)
{
x_math
[
j
]
=
0.0
F
;
}
}
if
(
is_valid_nhw
&&
(
lane_id
<
ELEMENTS_PER_LDG
))
{
gmem_relu_bitmask
[
idx
*
BITMASK_OFFSET
+
lane_id
]
=
relu_mask
;
}
}
else
if
(
USE_RELU
)
{
relu_activation
(
x_math
);
}
// Write back.
if
(
is_valid
)
{
stg_stream
(
&
gmem_dst
[
idx
*
params
.
c
],
x_math
);
}
}
// The next value of nhw.
out_nhw
-=
pixels_per_iteration
;
// Read the next elements from memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
out_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
}
}
}
// Normalize the elements from SMEM and write them out.
if
(
pixels_in_smem
>
0
)
{
#pragma unroll 2
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_SMEM
;
++
i
)
{
const
int
idx
=
smem_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
const
bool
is_valid_nhw
=
static_cast
<
unsigned
int
>
(
idx
)
<
static_cast
<
unsigned
int
>
(
params
.
nhw
);
const
bool
is_valid
=
is_valid_nhw
&&
is_valid_c
;
// Read from SMEM.
const
int
offset
=
i
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
PackedStorageType
x_storage_local
[
PACKED_ELEMENTS_PER_LDG
];
read_from_smem
(
x_storage_local
,
&
smem_storage_packed
[
offset
],
threadIdx
.
x
);
float
x_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage_local
);
// Normalize and apply activation function
normalize
(
x_math
,
bias
,
scale
,
m1
);
if
(
USE_ADD_RELU
)
{
float
x1_math
[
ELEMENTS_PER_LDG
];
ldg_stream
(
x1_math
,
&
gmem_src1
[(
is_valid
?
idx
:
0
)
*
params
.
c
]);
add
(
x_math
,
x1_math
);
bitmask_t
relu_mask
;
#ifdef __HIP_PLATFORM_HCC__
int
lane_id
=
threadIdx
.
x
&
63
;
#else
int
lane_id
=
threadIdx
.
x
&
31
;
#endif
#pragma unroll
for
(
int
j
=
0
;
j
<
ELEMENTS_PER_LDG
;
++
j
)
{
#ifdef __HIP_PLATFORM_HCC__
bool
rectified
=
__hle
(
__float2half
(
x_math
[
j
]),
zero_h
);
#else
bool
rectified
=
x_math
[
j
]
<
0
;
#endif
bitmask_t
local_relu_mask
=
ballot
(
rectified
);
if
(
lane_id
==
j
)
{
relu_mask
=
local_relu_mask
;
}
if
(
rectified
)
{
x_math
[
j
]
=
0.0
F
;
}
}
if
(
is_valid_nhw
&&
(
lane_id
<
ELEMENTS_PER_LDG
))
{
gmem_relu_bitmask
[
idx
*
BITMASK_OFFSET
+
lane_id
]
=
relu_mask
;
}
}
else
if
(
USE_RELU
)
{
relu_activation
(
x_math
);
}
// Write back.
if
(
is_valid
)
{
stg_stream
(
&
gmem_dst
[
idx
*
params
.
c
],
x_math
);
}
}
}
// We're about to start on the next c-blk. Needed?
__syncthreads
();
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
NhwcBatchNormBwdParams
{
// The input/output tensors.
uint16_t
*
gmem_src
,
*
gmem_dy
,
*
gmem_dst
,
*
gmem_dst1
;
// dscale/dbias
float
*
gmem_dscale
,
*
gmem_dbias
;
// The scale and bias.
float
*
gmem_scale
,
*
gmem_bias
;
// The mean/inv-var saved from fwd pass
float
*
gmem_saved_mean
,
*
gmem_saved_var
;
// ReLU bitmask
bitmask_t
*
gmem_relu_bitmask
;
// The dimensions.
int
nhw
,
c
;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
float
svar_inv_count
;
// The buffer to do the reduction for dscale and dbias
float
*
gmem_sums
;
// The counters of retired CTAs.
int
*
gmem_retired_ctas
;
// outer loop count
int
outer_loops
;
// number of CTAs along .x dimension
int
c_blks
;
void
*
my_data
;
void
*
pair_datas
[
4
];
int
magic
;
int
sync_iters
;
float
wgrad_coeff
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
relu_bwd
(
float
(
&
dy
)[
N
],
const
float
(
&
x
)[
N
],
const
float
(
&
mean_var_scale_bias
)[
N
],
const
float
(
&
var_scale
)[
N
],
bool
valid_data
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
float
y
=
(
x
[
j
]
*
var_scale
[
j
])
+
mean_var_scale_bias
[
j
];
if
((
y
<=
0.
f
)
&&
valid_data
)
{
dy
[
j
]
=
0.
f
;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
relu_bwd
(
float
(
&
dy
)[
N
],
const
float
(
&
y
)[
N
],
bool
valid_data
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
if
((
y
[
j
]
<=
0.
f
)
&&
valid_data
)
{
dy
[
j
]
=
0.
f
;
}
}
}
template
<
int
N
>
DEVICE_FUNCTION
void
relu_bwd
(
float
(
&
dy
)[
N
],
const
bool
(
&
rectified
)[
N
],
bool
valid_data
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
if
(
rectified
[
j
]
&&
valid_data
)
{
dy
[
j
]
=
0.
f
;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
relu_bwd_for_dx
(
float
(
&
dy
)[
N
],
const
float
(
&
x
)[
N
],
const
float
(
&
mean_var_scale_bias
)[
N
],
const
float
(
&
var_scale
)[
N
])
{
#pragma unroll
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
float
y
=
(
x
[
j
]
*
var_scale
[
j
])
+
mean_var_scale_bias
[
j
];
if
(
y
<=
0.
f
)
{
dy
[
j
]
=
0.
f
;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
relu_bwd_for_dx
(
float
(
&
dy
)[
N
],
const
float
(
&
y
)[
N
])
{
#pragma unroll
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
if
(
y
[
j
]
<=
0.
f
)
{
dy
[
j
]
=
0.
f
;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
bwd_update
(
float
(
&
dscale
)[
N
],
float
(
&
dbias
)[
N
],
const
float
(
&
dy
)[
N
],
const
float
(
&
x
)[
N
],
const
float
(
&
mean
)[
N
],
float
inv_count
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
float
delta0
=
dy
[
j
]
-
dbias
[
j
];
dbias
[
j
]
+=
delta0
*
inv_count
;
delta0
=
(
dy
[
j
]
*
(
x
[
j
]
-
mean
[
j
]))
-
dscale
[
j
];
dscale
[
j
]
+=
delta0
*
inv_count
;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
DEVICE_FUNCTION
void
bwd_dx
(
float
(
&
dx
)[
N
],
const
float
(
&
dy
)[
N
],
const
float
(
&
var
)[
N
],
const
float
(
&
x
)[
N
],
const
float
(
&
mean
)[
N
],
const
float
(
&
dscale
)[
N
],
const
float
(
&
dbias
)[
N
],
float
inv_count
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
float
tmp1
=
dy
[
j
]
-
(
dbias
[
j
]
*
inv_count
);
float
tmp2
=
dscale
[
j
]
*
inv_count
;
float
tmp3
=
x
[
j
]
-
mean
[
j
];
dx
[
j
]
=
var
[
j
]
*
(
tmp1
-
(
tmp2
*
tmp3
));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Storage
,
int
THREADS_PER_CTA
,
int
THREADS_PER_PIXEL
,
int
PIXELS_PER_THREAD_IN_REGISTERS
,
int
PIXELS_PER_THREAD_IN_SMEM
,
int
ELEMENTS_PER_LDG
,
int
USE_ONLINE_APPROACH
,
int
OUTER_LOOPS_
,
int
DESIRED_OCCUPANCY
>
__global__
__launch_bounds__
(
THREADS_PER_CTA
,
DESIRED_OCCUPANCY
)
void
nhwc_batch_norm_bwd
(
NhwcBatchNormBwdParams
params
)
{
// The number of pixels loaded in a single LDG.
const
int
PIXELS_PER_LDG
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
;
// The number of pixels computed per CTA stored in registers.
const
int
PIXELS_PER_CTA_IN_REGISTERS
=
PIXELS_PER_THREAD_IN_REGISTERS
*
PIXELS_PER_LDG
;
// The number of pixels computed per CTA stored in SMEM.
const
int
PIXELS_PER_CTA_IN_SMEM
=
PIXELS_PER_THREAD_IN_SMEM
*
PIXELS_PER_LDG
;
// The number of C elements per CTA.
const
int
C_ELEMENTS_PER_CTA
=
THREADS_PER_PIXEL
*
ELEMENTS_PER_LDG
;
// Shared memory to do CTA-wide parallel sums.
__shared__
float
smem
[
THREADS_PER_PIXEL
*
(
THREADS_PER_CTA
/
warpSize
)
*
ELEMENTS_PER_LDG
];
// The adapter for the storage.
typedef
PackedStorage
<
Storage
,
ELEMENTS_PER_LDG
>
PackedStorage_
;
// The data type for packed storage in SMEM.
typedef
typename
PackedStorage_
::
Type
PackedStorageType
;
// The number of elements in the packed storage.
const
int
PACKED_ELEMENTS_PER_LDG
=
PackedStorage_
::
PACKED_ELEMENTS_PER_LDG
;
// Registers to keep the data live for the persistent approach.
PackedStorageType
x_storage
[
PIXELS_PER_THREAD_IN_REGISTERS
][
PACKED_ELEMENTS_PER_LDG
];
PackedStorageType
dy_storage
[
PIXELS_PER_THREAD_IN_REGISTERS
][
PACKED_ELEMENTS_PER_LDG
];
// Shared memory buffer to store the extra pixels.
extern
__shared__
PackedStorageType
smem_storage_packed
[];
for
(
int
c_blk_index
=
blockIdx
.
y
;
c_blk_index
<
params
.
c_blks
;
c_blk_index
+=
gridDim
.
y
)
{
// The position in the NHW dimension where the CTA starts.
int
cta_nhw_regs
=
blockIdx
.
x
*
PIXELS_PER_CTA_IN_REGISTERS
;
// The position in the NHW dimension where the CTA starts for the portion in SMEM.
int
cta_nhw_smem
=
blockIdx
.
x
*
PIXELS_PER_CTA_IN_SMEM
;
// Compute the NHW coordinate of the thread in the CTA.
const
int
thread_in_cta_nhw
=
threadIdx
.
x
/
THREADS_PER_PIXEL
;
// The position in the C dimension where the CTA starts.
const
int
cta_c
=
c_blk_index
*
C_ELEMENTS_PER_CTA
;
// Compute the C coordinate of the thread in the CTA.
const
int
thread_in_cta_c
=
threadIdx
.
x
%
THREADS_PER_PIXEL
;
// Compute the C coordinate of the thread.
const
int
thread_c
=
cta_c
+
thread_in_cta_c
*
ELEMENTS_PER_LDG
;
// Is the thread working on a valid C dimension?
const
int
is_valid_c
=
thread_c
<
params
.
c
;
// Registers to store the mean used for entire duration
float
mean
[
ELEMENTS_PER_LDG
];
zero_array
(
mean
);
if
(
is_valid_c
)
{
read_from_gmem
(
mean
,
params
.
gmem_saved_mean
,
thread_c
/
ELEMENTS_PER_LDG
);
}
// accumulation related registers
float
count
=
0.
f
,
dscale
[
ELEMENTS_PER_LDG
],
dbias
[
ELEMENTS_PER_LDG
];
zero_array
(
dscale
);
zero_array
(
dbias
);
// The number of elements loaded by this CTA.
int
cta_count
=
0
;
// The base pointers to load from.
const
uint16_t
*
gmem_src
=
&
params
.
gmem_src
[
thread_c
];
const
uint16_t
*
gmem_dy
=
&
params
.
gmem_dy
[
thread_c
];
// outer loops
int
OUTER_LOOPS
=
OUTER_LOOPS_
==
1
?
1
:
params
.
outer_loops
;
// Load the batch of elements. Compute sum across them
const
int
pixels_per_iteration
=
PIXELS_PER_CTA_IN_REGISTERS
*
gridDim
.
x
;
if
(
OUTER_LOOPS_
!=
1
)
{
// We cannot load everything to store persistently, so let's makes sure registers and
// smem are fully utilized
int
offset
=
params
.
nhw
-
pixels_per_iteration
*
OUTER_LOOPS
-
PIXELS_PER_CTA_IN_SMEM
*
gridDim
.
x
;
cta_nhw_regs
+=
offset
;
cta_nhw_smem
+=
offset
;
}
#pragma unroll 1
for
(
int
loop_i
=
0
;
loop_i
<
OUTER_LOOPS
;
++
loop_i
)
{
// The nhw position.
int
nhw_regs
=
cta_nhw_regs
+
loop_i
*
pixels_per_iteration
;
// Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!
cta_count
+=
max
(
0
,
min
(
PIXELS_PER_CTA_IN_REGISTERS
,
params
.
nhw
-
nhw_regs
));
// Read the elements from memory.
float
is_valid
[
PIXELS_PER_THREAD_IN_REGISTERS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
nhw_regs
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
zero_array
(
x_storage
[
i
]);
zero_array
(
dy_storage
[
i
]);
is_valid
[
i
]
=
0.
f
;
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
if
(
loop_i
==
OUTER_LOOPS
-
1
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage
[
i
],
&
gmem_dy
[
idx
*
params
.
c
]);
}
else
{
ldg
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg
(
dy_storage
[
i
],
&
gmem_dy
[
idx
*
params
.
c
]);
}
is_valid
[
i
]
=
1.
f
;
}
}
// Do the math.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
// Convert to float and update
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
to_float
(
dy_math
,
dy_storage
[
i
]);
// Update the count.
count
+=
is_valid
[
i
];
// Invert the count.
float
inv_count
=
is_valid
[
i
]
?
1.
f
/
count
:
0.
f
;
bwd_update
(
dscale
,
dbias
,
dy_math
,
x_math
,
mean
,
inv_count
);
}
}
// The elements to load and store in SMEM.
int
smem_nhw
=
OUTER_LOOPS
*
pixels_per_iteration
+
cta_nhw_smem
;
// Load elements from SMEM, update the CTA count.
int
pixels_in_smem
=
min
(
PIXELS_PER_CTA_IN_SMEM
,
params
.
nhw
-
smem_nhw
);
if
(
pixels_in_smem
>
0
)
{
cta_count
+=
pixels_in_smem
;
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_SMEM
;
++
i
)
{
const
int
idx
=
smem_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
bool
is_pixel_valid
=
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
);
PackedStorageType
x_storage_local
[
PACKED_ELEMENTS_PER_LDG
],
dy_storage_local
[
PACKED_ELEMENTS_PER_LDG
];
zero_array
(
x_storage_local
);
zero_array
(
dy_storage_local
);
if
(
is_pixel_valid
)
{
ldg_stream
(
x_storage_local
,
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage_local
,
&
gmem_dy
[
idx
*
params
.
c
]);
}
// The offset to store in SMEM.
int
offset
=
i
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
// Store in SMEM.
write_to_smem
(
&
smem_storage_packed
[
offset
],
threadIdx
.
x
,
x_storage_local
);
offset
+=
PIXELS_PER_THREAD_IN_SMEM
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
write_to_smem
(
&
smem_storage_packed
[
offset
],
threadIdx
.
x
,
dy_storage_local
);
// Update the count.
count
+=
is_pixel_valid
;
// Invert the count.
float
inv_count
=
is_pixel_valid
?
1.
f
/
count
:
0.
f
;
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage_local
);
to_float
(
dy_math
,
dy_storage_local
);
bwd_update
(
dscale
,
dbias
,
dy_math
,
x_math
,
mean
,
inv_count
);
}
}
// We scale the mean by the number of elements. It brings more stability.
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
dbias
[
i
]
*=
count
;
dscale
[
i
]
*=
count
;
}
// dscale parallel sum
#ifdef __HIP_PLATFORM_HCC__
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
template
dispatch
<
THREADS_PER_CTA
>(
#else
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
#endif
smem
,
dscale
,
thread_in_cta_nhw
);
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dscale
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// dbias parallel sum
#ifdef __HIP_PLATFORM_HCC__
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
template
dispatch
<
THREADS_PER_CTA
>(
#else
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
#endif
smem
,
dbias
,
thread_in_cta_nhw
);
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dbias
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// The workspace in global memory is distributed across the different CTA.
int
gmem_sums_offset
=
c_blk_index
*
gridDim
.
x
*
C_ELEMENTS_PER_CTA
*
2
;
// Write the data for the CTA to global memory.
float
*
gmem_sums
=
&
params
.
gmem_sums
[
gmem_sums_offset
];
if
(
threadIdx
.
x
<
THREADS_PER_PIXEL
)
{
const
int
idx
=
blockIdx
.
x
*
THREADS_PER_PIXEL
+
threadIdx
.
x
;
write_to_gmem
(
&
gmem_sums
[
0
],
idx
,
dscale
);
write_to_gmem
(
&
gmem_sums
[
C_ELEMENTS_PER_CTA
*
gridDim
.
x
],
idx
,
dbias
);
}
// The counters to count how many CTAs have retired at this point.
// A given cta uses the same counter every other time through the outer loop.
int
*
gmem_retired_ctas
=
&
params
.
gmem_retired_ctas
[
c_blk_index
%
(
2
*
gridDim
.
y
)];
inter_block_sync
(
gmem_retired_ctas
,
gridDim
.
x
,
blockIdx
.
x
==
0
);
// Reset the accumulators for global summation
zero_array
(
dscale
);
zero_array
(
dbias
);
// Build the global accumulation
#pragma unroll 1
for
(
int
idx
=
threadIdx
.
x
;
idx
<
THREADS_PER_PIXEL
*
gridDim
.
x
;
idx
+=
THREADS_PER_CTA
)
{
float
tmp1
[
ELEMENTS_PER_LDG
],
tmp2
[
ELEMENTS_PER_LDG
];
read_from_gmem
(
tmp1
,
gmem_sums
,
idx
);
read_from_gmem
(
tmp2
,
gmem_sums
+
C_ELEMENTS_PER_CTA
*
gridDim
.
x
,
idx
);
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
dscale
[
i
]
+=
tmp1
[
i
];
dbias
[
i
]
+=
tmp2
[
i
];
}
}
// dscale parallel sum
#ifndef __HIP_PLATFORM_HCC__
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_datas
,
4
*
c_blk_index
+
1
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
#endif
#ifdef __HIP_PLATFORM_HCC__
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
template
dispatch
<
THREADS_PER_CTA
>(
#else
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
#endif
smem
,
dscale
,
thread_in_cta_nhw
);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dscale
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// dbias parallel sum
#ifndef __HIP_PLATFORM_HCC__
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_datas
,
4
*
c_blk_index
+
0
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
#endif
#ifdef __HIP_PLATFORM_HCC__
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
template
dispatch
<
THREADS_PER_CTA
>(
#else
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
#endif
smem
,
dbias
,
thread_in_cta_nhw
);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dbias
,
smem
,
thread_in_cta_c
);
// inv-var
float
var
[
ELEMENTS_PER_LDG
];
zero_array
(
var
);
if
(
is_valid_c
)
{
read_from_gmem
(
var
,
params
.
gmem_saved_var
,
thread_c
/
ELEMENTS_PER_LDG
);
}
// Normalize the dscale.
multiply
(
dscale
,
var
);
// store dscale/dbias
bool
is_valid_for_saving
=
is_valid_c
&&
blockIdx
.
x
==
0
&&
thread_in_cta_nhw
==
0
;
if
(
is_valid_for_saving
)
{
if
(
params
.
sync_iters
>
0
)
{
scaled_write_to_gmem
(
params
.
gmem_dscale
,
thread_c
/
ELEMENTS_PER_LDG
,
dscale
,
params
.
wgrad_coeff
);
scaled_write_to_gmem
(
params
.
gmem_dbias
,
thread_c
/
ELEMENTS_PER_LDG
,
dbias
,
params
.
wgrad_coeff
);
}
else
{
write_to_gmem
(
params
.
gmem_dscale
,
thread_c
/
ELEMENTS_PER_LDG
,
dscale
);
write_to_gmem
(
params
.
gmem_dbias
,
thread_c
/
ELEMENTS_PER_LDG
,
dbias
);
}
}
// scale
float
scale
[
ELEMENTS_PER_LDG
];
zero_array
(
scale
);
if
(
is_valid_c
)
{
read_from_gmem
(
scale
,
params
.
gmem_scale
,
thread_c
/
ELEMENTS_PER_LDG
);
}
// Further normalize the dscale to be used in dx calculation
multiply
(
dscale
,
var
);
// scale the inv-var as well, afterwards
multiply
(
var
,
scale
);
// inverse count
float
inv_count
=
params
.
svar_inv_count
;
// The base pointer to write to.
uint16_t
*
const
gmem_dst
=
&
params
.
gmem_dst
[
thread_c
];
// Store the elements in registers.
#pragma unroll 1
for
(
int
loop_i
=
OUTER_LOOPS
-
1
;
loop_i
>=
0
;
--
loop_i
)
{
// The value for nhw.
int
out_nhw
=
cta_nhw_regs
+
loop_i
*
pixels_per_iteration
;
// Normalize the elements and write to memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
// Convert to float.
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
to_float
(
dy_math
,
dy_storage
[
i
]);
float
dx
[
ELEMENTS_PER_LDG
];
bwd_dx
(
dx
,
dy_math
,
var
,
x_math
,
mean
,
dscale
,
dbias
,
inv_count
);
// Write back.
const
int
idx
=
out_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
stg_stream
(
&
gmem_dst
[
idx
*
params
.
c
],
dx
);
}
}
// The next value of nhw.
out_nhw
-=
pixels_per_iteration
;
// Read the next elements from memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
out_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage
[
i
],
&
gmem_dy
[
idx
*
params
.
c
]);
}
}
}
// Normalize the elements from SMEM and write them out.
if
(
pixels_in_smem
>
0
)
{
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_SMEM
;
++
i
)
{
const
int
idx
=
smem_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
const
bool
is_valid
=
((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
;
if
(
is_valid
)
{
// Read from SMEM.
int
offset
=
i
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
PackedStorageType
x_storage_local
[
PACKED_ELEMENTS_PER_LDG
],
dy_storage_local
[
PACKED_ELEMENTS_PER_LDG
];
read_from_smem
(
x_storage_local
,
&
smem_storage_packed
[
offset
],
threadIdx
.
x
);
offset
+=
PIXELS_PER_THREAD_IN_SMEM
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
read_from_smem
(
dy_storage_local
,
&
smem_storage_packed
[
offset
],
threadIdx
.
x
);
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage_local
);
to_float
(
dy_math
,
dy_storage_local
);
float
dx
[
ELEMENTS_PER_LDG
];
bwd_dx
(
dx
,
dy_math
,
var
,
x_math
,
mean
,
dscale
,
dbias
,
inv_count
);
// Write back.
stg_stream
(
&
gmem_dst
[
idx
*
params
.
c
],
dx
);
}
}
}
// We're about to start on the next c-blk. Needed?
__syncthreads
();
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Storage
,
int
THREADS_PER_CTA
,
int
THREADS_PER_PIXEL
,
int
PIXELS_PER_THREAD_IN_REGISTERS
,
int
PIXELS_PER_THREAD_IN_SMEM
,
int
ELEMENTS_PER_LDG
,
int
USE_ONLINE_APPROACH
,
int
OUTER_LOOPS_
,
int
DESIRED_OCCUPANCY
>
__global__
__launch_bounds__
(
THREADS_PER_CTA
,
DESIRED_OCCUPANCY
)
void
nhwc_batch_norm_bwd_relu
(
NhwcBatchNormBwdParams
params
)
{
// The number of pixels loaded in a single LDG.
const
int
PIXELS_PER_LDG
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
;
// The number of pixels computed per CTA stored in registers.
const
int
PIXELS_PER_CTA_IN_REGISTERS
=
PIXELS_PER_THREAD_IN_REGISTERS
*
PIXELS_PER_LDG
;
// The number of pixels computed per CTA stored in SMEM.
const
int
PIXELS_PER_CTA_IN_SMEM
=
PIXELS_PER_THREAD_IN_SMEM
*
PIXELS_PER_LDG
;
// The number of C elements per CTA.
const
int
C_ELEMENTS_PER_CTA
=
THREADS_PER_PIXEL
*
ELEMENTS_PER_LDG
;
// Shared memory to do CTA-wide parallel sums.
__shared__
float
smem
[
THREADS_PER_PIXEL
*
(
THREADS_PER_CTA
/
warpSize
)
*
ELEMENTS_PER_LDG
];
// The adapter for the storage.
typedef
PackedStorage
<
Storage
,
ELEMENTS_PER_LDG
>
PackedStorage_
;
// The data type for packed storage in SMEM.
typedef
typename
PackedStorage_
::
Type
PackedStorageType
;
// The number of elements in the packed storage.
const
int
PACKED_ELEMENTS_PER_LDG
=
PackedStorage_
::
PACKED_ELEMENTS_PER_LDG
;
// Registers to keep the data live for the persistent approach.
PackedStorageType
x_storage
[
PIXELS_PER_THREAD_IN_REGISTERS
][
PACKED_ELEMENTS_PER_LDG
];
PackedStorageType
dy_storage
[
PIXELS_PER_THREAD_IN_REGISTERS
][
PACKED_ELEMENTS_PER_LDG
];
// Shared memory buffer to store the extra pixels.
extern
__shared__
PackedStorageType
smem_storage_packed
[];
for
(
int
c_blk_index
=
blockIdx
.
y
;
c_blk_index
<
params
.
c_blks
;
c_blk_index
+=
gridDim
.
y
)
{
// The position in the NHW dimension where the CTA starts.
int
cta_nhw_regs
=
blockIdx
.
x
*
PIXELS_PER_CTA_IN_REGISTERS
;
// The position in the NHW dimension where the CTA starts for the portion in SMEM.
int
cta_nhw_smem
=
blockIdx
.
x
*
PIXELS_PER_CTA_IN_SMEM
;
// Compute the NHW coordinate of the thread in the CTA.
const
int
thread_in_cta_nhw
=
threadIdx
.
x
/
THREADS_PER_PIXEL
;
// The position in the C dimension where the CTA starts.
const
int
cta_c
=
c_blk_index
*
C_ELEMENTS_PER_CTA
;
// Compute the C coordinate of the thread in the CTA.
const
int
thread_in_cta_c
=
threadIdx
.
x
%
THREADS_PER_PIXEL
;
// Compute the C coordinate of the thread.
const
int
thread_c
=
cta_c
+
thread_in_cta_c
*
ELEMENTS_PER_LDG
;
// Is the thread working on a valid C dimension?
const
int
is_valid_c
=
thread_c
<
params
.
c
;
// Registers to store the mean/var/scale/bias used for the entire duration
// Register usage optimizations:
// 1. Can combine bias - (mean * var * scale) into a single register
// 2. Can combine var * scale into a single register
float
varscale
[
ELEMENTS_PER_LDG
];
zero_array
(
varscale
);
if
(
is_valid_c
)
{
read_from_gmem
(
varscale
,
params
.
gmem_saved_var
,
thread_c
/
ELEMENTS_PER_LDG
);
}
float
tmp
[
ELEMENTS_PER_LDG
];
zero_array
(
tmp
);
if
(
is_valid_c
)
{
read_from_gmem
(
tmp
,
params
.
gmem_scale
,
thread_c
/
ELEMENTS_PER_LDG
);
}
multiply
(
varscale
,
tmp
);
float
mean
[
ELEMENTS_PER_LDG
];
zero_array
(
mean
);
if
(
is_valid_c
)
{
read_from_gmem
(
mean
,
params
.
gmem_saved_mean
,
thread_c
/
ELEMENTS_PER_LDG
);
}
zero_array
(
tmp
);
if
(
is_valid_c
)
{
read_from_gmem
(
tmp
,
params
.
gmem_bias
,
thread_c
/
ELEMENTS_PER_LDG
);
}
float
mean_var_scale_bias
[
ELEMENTS_PER_LDG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
mean_var_scale_bias
[
i
]
=
tmp
[
i
]
-
(
mean
[
i
]
*
varscale
[
i
]);
}
// accumulation related registers
float
count
=
0.
f
,
dscale
[
ELEMENTS_PER_LDG
],
dbias
[
ELEMENTS_PER_LDG
];
zero_array
(
dscale
);
zero_array
(
dbias
);
// The number of elements loaded by this CTA.
int
cta_count
=
0
;
// The base pointers to load from.
const
uint16_t
*
gmem_src
=
&
params
.
gmem_src
[
thread_c
];
const
uint16_t
*
gmem_dy
=
&
params
.
gmem_dy
[
thread_c
];
// outer loops
int
OUTER_LOOPS
=
OUTER_LOOPS_
==
1
?
1
:
params
.
outer_loops
;
// Load the batch of elements. Compute sum across them
const
int
pixels_per_iteration
=
PIXELS_PER_CTA_IN_REGISTERS
*
gridDim
.
x
;
if
(
OUTER_LOOPS_
!=
1
)
{
// We cannot load everything to store persistently, so let's makes sure registers and
// smem are fully utilized
int
offset
=
params
.
nhw
-
pixels_per_iteration
*
OUTER_LOOPS
-
PIXELS_PER_CTA_IN_SMEM
*
gridDim
.
x
;
cta_nhw_regs
+=
offset
;
cta_nhw_smem
+=
offset
;
}
#pragma unroll 1
for
(
int
loop_i
=
0
;
loop_i
<
OUTER_LOOPS
;
++
loop_i
)
{
// The nhw position.
int
nhw_regs
=
cta_nhw_regs
+
loop_i
*
pixels_per_iteration
;
// Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!
cta_count
+=
max
(
0
,
min
(
PIXELS_PER_CTA_IN_REGISTERS
,
params
.
nhw
-
nhw_regs
));
// Read the elements from memory.
float
is_valid
[
PIXELS_PER_THREAD_IN_REGISTERS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
nhw_regs
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
zero_array
(
x_storage
[
i
]);
zero_array
(
dy_storage
[
i
]);
is_valid
[
i
]
=
0.
f
;
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
if
(
loop_i
==
OUTER_LOOPS
-
1
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage
[
i
],
&
gmem_dy
[
idx
*
params
.
c
]);
}
else
{
ldg
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg
(
dy_storage
[
i
],
&
gmem_dy
[
idx
*
params
.
c
]);
}
is_valid
[
i
]
=
1.
f
;
}
}
// Do the math.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
// Convert to float and update
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
to_float
(
dy_math
,
dy_storage
[
i
]);
// Update the count.
count
+=
is_valid
[
i
];
// Invert the count.
float
inv_count
=
is_valid
[
i
]
?
1.
f
/
count
:
0.
f
;
relu_bwd
(
dy_math
,
x_math
,
mean_var_scale_bias
,
varscale
,
is_valid
[
i
]);
bwd_update
(
dscale
,
dbias
,
dy_math
,
x_math
,
mean
,
inv_count
);
}
}
// The elements to load and store in SMEM.
int
smem_nhw
=
OUTER_LOOPS
*
pixels_per_iteration
+
cta_nhw_smem
;
// Load elements from SMEM, update the CTA count.
int
pixels_in_smem
=
min
(
PIXELS_PER_CTA_IN_SMEM
,
params
.
nhw
-
smem_nhw
);
if
(
pixels_in_smem
>
0
)
{
cta_count
+=
pixels_in_smem
;
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_SMEM
;
++
i
)
{
const
int
idx
=
smem_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
bool
is_pixel_valid
=
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
);
PackedStorageType
x_storage_local
[
PACKED_ELEMENTS_PER_LDG
],
dy_storage_local
[
PACKED_ELEMENTS_PER_LDG
];
zero_array
(
x_storage_local
);
zero_array
(
dy_storage_local
);
if
(
is_pixel_valid
)
{
ldg_stream
(
x_storage_local
,
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage_local
,
&
gmem_dy
[
idx
*
params
.
c
]);
}
// The offset to store in SMEM.
int
offset
=
i
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
// Store in SMEM.
write_to_smem
(
&
smem_storage_packed
[
offset
],
threadIdx
.
x
,
x_storage_local
);
offset
+=
PIXELS_PER_THREAD_IN_SMEM
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
write_to_smem
(
&
smem_storage_packed
[
offset
],
threadIdx
.
x
,
dy_storage_local
);
// Update the count.
count
+=
is_pixel_valid
;
// Invert the count.
float
inv_count
=
is_pixel_valid
?
1.
f
/
count
:
0.
f
;
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage_local
);
to_float
(
dy_math
,
dy_storage_local
);
relu_bwd
(
dy_math
,
x_math
,
mean_var_scale_bias
,
varscale
,
is_pixel_valid
);
bwd_update
(
dscale
,
dbias
,
dy_math
,
x_math
,
mean
,
inv_count
);
}
}
// We scale the mean by the number of elements. It brings more stability.
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
dbias
[
i
]
*=
count
;
dscale
[
i
]
*=
count
;
}
// dscale parallel sum
#ifdef __HIP_PLATFORM_HCC__
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
template
dispatch
<
THREADS_PER_CTA
>(
#else
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
#endif
smem
,
dscale
,
thread_in_cta_nhw
);
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dscale
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// dbias parallel sum
#ifdef __HIP_PLATFORM_HCC__
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
template
dispatch
<
THREADS_PER_CTA
>(
#else
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
#endif
smem
,
dbias
,
thread_in_cta_nhw
);
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dbias
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// The workspace in global memory is distributed across the different CTA.
int
gmem_sums_offset
=
c_blk_index
*
gridDim
.
x
*
C_ELEMENTS_PER_CTA
*
2
;
// Write the data for the CTA to global memory.
float
*
gmem_sums
=
&
params
.
gmem_sums
[
gmem_sums_offset
];
if
(
threadIdx
.
x
<
THREADS_PER_PIXEL
)
{
const
int
idx
=
blockIdx
.
x
*
THREADS_PER_PIXEL
+
threadIdx
.
x
;
write_to_gmem
(
&
gmem_sums
[
0
],
idx
,
dscale
);
write_to_gmem
(
&
gmem_sums
[
C_ELEMENTS_PER_CTA
*
gridDim
.
x
],
idx
,
dbias
);
}
// The counters to count how many CTAs have retired at this point.
// A given cta uses the same counter every other time through the outer loop.
int
*
gmem_retired_ctas
=
&
params
.
gmem_retired_ctas
[
c_blk_index
%
(
2
*
gridDim
.
y
)];
inter_block_sync
(
gmem_retired_ctas
,
gridDim
.
x
,
blockIdx
.
x
==
0
);
// Reset the accumulators for global summation
zero_array
(
dscale
);
zero_array
(
dbias
);
// Build the global accumulation
#pragma unroll 1
for
(
int
idx
=
threadIdx
.
x
;
idx
<
THREADS_PER_PIXEL
*
gridDim
.
x
;
idx
+=
THREADS_PER_CTA
)
{
float
tmp1
[
ELEMENTS_PER_LDG
],
tmp2
[
ELEMENTS_PER_LDG
];
read_from_gmem
(
tmp1
,
gmem_sums
,
idx
);
read_from_gmem
(
tmp2
,
gmem_sums
+
C_ELEMENTS_PER_CTA
*
gridDim
.
x
,
idx
);
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
dscale
[
i
]
+=
tmp1
[
i
];
dbias
[
i
]
+=
tmp2
[
i
];
}
}
// dscale parallel sum
#ifndef __HIP_PLATFORM_HCC__
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_datas
,
4
*
c_blk_index
+
1
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
#endif
#ifdef __HIP_PLATFORM_HCC__
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
template
dispatch
<
THREADS_PER_CTA
>(
#else
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
#endif
smem
,
dscale
,
thread_in_cta_nhw
);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dscale
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// dbias parallel sum
#ifndef __HIP_PLATFORM_HCC__
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_datas
,
4
*
c_blk_index
+
0
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
#endif
#ifdef __HIP_PLATFORM_HCC__
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
template
dispatch
<
THREADS_PER_CTA
>(
#else
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
#endif
smem
,
dbias
,
thread_in_cta_nhw
);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dbias
,
smem
,
thread_in_cta_c
);
// Normalize the dscale.
float
var
[
ELEMENTS_PER_LDG
];
zero_array
(
var
);
if
(
is_valid_c
)
{
read_from_gmem
(
var
,
params
.
gmem_saved_var
,
thread_c
/
ELEMENTS_PER_LDG
);
}
multiply
(
dscale
,
var
);
// store dscale/dbias
bool
is_valid_for_saving
=
is_valid_c
&&
blockIdx
.
x
==
0
&&
thread_in_cta_nhw
==
0
;
if
(
is_valid_for_saving
)
{
if
(
params
.
sync_iters
>
0
)
{
scaled_write_to_gmem
(
params
.
gmem_dscale
,
thread_c
/
ELEMENTS_PER_LDG
,
dscale
,
params
.
wgrad_coeff
);
scaled_write_to_gmem
(
params
.
gmem_dbias
,
thread_c
/
ELEMENTS_PER_LDG
,
dbias
,
params
.
wgrad_coeff
);
}
else
{
write_to_gmem
(
params
.
gmem_dscale
,
thread_c
/
ELEMENTS_PER_LDG
,
dscale
);
write_to_gmem
(
params
.
gmem_dbias
,
thread_c
/
ELEMENTS_PER_LDG
,
dbias
);
}
}
// Further normalize the dscale to be used in dx calculation
float
scale
[
ELEMENTS_PER_LDG
];
zero_array
(
scale
);
if
(
is_valid_c
)
{
read_from_gmem
(
scale
,
params
.
gmem_scale
,
thread_c
/
ELEMENTS_PER_LDG
);
}
multiply
(
dscale
,
var
);
// scale the inv-var as well, afterwards
multiply
(
var
,
scale
);
// inverse count
float
inv_count
=
params
.
svar_inv_count
;
// The base pointer to write to.
uint16_t
*
const
gmem_dst
=
&
params
.
gmem_dst
[
thread_c
];
// Store the elements in registers.
#pragma unroll 1
for
(
int
loop_i
=
OUTER_LOOPS
-
1
;
loop_i
>=
0
;
--
loop_i
)
{
// The value for nhw.
int
out_nhw
=
cta_nhw_regs
+
loop_i
*
pixels_per_iteration
;
// Normalize the elements and write to memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
// Convert to float.
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
to_float
(
dy_math
,
dy_storage
[
i
]);
relu_bwd_for_dx
(
dy_math
,
x_math
,
mean_var_scale_bias
,
var
);
float
dx
[
ELEMENTS_PER_LDG
];
bwd_dx
(
dx
,
dy_math
,
var
,
x_math
,
mean
,
dscale
,
dbias
,
inv_count
);
// Write back.
const
int
idx
=
out_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
stg_stream
(
&
gmem_dst
[
idx
*
params
.
c
],
dx
);
}
}
// The next value of nhw.
out_nhw
-=
pixels_per_iteration
;
// Read the next elements from memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
out_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage
[
i
],
&
gmem_dy
[
idx
*
params
.
c
]);
}
}
}
// Normalize the elements from SMEM and write them out.
if
(
pixels_in_smem
>
0
)
{
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_SMEM
;
++
i
)
{
const
int
idx
=
smem_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
const
bool
is_valid
=
((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
;
if
(
is_valid
)
{
// Read from SMEM.
int
offset
=
i
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
PackedStorageType
x_storage_local
[
PACKED_ELEMENTS_PER_LDG
],
dy_storage_local
[
PACKED_ELEMENTS_PER_LDG
];
read_from_smem
(
x_storage_local
,
&
smem_storage_packed
[
offset
],
threadIdx
.
x
);
offset
+=
PIXELS_PER_THREAD_IN_SMEM
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
read_from_smem
(
dy_storage_local
,
&
smem_storage_packed
[
offset
],
threadIdx
.
x
);
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage_local
);
to_float
(
dy_math
,
dy_storage_local
);
relu_bwd_for_dx
(
dy_math
,
x_math
,
mean_var_scale_bias
,
var
);
float
dx
[
ELEMENTS_PER_LDG
];
bwd_dx
(
dx
,
dy_math
,
var
,
x_math
,
mean
,
dscale
,
dbias
,
inv_count
);
// Write back.
stg_stream
(
&
gmem_dst
[
idx
*
params
.
c
],
dx
);
}
}
}
// We're about to start on the next c-blk. Needed?
__syncthreads
();
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Storage
,
int
THREADS_PER_CTA
,
int
THREADS_PER_PIXEL
,
int
PIXELS_PER_THREAD_IN_REGISTERS
,
int
PIXELS_PER_THREAD_IN_SMEM
,
int
ELEMENTS_PER_LDG
,
int
USE_ONLINE_APPROACH
,
int
OUTER_LOOPS_
,
int
DESIRED_OCCUPANCY
>
__global__
__launch_bounds__
(
THREADS_PER_CTA
,
DESIRED_OCCUPANCY
)
void
nhwc_batch_norm_bwd_add_relu
(
NhwcBatchNormBwdParams
params
)
{
// The number of pixels loaded in a single LDG.
const
int
PIXELS_PER_LDG
=
THREADS_PER_CTA
/
THREADS_PER_PIXEL
;
// The number of pixels computed per CTA stored in registers.
const
int
PIXELS_PER_CTA_IN_REGISTERS
=
PIXELS_PER_THREAD_IN_REGISTERS
*
PIXELS_PER_LDG
;
// The number of pixels computed per CTA stored in SMEM.
const
int
PIXELS_PER_CTA_IN_SMEM
=
PIXELS_PER_THREAD_IN_SMEM
*
PIXELS_PER_LDG
;
// The number of C elements per CTA.
const
int
C_ELEMENTS_PER_CTA
=
THREADS_PER_PIXEL
*
ELEMENTS_PER_LDG
;
// Shared memory to do CTA-wide parallel sums.
__shared__
float
smem
[
THREADS_PER_PIXEL
*
(
THREADS_PER_CTA
/
warpSize
)
*
ELEMENTS_PER_LDG
];
// The adapter for the storage.
typedef
PackedStorage
<
Storage
,
ELEMENTS_PER_LDG
>
PackedStorage_
;
// The data type for packed storage in SMEM.
typedef
typename
PackedStorage_
::
Type
PackedStorageType
;
// The number of elements in the packed storage.
const
int
PACKED_ELEMENTS_PER_LDG
=
PackedStorage_
::
PACKED_ELEMENTS_PER_LDG
;
// Registers to keep the data live for the persistent approach.
PackedStorageType
x_storage
[
PIXELS_PER_THREAD_IN_REGISTERS
][
PACKED_ELEMENTS_PER_LDG
];
PackedStorageType
dy_storage
[
PIXELS_PER_THREAD_IN_REGISTERS
][
PACKED_ELEMENTS_PER_LDG
];
// Shared memory buffer to store the extra pixels.
extern
__shared__
PackedStorageType
smem_storage_packed
[];
for
(
int
c_blk_index
=
blockIdx
.
y
;
c_blk_index
<
params
.
c_blks
;
c_blk_index
+=
gridDim
.
y
)
{
// The position in the NHW dimension where the CTA starts.
int
cta_nhw_regs
=
blockIdx
.
x
*
PIXELS_PER_CTA_IN_REGISTERS
;
// The position in the NHW dimension where the CTA starts for the portion in SMEM.
int
cta_nhw_smem
=
blockIdx
.
x
*
PIXELS_PER_CTA_IN_SMEM
;
// Compute the NHW coordinate of the thread in the CTA.
const
int
thread_in_cta_nhw
=
threadIdx
.
x
/
THREADS_PER_PIXEL
;
// The position in the C dimension where the CTA starts.
const
int
cta_c
=
c_blk_index
*
C_ELEMENTS_PER_CTA
;
// Compute the C coordinate of the thread in the CTA.
const
int
thread_in_cta_c
=
threadIdx
.
x
%
THREADS_PER_PIXEL
;
// Compute the C coordinate of the thread.
const
int
thread_c
=
cta_c
+
thread_in_cta_c
*
ELEMENTS_PER_LDG
;
// Is the thread working on a valid C dimension?
const
int
is_valid_c
=
thread_c
<
params
.
c
;
float
mean
[
ELEMENTS_PER_LDG
];
zero_array
(
mean
);
if
(
is_valid_c
)
{
read_from_gmem
(
mean
,
params
.
gmem_saved_mean
,
thread_c
/
ELEMENTS_PER_LDG
);
}
// accumulation related registers
float
count
=
0.
f
,
dscale
[
ELEMENTS_PER_LDG
],
dbias
[
ELEMENTS_PER_LDG
];
zero_array
(
dscale
);
zero_array
(
dbias
);
// The number of elements loaded by this CTA.
int
cta_count
=
0
;
// The base pointers to load from.
const
uint16_t
*
gmem_src
=
&
params
.
gmem_src
[
thread_c
];
const
uint16_t
*
gmem_dy
=
&
params
.
gmem_dy
[
thread_c
];
uint16_t
*
gmem_dst1
=
&
params
.
gmem_dst1
[
thread_c
];
// outer loops
int
OUTER_LOOPS
=
OUTER_LOOPS_
==
1
?
1
:
params
.
outer_loops
;
// Load the batch of elements. Compute sum across them
const
int
pixels_per_iteration
=
PIXELS_PER_CTA_IN_REGISTERS
*
gridDim
.
x
;
if
(
OUTER_LOOPS_
!=
1
)
{
// We cannot load everything to store persistently, so let's makes sure registers and
// smem are fully utilized, offset is evenly divisible by 32
int
offset
=
(
pixels_per_iteration
*
OUTER_LOOPS
+
PIXELS_PER_CTA_IN_SMEM
*
gridDim
.
x
-
params
.
nhw
)
&
~
31
;
cta_nhw_regs
-=
offset
;
cta_nhw_smem
-=
offset
;
}
const
bitmask_t
*
const
gmem_relu_bitmask
=
params
.
gmem_relu_bitmask
+
#ifdef __HIP_PLATFORM_HCC__
((
params
.
nhw
+
3
)
&
~
3
)
*
2
*
c_blk_index
;
#else
((
params
.
nhw
+
31
)
&
~
31
)
*
2
*
c_blk_index
;
#endif
#pragma unroll 1
for
(
int
loop_i
=
0
;
loop_i
<
OUTER_LOOPS
;
++
loop_i
)
{
// The nhw position.
int
nhw_regs
=
cta_nhw_regs
+
loop_i
*
pixels_per_iteration
;
// Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!
cta_count
+=
max
(
0
,
min
(
PIXELS_PER_CTA_IN_REGISTERS
,
params
.
nhw
-
nhw_regs
));
#ifdef __HIP_PLATFORM_HCC__
int
lane_id
=
threadIdx
.
x
&
63
;
#else
int
lane_id
=
threadIdx
.
x
&
31
;
#endif
// Read the elements from memory.
float
is_valid
[
PIXELS_PER_THREAD_IN_REGISTERS
];
bitmask_t
relu_mask
[
PIXELS_PER_THREAD_IN_REGISTERS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
nhw_regs
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
zero_array
(
x_storage
[
i
]);
zero_array
(
dy_storage
[
i
]);
is_valid
[
i
]
=
0.
f
;
const
bool
is_valid_nhw
=
static_cast
<
unsigned
int
>
(
idx
)
<
static_cast
<
unsigned
int
>
(
params
.
nhw
);
if
(
is_valid_nhw
)
{
if
(
is_valid_c
)
{
if
(
loop_i
==
OUTER_LOOPS
-
1
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage
[
i
],
&
gmem_dy
[
idx
*
params
.
c
]);
}
else
{
ldg
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg
(
dy_storage
[
i
],
&
gmem_dy
[
idx
*
params
.
c
]);
}
is_valid
[
i
]
=
1.
f
;
}
if
(
lane_id
<
ELEMENTS_PER_LDG
)
{
relu_mask
[
i
]
=
gmem_relu_bitmask
[
idx
*
BITMASK_OFFSET
+
lane_id
];
}
}
}
// Do the math.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
nhw_regs
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
// Convert to float and update
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
bool
rectified
[
ELEMENTS_PER_LDG
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ELEMENTS_PER_LDG
;
++
j
)
{
rectified
[
j
]
=
((
shfl_sync
(
relu_mask
[
i
],
j
)
&
(
ONE_BITMASK
<<
lane_id
))
!=
0
);
}
to_float
(
x_math
,
x_storage
[
i
]);
to_float
(
dy_math
,
dy_storage
[
i
]);
// Update the count.
count
+=
is_valid
[
i
];
// Invert the count.
float
inv_count
=
is_valid
[
i
]
?
1.
f
/
count
:
0.
f
;
relu_bwd
(
dy_math
,
rectified
,
is_valid
[
i
]);
bwd_update
(
dscale
,
dbias
,
dy_math
,
x_math
,
mean
,
inv_count
);
// Lastly we need 'dy' only for BN, so store the 'relu-dgrad'ed version
from_float
(
dy_storage
[
i
],
dy_math
);
// dZ for elementwise add
if
(
is_valid
[
i
])
{
if
(
loop_i
==
OUTER_LOOPS
-
1
)
{
stg_stream
(
&
gmem_dst1
[
idx
*
params
.
c
],
dy_storage
[
i
]);
}
else
{
stg
(
&
gmem_dst1
[
idx
*
params
.
c
],
dy_storage
[
i
]);
}
}
}
}
// The elements to load and store in SMEM.
int
smem_nhw
=
OUTER_LOOPS
*
pixels_per_iteration
+
cta_nhw_smem
;
// Load elements from SMEM, update the CTA count.
int
pixels_in_smem
=
min
(
PIXELS_PER_CTA_IN_SMEM
,
params
.
nhw
-
smem_nhw
);
if
(
pixels_in_smem
>
0
)
{
cta_count
+=
pixels_in_smem
;
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_SMEM
;
++
i
)
{
const
int
idx
=
smem_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
const
bool
is_pixel_valid_nhw
=
static_cast
<
unsigned
int
>
(
idx
)
<
static_cast
<
unsigned
int
>
(
params
.
nhw
);
const
bool
is_pixel_valid
=
is_pixel_valid_nhw
&&
is_valid_c
;
PackedStorageType
x_storage_local
[
PACKED_ELEMENTS_PER_LDG
],
dy_storage_local
[
PACKED_ELEMENTS_PER_LDG
];
bitmask_t
relu_mask
;
#ifdef __HIP_PLATFORM_HCC__
int
lane_id
=
threadIdx
.
x
&
63
;
#else
int
lane_id
=
threadIdx
.
x
&
31
;
#endif
zero_array
(
x_storage_local
);
zero_array
(
dy_storage_local
);
if
(
is_pixel_valid_nhw
)
{
if
(
is_valid_c
)
{
ldg_stream
(
x_storage_local
,
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage_local
,
&
gmem_dy
[
idx
*
params
.
c
]);
}
if
(
lane_id
<
ELEMENTS_PER_LDG
)
{
relu_mask
=
gmem_relu_bitmask
[
idx
*
BITMASK_OFFSET
+
lane_id
];
}
}
bool
rectified
[
ELEMENTS_PER_LDG
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ELEMENTS_PER_LDG
;
++
j
)
{
rectified
[
j
]
=
((
shfl_sync
(
relu_mask
,
j
)
&
(
ONE_BITMASK
<<
lane_id
))
!=
0
);
}
// The offset to store in SMEM.
int
offset
=
i
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
// Store in SMEM.
write_to_smem
(
&
smem_storage_packed
[
offset
],
threadIdx
.
x
,
x_storage_local
);
offset
+=
PIXELS_PER_THREAD_IN_SMEM
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
// Update the count.
count
+=
is_pixel_valid
;
// Invert the count.
float
inv_count
=
is_pixel_valid
?
1.
f
/
count
:
0.
f
;
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage_local
);
to_float
(
dy_math
,
dy_storage_local
);
relu_bwd
(
dy_math
,
rectified
,
is_pixel_valid
);
bwd_update
(
dscale
,
dbias
,
dy_math
,
x_math
,
mean
,
inv_count
);
from_float
(
dy_storage_local
,
dy_math
);
// dZ for elementwise add
if
(
is_pixel_valid
)
{
stg_stream
(
&
gmem_dst1
[
idx
*
params
.
c
],
dy_storage_local
);
}
// only store the 'relu-dgrad'ed version!
write_to_smem
(
&
smem_storage_packed
[
offset
],
threadIdx
.
x
,
dy_storage_local
);
}
}
// We scale the mean by the number of elements. It brings more stability.
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
dbias
[
i
]
*=
count
;
dscale
[
i
]
*=
count
;
}
// dscale parallel sum
#ifdef __HIP_PLATFORM_HCC__
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
template
dispatch
<
THREADS_PER_CTA
>(
#else
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
#endif
smem
,
dscale
,
thread_in_cta_nhw
);
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dscale
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// dbias parallel sum
#ifdef __HIP_PLATFORM_HCC__
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
template
dispatch
<
THREADS_PER_CTA
>(
#else
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
#endif
smem
,
dbias
,
thread_in_cta_nhw
);
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dbias
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// The workspace in global memory is distributed across the different CTA.
int
gmem_sums_offset
=
c_blk_index
*
gridDim
.
x
*
C_ELEMENTS_PER_CTA
*
2
;
// Write the data for the CTA to global memory.
float
*
gmem_sums
=
&
params
.
gmem_sums
[
gmem_sums_offset
];
if
(
threadIdx
.
x
<
THREADS_PER_PIXEL
)
{
const
int
idx
=
blockIdx
.
x
*
THREADS_PER_PIXEL
+
threadIdx
.
x
;
write_to_gmem
(
&
gmem_sums
[
0
],
idx
,
dscale
);
write_to_gmem
(
&
gmem_sums
[
C_ELEMENTS_PER_CTA
*
gridDim
.
x
],
idx
,
dbias
);
}
// The counters to count how many CTAs have retired at this point.
// A given cta uses the same counter every other time through the outer loop.
int
*
gmem_retired_ctas
=
&
params
.
gmem_retired_ctas
[
c_blk_index
%
(
2
*
gridDim
.
y
)];
inter_block_sync
(
gmem_retired_ctas
,
gridDim
.
x
,
blockIdx
.
x
==
0
);
// Reset the accumulators for global summation
zero_array
(
dscale
);
zero_array
(
dbias
);
// Build the global accumulation
#pragma unroll 1
for
(
int
idx
=
threadIdx
.
x
;
idx
<
THREADS_PER_PIXEL
*
gridDim
.
x
;
idx
+=
THREADS_PER_CTA
)
{
float
tmp1
[
ELEMENTS_PER_LDG
],
tmp2
[
ELEMENTS_PER_LDG
];
read_from_gmem
(
tmp1
,
gmem_sums
,
idx
);
read_from_gmem
(
tmp2
,
gmem_sums
+
C_ELEMENTS_PER_CTA
*
gridDim
.
x
,
idx
);
#pragma unroll
for
(
int
i
=
0
;
i
<
ELEMENTS_PER_LDG
;
++
i
)
{
dscale
[
i
]
+=
tmp1
[
i
];
dbias
[
i
]
+=
tmp2
[
i
];
}
}
// dscale parallel sum
#ifndef __HIP_PLATFORM_HCC__
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
dscale
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_datas
,
4
*
c_blk_index
+
1
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
#endif
#ifdef __HIP_PLATFORM_HCC__
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
template
dispatch
<
THREADS_PER_CTA
>(
#else
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
#endif
smem
,
dscale
,
thread_in_cta_nhw
);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dscale
,
smem
,
thread_in_cta_c
);
__syncthreads
();
// dbias parallel sum
#ifndef __HIP_PLATFORM_HCC__
if
(
params
.
sync_iters
>
0
)
{
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatchX
<
THREADS_PER_CTA
>
(
smem
,
dbias
,
thread_in_cta_nhw
,
params
.
my_data
,
params
.
pair_datas
,
4
*
c_blk_index
+
0
,
params
.
magic
,
params
.
sync_iters
);
}
else
{
#endif
#ifdef __HIP_PLATFORM_HCC__
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
template
dispatch
<
THREADS_PER_CTA
>(
#else
ParallelSums
<
THREADS_PER_PIXEL
,
ELEMENTS_PER_LDG
>::
dispatch
<
THREADS_PER_CTA
>
(
#endif
smem
,
dbias
,
thread_in_cta_nhw
);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
__syncthreads
();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem
(
dbias
,
smem
,
thread_in_cta_c
);
// Normalize the dscale.
float
var
[
ELEMENTS_PER_LDG
];
zero_array
(
var
);
if
(
is_valid_c
)
{
read_from_gmem
(
var
,
params
.
gmem_saved_var
,
thread_c
/
ELEMENTS_PER_LDG
);
}
multiply
(
dscale
,
var
);
// store dscale/dbias
bool
is_valid_for_saving
=
is_valid_c
&&
blockIdx
.
x
==
0
&&
thread_in_cta_nhw
==
0
;
if
(
is_valid_for_saving
)
{
if
(
params
.
sync_iters
>
0
)
{
scaled_write_to_gmem
(
params
.
gmem_dscale
,
thread_c
/
ELEMENTS_PER_LDG
,
dscale
,
params
.
wgrad_coeff
);
scaled_write_to_gmem
(
params
.
gmem_dbias
,
thread_c
/
ELEMENTS_PER_LDG
,
dbias
,
params
.
wgrad_coeff
);
}
else
{
write_to_gmem
(
params
.
gmem_dscale
,
thread_c
/
ELEMENTS_PER_LDG
,
dscale
);
write_to_gmem
(
params
.
gmem_dbias
,
thread_c
/
ELEMENTS_PER_LDG
,
dbias
);
}
}
// Further normalize the dscale to be used in dx calculation
float
scale
[
ELEMENTS_PER_LDG
];
zero_array
(
scale
);
if
(
is_valid_c
)
{
read_from_gmem
(
scale
,
params
.
gmem_scale
,
thread_c
/
ELEMENTS_PER_LDG
);
}
multiply
(
dscale
,
var
);
// scale the inv-var as well, afterwards
multiply
(
var
,
scale
);
// inverse count
float
inv_count
=
params
.
svar_inv_count
;
// The base pointer to write to.
uint16_t
*
const
gmem_dst
=
&
params
.
gmem_dst
[
thread_c
];
// Store the elements in registers.
#pragma unroll 1
for
(
int
loop_i
=
OUTER_LOOPS
-
1
;
loop_i
>=
0
;
--
loop_i
)
{
// The value for nhw.
int
out_nhw
=
cta_nhw_regs
+
loop_i
*
pixels_per_iteration
;
// Normalize the elements and write to memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
out_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
const
bool
is_valid
=
((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
;
// Convert to float.
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage
[
i
]);
to_float
(
dy_math
,
dy_storage
[
i
]);
float
dx
[
ELEMENTS_PER_LDG
];
bwd_dx
(
dx
,
dy_math
,
var
,
x_math
,
mean
,
dscale
,
dbias
,
inv_count
);
// Write back.
if
(
is_valid
)
{
stg_stream
(
&
gmem_dst
[
idx
*
params
.
c
],
dx
);
}
}
// The next value of nhw.
out_nhw
-=
pixels_per_iteration
;
// Read the next elements from memory.
#pragma unroll
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_REGISTERS
;
++
i
)
{
const
int
idx
=
out_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
float
y
[
ELEMENTS_PER_LDG
];
zero_array
(
y
);
if
(((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
)
{
ldg_stream
(
x_storage
[
i
],
&
gmem_src
[
idx
*
params
.
c
]);
ldg_stream
(
dy_storage
[
i
],
&
gmem_dst1
[
idx
*
params
.
c
]);
}
}
}
// Normalize the elements from SMEM and write them out.
if
(
pixels_in_smem
>
0
)
{
for
(
int
i
=
0
;
i
<
PIXELS_PER_THREAD_IN_SMEM
;
++
i
)
{
const
int
idx
=
smem_nhw
+
thread_in_cta_nhw
+
i
*
PIXELS_PER_LDG
;
const
bool
is_valid
=
((
unsigned
int
)
idx
<
(
unsigned
int
)
params
.
nhw
)
&&
is_valid_c
;
if
(
is_valid
)
{
// Read from SMEM.
int
offset
=
i
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
PackedStorageType
x_storage_local
[
PACKED_ELEMENTS_PER_LDG
],
dy_storage_local
[
PACKED_ELEMENTS_PER_LDG
];
read_from_smem
(
x_storage_local
,
&
smem_storage_packed
[
offset
],
threadIdx
.
x
);
offset
+=
PIXELS_PER_THREAD_IN_SMEM
*
THREADS_PER_CTA
*
PACKED_ELEMENTS_PER_LDG
;
read_from_smem
(
dy_storage_local
,
&
smem_storage_packed
[
offset
],
threadIdx
.
x
);
float
x_math
[
ELEMENTS_PER_LDG
],
dy_math
[
ELEMENTS_PER_LDG
];
to_float
(
x_math
,
x_storage_local
);
to_float
(
dy_math
,
dy_storage_local
);
float
dx
[
ELEMENTS_PER_LDG
];
bwd_dx
(
dx
,
dy_math
,
var
,
x_math
,
mean
,
dscale
,
dbias
,
inv_count
);
// Write back.
stg_stream
(
&
gmem_dst
[
idx
*
params
.
c
],
dx
);
}
}
}
// We're about to start on the next c-blk. Needed?
__syncthreads
();
}
}
#endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_
apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp
deleted
100644 → 0
View file @
2a4864d5
#include <torch/torch.h>
#include <vector>
#include <cstdint>
void
index_mul_2d_float_foward_cuda
(
at
::
Tensor
&
out
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
);
void
index_mul_2d_float_backward_cuda
(
at
::
Tensor
&
grad_in1
,
at
::
Tensor
&
grad_in2
,
const
at
::
Tensor
&
grad_out
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
);
void
index_mul_2d_float_backward_backward_cuda
(
at
::
Tensor
&
grad_grad_out
,
at
::
Tensor
&
grad_in1
,
at
::
Tensor
&
grad_in2
,
const
at
::
Tensor
&
grad_out
,
const
at
::
Tensor
&
grad_grad_in1
,
const
at
::
Tensor
&
grad_grad_in2
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
);
void
index_mul_2d_half_foward_cuda
(
at
::
Tensor
&
out
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
);
void
index_mul_2d_half_backward_cuda
(
at
::
Tensor
&
grad_in1
,
at
::
Tensor
&
grad_in2
,
const
at
::
Tensor
&
grad_out
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
);
void
index_mul_2d_half_backward_backward_cuda
(
at
::
Tensor
&
grad_grad_out
,
at
::
Tensor
&
grad_in1
,
at
::
Tensor
&
grad_in2
,
const
at
::
Tensor
&
grad_out
,
const
at
::
Tensor
&
grad_grad_in1
,
const
at
::
Tensor
&
grad_grad_in2
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
);
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
void
index_mul_2d_float_forward
(
at
::
Tensor
&
out
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
)
{
return
index_mul_2d_float_foward_cuda
(
out
,
in1
,
in2
,
idx1
);
}
void
index_mul_2d_float_backward
(
at
::
Tensor
&
grad_in1
,
at
::
Tensor
&
grad_in2
,
const
at
::
Tensor
&
grad_out
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
)
{
return
index_mul_2d_float_backward_cuda
(
grad_in1
,
grad_in2
,
grad_out
,
in1
,
in2
,
idx1
);
}
void
index_mul_2d_float_backwrad_backward
(
at
::
Tensor
&
grad_grad_out
,
at
::
Tensor
&
grad_in1
,
at
::
Tensor
&
grad_in2
,
const
at
::
Tensor
&
grad_out
,
const
at
::
Tensor
&
grad_grad_in1
,
const
at
::
Tensor
&
grad_grad_in2
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
)
{
return
index_mul_2d_float_backward_backward_cuda
(
grad_grad_out
,
grad_in1
,
grad_in2
,
grad_out
,
grad_grad_in1
,
grad_grad_in2
,
in1
,
in2
,
idx1
);
}
void
index_mul_2d_half_forward
(
at
::
Tensor
&
out
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
)
{
return
index_mul_2d_half_foward_cuda
(
out
,
in1
,
in2
,
idx1
);
}
void
index_mul_2d_half_backward
(
at
::
Tensor
&
grad_in1
,
at
::
Tensor
&
grad_in2
,
const
at
::
Tensor
&
grad_out
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
)
{
return
index_mul_2d_half_backward_cuda
(
grad_in1
,
grad_in2
,
grad_out
,
in1
,
in2
,
idx1
);
}
void
index_mul_2d_half_backwrad_backward
(
at
::
Tensor
&
grad_grad_out
,
at
::
Tensor
&
grad_in1
,
at
::
Tensor
&
grad_in2
,
const
at
::
Tensor
&
grad_out
,
const
at
::
Tensor
&
grad_grad_in1
,
const
at
::
Tensor
&
grad_grad_in2
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
)
{
return
index_mul_2d_half_backward_backward_cuda
(
grad_grad_out
,
grad_in1
,
grad_in2
,
grad_out
,
grad_grad_in1
,
grad_grad_in2
,
in1
,
in2
,
idx1
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"float_forward"
,
&
index_mul_2d_float_forward
,
"index mul float calculation forward (CUDA)"
);
m
.
def
(
"float_backward"
,
&
index_mul_2d_float_backward
,
"index mul float calculation backward (CUDA)"
);
m
.
def
(
"float_backward_backward"
,
&
index_mul_2d_float_backwrad_backward
,
"index mul float calculation backward backward (CUDA)"
);
m
.
def
(
"half_forward"
,
&
index_mul_2d_half_forward
,
"index mul half calculation forward (CUDA)"
);
m
.
def
(
"half_backward"
,
&
index_mul_2d_half_backward
,
"index mul half calculation backward (CUDA)"
);
m
.
def
(
"half_backward_backward"
,
&
index_mul_2d_half_backwrad_backward
,
"index mul half calculation backward backward (CUDA)"
);
}
apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu
deleted
100644 → 0
View file @
2a4864d5
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#ifdef ATEN_ATOMIC_HEADER
#include <ATen/cuda/Atomic.cuh>
#else
#include <THC/THCAtomics.cuh>
#endif
__global__
void
index_mul_2d_float_dim64
(
float
*
out
,
const
float
*
in1
,
const
float
*
in2
,
const
int64_t
*
idx1
,
const
int64_t
size
)
{
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
bidx
=
blockIdx
.
x
;
const
int
start_idx
=
bidx
*
blockDim
.
y
+
tidy
;
constexpr
int
fea_dim
=
64
;
if
(
start_idx
<
size
)
{
int64_t
vec_idx1
=
(
idx1
[
start_idx
]
*
fea_dim
)
/
4
+
tidx
;
int64_t
vec_idx2
=
(
start_idx
*
fea_dim
)
/
4
+
tidx
;
float4
res
,
src1
,
src2
;
src1
=
reinterpret_cast
<
const
float4
*>
(
in1
)[
vec_idx1
];
src2
=
reinterpret_cast
<
const
float4
*>
(
in2
)[
vec_idx2
];
res
.
x
=
src1
.
x
*
src2
.
x
;
res
.
y
=
src1
.
y
*
src2
.
y
;
res
.
z
=
src1
.
z
*
src2
.
z
;
res
.
w
=
src1
.
w
*
src2
.
w
;
reinterpret_cast
<
float4
*>
(
out
)[
vec_idx2
]
=
res
;
}
}
__global__
void
index_mul_2d_float
(
float
*
out
,
const
float
*
in1
,
const
float
*
in2
,
const
int64_t
*
idx1
,
const
int64_t
size
,
const
int64_t
fea_dim
)
{
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
bidx
=
blockIdx
.
x
;
const
int
start_idx
=
bidx
*
blockDim
.
y
+
tidy
;
const
int
stride
=
blockDim
.
x
;
if
(
start_idx
<
size
)
{
int64_t
vec_idx1
=
(
idx1
[
start_idx
]
*
fea_dim
);
int64_t
vec_idx2
=
(
start_idx
*
fea_dim
);
for
(
int
i
=
tidx
;
i
<
fea_dim
;
i
+=
stride
)
{
out
[
vec_idx2
+
i
]
=
in1
[
vec_idx1
+
i
]
*
in2
[
vec_idx2
+
i
];
}
}
}
__global__
void
index_mul_2d_half
(
at
::
Half
*
out
,
const
at
::
Half
*
in1
,
const
at
::
Half
*
in2
,
const
int64_t
*
idx1
,
const
int64_t
size
,
const
int64_t
fea_dim
)
{
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
bidx
=
blockIdx
.
x
;
const
int
start_idx
=
bidx
*
blockDim
.
y
+
tidy
;
const
int
stride
=
blockDim
.
x
;
if
(
start_idx
<
size
)
{
int64_t
vec_idx1
=
(
idx1
[
start_idx
]
*
fea_dim
);
int64_t
vec_idx2
=
(
start_idx
*
fea_dim
);
for
(
int
i
=
tidx
;
i
<
fea_dim
;
i
+=
stride
)
{
out
[
vec_idx2
+
i
]
=
at
::
Half
(
static_cast
<
float
>
(
in1
[
vec_idx1
+
i
])
*
static_cast
<
float
>
(
in2
[
vec_idx2
+
i
]));
}
}
}
__global__
void
index_mul_2d_grad_float_dim64
(
float
*
grad_in1
,
float
*
grad_in2
,
const
float
*
grad_out
,
const
float
*
in1
,
const
float
*
in2
,
const
int64_t
*
idx1
,
const
int64_t
size
)
{
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
bidx
=
blockIdx
.
x
;
const
int
start_idx
=
bidx
*
blockDim
.
y
+
tidy
;
constexpr
int
fea_dim
=
64
;
if
(
start_idx
<
size
)
{
int64_t
vec_idx1
=
(
idx1
[
start_idx
]
*
fea_dim
)
/
4
+
tidx
;
int64_t
vec_idx2
=
(
start_idx
*
fea_dim
)
/
4
+
tidx
;
float4
src_in1
,
src_in2
,
src_grad_out
,
dst_grad_in2
;
src_grad_out
=
reinterpret_cast
<
const
float4
*>
(
grad_out
)[
vec_idx2
];
src_in1
=
reinterpret_cast
<
const
float4
*>
(
in1
)[
vec_idx1
];
src_in2
=
reinterpret_cast
<
const
float4
*>
(
in2
)[
vec_idx2
];
int64_t
grad_in1_base_idx
=
idx1
[
start_idx
]
*
fea_dim
+
tidx
*
4
;
gpuAtomicAdd
(
grad_in1
+
grad_in1_base_idx
+
0
,
src_grad_out
.
x
*
src_in2
.
x
);
gpuAtomicAdd
(
grad_in1
+
grad_in1_base_idx
+
1
,
src_grad_out
.
y
*
src_in2
.
y
);
gpuAtomicAdd
(
grad_in1
+
grad_in1_base_idx
+
2
,
src_grad_out
.
z
*
src_in2
.
z
);
gpuAtomicAdd
(
grad_in1
+
grad_in1_base_idx
+
3
,
src_grad_out
.
w
*
src_in2
.
w
);
dst_grad_in2
.
x
=
src_grad_out
.
x
*
src_in1
.
x
;
dst_grad_in2
.
y
=
src_grad_out
.
y
*
src_in1
.
y
;
dst_grad_in2
.
z
=
src_grad_out
.
z
*
src_in1
.
z
;
dst_grad_in2
.
w
=
src_grad_out
.
w
*
src_in1
.
w
;
reinterpret_cast
<
float4
*>
(
grad_in2
)[
vec_idx2
]
=
dst_grad_in2
;
}
}
__global__
void
index_mul_2d_grad_float
(
float
*
grad_in1
,
float
*
grad_in2
,
const
float
*
grad_out
,
const
float
*
in1
,
const
float
*
in2
,
const
int64_t
*
idx1
,
const
int64_t
size
,
const
int64_t
fea_dim
)
{
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
bidx
=
blockIdx
.
x
;
const
int
start_idx
=
bidx
*
blockDim
.
y
+
tidy
;
const
int
stride
=
blockDim
.
x
;
if
(
start_idx
<
size
)
{
int64_t
vec_idx1
=
idx1
[
start_idx
]
*
fea_dim
;
int64_t
vec_idx2
=
start_idx
*
fea_dim
;
for
(
int
i
=
tidx
;
i
<
fea_dim
;
i
+=
stride
)
{
float
src_in1
=
in1
[
vec_idx1
+
i
];
float
src_in2
=
in2
[
vec_idx2
+
i
];
float
src_grad_out
=
grad_out
[
vec_idx2
+
i
];
grad_in2
[
vec_idx2
+
i
]
=
src_grad_out
*
src_in1
;
gpuAtomicAdd
(
grad_in1
+
vec_idx1
+
i
,
src_grad_out
*
src_in2
);
}
}
}
__global__
void
index_mul_2d_grad_half
(
at
::
Half
*
grad_in1
,
at
::
Half
*
grad_in2
,
const
at
::
Half
*
grad_out
,
const
at
::
Half
*
in1
,
const
at
::
Half
*
in2
,
const
int64_t
*
idx1
,
const
int64_t
size
,
const
int64_t
fea_dim
)
{
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
bidx
=
blockIdx
.
x
;
const
int
start_idx
=
bidx
*
blockDim
.
y
+
tidy
;
const
int
stride
=
blockDim
.
x
;
if
(
start_idx
<
size
)
{
int64_t
vec_idx1
=
idx1
[
start_idx
]
*
fea_dim
;
int64_t
vec_idx2
=
start_idx
*
fea_dim
;
for
(
int
i
=
tidx
;
i
<
fea_dim
;
i
+=
stride
)
{
float
src_in1
=
static_cast
<
float
>
(
in1
[
vec_idx1
+
i
]);
float
src_in2
=
static_cast
<
float
>
(
in2
[
vec_idx2
+
i
]);
float
src_grad_out
=
static_cast
<
float
>
(
grad_out
[
vec_idx2
+
i
]);
grad_in2
[
vec_idx2
+
i
]
=
at
::
Half
(
src_grad_out
*
src_in1
);
gpuAtomicAdd
(
grad_in1
+
vec_idx1
+
i
,
at
::
Half
(
src_grad_out
*
src_in2
));
}
}
}
__global__
void
index_mul_2d_grad_grad_float_dim64
(
float
*
grad_grad_out
,
float
*
grad_in1
,
float
*
grad_in2
,
const
float
*
grad_out
,
const
float
*
grad_grad_in1
,
const
float
*
grad_grad_in2
,
const
float
*
in1
,
const
float
*
in2
,
const
int64_t
*
idx1
,
const
int64_t
size
)
{
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
bidx
=
blockIdx
.
x
;
const
int
start_idx
=
bidx
*
blockDim
.
y
+
tidy
;
constexpr
int
fea_dim
=
64
;
if
(
start_idx
<
size
)
{
int64_t
vec_idx1
=
(
idx1
[
start_idx
]
*
fea_dim
)
/
4
+
tidx
;
int64_t
vec_idx2
=
(
start_idx
*
fea_dim
)
/
4
+
tidx
;
float4
src_grad_grad_in1
,
src_in1
,
src_grad_grad_in2
,
src_in2
,
src_grad_out
;
float4
dst_grad_grad_out
,
dst_grad_in2
;
src_grad_grad_in1
=
reinterpret_cast
<
const
float4
*>
(
grad_grad_in1
)[
vec_idx1
];
src_in1
=
reinterpret_cast
<
const
float4
*>
(
in1
)[
vec_idx1
];
src_grad_grad_in2
=
reinterpret_cast
<
const
float4
*>
(
grad_grad_in2
)[
vec_idx2
];
src_in2
=
reinterpret_cast
<
const
float4
*>
(
in2
)[
vec_idx2
];
dst_grad_grad_out
.
x
=
src_grad_grad_in1
.
x
*
src_in2
.
x
+
src_grad_grad_in2
.
x
*
src_in1
.
x
;
dst_grad_grad_out
.
y
=
src_grad_grad_in1
.
y
*
src_in2
.
y
+
src_grad_grad_in2
.
y
*
src_in1
.
y
;
dst_grad_grad_out
.
z
=
src_grad_grad_in1
.
z
*
src_in2
.
z
+
src_grad_grad_in2
.
z
*
src_in1
.
z
;
dst_grad_grad_out
.
w
=
src_grad_grad_in1
.
w
*
src_in2
.
w
+
src_grad_grad_in2
.
w
*
src_in1
.
w
;
reinterpret_cast
<
float4
*>
(
grad_grad_out
)[
vec_idx2
]
=
dst_grad_grad_out
;
src_grad_out
=
reinterpret_cast
<
const
float4
*>
(
grad_out
)[
vec_idx2
];
int64_t
grad_in1_base_idx
=
idx1
[
start_idx
]
*
fea_dim
+
tidx
*
4
;
gpuAtomicAdd
(
grad_in1
+
grad_in1_base_idx
+
0
,
src_grad_grad_in2
.
x
*
src_grad_out
.
x
);
gpuAtomicAdd
(
grad_in1
+
grad_in1_base_idx
+
1
,
src_grad_grad_in2
.
y
*
src_grad_out
.
y
);
gpuAtomicAdd
(
grad_in1
+
grad_in1_base_idx
+
2
,
src_grad_grad_in2
.
z
*
src_grad_out
.
z
);
gpuAtomicAdd
(
grad_in1
+
grad_in1_base_idx
+
3
,
src_grad_grad_in2
.
w
*
src_grad_out
.
w
);
dst_grad_in2
.
x
=
src_grad_grad_in1
.
x
*
src_grad_out
.
x
;
dst_grad_in2
.
y
=
src_grad_grad_in1
.
y
*
src_grad_out
.
y
;
dst_grad_in2
.
z
=
src_grad_grad_in1
.
z
*
src_grad_out
.
z
;
dst_grad_in2
.
w
=
src_grad_grad_in1
.
w
*
src_grad_out
.
w
;
reinterpret_cast
<
float4
*>
(
grad_in2
)[
vec_idx2
]
=
dst_grad_in2
;
}
}
__global__
void
index_mul_2d_grad_grad_float
(
float
*
grad_grad_out
,
float
*
grad_in1
,
float
*
grad_in2
,
const
float
*
grad_out
,
const
float
*
grad_grad_in1
,
const
float
*
grad_grad_in2
,
const
float
*
in1
,
const
float
*
in2
,
const
int64_t
*
idx1
,
const
int64_t
size
,
const
int64_t
fea_dim
)
{
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
bidx
=
blockIdx
.
x
;
const
int
start_idx
=
bidx
*
blockDim
.
y
+
tidy
;
const
int
stride
=
blockDim
.
x
;
if
(
start_idx
<
size
)
{
int64_t
vec_idx1
=
idx1
[
start_idx
]
*
fea_dim
;
int64_t
vec_idx2
=
start_idx
*
fea_dim
;
for
(
int
i
=
tidx
;
i
<
fea_dim
;
i
+=
stride
)
{
float
src_grad_grad_in1
=
grad_grad_in1
[
vec_idx1
+
i
];
float
src_grad_grad_in2
=
grad_grad_in2
[
vec_idx2
+
i
];
float
src_in1
=
in1
[
vec_idx1
+
i
];
float
src_in2
=
in2
[
vec_idx2
+
i
];
float
src_grad_out
=
grad_out
[
vec_idx2
+
i
];
grad_grad_out
[
vec_idx2
+
i
]
=
src_grad_grad_in1
*
src_in2
+
src_grad_grad_in2
*
src_in1
;
grad_in2
[
vec_idx2
+
i
]
=
src_grad_grad_in1
*
src_grad_out
;
gpuAtomicAdd
(
grad_in1
+
vec_idx1
+
i
,
src_grad_grad_in2
*
src_grad_out
);
}
}
}
__global__
void
index_mul_2d_grad_grad_half
(
at
::
Half
*
grad_grad_out
,
at
::
Half
*
grad_in1
,
at
::
Half
*
grad_in2
,
const
at
::
Half
*
grad_out
,
const
at
::
Half
*
grad_grad_in1
,
const
at
::
Half
*
grad_grad_in2
,
const
at
::
Half
*
in1
,
const
at
::
Half
*
in2
,
const
int64_t
*
idx1
,
const
int64_t
size
,
const
int64_t
fea_dim
)
{
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
bidx
=
blockIdx
.
x
;
const
int
start_idx
=
bidx
*
blockDim
.
y
+
tidy
;
const
int
stride
=
blockDim
.
x
;
if
(
start_idx
<
size
)
{
int64_t
vec_idx1
=
idx1
[
start_idx
]
*
fea_dim
;
int64_t
vec_idx2
=
start_idx
*
fea_dim
;
for
(
int
i
=
tidx
;
i
<
fea_dim
;
i
+=
stride
)
{
float
src_grad_grad_in1
=
static_cast
<
float
>
(
grad_grad_in1
[
vec_idx1
+
i
]);
float
src_grad_grad_in2
=
static_cast
<
float
>
(
grad_grad_in2
[
vec_idx2
+
i
]);
float
src_in1
=
static_cast
<
float
>
(
in1
[
vec_idx1
+
i
]);
float
src_in2
=
static_cast
<
float
>
(
in2
[
vec_idx2
+
i
]);
float
src_grad_out
=
static_cast
<
float
>
(
grad_out
[
vec_idx2
+
i
]);
grad_grad_out
[
vec_idx2
+
i
]
=
at
::
Half
(
src_grad_grad_in1
*
src_in2
+
src_grad_grad_in2
*
src_in1
);
grad_in2
[
vec_idx2
+
i
]
=
at
::
Half
(
src_grad_grad_in1
*
src_grad_out
);
gpuAtomicAdd
(
grad_in1
+
vec_idx1
+
i
,
at
::
Half
(
src_grad_grad_in2
*
src_grad_out
));
}
}
}
void
index_mul_2d_float_foward_cuda
(
at
::
Tensor
&
out
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
)
{
const
int64_t
size
=
in2
.
size
(
0
);
const
int64_t
fea_dim
=
in2
.
size
(
1
);
if
(
size
<
0
){
return
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
fea_dim
==
64
)
{
const
int
BLOCK_THREADS_DIMX
=
16
;
const
int
BLOCK_THREADS_DIMY
=
16
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
dim3
threads
(
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
);
index_mul_2d_float_dim64
<<<
BLOCK_NUMS
,
threads
,
0
,
stream
>>>
(
out
.
data_ptr
<
float
>
(),
in1
.
data_ptr
<
float
>
(),
in2
.
data_ptr
<
float
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
);
}
else
{
const
int
BLOCK_THREADS_DIMX
=
32
;
const
int
BLOCK_THREADS_DIMY
=
8
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
dim3
threads
(
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
);
index_mul_2d_float
<<<
BLOCK_NUMS
,
threads
,
0
,
stream
>>>
(
out
.
data_ptr
<
float
>
(),
in1
.
data_ptr
<
float
>
(),
in2
.
data_ptr
<
float
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
,
fea_dim
);
}
AT_CUDA_CHECK
(
cudaGetLastError
());
}
void
index_mul_2d_float_backward_cuda
(
at
::
Tensor
&
grad_in1
,
at
::
Tensor
&
grad_in2
,
const
at
::
Tensor
&
grad_out
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
)
{
const
int64_t
size
=
in2
.
size
(
0
);
const
int64_t
fea_dim
=
in2
.
size
(
1
);
if
(
size
<
0
){
return
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
fea_dim
==
64
)
{
const
int
BLOCK_THREADS_DIMX
=
16
;
const
int
BLOCK_THREADS_DIMY
=
16
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
dim3
threads
(
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
);
index_mul_2d_grad_float_dim64
<<<
BLOCK_NUMS
,
threads
,
0
,
stream
>>>
(
grad_in1
.
data_ptr
<
float
>
(),
grad_in2
.
data_ptr
<
float
>
(),
grad_out
.
data_ptr
<
float
>
(),
in1
.
data_ptr
<
float
>
(),
in2
.
data_ptr
<
float
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
);
AT_CUDA_CHECK
(
cudaGetLastError
());
}
else
{
const
int
BLOCK_THREADS_DIMX
=
32
;
const
int
BLOCK_THREADS_DIMY
=
8
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
dim3
threads
(
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
);
index_mul_2d_grad_float
<<<
BLOCK_NUMS
,
threads
,
0
,
stream
>>>
(
grad_in1
.
data_ptr
<
float
>
(),
grad_in2
.
data_ptr
<
float
>
(),
grad_out
.
data_ptr
<
float
>
(),
in1
.
data_ptr
<
float
>
(),
in2
.
data_ptr
<
float
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
,
fea_dim
);
}
}
void
index_mul_2d_float_backward_backward_cuda
(
at
::
Tensor
&
grad_grad_out
,
at
::
Tensor
&
grad_in1
,
at
::
Tensor
&
grad_in2
,
const
at
::
Tensor
&
grad_out
,
const
at
::
Tensor
&
grad_grad_in1
,
const
at
::
Tensor
&
grad_grad_in2
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
)
{
const
int64_t
size
=
in2
.
size
(
0
);
const
int64_t
fea_dim
=
in2
.
size
(
1
);
if
(
size
<
0
){
return
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
fea_dim
==
64
)
{
const
int
BLOCK_THREADS_DIMX
=
16
;
const
int
BLOCK_THREADS_DIMY
=
16
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
dim3
threads
(
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
);
index_mul_2d_grad_grad_float_dim64
<<<
BLOCK_NUMS
,
threads
,
0
,
stream
>>>
(
grad_grad_out
.
data_ptr
<
float
>
(),
grad_in1
.
data_ptr
<
float
>
(),
grad_in2
.
data_ptr
<
float
>
(),
grad_out
.
data_ptr
<
float
>
(),
grad_grad_in1
.
data_ptr
<
float
>
(),
grad_grad_in2
.
data_ptr
<
float
>
(),
in1
.
data_ptr
<
float
>
(),
in2
.
data_ptr
<
float
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
);
}
else
{
const
int
BLOCK_THREADS_DIMX
=
32
;
const
int
BLOCK_THREADS_DIMY
=
8
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
dim3
threads
(
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
);
index_mul_2d_grad_grad_float
<<<
BLOCK_NUMS
,
threads
,
0
,
stream
>>>
(
grad_grad_out
.
data_ptr
<
float
>
(),
grad_in1
.
data_ptr
<
float
>
(),
grad_in2
.
data_ptr
<
float
>
(),
grad_out
.
data_ptr
<
float
>
(),
grad_grad_in1
.
data_ptr
<
float
>
(),
grad_grad_in2
.
data_ptr
<
float
>
(),
in1
.
data_ptr
<
float
>
(),
in2
.
data_ptr
<
float
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
,
fea_dim
);
}
AT_CUDA_CHECK
(
cudaGetLastError
());
}
void
index_mul_2d_half_foward_cuda
(
at
::
Tensor
&
out
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
)
{
const
int64_t
size
=
in2
.
size
(
0
);
const
int64_t
fea_dim
=
in2
.
size
(
1
);
if
(
size
<
0
){
return
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
int
BLOCK_THREADS_DIMX
=
32
;
const
int
BLOCK_THREADS_DIMY
=
8
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
dim3
threads
(
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
);
index_mul_2d_half
<<<
BLOCK_NUMS
,
threads
,
0
,
stream
>>>
(
out
.
data_ptr
<
at
::
Half
>
(),
in1
.
data_ptr
<
at
::
Half
>
(),
in2
.
data_ptr
<
at
::
Half
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
,
fea_dim
);
AT_CUDA_CHECK
(
cudaGetLastError
());
}
void
index_mul_2d_half_backward_cuda
(
at
::
Tensor
&
grad_in1
,
at
::
Tensor
&
grad_in2
,
const
at
::
Tensor
&
grad_out
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
)
{
const
int64_t
size
=
in2
.
size
(
0
);
const
int64_t
fea_dim
=
in2
.
size
(
1
);
if
(
size
<
0
){
return
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
int
BLOCK_THREADS_DIMX
=
32
;
const
int
BLOCK_THREADS_DIMY
=
8
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
dim3
threads
(
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
);
index_mul_2d_grad_half
<<<
BLOCK_NUMS
,
threads
,
0
,
stream
>>>
(
grad_in1
.
data_ptr
<
at
::
Half
>
(),
grad_in2
.
data_ptr
<
at
::
Half
>
(),
grad_out
.
data_ptr
<
at
::
Half
>
(),
in1
.
data_ptr
<
at
::
Half
>
(),
in2
.
data_ptr
<
at
::
Half
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
,
fea_dim
);
}
void
index_mul_2d_half_backward_backward_cuda
(
at
::
Tensor
&
grad_grad_out
,
at
::
Tensor
&
grad_in1
,
at
::
Tensor
&
grad_in2
,
const
at
::
Tensor
&
grad_out
,
const
at
::
Tensor
&
grad_grad_in1
,
const
at
::
Tensor
&
grad_grad_in2
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
)
{
const
int64_t
size
=
in2
.
size
(
0
);
const
int64_t
fea_dim
=
in2
.
size
(
1
);
if
(
size
<
0
){
return
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
int
BLOCK_THREADS_DIMX
=
32
;
const
int
BLOCK_THREADS_DIMY
=
8
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
dim3
threads
(
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
);
index_mul_2d_grad_grad_half
<<<
BLOCK_NUMS
,
threads
,
0
,
stream
>>>
(
grad_grad_out
.
data_ptr
<
at
::
Half
>
(),
grad_in1
.
data_ptr
<
at
::
Half
>
(),
grad_in2
.
data_ptr
<
at
::
Half
>
(),
grad_out
.
data_ptr
<
at
::
Half
>
(),
grad_grad_in1
.
data_ptr
<
at
::
Half
>
(),
grad_grad_in2
.
data_ptr
<
at
::
Half
>
(),
in1
.
data_ptr
<
at
::
Half
>
(),
in2
.
data_ptr
<
at
::
Half
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
,
fea_dim
);
AT_CUDA_CHECK
(
cudaGetLastError
());
}
apex/contrib/csrc/layer_norm/ln.h
deleted
100644 → 0
View file @
2a4864d5
#pragma once
#include <unordered_map>
#include <functional>
#if defined(__HIP_PLATFORM_HCC__)
#include "hip/hip_fp16.h"
#include "hip/hip_bfloat16.h"
#else
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#endif
namespace
layer_norm
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Params
>
struct
LaunchParams
{
size_t
workspace_bytes
;
size_t
barrier_size
;
cudaDeviceProp
*
props
;
cudaStream_t
stream
;
Params
params
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
ParamsBase
{
ParamsBase
()
:
ctas_per_col
(
0
)
,
rows
(
0
)
,
cols
(
0
)
,
x
(
nullptr
)
,
mu
(
nullptr
)
,
rs
(
nullptr
)
,
gamma
(
nullptr
)
,
workspace
(
nullptr
)
,
barrier
(
nullptr
)
{
}
// For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
int
ctas_per_col
;
// Input is interpreted as matrix. We normalize across columns.
int
rows
;
int
cols
;
// Common data pointers.
void
*
x
;
void
*
mu
;
void
*
rs
;
void
*
gamma
;
// Multi-CTA workspace in gmem.
void
*
workspace
;
// Multi-CTA sync barriers in gmem.
int
*
barrier
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
FwdParams
:
public
ParamsBase
{
FwdParams
()
:
ParamsBase
()
,
z
(
nullptr
)
,
beta
(
nullptr
)
,
epsilon
(
0.
f
)
{
}
// Output of LN FWD.
void
*
z
;
void
*
beta
;
float
epsilon
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
BwdParams
:
public
ParamsBase
{
BwdParams
()
:
ParamsBase
()
,
dz
(
nullptr
)
,
dbeta_part
(
nullptr
)
,
dgamma_part
(
nullptr
)
,
dx
(
nullptr
)
,
dbeta
(
nullptr
)
,
dgamma
(
nullptr
)
{
}
// Input: gradient wrt. LN FWD output.
void
*
dz
;
// Workspace for Wgrad pre-reduction.
void
*
dbeta_part
;
void
*
dgamma_part
;
// Output: Dgrad.
void
*
dx
;
// Output: Wgrad.
void
*
dbeta
;
void
*
dgamma
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
using
FwdFunction
=
std
::
function
<
void
(
LaunchParams
<
FwdParams
>&
,
const
bool
)
>
;
using
BwdFunction
=
std
::
function
<
void
(
LaunchParams
<
BwdParams
>&
,
const
bool
)
>
;
using
FunctionKey
=
uint64_t
;
using
FwdRegistry
=
std
::
unordered_map
<
FunctionKey
,
FwdFunction
>
;
using
BwdRegistry
=
std
::
unordered_map
<
FunctionKey
,
BwdFunction
>
;
extern
FwdRegistry
FWD_FUNCS
;
extern
BwdRegistry
BWD_FUNCS
;
////////////////////////////////////////////////////////////////////////////////////////////////////
using
fp32
=
float
;
using
fp16
=
half
;
#if defined(__HIP_PLATFORM_HCC__)
using
bf16
=
hip_bfloat16
;
#else
using
bf16
=
nv_bfloat16
;
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
TypeId
{};
template
<
>
struct
TypeId
<
fp16
>
{
constexpr
static
uint32_t
Value
=
0
;
};
template
<
>
struct
TypeId
<
bf16
>
{
constexpr
static
uint32_t
Value
=
1
;
};
template
<
>
struct
TypeId
<
fp32
>
{
constexpr
static
uint32_t
Value
=
2
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
S
>
struct
Type2Key
{
constexpr
static
uint32_t
Value
=
TypeId
<
T
>::
Value
<<
S
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
WeightType2Key
:
public
Type2Key
<
T
,
0
>
{};
template
<
typename
T
>
struct
InputType2Key
:
public
Type2Key
<
T
,
2
>
{};
template
<
typename
T
>
struct
OutputType2Key
:
public
Type2Key
<
T
,
4
>
{};
template
<
typename
T
>
struct
ComputeType2Key
:
public
Type2Key
<
T
,
6
>
{};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
W
,
typename
I
,
typename
O
,
typename
C
>
struct
Types2Key
{
constexpr
static
uint32_t
Value
=
WeightType2Key
<
W
>::
Value
|
InputType2Key
<
I
>::
Value
|
OutputType2Key
<
O
>::
Value
|
ComputeType2Key
<
C
>::
Value
;
constexpr
static
inline
uint64_t
get
(
const
uint64_t
hidden_size
){
constexpr
uint64_t
type_key
=
Value
;
return
(
type_key
<<
32
)
|
hidden_size
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
W
,
typename
I
,
typename
O
,
typename
C
,
uint64_t
HIDDEN_SIZE
>
struct
FwdRegistrar
{
FwdRegistrar
(
FwdFunction
f
){
uint64_t
key
=
Types2Key
<
W
,
I
,
O
,
C
>::
get
(
HIDDEN_SIZE
);
FWD_FUNCS
.
insert
({
key
,
f
});
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
W
,
typename
I
,
typename
O
,
typename
C
,
uint64_t
HIDDEN_SIZE
>
struct
BwdRegistrar
{
BwdRegistrar
(
BwdFunction
f
){
uint64_t
key
=
Types2Key
<
W
,
I
,
O
,
C
>::
get
(
HIDDEN_SIZE
);
BWD_FUNCS
.
insert
({
key
,
f
});
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace layer_norm
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment