Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
f79993d9
Commit
f79993d9
authored
Oct 15, 2021
by
hubertlu-tw
Browse files
Merge remote-tracking branch 'upstream/master' into IFU-master-2021-10-15
parents
297ab210
1d5f7e55
Changes
117
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3078 additions
and
19 deletions
+3078
-19
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu
+60
-0
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu
+60
-0
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu
+105
-0
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h
+558
-0
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h
+571
-0
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu
+58
-0
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu
+58
-0
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu
+57
-0
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu
+98
-0
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h
+336
-0
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_nl.h
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_nl.h
+343
-0
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_reload_v.h
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_reload_v.h
+322
-0
apex/contrib/csrc/fmha/src/fmha_kernel.h
apex/contrib/csrc/fmha/src/fmha_kernel.h
+169
-0
apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu
apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu
+177
-0
apex/contrib/csrc/fmha/src/fmha_utils.h
apex/contrib/csrc/fmha/src/fmha_utils.h
+92
-0
apex/contrib/csrc/groupbn/batch_norm.cu
apex/contrib/csrc/groupbn/batch_norm.cu
+6
-9
apex/contrib/csrc/groupbn/batch_norm.h
apex/contrib/csrc/groupbn/batch_norm.h
+1
-0
apex/contrib/csrc/groupbn/batch_norm_add_relu.cu
apex/contrib/csrc/groupbn/batch_norm_add_relu.cu
+6
-9
apex/contrib/csrc/groupbn/batch_norm_add_relu.h
apex/contrib/csrc/groupbn/batch_norm_add_relu.h
+1
-0
apex/contrib/csrc/groupbn/ipc.cu
apex/contrib/csrc/groupbn/ipc.cu
+0
-1
No files found.
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu
0 → 100644
View file @
f79993d9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
>
;
extern
"C"
__global__
void
fmha_dgrad_fp16_256_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
fmha
::
compute_dq_dk_1xN
<
Kernel_traits
>
(
params
);
}
void
run_fmha_dgrad_fp16_256_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
using
Smem_tile_s
=
fmha
::
Smem_tile_mma_transposed
<
Kernel_traits
::
Cta_tile_p
>
;
constexpr
int
smem_size_s
=
Smem_tile_s
::
BYTES_PER_TILE
;
static_assert
(
smem_size_s
==
16
*
256
*
2
);
static_assert
(
smem_size_o
==
16
*
64
*
4
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
);
constexpr
int
smem_size_dv
=
smem_size_s
+
2
*
smem_size_q
+
smem_size_v
+
smem_size_softmax
;
constexpr
int
smem_size_dq_dk
=
smem_size_s
+
smem_size_o
+
smem_size_q
+
smem_size_v
;
constexpr
int
smem_size
=
std
::
max
(
smem_size_dv
,
smem_size_dq_dk
);
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
fmha_dgrad_fp16_256_64_sm80_kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
dim3
grid
(
params
.
h
,
params
.
b
);
fmha_dgrad_fp16_256_64_sm80_kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
}
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu
0 → 100644
View file @
f79993d9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
384
,
64
,
16
,
1
,
8
,
0x08u
>
;
extern
"C"
__global__
void
fmha_dgrad_fp16_384_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
fmha
::
compute_dq_dk_1xN
<
Kernel_traits
>
(
params
);
}
void
run_fmha_dgrad_fp16_384_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
using
Smem_tile_s
=
fmha
::
Smem_tile_mma_transposed
<
Kernel_traits
::
Cta_tile_p
>
;
constexpr
int
smem_size_s
=
Smem_tile_s
::
BYTES_PER_TILE
;
static_assert
(
smem_size_s
==
16
*
384
*
2
);
static_assert
(
smem_size_o
==
16
*
64
*
4
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
);
constexpr
int
smem_size_dv
=
smem_size_s
+
2
*
smem_size_q
+
smem_size_v
+
smem_size_softmax
;
constexpr
int
smem_size_dq_dk
=
smem_size_s
+
smem_size_o
+
smem_size_q
+
smem_size_v
;
constexpr
int
smem_size
=
std
::
max
(
smem_size_dv
,
smem_size_dq_dk
);
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
fmha_dgrad_fp16_384_64_sm80_kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
dim3
grid
(
params
.
h
,
params
.
b
);
fmha_dgrad_fp16_384_64_sm80_kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
}
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu
0 → 100644
View file @
f79993d9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
#include "fmha_dgrad_kernel_1xN_reload_nl.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
512
,
64
,
16
,
1
,
8
,
0x08u
>
;
extern
"C"
__global__
void
fmha_dgrad_fp16_512_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
fmha
::
compute_dq_dk_1xN
<
Kernel_traits
>
(
params
);
}
template
<
int
CHUNKS
>
__global__
void
fmha_dgrad_fp16_512_64_sm80_nl_kernel
(
Fused_multihead_attention_fprop_params
params
){
fmha
::
compute_dv_1xN_nl
<
CHUNKS
,
Kernel_traits
>
(
params
);
fmha
::
compute_dq_dk_1xN_nl
<
CHUNKS
,
Kernel_traits
>
(
params
);
}
void
run_fmha_dgrad_fp16_512_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
using
Smem_tile_s
=
fmha
::
Smem_tile_mma_transposed
<
Kernel_traits
::
Cta_tile_p
>
;
constexpr
int
smem_size_s
=
Smem_tile_s
::
BYTES_PER_TILE
;
static_assert
(
smem_size_s
==
16
*
512
*
2
);
static_assert
(
smem_size_o
==
16
*
64
*
4
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
);
constexpr
int
smem_size_dv
=
smem_size_s
+
2
*
smem_size_q
+
smem_size_v
+
smem_size_softmax
;
constexpr
int
smem_size_dq_dk
=
smem_size_s
+
smem_size_o
+
smem_size_q
+
smem_size_v
;
constexpr
int
smem_size
=
std
::
max
(
smem_size_dv
,
smem_size_dq_dk
);
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
fmha_dgrad_fp16_512_64_sm80_kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
dim3
grid
(
params
.
h
,
params
.
b
);
fmha_dgrad_fp16_512_64_sm80_kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
}
void
run_fmha_dgrad_fp16_512_64_sm80_nl
(
const
Fused_multihead_attention_fprop_params
&
params
,
const
int
num_chunks
,
cudaStream_t
stream
)
{
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
using
Smem_tile_s
=
fmha
::
Smem_tile_mma_transposed
<
Kernel_traits
::
Cta_tile_p
>
;
constexpr
int
smem_size_s
=
Smem_tile_s
::
BYTES_PER_TILE
;
static_assert
(
smem_size_s
==
16
*
512
*
2
);
static_assert
(
smem_size_o
==
16
*
64
*
4
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
);
constexpr
int
smem_size_dv
=
smem_size_s
+
2
*
smem_size_q
+
smem_size_v
+
smem_size_softmax
;
constexpr
int
smem_size_dq_dk
=
smem_size_s
+
smem_size_o
+
smem_size_q
+
smem_size_v
;
constexpr
int
smem_size
=
std
::
max
(
smem_size_dv
,
smem_size_dq_dk
);
auto
kernel
=
fmha_dgrad_fp16_512_64_sm80_nl_kernel
<
2
>
;
if
(
num_chunks
==
2
)
{
kernel
=
fmha_dgrad_fp16_512_64_sm80_nl_kernel
<
2
>
;
}
else
if
(
num_chunks
==
3
)
{
kernel
=
fmha_dgrad_fp16_512_64_sm80_nl_kernel
<
3
>
;
}
else
{
assert
(
false
&&
"Unsupperted number of chunks"
);
}
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
dim3
grid
(
params
.
h
,
params
.
b
,
num_chunks
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h
0 → 100644
View file @
f79993d9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
typename
Params
>
inline
__device__
void
compute_dv_1xN
(
const
Params
&
params
)
{
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
// The description of the CTA tile for the 2nd batched GEMM.
using
Cta_tile_dv
=
fmha
::
Cta_tile_extd
<
Cta_tile_p
::
N
,
Cta_tile_p
::
K
,
Cta_tile_p
::
M
,
Cta_tile_p
::
WARPS_N
,
1
,
Cta_tile_p
::
WARPS_M
>
;
static_assert
(
Cta_tile_dv
::
M
==
512
||
Cta_tile_dv
::
M
==
384
||
Cta_tile_dv
::
M
==
256
||
Cta_tile_dv
::
M
==
128
);
static_assert
(
Cta_tile_dv
::
N
==
64
);
static_assert
(
Cta_tile_dv
::
K
==
16
);
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
// The MMA tile for the 2nd GEMM.
using
Mma_tile_dv
=
fmha
::
Hmma_tile
<
Cta_tile_dv
>
;
// The global memory tile to load Q.
using
Gmem_tile_q
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The shared memory tile to swizzle Q.
// using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
using
Smem_tile_q
=
fmha
::
Smem_tile_a
<
Cta_tile_p
,
fmha
::
Row
,
Gmem_tile_q
::
BYTES_PER_LDG
,
2
>
;
// The shared memory tile to reload Q as fragment b.
using
Smem_tile_qt
=
fmha
::
Smem_tile_b
<
Cta_tile_dv
,
fmha
::
Row
,
Gmem_tile_q
::
BYTES_PER_LDG
,
2
>
;
// The global memory tile to load K.
using
Gmem_tile_k
=
typename
Kernel_traits
::
Gmem_tile_k
;
// The shared memory tile to swizzle K.
using
Smem_tile_k
=
typename
Kernel_traits
::
Smem_tile_k
;
// The global memory tile to load V.
using
Gmem_tile_v
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_v
;
// The global memory tile to store O.
using
Gmem_tile_o
=
typename
Kernel_traits
::
Gmem_tile_o
;
// The shared memory tile to swizzle O.
using
Smem_tile_o
=
typename
Kernel_traits
::
Smem_tile_o
;
// The global memory tile to store dV.
using
Gmem_tile_dv
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle dV.
using
Smem_tile_dv
=
fmha
::
Smem_tile_mma_epilogue
<
Cta_tile_dv
>
;
static_assert
(
Smem_tile_dv
::
NUM_LDS
==
Gmem_tile_dv
::
LDGS
);
static_assert
(
Smem_tile_dv
::
THREADS_PER_ROW
==
Gmem_tile_dv
::
THREADS_PER_ROW
);
using
Gmem_tile_s
=
typename
Kernel_traits
::
Gmem_tile_s
;
using
Smem_tile_st
=
typename
Kernel_traits
::
Smem_tile_st
;
using
Gmem_tile_do
=
typename
Kernel_traits
::
Gmem_tile_do
;
// Shared memory.
extern
__shared__
char
smem_
[];
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
if
(
binfo
.
stop_early
()
)
return
;
Mask
<
Cta_tile_p
>
mask
(
params
,
binfo
,
tidx
);
// Allocate the global memory tile loader for Q.
Gmem_tile_do
gmem_q
(
params
,
binfo
,
tidx
);
// treating dout as Q
// Allocate the shared memory tile loader for Q.
Smem_tile_q
smem_q
(
&
smem_
[
0
],
tidx
);
Smem_tile_qt
smem_qt
(
&
smem_
[
0
],
tidx
);
Smem_tile_st
smem_s
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
2
,
binfo
,
tidx
);
// treating V as K
// Allocate the shared memory tile loader for K.
Smem_tile_k
smem_k
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
// Trigger the loads for Q.
gmem_q
.
load
(
smem_q
);
// Trigger the loads for K.
gmem_k
.
load
(
smem_k
);
// Commit the data for Q and K to shared memory.
gmem_q
.
commit
(
smem_q
);
gmem_k
.
commit
(
smem_k
);
// Make sure the data is in shared memory.
__syncthreads
();
// Load the fragments for Q.
typename
Smem_tile_q
::
Fragment
frag_q
[
2
][
Mma_tile_p
::
MMAS_M
];
smem_q
.
load
(
frag_q
[
0
],
0
);
typename
Smem_tile_qt
::
Fragment
frag_qt
[
2
][
Mma_tile_dv
::
MMAS_N
];
static_assert
(
Smem_tile_qt
::
Fragment
::
NUM_REGS
==
4
);
static_assert
(
Mma_tile_dv
::
MMAS_K
==
1
);
smem_qt
.
load
(
frag_qt
[
0
],
0
);
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename
Smem_tile_k
::
Fragment
frag_k
[
2
][
Mma_tile_p
::
MMAS_N
];
smem_k
.
load
(
frag_k
[
0
],
0
);
enum
{
BITS_PER_ELT_S
=
sizeof
(
fmha
::
A_type
)
*
8
};
Gmem_tile_s
gmem_s
(
params
.
s_ptr
,
params
,
tidx
);
// Create the object to do the softmax.
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
Softmax
softmax
(
params
,
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_st
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
],
bidb
,
tidx
);
enum
{
THREADS_PER_ROW
=
32
};
enum
{
M
=
Mma_tile_p
::
MMAS_M
};
enum
{
N
=
Mma_tile_p
::
MMAS_N
};
// Declare the accumulators for the 2nd gemm.
fmha
::
Fragment_accumulator
acc_dv
[
Mma_tile_dv
::
MMAS_M
][
Mma_tile_dv
::
MMAS_N
];
fmha
::
Clear_accumulator
<
fmha
::
Accumulator_type
,
Cta_tile_dv
::
WARPS_K
>::
apply
(
acc_dv
);
enum
{
STEPS
=
Cta_tile_p
::
N
/
Cta_tile_p
::
M
};
// Load over the entire sequence length.
for
(
int
l
=
0
;
l
<
STEPS
;
l
++
)
{
const
int
loop
=
l
*
Cta_tile_p
::
M
;
if
(
loop
>=
binfo
.
actual_seqlen
)
break
;
// Load S
uint4
s_regs
[
M
][
N
];
gmem_s
.
load
(
s_regs
,
mask
);
fmha
::
Fragment_accumulator
acc_p
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Clear_accumulator
<
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_p
);
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_q
.
load
(
frag_q
[
ki
&
1
],
ki
);
smem_k
.
load
(
frag_k
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_p
,
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
}
// Store s * dmask to smem for transpose
smem_s
.
store
(
s_regs
);
// Declare the accumulators for the 1st gemm.
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm
(
acc_p
,
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
}
// Trigger the load for the next Q values. We're using double buffering, so reading qt is safe
if
(
l
<
STEPS
-
1
)
{
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
load
(
smem_q
);
}
// Convert from the accumulator type to FP32 for Softmax.
softmax
.
unpack
(
acc_p
);
float
s_mat
[
2
*
M
][
4
*
N
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
uint4
&
dst
=
s_regs
[
mi
][
ni
];
fmha
::
half2_to_float2
(
s_mat
[
2
*
mi
+
0
][
4
*
ni
+
0
],
s_mat
[
2
*
mi
+
0
][
4
*
ni
+
1
],
dst
.
x
);
fmha
::
half2_to_float2
(
s_mat
[
2
*
mi
+
0
][
4
*
ni
+
2
],
s_mat
[
2
*
mi
+
0
][
4
*
ni
+
3
],
dst
.
y
);
fmha
::
half2_to_float2
(
s_mat
[
2
*
mi
+
1
][
4
*
ni
+
0
],
s_mat
[
2
*
mi
+
1
][
4
*
ni
+
1
],
dst
.
z
);
fmha
::
half2_to_float2
(
s_mat
[
2
*
mi
+
1
][
4
*
ni
+
2
],
s_mat
[
2
*
mi
+
1
][
4
*
ni
+
3
],
dst
.
w
);
}
}
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
2
;
ii
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
#pragma unroll
for
(
int
jj
=
0
;
jj
<
4
;
jj
++
)
{
float
&
s_dmask
=
s_mat
[
2
*
mi
+
ii
][
4
*
ni
+
jj
];
const
bool
drop
=
reinterpret_cast
<
const
uint32_t
&>
(
s_dmask
)
&
0x80000000
;
const
float
d_s
=
drop
?
0.
f
:
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
*
params
.
rp_dropout
;
s_dmask
=
fabsf
(
s_dmask
);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
=
d_s
*
fabsf
(
s_dmask
);
}
}
}
}
float
p_sum
[
2
*
M
];
softmax
.
template
reduce
<
fmha
::
Sum_
>(
p_sum
);
const
float
scalef
=
reinterpret_cast
<
const
float
&>
(
params
.
scale_softmax
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
2
;
ii
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
#pragma unroll
for
(
int
jj
=
0
;
jj
<
4
;
jj
++
)
{
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
-=
p_sum
[
2
*
mi
+
ii
]
*
(
s_mat
[
2
*
mi
+
ii
][
4
*
ni
+
jj
])
;
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
*=
scalef
;
}
}
}
}
typename
Smem_tile_st
::
Fragment
frag_s
[
Mma_tile_dv
::
MMAS_K
][
Mma_tile_dv
::
MMAS_M
];
smem_s
.
load
(
frag_s
);
for
(
int
ki
=
0
;
ki
<
Mma_tile_dv
::
MMAS_K
;
ki
++
)
{
for
(
int
mi
=
0
;
mi
<
Mma_tile_dv
::
MMAS_M
;
mi
++
)
{
for
(
int
ii
=
0
;
ii
<
Smem_tile_st
::
Fragment
::
NUM_REGS
;
ii
++
)
{
frag_s
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hmul2
(
frag_s
[
ki
][
mi
].
reg
(
ii
),
params
.
scale_dropout
);
frag_s
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hrelu2
(
frag_s
[
ki
][
mi
].
reg
(
ii
));
}
}
}
gmem_s
.
store
(
softmax
.
elt_
,
mask
);
gmem_s
.
move
();
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_dv
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_qt
.
load
(
frag_qt
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_dv
::
MMAS_K
;
fmha
::
gemm
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
}
// Commit the values for Q into shared memory.
if
(
l
<
STEPS
-
1
)
{
gmem_q
.
commit
(
smem_q
);
}
// Make sure we are reading from the correct buffer.
smem_q
.
move_to_next_read_buffer
();
smem_qt
.
move_to_next_read_buffer
();
// Make sure the data is in shared memory.
__syncthreads
();
// Trigger the loads for the values of Q for the next iteration.
smem_q
.
load
(
frag_q
[
0
],
0
);
smem_k
.
load
(
frag_k
[
0
],
0
);
smem_qt
.
load
(
frag_qt
[
0
],
0
);
}
// Outer loop over the sequence length.
// Epilogue swizzle for dV
Smem_tile_dv
smem_dv
(
&
smem_
[
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
smem_dv
.
store
(
acc_dv
);
__syncthreads
();
uint4
dv_out
[
Smem_tile_dv
::
NUM_LDS
];
smem_dv
.
load
(
dv_out
);
Qkv_params
dv_params
;
dv_params
.
qkv_ptr
=
params
.
dqkv_ptr
;
dv_params
.
qkv_stride_in_bytes
=
params
.
qkv_stride_in_bytes
;
dv_params
.
h
=
params
.
h
;
Gmem_tile_dv
gmem_dv
(
dv_params
,
2
,
binfo
,
tidx
);
gmem_dv
.
store
(
dv_out
);
}
template
<
typename
Kernel_traits
,
typename
Params
>
inline
__device__
void
compute_dq_dk_1xN
(
const
Params
&
params
)
{
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
using
Cta_tile_o
=
typename
Kernel_traits
::
Cta_tile_o
;
// The description of the CTA tile for the 2nd batched GEMM.
using
Cta_tile_dk
=
fmha
::
Cta_tile_extd
<
Cta_tile_p
::
N
,
Cta_tile_p
::
K
,
Cta_tile_p
::
M
,
Cta_tile_p
::
WARPS_N
,
1
,
Cta_tile_p
::
WARPS_M
>
;
static_assert
(
Cta_tile_dk
::
M
==
512
||
Cta_tile_dk
::
M
==
384
||
Cta_tile_dk
::
M
==
256
||
Cta_tile_dk
::
M
==
128
);
static_assert
(
Cta_tile_dk
::
N
==
64
);
static_assert
(
Cta_tile_dk
::
K
==
16
);
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
using
Mma_tile_o
=
fmha
::
Hmma_tile
<
Cta_tile_o
>
;
// The MMA tile for the 2nd GEMM.
using
Mma_tile_dk
=
fmha
::
Hmma_tile
<
Cta_tile_dk
>
;
// The global memory tile to load Q.
using
Gmem_tile_q
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The shared memory tile to swizzle Q.
using
Smem_tile_q
=
typename
Kernel_traits
::
Smem_tile_q
;
// The global memory tile to load K.
using
Gmem_tile_k
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle K.
using
Smem_tile_k
=
typename
Kernel_traits
::
Smem_tile_v
;
// K is used like V in fprop
// The global memory tile to load V.
using
Gmem_tile_v
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_v
;
// The global memory tile to store O.
// using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
using
Gmem_tile_o
=
fmha
::
Gmem_tile_dq
<
Cta_tile_o
>
;
// The shared memory tile to swizzle O.
using
Smem_tile_o
=
typename
Kernel_traits
::
Smem_tile_o
;
// The global memory tile to store dK.
using
Gmem_tile_dk
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle dK.
using
Smem_tile_dk
=
fmha
::
Smem_tile_mma_epilogue
<
Cta_tile_dk
>
;
static_assert
(
Smem_tile_dk
::
NUM_LDS
==
Gmem_tile_dk
::
LDGS
);
static_assert
(
Smem_tile_dk
::
THREADS_PER_ROW
==
Gmem_tile_dk
::
THREADS_PER_ROW
);
// The shared memory tile to reload Q transposed.
using
Smem_tile_qt
=
fmha
::
Smem_tile_b
<
Cta_tile_dk
,
fmha
::
Row
,
Gmem_tile_q
::
BYTES_PER_LDG
,
1
>
;
using
Gmem_tile_s
=
typename
Kernel_traits
::
Gmem_tile_s
;
using
Smem_tile_st
=
typename
Kernel_traits
::
Smem_tile_st
;
enum
{
M
=
Mma_tile_p
::
MMAS_M
};
enum
{
N
=
Mma_tile_p
::
MMAS_N
};
static_assert
(
M
==
Mma_tile_o
::
MMAS_M
);
static_assert
(
N
==
Mma_tile_o
::
MMAS_K
);
// Shared memory.
extern
__shared__
char
smem_
[];
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
if
(
binfo
.
stop_early
()
)
return
;
Mask
<
Cta_tile_p
>
mask
(
params
,
binfo
,
tidx
);
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for Q.
Smem_tile_q
smem_q
(
&
smem_
[
0
],
tidx
);
Smem_tile_qt
smem_qt
(
&
smem_
[
0
],
tidx
);
Smem_tile_st
smem_s
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for K.
Smem_tile_k
smem_k
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o
smem_o
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
],
tidx
);
// Trigger the loads for Q.
gmem_q
.
load
(
smem_q
);
// Trigger the loads for K.
gmem_k
.
load
(
smem_k
);
Gmem_tile_s
gmem_s
(
params
.
s_ptr
,
params
,
tidx
);
// Load dP
uint4
s_regs
[
M
][
N
];
gmem_s
.
load
(
s_regs
,
mask
);
gmem_s
.
move
();
// Commit the data for Q and K to shared memory.
gmem_q
.
commit
(
smem_q
);
gmem_k
.
commit
(
smem_k
);
// Make sure the data is in shared memory.
__syncthreads
();
typename
Smem_tile_qt
::
Fragment
frag_qt
[
2
][
Mma_tile_dk
::
MMAS_N
];
smem_qt
.
load
(
frag_qt
[
0
],
0
);
typename
Smem_tile_k
::
Fragment
frag_k
[
2
][
Mma_tile_o
::
MMAS_N
];
smem_k
.
load
(
frag_k
[
0
],
0
);
enum
{
BITS_PER_ELT_S
=
sizeof
(
fmha
::
A_type
)
*
8
};
enum
{
THREADS_PER_ROW
=
32
};
enum
{
STEPS
=
Cta_tile_p
::
N
/
Cta_tile_p
::
M
};
// Declare the accumulators for the 2nd gemm.
fmha
::
Fragment_accumulator
acc_dk
[
Mma_tile_dk
::
MMAS_M
][
Mma_tile_dk
::
MMAS_N
];
fmha
::
Clear_accumulator
<
fmha
::
Accumulator_type
,
Cta_tile_dk
::
WARPS_K
>::
apply
(
acc_dk
);
// Load over the entire sequence length.
for
(
int
l
=
0
;
l
<
STEPS
;
l
++
)
{
const
int
loop
=
l
*
Cta_tile_p
::
M
;
if
(
loop
>=
binfo
.
actual_seqlen
)
break
;
// Pack dP as Fragment_a
fmha
::
Fragment_a
<
fmha
::
Row
>
frag_p
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_M
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
uint4
&
dst
=
s_regs
[
mi
][
ni
];
frag_p
[
ni
][
mi
].
reg
(
0
)
=
dst
.
x
;
// row 0, cols 0,1
frag_p
[
ni
][
mi
].
reg
(
1
)
=
dst
.
z
;
// row 8, cols 0,1
frag_p
[
ni
][
mi
].
reg
(
2
)
=
dst
.
y
;
// row 0, cols 8,9
frag_p
[
ni
][
mi
].
reg
(
3
)
=
dst
.
w
;
// row 8, cols 8,9
}
}
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_o
[
Mma_tile_o
::
MMAS_M
][
Mma_tile_o
::
MMAS_N
];
fmha
::
Clear_accumulator
<
fmha
::
Accumulator_type
,
Cta_tile_o
::
WARPS_K
>::
apply
(
acc_o
);
// Do this part of O = P^T * V^T. dQ = dP x dK
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_k
.
load
(
frag_k
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_o
,
frag_p
[
ki
-
1
],
frag_k
[(
ki
-
1
)
&
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_o
::
MMAS_K
;
fmha
::
gemm
(
acc_o
,
frag_p
[
ki
-
1
],
frag_k
[(
ki
-
1
)
&
1
]);
}
// Store dP to smem for transpose
smem_s
.
store
(
s_regs
);
if
(
l
<
STEPS
-
1
)
{
// Load next part of S
gmem_s
.
load
(
s_regs
,
mask
);
gmem_s
.
move
();
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
load
(
smem_q
);
}
// Loop over MMAS_M.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Gmem_tile_o
::
LOOPS
;
++
ii
)
{
// Swizzle the elements and do the final reduction.
smem_o
.
store
(
acc_o
,
ii
);
// Make sure the data is in shared memory.
__syncthreads
();
// Load from shared memory.
uint4
out
[
Gmem_tile_o
::
STGS_PER_LOOP
];
smem_o
.
load
(
out
);
// Make sure the data was read from shared memory.
if
(
ii
<
Gmem_tile_o
::
LOOPS
-
1
)
{
__syncthreads
();
}
// Output the values.
gmem_o
.
store
(
out
,
ii
);
}
// Move to the next part of the output.
gmem_o
.
move
();
typename
Smem_tile_st
::
Fragment
frag_s
[
Mma_tile_dk
::
MMAS_K
][
Mma_tile_dk
::
MMAS_M
];
smem_s
.
load
(
frag_s
);
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_dk
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_qt
.
load
(
frag_qt
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_dk
,
frag_s
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_dk
::
MMAS_K
;
fmha
::
gemm
(
acc_dk
,
frag_s
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
}
// Commit the values for Q into shared memory.
if
(
l
<
STEPS
-
1
)
{
gmem_q
.
commit
(
smem_q
);
}
// Make sure the data is in shared memory.
__syncthreads
();
// Trigger the loads for the values of Q for the next iteration.
smem_qt
.
load
(
frag_qt
[
0
],
0
);
smem_k
.
load
(
frag_k
[
0
],
0
);
}
// Outer loop over the sequence length.
// Epilogue swizzle for dK
Smem_tile_dk
smem_dk
(
&
smem_
[
0
],
tidx
);
smem_dk
.
store
(
acc_dk
);
__syncthreads
();
uint4
dk_out
[
Smem_tile_dk
::
NUM_LDS
];
smem_dk
.
load
(
dk_out
);
Qkv_params
dk_params
;
dk_params
.
qkv_ptr
=
params
.
dqkv_ptr
;
dk_params
.
qkv_stride_in_bytes
=
params
.
qkv_stride_in_bytes
;
dk_params
.
h
=
params
.
h
;
Gmem_tile_dk
gmem_dk
(
dk_params
,
1
,
binfo
,
tidx
);
gmem_dk
.
store
(
dk_out
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h
0 → 100644
View file @
f79993d9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
CHUNKS
,
typename
Kernel_traits
,
typename
Params
>
inline
__device__
void
compute_dv_1xN_nl
(
const
Params
&
params
)
{
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
// The description of the CTA tile for the 2nd batched GEMM.
using
Cta_tile_dv
=
fmha
::
Cta_tile_extd
<
Cta_tile_p
::
N
,
Cta_tile_p
::
K
,
Cta_tile_p
::
M
,
Cta_tile_p
::
WARPS_N
,
1
,
Cta_tile_p
::
WARPS_M
>
;
static_assert
(
Cta_tile_dv
::
M
==
512
||
Cta_tile_dv
::
M
==
384
||
Cta_tile_dv
::
M
==
256
||
Cta_tile_dv
::
M
==
128
);
static_assert
(
Cta_tile_dv
::
N
==
64
);
static_assert
(
Cta_tile_dv
::
K
==
16
);
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
// The MMA tile for the 2nd GEMM.
using
Mma_tile_dv
=
fmha
::
Hmma_tile
<
Cta_tile_dv
>
;
// The global memory tile to load Q.
using
Gmem_tile_q
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The shared memory tile to swizzle Q.
using
Smem_tile_q
=
fmha
::
Smem_tile_a
<
Cta_tile_p
,
fmha
::
Row
,
Gmem_tile_q
::
BYTES_PER_LDG
,
2
>
;
// The shared memory tile to reload Q as fragment b.
using
Smem_tile_qt
=
fmha
::
Smem_tile_b
<
Cta_tile_dv
,
fmha
::
Row
,
Gmem_tile_q
::
BYTES_PER_LDG
,
2
>
;
// The global memory tile to load K.
using
Gmem_tile_k
=
typename
Kernel_traits
::
Gmem_tile_k
;
// The shared memory tile to swizzle K.
using
Smem_tile_k
=
typename
Kernel_traits
::
Smem_tile_k
;
// The global memory tile to load V.
using
Gmem_tile_v
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_v
;
// The global memory tile to store dV.
using
Gmem_tile_dv
=
fmha
::
Gmem_tile_qkv
<
typename
Kernel_traits
::
Cta_tile_o
,
fmha
::
BITS_PER_ELEMENT_B
,
Cta_tile_p
::
N
,
//S,
Cta_tile_p
::
K
,
//D,
2
*
CHUNKS
>
;
// The shared memory tile to swizzle dV.
using
Smem_tile_dv
=
fmha
::
Smem_tile_mma_epilogue
<
Cta_tile_dv
>
;
static_assert
(
Smem_tile_dv
::
NUM_LDS
==
Gmem_tile_dv
::
LDGS
);
static_assert
(
Smem_tile_dv
::
THREADS_PER_ROW
==
Gmem_tile_dv
::
THREADS_PER_ROW
);
using
Gmem_tile_s
=
typename
Kernel_traits
::
Gmem_tile_s
;
using
Smem_tile_st
=
typename
Kernel_traits
::
Smem_tile_st
;
using
Gmem_tile_do
=
typename
Kernel_traits
::
Gmem_tile_do
;
// Shared memory.
extern
__shared__
char
smem_
[];
// The block index for the chunk.
const
int
bidc
=
blockIdx
.
z
;
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
if
(
binfo
.
stop_early
()
)
return
;
fmha
::
Mask
<
Cta_tile_p
>
mask
(
params
,
binfo
,
tidx
);
// Allocate the global memory tile loader for Q.
Gmem_tile_do
gmem_q
(
params
,
binfo
,
tidx
);
// treating dout as Q
// Allocate the shared memory tile loader for Q.
Smem_tile_q
smem_q
(
&
smem_
[
0
],
tidx
);
Smem_tile_qt
smem_qt
(
&
smem_
[
0
],
tidx
);
Smem_tile_st
smem_s
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
2
,
binfo
,
tidx
);
// treating V as K
// Allocate the shared memory tile loader for K.
Smem_tile_k
smem_k
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
Gmem_tile_s
gmem_s
(
params
.
s_ptr
,
params
,
tidx
);
using
Noloop
=
Noloop_traits
<
CHUNKS
,
Cta_tile_p
>
;
Noloop
nl_traits
(
bidc
);
nl_traits
.
move_all
(
gmem_q
,
gmem_s
);
// Trigger the loads for Q.
gmem_q
.
load
(
smem_q
);
// Trigger the loads for K.
gmem_k
.
load
(
smem_k
);
// Commit the data for Q and K to shared memory.
gmem_q
.
commit
(
smem_q
);
gmem_k
.
commit
(
smem_k
);
// Make sure the data is in shared memory.
__syncthreads
();
// Load the fragments for Q.
typename
Smem_tile_q
::
Fragment
frag_q
[
2
][
Mma_tile_p
::
MMAS_M
];
smem_q
.
load
(
frag_q
[
0
],
0
);
typename
Smem_tile_qt
::
Fragment
frag_qt
[
2
][
Mma_tile_dv
::
MMAS_N
];
static_assert
(
Smem_tile_qt
::
Fragment
::
NUM_REGS
==
4
);
static_assert
(
Mma_tile_dv
::
MMAS_K
==
1
);
smem_qt
.
load
(
frag_qt
[
0
],
0
);
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename
Smem_tile_k
::
Fragment
frag_k
[
2
][
Mma_tile_p
::
MMAS_N
];
smem_k
.
load
(
frag_k
[
0
],
0
);
enum
{
BITS_PER_ELT_S
=
sizeof
(
fmha
::
A_type
)
*
8
};
// Create the object to do the softmax.
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
Softmax
softmax
(
params
,
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_st
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
],
bidb
,
tidx
);
enum
{
THREADS_PER_ROW
=
32
};
enum
{
M
=
Mma_tile_p
::
MMAS_M
};
enum
{
N
=
Mma_tile_p
::
MMAS_N
};
// Declare the accumulators for the 2nd gemm.
fmha
::
Fragment_accumulator
acc_dv
[
Mma_tile_dv
::
MMAS_M
][
Mma_tile_dv
::
MMAS_N
];
fmha
::
Clear_accumulator
<
fmha
::
Accumulator_type
,
Cta_tile_dv
::
WARPS_K
>::
apply
(
acc_dv
);
// Load over the entire sequence length.
for
(
int
l
=
0
;
l
<
nl_traits
.
num_steps_
;
l
++
)
{
const
int
loop
=
nl_traits
.
offset_loop_count
(
l
);
if
(
loop
>=
binfo
.
actual_seqlen
)
break
;
uint4
s_regs
[
M
][
N
];
gmem_s
.
load
(
s_regs
,
mask
);
fmha
::
Fragment_accumulator
acc_p
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Clear_accumulator
<
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_p
);
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_q
.
load
(
frag_q
[
ki
&
1
],
ki
);
smem_k
.
load
(
frag_k
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_p
,
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
}
smem_s
.
store
(
s_regs
);
// Declare the accumulators for the 1st gemm.
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm
(
acc_p
,
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
}
// Trigger the load for the next Q values. We're using double buffering, so reading qt is safe
if
(
l
<
nl_traits
.
num_steps_
-
1
)
{
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
load
(
smem_q
);
}
// Convert from the accumulator type to FP32 for Softmax.
softmax
.
unpack
(
acc_p
);
float
s_mat
[
2
*
M
][
4
*
N
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
uint4
&
dst
=
s_regs
[
mi
][
ni
];
fmha
::
half2_to_float2
(
s_mat
[
2
*
mi
+
0
][
4
*
ni
+
0
],
s_mat
[
2
*
mi
+
0
][
4
*
ni
+
1
],
dst
.
x
);
fmha
::
half2_to_float2
(
s_mat
[
2
*
mi
+
0
][
4
*
ni
+
2
],
s_mat
[
2
*
mi
+
0
][
4
*
ni
+
3
],
dst
.
y
);
fmha
::
half2_to_float2
(
s_mat
[
2
*
mi
+
1
][
4
*
ni
+
0
],
s_mat
[
2
*
mi
+
1
][
4
*
ni
+
1
],
dst
.
z
);
fmha
::
half2_to_float2
(
s_mat
[
2
*
mi
+
1
][
4
*
ni
+
2
],
s_mat
[
2
*
mi
+
1
][
4
*
ni
+
3
],
dst
.
w
);
}
}
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
2
;
ii
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
#pragma unroll
for
(
int
jj
=
0
;
jj
<
4
;
jj
++
)
{
float
&
s_dmask
=
s_mat
[
2
*
mi
+
ii
][
4
*
ni
+
jj
];
const
bool
drop
=
reinterpret_cast
<
const
uint32_t
&>
(
s_dmask
)
&
0x80000000
;
const
float
d_s
=
drop
?
0.
f
:
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
*
params
.
rp_dropout
;
s_dmask
=
fabsf
(
s_dmask
);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
=
d_s
*
(
s_dmask
);
}
}
}
}
float
p_sum
[
2
*
M
];
softmax
.
template
reduce
<
fmha
::
Sum_
>(
p_sum
);
const
float
scalef
=
reinterpret_cast
<
const
float
&>
(
params
.
scale_softmax
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
2
;
ii
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
#pragma unroll
for
(
int
jj
=
0
;
jj
<
4
;
jj
++
)
{
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
-=
p_sum
[
2
*
mi
+
ii
]
*
(
s_mat
[
2
*
mi
+
ii
][
4
*
ni
+
jj
])
;
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
*=
scalef
;
}
}
}
}
typename
Smem_tile_st
::
Fragment
frag_s
[
Mma_tile_dv
::
MMAS_K
][
Mma_tile_dv
::
MMAS_M
];
smem_s
.
load
(
frag_s
);
for
(
int
ki
=
0
;
ki
<
Mma_tile_dv
::
MMAS_K
;
ki
++
)
{
for
(
int
mi
=
0
;
mi
<
Mma_tile_dv
::
MMAS_M
;
mi
++
)
{
for
(
int
ii
=
0
;
ii
<
Smem_tile_st
::
Fragment
::
NUM_REGS
;
ii
++
)
{
frag_s
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hmul2
(
frag_s
[
ki
][
mi
].
reg
(
ii
),
params
.
scale_dropout
);
frag_s
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hrelu2
(
frag_s
[
ki
][
mi
].
reg
(
ii
));
}
}
}
gmem_s
.
store
(
softmax
.
elt_
,
mask
);
gmem_s
.
move
();
static_assert
(
Mma_tile_dv
::
MMAS_K
==
1
);
// DEBUG
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_dv
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_qt
.
load
(
frag_qt
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_dv
::
MMAS_K
;
fmha
::
gemm
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
}
// Commit the values for Q into shared memory.
if
(
l
<
nl_traits
.
num_steps_
-
1
)
{
gmem_q
.
commit
(
smem_q
);
}
// Make sure we are reading from the correct buffer.
smem_q
.
move_to_next_read_buffer
();
smem_qt
.
move_to_next_read_buffer
();
// Make sure the data is in shared memory.
__syncthreads
();
// Trigger the loads for the values of Q for the next iteration.
smem_q
.
load
(
frag_q
[
0
],
0
);
smem_k
.
load
(
frag_k
[
0
],
0
);
smem_qt
.
load
(
frag_qt
[
0
],
0
);
}
// Outer loop over the sequence length.
// Epilogue for dV = (S * D)' * dout'. We're fully exposed to this!
// Epilogue swizzle for dV
Smem_tile_dv
smem_dv
(
&
smem_
[
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
smem_dv
.
store
(
acc_dv
);
__syncthreads
();
uint4
dv_out
[
Smem_tile_dv
::
NUM_LDS
];
smem_dv
.
load
(
dv_out
);
Qkv_params
dv_params
;
dv_params
.
qkv_ptr
=
params
.
dkv_ptr
;
dv_params
.
qkv_stride_in_bytes
=
params
.
h
*
2
*
CHUNKS
*
params
.
d
*
sizeof
(
half
);
dv_params
.
h
=
params
.
h
;
Gmem_tile_dv
gmem_dv
(
dv_params
,
nl_traits
.
get_idx_dv
(),
binfo
,
tidx
);
gmem_dv
.
store
(
dv_out
);
}
template
<
int
CHUNKS
,
typename
Kernel_traits
,
typename
Params
>
inline
__device__
void
compute_dq_dk_1xN_nl
(
const
Params
&
params
)
{
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
using
Cta_tile_o
=
typename
Kernel_traits
::
Cta_tile_o
;
// The description of the CTA tile for the 2nd batched GEMM.
using
Cta_tile_dk
=
fmha
::
Cta_tile_extd
<
Cta_tile_p
::
N
,
Cta_tile_p
::
K
,
Cta_tile_p
::
M
,
Cta_tile_p
::
WARPS_N
,
1
,
Cta_tile_p
::
WARPS_M
>
;
static_assert
(
Cta_tile_dk
::
M
==
512
||
Cta_tile_dk
::
M
==
384
||
Cta_tile_dk
::
M
==
256
||
Cta_tile_dk
::
M
==
128
);
static_assert
(
Cta_tile_dk
::
N
==
64
);
static_assert
(
Cta_tile_dk
::
K
==
16
);
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
using
Mma_tile_o
=
fmha
::
Hmma_tile
<
Cta_tile_o
>
;
// The MMA tile for the 2nd GEMM.
using
Mma_tile_dk
=
fmha
::
Hmma_tile
<
Cta_tile_dk
>
;
// The global memory tile to load Q.
using
Gmem_tile_q
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The shared memory tile to swizzle Q.
using
Smem_tile_q
=
typename
Kernel_traits
::
Smem_tile_q
;
// The global memory tile to load K.
using
Gmem_tile_k
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle K.
using
Smem_tile_k
=
typename
Kernel_traits
::
Smem_tile_v
;
// K is used like V in fprop
// The global memory tile to load V.
using
Gmem_tile_v
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_v
;
// The global memory tile to store O.
using
Gmem_tile_o
=
Gmem_tile_dq
<
Cta_tile_o
>
;
// The shared memory tile to swizzle O.
using
Smem_tile_o
=
typename
Kernel_traits
::
Smem_tile_o
;
// The global memory tile to store dK.
using
Gmem_tile_dk
=
fmha
::
Gmem_tile_qkv
<
typename
Kernel_traits
::
Cta_tile_o
,
fmha
::
BITS_PER_ELEMENT_B
,
Cta_tile_p
::
N
,
//S,
Cta_tile_p
::
K
,
//D,
2
*
CHUNKS
>
;
// The shared memory tile to swizzle dK.
using
Smem_tile_dk
=
fmha
::
Smem_tile_mma_epilogue
<
Cta_tile_dk
>
;
static_assert
(
Smem_tile_dk
::
NUM_LDS
==
Gmem_tile_dk
::
LDGS
);
static_assert
(
Smem_tile_dk
::
THREADS_PER_ROW
==
Gmem_tile_dk
::
THREADS_PER_ROW
);
// The shared memory tile to reload Q transposed.
using
Smem_tile_qt
=
fmha
::
Smem_tile_b
<
Cta_tile_dk
,
fmha
::
Row
,
Gmem_tile_q
::
BYTES_PER_LDG
,
1
>
;
// The global memory tile to load dP, stored in S
using
Gmem_tile_s
=
Gmem_tile_mma_s
<
Cta_tile_p
>
;
// The shared memory tile to transpose dP.
using
Smem_tile_st
=
Smem_tile_mma_transposed
<
Cta_tile_p
>
;
using
Noloop
=
Noloop_traits
<
CHUNKS
,
Cta_tile_p
>
;
enum
{
M
=
Mma_tile_p
::
MMAS_M
};
enum
{
N
=
Mma_tile_p
::
MMAS_N
};
static_assert
(
M
==
Mma_tile_o
::
MMAS_M
);
static_assert
(
N
==
Mma_tile_o
::
MMAS_K
);
// Shared memory.
extern
__shared__
char
smem_
[];
const
int
bidc
=
blockIdx
.
z
;
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
if
(
binfo
.
stop_early
()
)
return
;
fmha
::
Mask
<
Cta_tile_p
>
mask
(
params
,
binfo
,
tidx
);
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for Q (as B).
Smem_tile_qt
smem_qt
(
&
smem_
[
0
],
tidx
);
// Allocate the global memory tile loader for dP.
Gmem_tile_s
gmem_s
(
params
.
s_ptr
,
params
,
tidx
);
// Allocate the shared memory tile loader for dP.
Smem_tile_st
smem_s
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for K.
Smem_tile_k
smem_k
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o
smem_o
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
],
tidx
);
Noloop
nl_traits
(
bidc
);
nl_traits
.
move_all
(
gmem_q
,
gmem_o
,
gmem_s
);
// Trigger the loads for Q.
gmem_q
.
load
(
smem_qt
);
// Trigger the loads for K.
gmem_k
.
load
(
smem_k
);
uint4
s_regs
[
M
][
N
];
gmem_s
.
load
(
s_regs
,
mask
);
// Commit the data for Q and K to shared memory.
gmem_q
.
commit
(
smem_qt
);
gmem_k
.
commit
(
smem_k
);
// Make sure the data is in shared memory.
__syncthreads
();
typename
Smem_tile_qt
::
Fragment
frag_qt
[
2
][
Mma_tile_dk
::
MMAS_N
];
smem_qt
.
load
(
frag_qt
[
0
],
0
);
typename
Smem_tile_k
::
Fragment
frag_k
[
2
][
Mma_tile_o
::
MMAS_N
];
smem_k
.
load
(
frag_k
[
0
],
0
);
enum
{
BITS_PER_ELT_S
=
sizeof
(
fmha
::
A_type
)
*
8
};
enum
{
THREADS_PER_ROW
=
32
};
// Declare the accumulators for the 2nd gemm.
fmha
::
Fragment_accumulator
acc_dk
[
Mma_tile_dk
::
MMAS_M
][
Mma_tile_dk
::
MMAS_N
];
fmha
::
Clear_accumulator
<
fmha
::
Accumulator_type
,
Cta_tile_dk
::
WARPS_K
>::
apply
(
acc_dk
);
// Load over the entire sequence length.
for
(
int
l
=
0
;
l
<
nl_traits
.
num_steps_
;
l
++
)
{
// Pack dP as Fragment_a
fmha
::
Fragment_a
<
fmha
::
Row
>
frag_p
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_M
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
uint4
&
dst
=
s_regs
[
mi
][
ni
];
frag_p
[
ni
][
mi
].
reg
(
0
)
=
dst
.
x
;
frag_p
[
ni
][
mi
].
reg
(
1
)
=
dst
.
z
;
frag_p
[
ni
][
mi
].
reg
(
2
)
=
dst
.
y
;
frag_p
[
ni
][
mi
].
reg
(
3
)
=
dst
.
w
;
}
}
smem_s
.
store
(
s_regs
);
if
(
l
<
nl_traits
.
num_steps_
-
1
)
{
// Load next part of S
gmem_s
.
move
();
gmem_s
.
load
(
s_regs
,
mask
);
// Trigger the load for the next Q values.
smem_qt
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
load
(
smem_qt
);
}
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_o
[
Mma_tile_o
::
MMAS_M
][
Mma_tile_o
::
MMAS_N
];
fmha
::
Clear_accumulator
<
fmha
::
Accumulator_type
,
Cta_tile_o
::
WARPS_K
>::
apply
(
acc_o
);
// Do this part of O = P^T * V^T. dQ = dP x dK
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_k
.
load
(
frag_k
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_o
,
frag_p
[
ki
-
1
],
frag_k
[(
ki
-
1
)
&
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_o
::
MMAS_K
;
fmha
::
gemm
(
acc_o
,
frag_p
[
ki
-
1
],
frag_k
[(
ki
-
1
)
&
1
]);
}
static_assert
(
Gmem_tile_o
::
LOOPS
==
1
);
//DEBUG
// Loop over MMAS_M.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Gmem_tile_o
::
LOOPS
;
++
ii
)
{
// Swizzle the elements and do the final reduction.
smem_o
.
store
(
acc_o
,
ii
);
// Make sure the data is in shared memory.
__syncthreads
();
// Load from shared memory.
uint4
out
[
Gmem_tile_o
::
STGS_PER_LOOP
];
smem_o
.
load
(
out
);
// Make sure the data was read from shared memory.
if
(
ii
<
Gmem_tile_o
::
LOOPS
-
1
)
{
__syncthreads
();
}
// Output the values.
gmem_o
.
store
(
out
,
ii
);
}
// Move to the next part of the output.
gmem_o
.
move
();
typename
Smem_tile_st
::
Fragment
frag_s
[
Mma_tile_dk
::
MMAS_K
][
Mma_tile_dk
::
MMAS_M
];
smem_s
.
load
(
frag_s
);
static_assert
(
Mma_tile_dk
::
MMAS_K
==
1
);
// DEBUG
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_dk
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_qt
.
load
(
frag_qt
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_dk
,
frag_s
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_dk
::
MMAS_K
;
fmha
::
gemm
(
acc_dk
,
frag_s
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
}
// Commit the values for Q into shared memory.
if
(
l
<
nl_traits
.
num_steps_
-
1
)
{
gmem_q
.
commit
(
smem_qt
);
__syncthreads
();
// Trigger the loads for the values of Q for the next iteration.
smem_qt
.
load
(
frag_qt
[
0
],
0
);
smem_k
.
load
(
frag_k
[
0
],
0
);
}
}
// Outer loop over the sequence length.
// Epilogue for dK = dP' * dq. We're fully exposed to this!
// Epilogue swizzle for dK
Smem_tile_dk
smem_dk
(
&
smem_
[
0
],
tidx
);
smem_dk
.
store
(
acc_dk
);
__syncthreads
();
uint4
dk_out
[
Smem_tile_dk
::
NUM_LDS
];
smem_dk
.
load
(
dk_out
);
Qkv_params
dk_params
;
dk_params
.
qkv_ptr
=
params
.
dkv_ptr
;
dk_params
.
qkv_stride_in_bytes
=
params
.
h
*
2
*
CHUNKS
*
params
.
d
*
sizeof
(
half
);
dk_params
.
h
=
params
.
h
;
Gmem_tile_dk
gmem_dk
(
dk_params
,
nl_traits
.
get_idx_dk
(),
binfo
,
tidx
);
gmem_dk
.
store
(
dk_out
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu
0 → 100644
View file @
f79993d9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
extern
"C"
__global__
void
fmha_fprop_fp16_128_64_sm80_train_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN
<
Kernel_traits
,
true
>
(
params
);
}
extern
"C"
__global__
void
fmha_fprop_fp16_128_64_sm80_predict_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN
<
Kernel_traits
,
false
>
(
params
);
}
void
run_fmha_fp16_128_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
)
{
auto
kernel
=
is_training
?
&
fmha_fprop_fp16_128_64_sm80_train_kernel
:
&
fmha_fprop_fp16_128_64_sm80_predict_kernel
;
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
constexpr
int
smem_size
=
smem_size_q
+
std
::
max
(
smem_size_v
,
smem_size_o
+
smem_size_softmax
);
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
dim3
grid
(
params
.
h
,
params
.
b
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
}
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu
0 → 100644
View file @
f79993d9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
>
;
extern
"C"
__global__
void
fmha_fprop_fp16_256_64_sm80_train_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN
<
Kernel_traits
,
true
>
(
params
);
}
extern
"C"
__global__
void
fmha_fprop_fp16_256_64_sm80_predict_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN
<
Kernel_traits
,
false
>
(
params
);
}
void
run_fmha_fp16_256_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
)
{
auto
kernel
=
is_training
?
&
fmha_fprop_fp16_256_64_sm80_train_kernel
:
&
fmha_fprop_fp16_256_64_sm80_predict_kernel
;
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
constexpr
int
smem_size
=
smem_size_q
+
std
::
max
(
smem_size_v
,
smem_size_o
+
smem_size_softmax
);
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
dim3
grid
(
params
.
h
,
params
.
b
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
}
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu
0 → 100644
View file @
f79993d9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN_reload_v.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
384
,
64
,
16
,
1
,
4
,
0x08u
>
;
extern
"C"
__global__
void
fmha_fprop_fp16_384_64_sm80_train_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN
<
Kernel_traits
,
true
>
(
params
);
}
extern
"C"
__global__
void
fmha_fprop_fp16_384_64_sm80_predict_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN
<
Kernel_traits
,
false
>
(
params
);
}
void
run_fmha_fp16_384_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
)
{
auto
kernel
=
is_training
?
&
fmha_fprop_fp16_384_64_sm80_train_kernel
:
&
fmha_fprop_fp16_384_64_sm80_predict_kernel
;
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
constexpr
int
smem_size
=
smem_size_v
+
smem_size_o
+
smem_size_softmax
;
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
dim3
grid
(
params
.
h
,
params
.
b
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
}
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu
0 → 100644
View file @
f79993d9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_fprop_kernel_1xN_nl.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
512
,
64
,
16
,
1
,
8
,
0x08u
>
;
extern
"C"
__global__
void
fmha_fprop_fp16_512_64_sm80_train_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN
<
Kernel_traits
,
true
>
(
params
);
}
extern
"C"
__global__
void
fmha_fprop_fp16_512_64_sm80_predict_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN
<
Kernel_traits
,
false
>
(
params
);
}
template
<
int
CHUNKS
>
__global__
void
fmha_fprop_fp16_512_64_sm80_train_nl_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN_nl
<
CHUNKS
,
Kernel_traits
,
true
>
(
params
);
}
template
<
int
CHUNKS
>
__global__
void
fmha_fprop_fp16_512_64_sm80_predict_nl_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN_nl
<
CHUNKS
,
Kernel_traits
,
false
>
(
params
);
}
void
run_fmha_fp16_512_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
)
{
auto
kernel
=
is_training
?
&
fmha_fprop_fp16_512_64_sm80_train_kernel
:
&
fmha_fprop_fp16_512_64_sm80_predict_kernel
;
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
constexpr
int
smem_size
=
smem_size_q
+
std
::
max
(
smem_size_v
,
smem_size_o
+
smem_size_softmax
);
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
dim3
grid
(
params
.
h
,
params
.
b
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
}
void
run_fmha_fp16_512_64_sm80_nl
(
const
Fused_multihead_attention_fprop_params
&
params
,
const
bool
is_training
,
const
int
num_chunks
,
cudaStream_t
stream
)
{
auto
kernel
=
is_training
?
&
fmha_fprop_fp16_512_64_sm80_train_nl_kernel
<
2
>
:
&
fmha_fprop_fp16_512_64_sm80_predict_nl_kernel
<
2
>
;
if
(
num_chunks
==
2
)
{
kernel
=
is_training
?
&
fmha_fprop_fp16_512_64_sm80_train_nl_kernel
<
2
>
:
&
fmha_fprop_fp16_512_64_sm80_predict_nl_kernel
<
2
>
;
}
else
if
(
num_chunks
==
3
)
{
kernel
=
is_training
?
&
fmha_fprop_fp16_512_64_sm80_train_nl_kernel
<
3
>
:
&
fmha_fprop_fp16_512_64_sm80_predict_nl_kernel
<
3
>
;
}
else
if
(
num_chunks
==
4
)
{
kernel
=
is_training
?
&
fmha_fprop_fp16_512_64_sm80_train_nl_kernel
<
4
>
:
&
fmha_fprop_fp16_512_64_sm80_predict_nl_kernel
<
4
>
;
}
else
{
assert
(
false
&&
"Unsupported num_chunks"
);
}
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
constexpr
int
smem_size
=
smem_size_q
+
std
::
max
(
smem_size_v
,
smem_size_o
+
smem_size_softmax
);
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
dim3
grid
(
params
.
h
,
params
.
b
,
num_chunks
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
}
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h
0 → 100644
View file @
f79993d9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_training
,
typename
Params
>
inline
__device__
void
device_1xN
(
const
Params
&
params
)
{
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
// The description of the CTA tile for the 2nd batched GEMM.
using
Cta_tile_o
=
typename
Kernel_traits
::
Cta_tile_o
;
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
// The MMA tile for the 2nd GEMM.
using
Mma_tile_o
=
fmha
::
Hmma_tile
<
Cta_tile_o
>
;
// The global memory tile to load Q.
using
Gmem_tile_q
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The shared memory tile to swizzle Q.
using
Smem_tile_q
=
typename
Kernel_traits
::
Smem_tile_q
;
// The global memory tile to load K.
using
Gmem_tile_k
=
typename
Kernel_traits
::
Gmem_tile_k
;
// The shared memory tile to swizzle K.
using
Smem_tile_k
=
typename
Kernel_traits
::
Smem_tile_k
;
// The global memory tile to load V.
using
Gmem_tile_v
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_v
;
// The global memory tile to store O.
using
Gmem_tile_o
=
typename
Kernel_traits
::
Gmem_tile_o
;
// The shared memory tile to swizzle O.
using
Smem_tile_o
=
typename
Kernel_traits
::
Smem_tile_o
;
using
Gmem_tile_s
=
typename
Kernel_traits
::
Gmem_tile_s
;
// Shared memory.
extern
__shared__
char
smem_
[];
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
if
(
binfo
.
stop_early
()
)
return
;
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
Philox
ph
(
std
::
get
<
0
>
(
seeds
),
binfo
.
tidx_global
,
std
::
get
<
1
>
(
seeds
));
Mask
<
Cta_tile_p
>
mask
(
params
,
binfo
,
tidx
);
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for Q.
Smem_tile_q
smem_q
(
&
smem_
[
0
],
tidx
);
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for K.
Smem_tile_k
smem_k
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
,
2
,
binfo
,
tidx
);
// The base pointer of smem_v;
char
*
smem_v_
=
nullptr
;
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
smem_v_
=
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
];
}
else
{
smem_v_
=
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
];
}
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v
smem_v
(
smem_v_
,
tidx
);
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o
smem_o
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
// Trigger the loads for Q.
gmem_q
.
load
(
smem_q
);
// Trigger the loads for K.
gmem_k
.
load
(
smem_k
);
// Trigger the loads for K.
gmem_v
.
load
(
smem_v
);
// Commit the data for Q and K to shared memory.
gmem_q
.
commit
(
smem_q
);
gmem_k
.
commit
(
smem_k
);
// Commit the data for V to shared memory.
if
(
!
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
gmem_v
.
commit
(
smem_v
);
}
// Make sure the data is in shared memory.
__syncthreads
();
// Load the fragments for Q.
typename
Smem_tile_q
::
Fragment
frag_q
[
2
][
Mma_tile_p
::
MMAS_M
];
smem_q
.
load
(
frag_q
[
0
],
0
);
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename
Smem_tile_k
::
Fragment
frag_k
[
Mma_tile_p
::
MMAS_K
][
Mma_tile_p
::
MMAS_N
];
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
smem_k
.
load
(
frag_k
[
ki
],
ki
);
}
// Commit the data for V to shared memory if it has not been done already.
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
// Make sure we are done loading the fragments for K.
__syncthreads
();
// Commit the data to shared memory for V.
gmem_v
.
commit
(
smem_v
);
// Make sure the data is in shared memory.
__syncthreads
();
}
// Load the fragments for V. We keep the data in registers during the entire kernel.
typename
Smem_tile_v
::
Fragment
frag_v
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_N
];
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
smem_v
.
load
(
frag_v
[
ki
],
ki
);
}
enum
{
BITS_PER_ELT_S
=
sizeof
(
fmha
::
A_type
)
*
8
};
Gmem_tile_s
gmem_s
(
params
.
s_ptr
,
params
,
tidx
);
// Create the object to do the softmax.
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
Softmax
softmax
(
params
,
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
],
bidb
,
tidx
);
enum
{
THREADS_PER_ROW
=
32
};
enum
{
STEPS
=
Cta_tile_p
::
N
/
Cta_tile_p
::
M
};
// Load over the entire sequence length.
for
(
int
l
=
0
;
l
<
STEPS
;
l
++
)
{
const
int
loop
=
l
*
Cta_tile_p
::
M
;
if
(
loop
>=
binfo
.
actual_seqlen
)
break
;
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_p
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_p
);
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_q
.
load
(
frag_q
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_p
,
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm
(
acc_p
,
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
// Load the mask for that iteration.
mask
.
load
(
l
);
// Convert from the accumulator type to FP32 for Softmax.
softmax
.
unpack
(
acc_p
);
// Apply the mask.
softmax
.
apply_mask
(
mask
);
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
&&
l
==
0
)
{
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads
();
}
// Compute the max.
float
p_max
[
Mma_tile_p
::
MMAS_M
*
2
];
softmax
.
template
reduce
<
fmha
::
Max_
>(
p_max
);
// Make sure we are done reading shared memory.
__syncthreads
();
// Compute the exponential value.
softmax
.
apply_exp
(
p_max
);
// Compute the sum.
float
p_sum
[
Mma_tile_p
::
MMAS_M
*
2
];
softmax
.
template
reduce
<
fmha
::
Sum_
>(
p_sum
);
// Finalize softmax on the accumulators of P^T.
softmax
.
scale
(
p_sum
);
if
(
Is_training
)
{
auto
encode_dropout
=
[](
bool
keep
,
float
val
)
{
return
keep
?
val
:
-
val
;
};
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
2
;
ii
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
float4
tmp
=
uniform4
(
ph
());
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from
// pre-existing zeros
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]
=
encode_dropout
(
tmp
.
x
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]
=
encode_dropout
(
tmp
.
y
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
2
]
=
encode_dropout
(
tmp
.
z
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
2
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]
=
encode_dropout
(
tmp
.
w
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]);
}
}
}
gmem_s
.
store
(
softmax
.
elt_
,
mask
);
gmem_s
.
move
();
}
// Trigger the load for the next Q values.
if
(
l
<
STEPS
-
1
)
{
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
load
(
smem_q
);
}
using
Frag_p
=
fmha
::
Fragment_a
<
fmha
::
Row
>
;
Frag_p
frag_p
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_M
];
softmax
.
pack
(
frag_p
);
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
ki
++
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_o
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Frag_p
::
NUM_REGS
;
ii
++
)
{
//"Apply" the dropout.
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hmul2
(
frag_p
[
ki
][
mi
].
reg
(
ii
),
params
.
scale_dropout
);
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hrelu2
(
frag_p
[
ki
][
mi
].
reg
(
ii
));
}
}
}
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_o
[
Mma_tile_o
::
MMAS_M
][
Mma_tile_o
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_o
::
WARPS_K
>::
apply
(
acc_o
);
// Do this part of O = P^T * V^T.
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
fmha
::
gemm
(
acc_o
,
frag_p
[
ki
],
frag_v
[
ki
]);
}
// Loop over MMAS_M.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Gmem_tile_o
::
LOOPS
;
++
ii
)
{
// Swizzle the elements and do the final reduction.
smem_o
.
store
(
acc_o
,
ii
);
// Make sure the data is in shared memory.
__syncthreads
();
// Load from shared memory.
uint4
out
[
Gmem_tile_o
::
STGS_PER_LOOP
];
smem_o
.
load
(
out
);
// Make sure the data was read from shared memory.
if
(
ii
<
Gmem_tile_o
::
LOOPS
-
1
)
{
__syncthreads
();
}
// Output the values.
gmem_o
.
store
(
out
,
ii
);
}
// Move to the next part of the output.
gmem_o
.
move
();
// Commit the values for Q into shared memory.
if
(
l
<
STEPS
-
1
)
{
gmem_q
.
commit
(
smem_q
);
}
// Make sure the data is in shared memory.
__syncthreads
();
// Trigger the loads for the values of Q for the next iteration.
smem_q
.
load
(
frag_q
[
0
],
0
);
}
// Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_nl.h
0 → 100644
View file @
f79993d9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
CHUNKS
,
typename
Kernel_traits
,
bool
Is_training
,
typename
Params
>
inline
__device__
void
device_1xN_nl
(
const
Params
&
params
)
{
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
// The description of the CTA tile for the 2nd batched GEMM.
using
Cta_tile_o
=
typename
Kernel_traits
::
Cta_tile_o
;
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
// The MMA tile for the 2nd GEMM.
using
Mma_tile_o
=
fmha
::
Hmma_tile
<
Cta_tile_o
>
;
// The global memory tile to load Q.
using
Gmem_tile_q
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The shared memory tile to swizzle Q.
using
Smem_tile_q
=
typename
Kernel_traits
::
Smem_tile_q
;
// The global memory tile to load K.
using
Gmem_tile_k
=
typename
Kernel_traits
::
Gmem_tile_k
;
// The shared memory tile to swizzle K.
using
Smem_tile_k
=
typename
Kernel_traits
::
Smem_tile_k
;
// The global memory tile to load V.
using
Gmem_tile_v
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_v
;
// The global memory tile to store O.
using
Gmem_tile_o
=
typename
Kernel_traits
::
Gmem_tile_o
;
// The shared memory tile to swizzle O.
using
Smem_tile_o
=
typename
Kernel_traits
::
Smem_tile_o
;
// The global memory tile to store S/D.
using
Gmem_tile_s
=
typename
Kernel_traits
::
Gmem_tile_s
;
using
Noloop
=
Noloop_traits
<
CHUNKS
,
Cta_tile_p
>
;
// Shared memory.
extern
__shared__
char
smem_
[];
const
int
bidc
=
blockIdx
.
z
;
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
Noloop
nl_traits
(
bidc
);
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
if
(
binfo
.
stop_early
()
)
return
;
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
Philox
ph
(
std
::
get
<
0
>
(
seeds
),
binfo
.
tidx_global
,
std
::
get
<
1
>
(
seeds
));
fmha
::
Mask
<
Cta_tile_p
>
mask
(
params
,
binfo
,
tidx
);
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for Q.
Smem_tile_q
smem_q
(
&
smem_
[
0
],
tidx
);
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for K.
Smem_tile_k
smem_k
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
,
2
,
binfo
,
tidx
);
// The base pointer of smem_v;
char
*
smem_v_
=
nullptr
;
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
smem_v_
=
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
];
}
else
{
smem_v_
=
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
];
}
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v
smem_v
(
smem_v_
,
tidx
);
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o
smem_o
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
Gmem_tile_s
gmem_s
(
params
.
s_ptr
,
params
,
tidx
);
nl_traits
.
move_all
(
gmem_q
,
gmem_o
,
gmem_s
);
// Trigger the loads for Q.
gmem_q
.
load
(
smem_q
);
// Trigger the loads for K.
gmem_k
.
load
(
smem_k
);
// Trigger the loads for K.
gmem_v
.
load
(
smem_v
);
// Commit the data for Q and K to shared memory.
gmem_q
.
commit
(
smem_q
);
gmem_k
.
commit
(
smem_k
);
// Commit the data for V to shared memory.
if
(
!
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
gmem_v
.
commit
(
smem_v
);
}
// Make sure the data is in shared memory.
__syncthreads
();
// Load the fragments for Q.
typename
Smem_tile_q
::
Fragment
frag_q
[
2
][
Mma_tile_p
::
MMAS_M
];
smem_q
.
load
(
frag_q
[
0
],
0
);
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename
Smem_tile_k
::
Fragment
frag_k
[
Mma_tile_p
::
MMAS_K
][
Mma_tile_p
::
MMAS_N
];
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
smem_k
.
load
(
frag_k
[
ki
],
ki
);
}
// Commit the data for V to shared memory if it has not been done already.
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
// Make sure we are done loading the fragments for K.
__syncthreads
();
// Commit the data to shared memory for V.
gmem_v
.
commit
(
smem_v
);
// Make sure the data is in shared memory.
__syncthreads
();
}
// Load the fragments for V. We keep the data in registers during the entire kernel.
typename
Smem_tile_v
::
Fragment
frag_v
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_N
];
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
smem_v
.
load
(
frag_v
[
ki
],
ki
);
}
enum
{
BITS_PER_ELT_S
=
sizeof
(
fmha
::
A_type
)
*
8
};
// Create the object to do the softmax.
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
Softmax
softmax
(
params
,
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
],
bidb
,
tidx
);
// The number of threads per row.
enum
{
THREADS_PER_ROW
=
32
};
// Load over the entire sequence length.
for
(
int
l
=
0
;
l
<
nl_traits
.
num_steps_
;
l
++
)
{
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_p
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_p
);
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_q
.
load
(
frag_q
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_p
,
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm
(
acc_p
,
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
// Trigger the load for the next Q values.
if
(
l
<
nl_traits
.
num_steps_
-
1
)
{
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
load
(
smem_q
);
}
// Load the mask for that iteration.
mask
.
load
(
nl_traits
.
loop_offset_
+
l
);
// Convert from the accumulator type to FP32 for Softmax.
softmax
.
unpack
(
acc_p
);
// Apply the mask.
softmax
.
apply_mask
(
mask
);
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
&&
l
==
0
)
{
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads
();
}
// Compute the max.
float
p_max
[
Mma_tile_p
::
MMAS_M
*
2
];
softmax
.
template
reduce
<
fmha
::
Max_
>(
p_max
);
// Make sure we are done reading shared memory.
__syncthreads
();
// Compute the exponential value.
softmax
.
apply_exp
(
p_max
);
// Compute the sum.
float
p_sum
[
Mma_tile_p
::
MMAS_M
*
2
];
softmax
.
template
reduce
<
fmha
::
Sum_
>(
p_sum
);
// Finalize softmax on the accumulators of P^T.
softmax
.
scale
(
p_sum
);
if
(
Is_training
)
{
auto
encode_dropout
=
[](
bool
keep
,
float
val
)
{
return
keep
?
val
:
-
val
;
};
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
2
;
ii
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
float4
tmp
=
uniform4
(
ph
());
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from pre-existing zeros
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]
=
encode_dropout
(
tmp
.
x
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]
=
encode_dropout
(
tmp
.
y
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
2
]
=
encode_dropout
(
tmp
.
z
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
2
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]
=
encode_dropout
(
tmp
.
w
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]);
}
}
}
gmem_s
.
store
(
softmax
.
elt_
,
mask
);
gmem_s
.
move
();
}
using
Frag_p
=
fmha
::
Fragment_a
<
fmha
::
Row
>
;
Frag_p
frag_p
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_M
];
softmax
.
pack
(
frag_p
);
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
ki
++
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_o
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Frag_p
::
NUM_REGS
;
ii
++
)
{
//"Apply" the dropout.
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hmul2
(
frag_p
[
ki
][
mi
].
reg
(
ii
),
params
.
scale_dropout
);
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hrelu2
(
frag_p
[
ki
][
mi
].
reg
(
ii
));
}
}
}
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_o
[
Mma_tile_o
::
MMAS_M
][
Mma_tile_o
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_o
::
WARPS_K
>::
apply
(
acc_o
);
// Do this part of O = P^T * V^T.
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
fmha
::
gemm
(
acc_o
,
frag_p
[
ki
],
frag_v
[
ki
]);
}
// Loop over MMAS_M.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Gmem_tile_o
::
LOOPS
;
++
ii
)
{
// Swizzle the elements and do the final reduction.
smem_o
.
store
(
acc_o
,
ii
);
// Make sure the data is in shared memory.
__syncthreads
();
// Load from shared memory.
uint4
out
[
Gmem_tile_o
::
STGS_PER_LOOP
];
smem_o
.
load
(
out
);
// Make sure the data was read from shared memory.
if
(
ii
<
Gmem_tile_o
::
LOOPS
-
1
)
{
__syncthreads
();
}
// Output the values.
gmem_o
.
store
(
out
,
ii
);
}
// Move to the next part of the output.
gmem_o
.
move
();
// Commit the values for Q into shared memory.
if
(
l
<
nl_traits
.
num_steps_
-
1
)
{
gmem_q
.
commit
(
smem_q
);
__syncthreads
();
smem_q
.
load
(
frag_q
[
0
],
0
);
}
}
// Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_reload_v.h
0 → 100644
View file @
f79993d9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_training
,
typename
Params
>
inline
__device__
void
device_1xN
(
const
Params
&
params
)
{
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
// The description of the CTA tile for the 2nd batched GEMM.
using
Cta_tile_o
=
typename
Kernel_traits
::
Cta_tile_o
;
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
// The MMA tile for the 2nd GEMM.
using
Mma_tile_o
=
fmha
::
Hmma_tile
<
Cta_tile_o
>
;
// The global memory tile to load Q.
using
Gmem_tile_q
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The shared memory tile to swizzle Q.
using
Smem_tile_q
=
typename
Kernel_traits
::
Smem_tile_q
;
// The global memory tile to load K.
using
Gmem_tile_k
=
typename
Kernel_traits
::
Gmem_tile_k
;
// The shared memory tile to swizzle K.
using
Smem_tile_k
=
typename
Kernel_traits
::
Smem_tile_k
;
// The global memory tile to load V.
using
Gmem_tile_v
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_v
;
// The global memory tile to store O.
using
Gmem_tile_o
=
typename
Kernel_traits
::
Gmem_tile_o
;
// The shared memory tile to swizzle O.
using
Smem_tile_o
=
typename
Kernel_traits
::
Smem_tile_o
;
using
Gmem_tile_s
=
typename
Kernel_traits
::
Gmem_tile_s
;
// Shared memory.
extern
__shared__
char
smem_
[];
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
if
(
binfo
.
stop_early
()
)
return
;
Mask
<
Cta_tile_p
>
mask
(
params
,
binfo
,
tidx
);
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
Philox
ph
(
std
::
get
<
0
>
(
seeds
),
binfo
.
tidx_global
,
std
::
get
<
1
>
(
seeds
));
static_assert
(
2
*
Mma_tile_p
::
MMAS_M
*
4
*
Mma_tile_p
::
MMAS_N
<=
64
);
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for K.
Smem_tile_k
smem_k
(
&
smem_
[
0
],
tidx
);
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
,
2
,
binfo
,
tidx
);
// The base pointer of smem_v;
char
*
smem_v_
=
nullptr
;
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
smem_v_
=
&
smem_
[
0
];
}
else
{
smem_v_
=
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
];
}
static_assert
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
);
static_assert
(
Smem_tile_k
::
BYTES_PER_TILE
==
Smem_tile_v
::
BYTES_PER_TILE
);
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v
smem_v
(
smem_v_
,
tidx
);
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for Q.
Smem_tile_q
smem_q
(
&
smem_
[
Smem_tile_v
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o
smem_o
(
&
smem_
[
Smem_tile_v
::
BYTES_PER_TILE
],
tidx
);
// Trigger the loads for Q.
gmem_q
.
load
(
smem_q
);
// Trigger the loads for K.
gmem_k
.
load
(
smem_k
);
// Trigger the loads for K.
gmem_v
.
load
(
smem_v
);
// Commit the data for Q and K to shared memory.
gmem_q
.
commit
(
smem_q
);
gmem_k
.
commit
(
smem_k
);
// Commit the data for V to shared memory.
if
(
!
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
gmem_v
.
commit
(
smem_v
);
}
// Make sure the data is in shared memory.
__syncthreads
();
// Load the fragments for Q.
typename
Smem_tile_q
::
Fragment
frag_q
[
1
][
Mma_tile_p
::
MMAS_M
];
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename
Smem_tile_k
::
Fragment
frag_k
[
Mma_tile_p
::
MMAS_K
][
Mma_tile_p
::
MMAS_N
];
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
smem_k
.
load
(
frag_k
[
ki
],
ki
);
}
// Commit the data for V to shared memory if it has not been done already.
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
// Make sure we are done loading the fragments for K.
__syncthreads
();
// Commit the data to shared memory for V.
gmem_v
.
commit
(
smem_v
);
}
enum
{
BITS_PER_ELT_S
=
sizeof
(
typename
fmha
::
A_type
)
*
8
};
Gmem_tile_s
gmem_s
(
params
.
s_ptr
,
params
,
tidx
);
// Create the object to do the softmax.
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
Softmax
softmax
(
params
,
&
smem_
[
Smem_tile_v
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
],
bidb
,
tidx
);
constexpr
int
SMEM_BYTES_SOFTMAX
=
Softmax
::
ELEMENTS
*
sizeof
(
float
);
static_assert
(
SMEM_BYTES_SOFTMAX
==
Cta_tile_p
::
M
*
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
));
enum
{
THREADS_PER_ROW
=
32
};
const
float
pinv
=
1.
f
/
params
.
p_dropout
;
// Load over the entire sequence length.
for
(
int
loop
=
0
,
outer
=
0
;
loop
<
Cta_tile_p
::
N
;
loop
+=
Cta_tile_p
::
M
,
outer
++
)
{
if
(
loop
>=
binfo
.
actual_seqlen
)
break
;
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_p
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_p
);
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_q
.
load
(
frag_q
[
0
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_p
,
frag_q
[
0
],
frag_k
[
ki
]);
}
// Load the mask for that iteration.
mask
.
load
(
outer
);
// Convert from the accumulator typ e to FP32 for Softmax.
softmax
.
unpack
(
acc_p
);
// Apply the mask.
softmax
.
apply_mask
(
mask
);
static_assert
(
2
*
Mma_tile_p
::
MMAS_M
*
4
*
Mma_tile_p
::
MMAS_N
<=
64
);
// Compute the max.
float
p_max
[
Mma_tile_p
::
MMAS_M
*
2
];
softmax
.
template
reduce
<
fmha
::
Max_
>(
p_max
);
// Make sure we are done reading shared memory.
__syncthreads
();
// Compute the exponential value.
softmax
.
apply_exp
(
p_max
);
// Compute the sum.
float
p_sum
[
Mma_tile_p
::
MMAS_M
*
2
];
softmax
.
template
reduce
<
fmha
::
Sum_
>(
p_sum
);
// Finalize softmax on the accumulators of P^T.
softmax
.
scale
(
p_sum
);
__syncthreads
();
if
(
Is_training
)
{
auto
encode_dropout
=
[](
bool
keep
,
float
val
)
{
return
keep
?
val
:
-
val
;
};
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
2
;
ii
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
float4
tmp
=
uniform4
(
ph
());
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from
// pre-existing zeros
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]
=
encode_dropout
(
tmp
.
x
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]
=
encode_dropout
(
tmp
.
y
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
2
]
=
encode_dropout
(
tmp
.
z
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
2
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]
=
encode_dropout
(
tmp
.
w
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]);
}
}
}
gmem_s
.
store
(
softmax
.
elt_
,
mask
);
gmem_s
.
move
();
}
// Trigger the load for the next Q values.
if
(
loop
+
Cta_tile_p
::
M
<
Cta_tile_p
::
N
)
{
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
load
(
smem_q
);
}
typename
Smem_tile_v
::
Fragment
frag_v
[
1
][
Mma_tile_o
::
MMAS_N
];
using
Frag_p
=
fmha
::
Fragment_a
<
fmha
::
Row
>
;
Frag_p
frag_p
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_M
];
softmax
.
pack
(
frag_p
);
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
ki
++
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_o
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Frag_p
::
NUM_REGS
;
ii
++
)
{
//"Apply" the dropout.
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hmul2
(
frag_p
[
ki
][
mi
].
reg
(
ii
),
params
.
scale_dropout
);
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hrelu2
(
frag_p
[
ki
][
mi
].
reg
(
ii
));
}
}
}
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_o
[
Mma_tile_o
::
MMAS_M
][
Mma_tile_o
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_o
::
WARPS_K
>::
apply
(
acc_o
);
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of V values.
smem_v
.
load
(
frag_v
[
0
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_o
,
frag_p
[
ki
],
frag_v
[
0
]);
}
// Loop over MMAS_M.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Gmem_tile_o
::
LOOPS
;
++
ii
)
{
// Swizzle the elements and do the final reduction.
smem_o
.
store
(
acc_o
,
ii
);
// Make sure the data is in shared memory.
__syncthreads
();
// Load from shared memory.
uint4
out
[
Gmem_tile_o
::
STGS_PER_LOOP
];
smem_o
.
load
(
out
);
// Always sync after last iter: shared smem_q and smem_o!
__syncthreads
();
// Output the values.
gmem_o
.
store
(
out
,
ii
);
}
// same smem as o
// Move to the next part of the output.
gmem_o
.
move
();
// Commit the values for Q into shared memory.
if
(
loop
+
Cta_tile_p
::
M
<
Cta_tile_p
::
N
)
{
gmem_q
.
commit
(
smem_q
);
}
// Make sure the data is in shared memory.
__syncthreads
();
}
// Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha_kernel.h
0 → 100644
View file @
f79993d9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <multihead_attn/philox.h>
#include <fmha.h>
#include <fmha/utils.h>
#include <fmha/smem_tile.h>
#include <fmha/gmem_tile.h>
#include <fmha/mask.h>
#include <fmha/softmax.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS_PER_CTA
>
struct
BlockInfoPadded
{
template
<
typename
Params
>
__device__
BlockInfoPadded
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
tidx
)
:
bidb
(
bidb
),
bidh
(
bidh
),
h
(
params
.
h
)
{
// The block index.
sum_s
=
params
.
cu_seqlens
[
bidb
];
actual_seqlen
=
params
.
cu_seqlens
[
bidb
+
1
]
-
sum_s
;
bidx
=
sum_s
*
params
.
h
+
bidh
;
tidx_global
=
(
bidb
*
params
.
h
+
bidh
)
*
THREADS_PER_CTA
+
tidx
;
}
__device__
bool
stop_early
()
const
{
return
actual_seqlen
==
0
;
}
int
actual_seqlen
;
int
bidx
;
int
sum_s
;
int
bidh
;
int
bidb
;
int
tidx_global
;
int
h
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
CHUNKS
,
typename
Cta_tile
>
struct
Noloop_traits
{
// Interpretation of Cta_tile dims, i.e. Cta_tile_p:
enum
{
STEP
=
Cta_tile
::
M
};
enum
{
SEQLEN
=
Cta_tile
::
N
};
// The size of the subsequence this CTA is processing
enum
{
SUBSEQ
=
SEQLEN
/
CHUNKS
};
static_assert
(
SUBSEQ
*
CHUNKS
==
SEQLEN
);
// The number of steps to process the subsequence
enum
{
NUM_STEPS
=
SUBSEQ
/
STEP
};
static_assert
(
NUM_STEPS
*
Cta_tile
::
M
==
SUBSEQ
);
inline
__device__
Noloop_traits
(
const
int
bidc
)
:
loop_offset_
(
NUM_STEPS
*
bidc
)
,
bidc_
(
bidc
)
{
}
template
<
typename
...
Tiles
>
inline
__device__
void
move_all
(
Tiles
&
...
tiles
)
const
{
using
expand_type
=
int
[];
for
(
int
s
=
0
;
s
<
loop_offset_
;
s
++
)
{
expand_type
{
(
tiles
.
move
(),
0
)...
};
}
}
inline
__device__
int
get_idx_dk
()
const
{
//return bidc_;
return
bidc_
*
2
+
0
;
}
inline
__device__
int
get_idx_dv
()
const
{
//return CHUNKS + bidc_;
return
bidc_
*
2
+
1
;
}
inline
__device__
int
offset_loop_count
(
const
int
l
)
{
// convert loop counter to position in the outer sequence
return
(
loop_offset_
+
l
)
*
STEP
;
}
const
int
loop_offset_
;
const
uint32_t
bidc_
;
const
int
num_steps_
=
NUM_STEPS
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
>
struct
Noloop_traits
<
3
,
Cta_tile
>
{
// Interpretation of Cta_tile dims, i.e. Cta_tile_p:
enum
{
STEP
=
Cta_tile
::
M
};
enum
{
SEQLEN
=
Cta_tile
::
N
};
static_assert
(
STEP
==
16
&&
SEQLEN
==
512
);
inline
__device__
Noloop_traits
(
const
int
bidc
)
:
bidc_
(
bidc
)
,
num_steps_
(
bidc
<
2
?
11
:
10
)
,
loop_offset_
(
bidc
*
11
)
{
}
template
<
typename
...
Tiles
>
inline
__device__
void
move_all
(
Tiles
&
...
tiles
)
const
{
using
expand_type
=
int
[];
for
(
int
s
=
0
;
s
<
loop_offset_
;
s
++
)
{
expand_type
{
(
tiles
.
move
(),
0
)...
};
}
}
inline
__device__
int
get_idx_dk
()
const
{
//return bidc_;
return
bidc_
*
2
+
0
;
}
inline
__device__
int
get_idx_dv
()
const
{
//return CHUNKS + bidc_;
return
bidc_
*
2
+
1
;
}
inline
__device__
int
offset_loop_count
(
const
int
l
)
{
// convert loop counter to position in the outer sequence
return
(
loop_offset_
+
l
)
*
STEP
;
}
const
int
loop_offset_
;
const
uint32_t
bidc_
;
const
int
num_steps_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu
0 → 100644
View file @
f79993d9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
inline
__device__
float4
ldg128
(
const
void
*
ptr
)
{
return
*
static_cast
<
const
float4
*>
(
ptr
);
}
inline
__device__
void
stg128
(
void
*
ptr
,
const
float4
&
data
)
{
*
static_cast
<
float4
*>
(
ptr
)
=
data
;
}
template
<
typename
T
,
int
THREADS
,
int
HIDDEN_SIZE
,
int
CHUNKS
>
__global__
__launch_bounds__
(
THREADS
)
void
fmha_noloop_reduce_kernel
(
void
*
__restrict__
out
,
const
void
*
__restrict__
in
,
const
int
*
__restrict__
cu_seqlens
,
const
int
batch_size
)
{
enum
{
BYTES_PER_LDG
=
16
};
enum
{
NUM_ELTS
=
BYTES_PER_LDG
/
sizeof
(
T
)
};
// One CTA hidden vector for K and V
enum
{
BYTES_PER_ROW
=
HIDDEN_SIZE
*
sizeof
(
T
)
*
2
};
// The stride in bytes in dQKV
enum
{
OUT_STRIDE_BYTES
=
3
*
HIDDEN_SIZE
*
sizeof
(
T
)
};
// The offset in bytes in dQKV to the dKV part for non-interleaved heads
enum
{
OUT_OFFSET_KV_BYTES
=
HIDDEN_SIZE
*
sizeof
(
T
)
};
static_assert
(
BYTES_PER_ROW
==
HIDDEN_SIZE
*
2
*
sizeof
(
T
));
// Size in bytes of the input tile
enum
{
BYTES_PER_TILE
=
CHUNKS
*
BYTES_PER_ROW
};
enum
{
BYTES_PER_CTA
=
THREADS
*
BYTES_PER_LDG
};
enum
{
LDGS
=
BYTES_PER_ROW
/
BYTES_PER_CTA
};
static_assert
(
BYTES_PER_CTA
*
LDGS
==
BYTES_PER_ROW
);
union
Vec_t
{
float4
raw
;
T
elt
[
NUM_ELTS
];
};
// ZERO-OUT invalid positions in dQKV
const
int
total
=
cu_seqlens
[
batch_size
];
if
(
blockIdx
.
x
>=
total
){
enum
{
BYTES_PER_QKV_ROW
=
3
*
HIDDEN_SIZE
*
sizeof
(
T
)
};
enum
{
STGS
=
BYTES_PER_QKV_ROW
/
BYTES_PER_LDG
};
const
float4
zeros
=
make_float4
(
0.
f
,
0.
f
,
0.
f
,
0.
f
);
char
*
base_ptr
=
static_cast
<
char
*>
(
out
)
+
blockIdx
.
x
*
OUT_STRIDE_BYTES
;
for
(
int
tidx
=
threadIdx
.
x
;
tidx
<
STGS
;
tidx
+=
THREADS
){
stg128
(
base_ptr
+
tidx
*
BYTES_PER_LDG
,
zeros
);
}
return
;
}
// SETUP
const
int
offset_in
=
blockIdx
.
x
*
BYTES_PER_TILE
+
threadIdx
.
x
*
BYTES_PER_LDG
;
const
char
*
ptr_in
=
static_cast
<
const
char
*>
(
in
)
+
offset_in
;
const
int
offset_out
=
blockIdx
.
x
*
OUT_STRIDE_BYTES
+
threadIdx
.
x
*
BYTES_PER_LDG
;
char
*
ptr_out
=
static_cast
<
char
*>
(
out
)
+
OUT_OFFSET_KV_BYTES
+
offset_out
;
// LOAD
Vec_t
local_in
[
CHUNKS
][
LDGS
];
#pragma unroll
for
(
int
c
=
0
;
c
<
CHUNKS
;
c
++
)
{
#pragma unroll
for
(
int
l
=
0
;
l
<
LDGS
;
l
++
)
{
int
offset
=
c
*
BYTES_PER_ROW
+
l
*
BYTES_PER_CTA
;
local_in
[
c
][
l
].
raw
=
ldg128
(
ptr_in
+
offset
);
}
}
// UNPACK
float
acc
[
LDGS
][
NUM_ELTS
];
#pragma unroll
for
(
int
l
=
0
;
l
<
LDGS
;
l
++
)
{
#pragma unroll
for
(
int
e
=
0
;
e
<
NUM_ELTS
;
e
++
)
{
acc
[
l
][
e
]
=
float
(
local_in
[
0
][
l
].
elt
[
e
]);
}
}
// COMPUTE
#pragma unroll
for
(
int
c
=
1
;
c
<
CHUNKS
;
c
++
)
{
#pragma unroll
for
(
int
l
=
0
;
l
<
LDGS
;
l
++
)
{
#pragma unroll
for
(
int
e
=
0
;
e
<
NUM_ELTS
;
e
++
)
{
acc
[
l
][
e
]
+=
float
(
local_in
[
c
][
l
].
elt
[
e
]);
}
}
}
// PACK
Vec_t
local_out
[
LDGS
];
#pragma unroll
for
(
int
l
=
0
;
l
<
LDGS
;
l
++
)
{
#pragma unroll
for
(
int
e
=
0
;
e
<
NUM_ELTS
;
e
++
)
{
local_out
[
l
].
elt
[
e
]
=
T
(
acc
[
l
][
e
]);
}
}
// STORE
#pragma unroll
for
(
int
l
=
0
;
l
<
LDGS
;
l
++
)
{
const
int
offset
=
l
*
BYTES_PER_CTA
;
stg128
(
ptr_out
+
offset
,
local_out
[
l
].
raw
);
}
}
void
fmha_run_noloop_reduce
(
void
*
out
,
const
void
*
in
,
const
int
*
cu_seqlens
,
const
int
hidden_size
,
const
int
batch_size
,
const
int
total
,
const
int
num_chunks
,
cudaStream_t
stream
)
{
const
int
blocks
=
total
;
if
(
hidden_size
==
1024
){
constexpr
int
HIDDEN_SIZE
=
1024
;
constexpr
int
THREADS
=
256
;
if
(
num_chunks
==
2
)
{
fmha_noloop_reduce_kernel
<
half
,
THREADS
,
HIDDEN_SIZE
,
2
><<<
blocks
,
THREADS
,
0
,
stream
>>>
(
out
,
in
,
cu_seqlens
,
batch_size
);
}
else
if
(
num_chunks
==
3
)
{
fmha_noloop_reduce_kernel
<
half
,
THREADS
,
HIDDEN_SIZE
,
3
><<<
blocks
,
THREADS
,
0
,
stream
>>>
(
out
,
in
,
cu_seqlens
,
batch_size
);
}
else
{
assert
(
false
&&
"Unsupported num_chunks"
);
}
}
else
{
assert
(
false
&&
"Unsupported hidden_size"
);
}
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
apex/contrib/csrc/fmha/src/fmha_utils.h
0 → 100644
View file @
f79993d9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
#define FMHA_CHECK_CUDA( call ) \
do { \
cudaError_t status_ = call; \
if( status_ != cudaSuccess ) { \
fprintf( stderr, \
"CUDA error (%s:%d): %s\n", \
__FILE__, \
__LINE__, \
cudaGetErrorString( status_ ) ); \
exit( 1 ); \
} \
} while( 0 )
////////////////////////////////////////////////////////////////////////////////////////////////////
enum
Data_type
{
DATA_TYPE_FP16
,
DATA_TYPE_FP32
,
DATA_TYPE_INT32
,
DATA_TYPE_INT8
};
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
void
set_alpha
(
uint32_t
&
alpha
,
float
norm
,
Data_type
dtype
)
{
if
(
dtype
==
DATA_TYPE_FP16
)
{
half
x
=
__float2half_rn
(
norm
);
uint16_t
h
=
reinterpret_cast
<
const
uint16_t
&>
(
x
);
ushort2
h2
=
{
h
,
h
};
alpha
=
reinterpret_cast
<
const
uint32_t
&>
(
h2
);
}
else
if
(
dtype
==
DATA_TYPE_FP32
)
{
alpha
=
reinterpret_cast
<
const
uint32_t
&>
(
norm
);
}
else
if
(
dtype
==
DATA_TYPE_INT32
)
{
int32_t
inorm
=
static_cast
<
int32_t
>
(
norm
);
alpha
=
reinterpret_cast
<
const
uint32_t
&>
(
inorm
);
}
else
{
assert
(
false
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
size_t
get_size_in_bytes
(
size_t
n
,
Data_type
dtype
)
{
switch
(
dtype
)
{
case
DATA_TYPE_FP32
:
return
n
*
4
;
case
DATA_TYPE_FP16
:
return
n
*
2
;
case
DATA_TYPE_INT32
:
return
n
*
4
;
case
DATA_TYPE_INT8
:
return
n
;
default:
assert
(
false
);
return
0
;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
apex/contrib/csrc/groupbn/batch_norm.cu
View file @
f79993d9
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <
THC/THCNumerics.cu
h>
#include <
c10/cuda/CUDACachingAllocator.
h>
#include "THC/THC.h"
...
...
@@ -26,23 +26,20 @@ static size_t round_up_to_multiple(size_t x, int multiple) {
return
((
x
+
multiple
-
1
)
/
multiple
)
*
multiple
;
}
// TODO: Stop manually allocating CUDA memory; allocate an ATen byte
// tensor instead.
struct
Workspace
{
Workspace
(
size_t
size
)
:
size
(
size
),
data
(
NULL
)
{
data
=
THCudaMalloc
(
at
::
globalContext
().
lazyInitCUDA
(),
size
);
auto
&
allocator
=
*::
c10
::
cuda
::
CUDACachingAllocator
::
get
();
dataPtr
=
allocator
.
allocate
(
size
);
data
=
dataPtr
.
get
();
}
Workspace
(
const
Workspace
&
)
=
delete
;
Workspace
(
Workspace
&&
)
=
default
;
Workspace
&
operator
=
(
Workspace
&&
)
=
default
;
~
Workspace
()
{
if
(
data
)
{
THCudaFree
(
at
::
globalContext
().
lazyInitCUDA
(),
data
);
}
}
~
Workspace
()
=
default
;
size_t
size
;
void
*
data
;
c10
::
DataPtr
dataPtr
;
};
// Return {y}
...
...
apex/contrib/csrc/groupbn/batch_norm.h
View file @
f79993d9
...
...
@@ -31,6 +31,7 @@
#include <algorithm>
#include <vector>
#include <string>
#include <iostream>
#include "nhwc_batch_norm_kernel.h"
#include "cuda_utils.h"
...
...
apex/contrib/csrc/groupbn/batch_norm_add_relu.cu
View file @
f79993d9
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <
THC/THCNumerics.cu
h>
#include <
c10/cuda/CUDACachingAllocator.
h>
#include "THC/THC.h"
...
...
@@ -27,23 +27,20 @@ static size_t round_up_to_multiple(size_t x, int multiple) {
return
((
x
+
multiple
-
1
)
/
multiple
)
*
multiple
;
}
// TODO: Stop manually allocating CUDA memory; allocate an ATen byte
// tensor instead.
struct
Workspace
{
Workspace
(
size_t
size
)
:
size
(
size
),
data
(
NULL
)
{
data
=
THCudaMalloc
(
at
::
globalContext
().
lazyInitCUDA
(),
size
);
auto
&
allocator
=
*::
c10
::
cuda
::
CUDACachingAllocator
::
get
();
dataPtr
=
allocator
.
allocate
(
size
);
data
=
dataPtr
.
get
();
}
Workspace
(
const
Workspace
&
)
=
delete
;
Workspace
(
Workspace
&&
)
=
default
;
Workspace
&
operator
=
(
Workspace
&&
)
=
default
;
~
Workspace
()
{
if
(
data
)
{
THCudaFree
(
at
::
globalContext
().
lazyInitCUDA
(),
data
);
}
}
~
Workspace
()
=
default
;
size_t
size
;
void
*
data
;
c10
::
DataPtr
dataPtr
;
};
// Return {y}
...
...
apex/contrib/csrc/groupbn/batch_norm_add_relu.h
View file @
f79993d9
...
...
@@ -31,6 +31,7 @@
#include <algorithm>
#include <vector>
#include <string>
#include <iostream>
#include "nhwc_batch_norm_kernel.h"
#include "cuda_utils.h"
...
...
apex/contrib/csrc/groupbn/ipc.cu
View file @
f79993d9
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCNumerics.cuh>
#include "THC/THC.h"
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment