Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
3c88451a
Unverified
Commit
3c88451a
authored
Mar 25, 2022
by
yjk21
Committed by
GitHub
Mar 25, 2022
Browse files
update fmha (#1344)
parent
a0ed4151
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
81 additions
and
727 deletions
+81
-727
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_nl.h
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_nl.h
+0
-343
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_reload_v.h
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_reload_v.h
+0
-322
apex/contrib/csrc/fmha/src/fmha_kernel.h
apex/contrib/csrc/fmha/src/fmha_kernel.h
+58
-48
apex/contrib/csrc/multihead_attn/philox.cuh
apex/contrib/csrc/multihead_attn/philox.cuh
+17
-8
apex/contrib/fmha/fmha.py
apex/contrib/fmha/fmha.py
+4
-3
apex/contrib/test/fmha/test_fmha.py
apex/contrib/test/fmha/test_fmha.py
+2
-3
No files found.
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_nl.h
deleted
100644 → 0
View file @
a0ed4151
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
CHUNKS
,
typename
Kernel_traits
,
bool
Is_training
,
typename
Params
>
inline
__device__
void
device_1xN_nl
(
const
Params
&
params
)
{
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
// The description of the CTA tile for the 2nd batched GEMM.
using
Cta_tile_o
=
typename
Kernel_traits
::
Cta_tile_o
;
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
// The MMA tile for the 2nd GEMM.
using
Mma_tile_o
=
fmha
::
Hmma_tile
<
Cta_tile_o
>
;
// The global memory tile to load Q.
using
Gmem_tile_q
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The shared memory tile to swizzle Q.
using
Smem_tile_q
=
typename
Kernel_traits
::
Smem_tile_q
;
// The global memory tile to load K.
using
Gmem_tile_k
=
typename
Kernel_traits
::
Gmem_tile_k
;
// The shared memory tile to swizzle K.
using
Smem_tile_k
=
typename
Kernel_traits
::
Smem_tile_k
;
// The global memory tile to load V.
using
Gmem_tile_v
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_v
;
// The global memory tile to store O.
using
Gmem_tile_o
=
typename
Kernel_traits
::
Gmem_tile_o
;
// The shared memory tile to swizzle O.
using
Smem_tile_o
=
typename
Kernel_traits
::
Smem_tile_o
;
// The global memory tile to store S/D.
using
Gmem_tile_s
=
typename
Kernel_traits
::
Gmem_tile_s
;
using
Noloop
=
Noloop_traits
<
CHUNKS
,
Cta_tile_p
>
;
// Shared memory.
extern
__shared__
char
smem_
[];
const
int
bidc
=
blockIdx
.
z
;
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
Noloop
nl_traits
(
bidc
);
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
if
(
binfo
.
stop_early
()
)
return
;
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
Philox
ph
(
std
::
get
<
0
>
(
seeds
),
binfo
.
tidx_global
,
std
::
get
<
1
>
(
seeds
));
fmha
::
Mask
<
Cta_tile_p
>
mask
(
params
,
binfo
,
tidx
);
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for Q.
Smem_tile_q
smem_q
(
&
smem_
[
0
],
tidx
);
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for K.
Smem_tile_k
smem_k
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
,
2
,
binfo
,
tidx
);
// The base pointer of smem_v;
char
*
smem_v_
=
nullptr
;
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
smem_v_
=
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
];
}
else
{
smem_v_
=
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
];
}
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v
smem_v
(
smem_v_
,
tidx
);
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o
smem_o
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
Gmem_tile_s
gmem_s
(
params
.
s_ptr
,
params
,
tidx
);
nl_traits
.
move_all
(
gmem_q
,
gmem_o
,
gmem_s
);
// Trigger the loads for Q.
gmem_q
.
load
(
smem_q
);
// Trigger the loads for K.
gmem_k
.
load
(
smem_k
);
// Trigger the loads for K.
gmem_v
.
load
(
smem_v
);
// Commit the data for Q and K to shared memory.
gmem_q
.
commit
(
smem_q
);
gmem_k
.
commit
(
smem_k
);
// Commit the data for V to shared memory.
if
(
!
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
gmem_v
.
commit
(
smem_v
);
}
// Make sure the data is in shared memory.
__syncthreads
();
// Load the fragments for Q.
typename
Smem_tile_q
::
Fragment
frag_q
[
2
][
Mma_tile_p
::
MMAS_M
];
smem_q
.
load
(
frag_q
[
0
],
0
);
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename
Smem_tile_k
::
Fragment
frag_k
[
Mma_tile_p
::
MMAS_K
][
Mma_tile_p
::
MMAS_N
];
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
smem_k
.
load
(
frag_k
[
ki
],
ki
);
}
// Commit the data for V to shared memory if it has not been done already.
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
// Make sure we are done loading the fragments for K.
__syncthreads
();
// Commit the data to shared memory for V.
gmem_v
.
commit
(
smem_v
);
// Make sure the data is in shared memory.
__syncthreads
();
}
// Load the fragments for V. We keep the data in registers during the entire kernel.
typename
Smem_tile_v
::
Fragment
frag_v
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_N
];
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
smem_v
.
load
(
frag_v
[
ki
],
ki
);
}
enum
{
BITS_PER_ELT_S
=
sizeof
(
fmha
::
A_type
)
*
8
};
// Create the object to do the softmax.
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
Softmax
softmax
(
params
,
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
],
bidb
,
tidx
);
// The number of threads per row.
enum
{
THREADS_PER_ROW
=
32
};
// Load over the entire sequence length.
for
(
int
l
=
0
;
l
<
nl_traits
.
num_steps_
;
l
++
)
{
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_p
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_p
);
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_q
.
load
(
frag_q
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_p
,
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm
(
acc_p
,
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
// Trigger the load for the next Q values.
if
(
l
<
nl_traits
.
num_steps_
-
1
)
{
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
load
(
smem_q
);
}
// Load the mask for that iteration.
mask
.
load
(
nl_traits
.
loop_offset_
+
l
);
// Convert from the accumulator type to FP32 for Softmax.
softmax
.
unpack
(
acc_p
);
// Apply the mask.
softmax
.
apply_mask
(
mask
);
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
&&
l
==
0
)
{
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads
();
}
// Compute the max.
float
p_max
[
Mma_tile_p
::
MMAS_M
*
2
];
softmax
.
template
reduce
<
fmha
::
Max_
>(
p_max
);
// Make sure we are done reading shared memory.
__syncthreads
();
// Compute the exponential value.
softmax
.
apply_exp
(
p_max
);
// Compute the sum.
float
p_sum
[
Mma_tile_p
::
MMAS_M
*
2
];
softmax
.
template
reduce
<
fmha
::
Sum_
>(
p_sum
);
// Finalize softmax on the accumulators of P^T.
softmax
.
scale
(
p_sum
);
if
(
Is_training
)
{
auto
encode_dropout
=
[](
bool
keep
,
float
val
)
{
return
keep
?
val
:
-
val
;
};
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
2
;
ii
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
float4
tmp
=
uniform4
(
ph
());
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from pre-existing zeros
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]
=
encode_dropout
(
tmp
.
x
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]
=
encode_dropout
(
tmp
.
y
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
2
]
=
encode_dropout
(
tmp
.
z
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
2
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]
=
encode_dropout
(
tmp
.
w
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]);
}
}
}
gmem_s
.
store
(
softmax
.
elt_
,
mask
);
gmem_s
.
move
();
}
using
Frag_p
=
fmha
::
Fragment_a
<
fmha
::
Row
>
;
Frag_p
frag_p
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_M
];
softmax
.
pack
(
frag_p
);
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
ki
++
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_o
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Frag_p
::
NUM_REGS
;
ii
++
)
{
//"Apply" the dropout.
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hmul2
(
frag_p
[
ki
][
mi
].
reg
(
ii
),
params
.
scale_dropout
);
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hrelu2
(
frag_p
[
ki
][
mi
].
reg
(
ii
));
}
}
}
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_o
[
Mma_tile_o
::
MMAS_M
][
Mma_tile_o
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_o
::
WARPS_K
>::
apply
(
acc_o
);
// Do this part of O = P^T * V^T.
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
fmha
::
gemm
(
acc_o
,
frag_p
[
ki
],
frag_v
[
ki
]);
}
// Loop over MMAS_M.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Gmem_tile_o
::
LOOPS
;
++
ii
)
{
// Swizzle the elements and do the final reduction.
smem_o
.
store
(
acc_o
,
ii
);
// Make sure the data is in shared memory.
__syncthreads
();
// Load from shared memory.
uint4
out
[
Gmem_tile_o
::
STGS_PER_LOOP
];
smem_o
.
load
(
out
);
// Make sure the data was read from shared memory.
if
(
ii
<
Gmem_tile_o
::
LOOPS
-
1
)
{
__syncthreads
();
}
// Output the values.
gmem_o
.
store
(
out
,
ii
);
}
// Move to the next part of the output.
gmem_o
.
move
();
// Commit the values for Q into shared memory.
if
(
l
<
nl_traits
.
num_steps_
-
1
)
{
gmem_q
.
commit
(
smem_q
);
__syncthreads
();
smem_q
.
load
(
frag_q
[
0
],
0
);
}
}
// Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_reload_v.h
deleted
100644 → 0
View file @
a0ed4151
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_training
,
typename
Params
>
inline
__device__
void
device_1xN
(
const
Params
&
params
)
{
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
// The description of the CTA tile for the 2nd batched GEMM.
using
Cta_tile_o
=
typename
Kernel_traits
::
Cta_tile_o
;
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
// The MMA tile for the 2nd GEMM.
using
Mma_tile_o
=
fmha
::
Hmma_tile
<
Cta_tile_o
>
;
// The global memory tile to load Q.
using
Gmem_tile_q
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The shared memory tile to swizzle Q.
using
Smem_tile_q
=
typename
Kernel_traits
::
Smem_tile_q
;
// The global memory tile to load K.
using
Gmem_tile_k
=
typename
Kernel_traits
::
Gmem_tile_k
;
// The shared memory tile to swizzle K.
using
Smem_tile_k
=
typename
Kernel_traits
::
Smem_tile_k
;
// The global memory tile to load V.
using
Gmem_tile_v
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_v
;
// The global memory tile to store O.
using
Gmem_tile_o
=
typename
Kernel_traits
::
Gmem_tile_o
;
// The shared memory tile to swizzle O.
using
Smem_tile_o
=
typename
Kernel_traits
::
Smem_tile_o
;
using
Gmem_tile_s
=
typename
Kernel_traits
::
Gmem_tile_s
;
// Shared memory.
extern
__shared__
char
smem_
[];
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
if
(
binfo
.
stop_early
()
)
return
;
Mask
<
Cta_tile_p
>
mask
(
params
,
binfo
,
tidx
);
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
Philox
ph
(
std
::
get
<
0
>
(
seeds
),
binfo
.
tidx_global
,
std
::
get
<
1
>
(
seeds
));
static_assert
(
2
*
Mma_tile_p
::
MMAS_M
*
4
*
Mma_tile_p
::
MMAS_N
<=
64
);
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for K.
Smem_tile_k
smem_k
(
&
smem_
[
0
],
tidx
);
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
,
2
,
binfo
,
tidx
);
// The base pointer of smem_v;
char
*
smem_v_
=
nullptr
;
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
smem_v_
=
&
smem_
[
0
];
}
else
{
smem_v_
=
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
];
}
static_assert
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
);
static_assert
(
Smem_tile_k
::
BYTES_PER_TILE
==
Smem_tile_v
::
BYTES_PER_TILE
);
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v
smem_v
(
smem_v_
,
tidx
);
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for Q.
Smem_tile_q
smem_q
(
&
smem_
[
Smem_tile_v
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o
smem_o
(
&
smem_
[
Smem_tile_v
::
BYTES_PER_TILE
],
tidx
);
// Trigger the loads for Q.
gmem_q
.
load
(
smem_q
);
// Trigger the loads for K.
gmem_k
.
load
(
smem_k
);
// Trigger the loads for K.
gmem_v
.
load
(
smem_v
);
// Commit the data for Q and K to shared memory.
gmem_q
.
commit
(
smem_q
);
gmem_k
.
commit
(
smem_k
);
// Commit the data for V to shared memory.
if
(
!
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
gmem_v
.
commit
(
smem_v
);
}
// Make sure the data is in shared memory.
__syncthreads
();
// Load the fragments for Q.
typename
Smem_tile_q
::
Fragment
frag_q
[
1
][
Mma_tile_p
::
MMAS_M
];
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename
Smem_tile_k
::
Fragment
frag_k
[
Mma_tile_p
::
MMAS_K
][
Mma_tile_p
::
MMAS_N
];
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
smem_k
.
load
(
frag_k
[
ki
],
ki
);
}
// Commit the data for V to shared memory if it has not been done already.
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
// Make sure we are done loading the fragments for K.
__syncthreads
();
// Commit the data to shared memory for V.
gmem_v
.
commit
(
smem_v
);
}
enum
{
BITS_PER_ELT_S
=
sizeof
(
typename
fmha
::
A_type
)
*
8
};
Gmem_tile_s
gmem_s
(
params
.
s_ptr
,
params
,
tidx
);
// Create the object to do the softmax.
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
Softmax
softmax
(
params
,
&
smem_
[
Smem_tile_v
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
],
bidb
,
tidx
);
constexpr
int
SMEM_BYTES_SOFTMAX
=
Softmax
::
ELEMENTS
*
sizeof
(
float
);
static_assert
(
SMEM_BYTES_SOFTMAX
==
Cta_tile_p
::
M
*
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
));
enum
{
THREADS_PER_ROW
=
32
};
const
float
pinv
=
1.
f
/
params
.
p_dropout
;
// Load over the entire sequence length.
for
(
int
loop
=
0
,
outer
=
0
;
loop
<
Cta_tile_p
::
N
;
loop
+=
Cta_tile_p
::
M
,
outer
++
)
{
if
(
loop
>=
binfo
.
actual_seqlen
)
break
;
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_p
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_p
);
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_q
.
load
(
frag_q
[
0
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_p
,
frag_q
[
0
],
frag_k
[
ki
]);
}
// Load the mask for that iteration.
mask
.
load
(
outer
);
// Convert from the accumulator typ e to FP32 for Softmax.
softmax
.
unpack
(
acc_p
);
// Apply the mask.
softmax
.
apply_mask
(
mask
);
static_assert
(
2
*
Mma_tile_p
::
MMAS_M
*
4
*
Mma_tile_p
::
MMAS_N
<=
64
);
// Compute the max.
float
p_max
[
Mma_tile_p
::
MMAS_M
*
2
];
softmax
.
template
reduce
<
fmha
::
Max_
>(
p_max
);
// Make sure we are done reading shared memory.
__syncthreads
();
// Compute the exponential value.
softmax
.
apply_exp
(
p_max
);
// Compute the sum.
float
p_sum
[
Mma_tile_p
::
MMAS_M
*
2
];
softmax
.
template
reduce
<
fmha
::
Sum_
>(
p_sum
);
// Finalize softmax on the accumulators of P^T.
softmax
.
scale
(
p_sum
);
__syncthreads
();
if
(
Is_training
)
{
auto
encode_dropout
=
[](
bool
keep
,
float
val
)
{
return
keep
?
val
:
-
val
;
};
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
2
;
ii
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
float4
tmp
=
uniform4
(
ph
());
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from
// pre-existing zeros
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]
=
encode_dropout
(
tmp
.
x
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]
=
encode_dropout
(
tmp
.
y
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
2
]
=
encode_dropout
(
tmp
.
z
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
2
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]
=
encode_dropout
(
tmp
.
w
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]);
}
}
}
gmem_s
.
store
(
softmax
.
elt_
,
mask
);
gmem_s
.
move
();
}
// Trigger the load for the next Q values.
if
(
loop
+
Cta_tile_p
::
M
<
Cta_tile_p
::
N
)
{
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
load
(
smem_q
);
}
typename
Smem_tile_v
::
Fragment
frag_v
[
1
][
Mma_tile_o
::
MMAS_N
];
using
Frag_p
=
fmha
::
Fragment_a
<
fmha
::
Row
>
;
Frag_p
frag_p
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_M
];
softmax
.
pack
(
frag_p
);
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
ki
++
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_o
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Frag_p
::
NUM_REGS
;
ii
++
)
{
//"Apply" the dropout.
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hmul2
(
frag_p
[
ki
][
mi
].
reg
(
ii
),
params
.
scale_dropout
);
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hrelu2
(
frag_p
[
ki
][
mi
].
reg
(
ii
));
}
}
}
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_o
[
Mma_tile_o
::
MMAS_M
][
Mma_tile_o
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_o
::
WARPS_K
>::
apply
(
acc_o
);
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of V values.
smem_v
.
load
(
frag_v
[
0
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_o
,
frag_p
[
ki
],
frag_v
[
0
]);
}
// Loop over MMAS_M.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Gmem_tile_o
::
LOOPS
;
++
ii
)
{
// Swizzle the elements and do the final reduction.
smem_o
.
store
(
acc_o
,
ii
);
// Make sure the data is in shared memory.
__syncthreads
();
// Load from shared memory.
uint4
out
[
Gmem_tile_o
::
STGS_PER_LOOP
];
smem_o
.
load
(
out
);
// Always sync after last iter: shared smem_q and smem_o!
__syncthreads
();
// Output the values.
gmem_o
.
store
(
out
,
ii
);
}
// same smem as o
// Move to the next part of the output.
gmem_o
.
move
();
// Commit the values for Q into shared memory.
if
(
loop
+
Cta_tile_p
::
M
<
Cta_tile_p
::
N
)
{
gmem_q
.
commit
(
smem_q
);
}
// Make sure the data is in shared memory.
__syncthreads
();
}
// Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha_kernel.h
View file @
3c88451a
...
@@ -79,17 +79,19 @@ struct Noloop_traits{
...
@@ -79,17 +79,19 @@ struct Noloop_traits{
enum
{
STEP
=
Cta_tile
::
M
};
enum
{
STEP
=
Cta_tile
::
M
};
enum
{
SEQLEN
=
Cta_tile
::
N
};
enum
{
SEQLEN
=
Cta_tile
::
N
};
// The size of the subsequence this CTA is processing
template
<
typename
Block_info
>
enum
{
SUBSEQ
=
SEQLEN
/
CHUNKS
};
inline
__device__
Noloop_traits
(
const
int
bidc
,
const
Block_info
&
binfo
)
static_assert
(
SUBSEQ
*
CHUNKS
==
SEQLEN
);
:
bidc_
(
bidc
)
{
const
int
seqlen
=
binfo
.
actual_seqlen
;
const
int
steps
=
(
seqlen
+
STEP
-
1
)
/
STEP
;
const
int
steps_per_chunk
=
(
steps
+
CHUNKS
-
1
)
/
CHUNKS
;
const
int
step_begin
=
bidc_
*
steps_per_chunk
;
const
int
step_end
=
min
(
steps
,
(
bidc_
+
1
)
*
steps_per_chunk
);
const
int
actual_steps
=
max
(
0
,
step_end
-
step_begin
);
loop_offset_
=
step_begin
;
num_steps_
=
actual_steps
;
// The number of steps to process the subsequence
enum
{
NUM_STEPS
=
SUBSEQ
/
STEP
};
static_assert
(
NUM_STEPS
*
Cta_tile
::
M
==
SUBSEQ
);
inline
__device__
Noloop_traits
(
const
int
bidc
)
:
loop_offset_
(
NUM_STEPS
*
bidc
)
,
bidc_
(
bidc
)
{
}
}
template
<
typename
...
Tiles
>
template
<
typename
...
Tiles
>
...
@@ -115,54 +117,62 @@ struct Noloop_traits{
...
@@ -115,54 +117,62 @@ struct Noloop_traits{
return
(
loop_offset_
+
l
)
*
STEP
;
return
(
loop_offset_
+
l
)
*
STEP
;
}
}
const
int
loop_offset_
;
const
uint32_t
bidc_
;
const
uint32_t
bidc_
;
const
int
num_steps_
=
NUM_STEPS
;
int
loop_offset_
;
int
num_steps_
;
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
>
template
<
typename
Kernel_traits
>
struct
Noloop_traits
<
3
,
Cta_tile
>
{
std
::
tuple
<
int
,
int
,
int
,
int
,
int
,
int
>
work_dist
(
const
int
total_ctas
,
const
int
heads_total
)
{
// Interpretation of Cta_tile dims, i.e. Cta_tile_p:
enum
{
STEP
=
Cta_tile
::
M
};
constexpr
int
STEPS_PER_HEAD
=
Kernel_traits
::
Cta_tile_p
::
N
/
Kernel_traits
::
Cta_tile_p
::
M
;
enum
{
SEQLEN
=
Cta_tile
::
N
};
const
int
num_full_heads
=
heads_total
/
total_ctas
;
static_assert
(
STEP
==
16
&&
SEQLEN
==
512
);
const
int
heads_last_wave
=
heads_total
%
total_ctas
;
inline
__device__
Noloop_traits
(
const
int
bidc
)
int
num_main_groups
=
0
;
:
bidc_
(
bidc
)
int
main_steps
=
0
;
,
num_steps_
(
bidc
<
2
?
11
:
10
)
int
rest_steps
=
0
;
,
loop_offset_
(
bidc
*
11
)
{
if
(
heads_last_wave
>
0
)
{
}
// Number of CTA groups that process within heads.
num_main_groups
=
total_ctas
/
heads_last_wave
;
template
<
typename
...
Tiles
>
// Remaining CTAs that process between heads.
inline
__device__
void
move_all
(
Tiles
&
...
tiles
)
const
{
const
int
rest_ctas
=
total_ctas
-
(
heads_last_wave
*
num_main_groups
);
using
expand_type
=
int
[];
if
(
rest_ctas
==
0
)
{
for
(
int
s
=
0
;
s
<
loop_offset_
;
s
++
)
{
// We have exactly "num_main_groups" CTAs to process each of the remaining heads.
expand_type
{
(
tiles
.
move
(),
0
)...
};
main_steps
=
(
STEPS_PER_HEAD
+
num_main_groups
-
1
)
/
num_main_groups
;
num_main_groups
=
STEPS_PER_HEAD
/
main_steps
;
// Here: main_step > 0
rest_steps
=
STEPS_PER_HEAD
%
main_steps
;
}
else
{
// Ideal number of steps if we could load-balance as evenly as possible.
const
int
steps_ideal
=
(
heads_last_wave
*
STEPS_PER_HEAD
+
total_ctas
-
1
)
/
total_ctas
;
// Iterations that a "rest" CTA has to do at most.
const
int
max_rest_iters
=
(
heads_last_wave
+
rest_ctas
-
1
)
/
rest_ctas
;
// Find the first step distribution, s.t. the maximum work of the "rest" CTAs is less than the work of the main CTAs.
main_steps
=
steps_ideal
;
rest_steps
=
STEPS_PER_HEAD
-
main_steps
*
num_main_groups
;
for
(
;
main_steps
*
num_main_groups
<
STEPS_PER_HEAD
;
main_steps
++
)
{
rest_steps
=
STEPS_PER_HEAD
-
main_steps
*
num_main_groups
;
const
int
max_rest_total_steps
=
rest_steps
*
max_rest_iters
;
if
(
max_rest_total_steps
<
main_steps
)
break
;
}
rest_steps
=
STEPS_PER_HEAD
-
main_steps
*
num_main_groups
;
}
}
}
}
inline
__device__
int
get_idx_dk
()
const
{
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
//return bidc_;
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
return
bidc_
*
2
+
0
;
}
inline
__device__
int
get_idx_dv
()
const
{
//return CHUNKS + bidc_;
return
bidc_
*
2
+
1
;
}
inline
__device__
int
offset_loop_count
(
const
int
l
)
{
const
int
max_steps
=
STEPS_PER_HEAD
*
num_full_heads
+
std
::
max
(
main_steps
,
rest_steps
);
// convert loop counter to position in the outer sequence
const
int
elts_per_thread_per_step
=
Mma_tile_p
::
MMAS_M
*
Mma_tile_p
::
MMAS_N
*
8
;
return
(
loop_offset_
+
l
)
*
STEP
;
const
int
elts_per_thread
=
max_steps
*
elts_per_thread_per_step
;
}
const
int
loop_offset_
;
return
{
num_full_heads
,
num_main_groups
,
heads_last_wave
,
main_steps
,
rest_steps
,
elts_per_thread
};
const
uint32_t
bidc_
;
}
const
int
num_steps_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
apex/contrib/csrc/multihead_attn/philox.cuh
View file @
3c88451a
...
@@ -7,14 +7,19 @@ class Philox {
...
@@ -7,14 +7,19 @@ class Philox {
public:
public:
__device__
inline
Philox
(
unsigned
long
long
seed
,
__device__
inline
Philox
(
unsigned
long
long
seed
,
unsigned
long
long
subsequence
,
unsigned
long
long
subsequence
,
unsigned
long
long
offset
)
{
unsigned
long
long
offset
)
:
STATE
(
0
)
{
key
.
x
=
(
unsigned
int
)
seed
;
//key.x = (unsigned int)seed;
key
.
y
=
(
unsigned
int
)(
seed
>>
32
);
//key.y = (unsigned int)(seed >> 32);
counter
=
make_uint4
(
0
,
0
,
0
,
0
);
//counter = make_uint4(0, 0, 0, 0);
counter
.
z
=
(
unsigned
int
)(
subsequence
);
//counter.z = (unsigned int)(subsequence);
counter
.
w
=
(
unsigned
int
)(
subsequence
>>
32
);
//counter.w = (unsigned int)(subsequence >> 32);
STATE
=
0
;
//STATE = 0;
incr_n
(
offset
/
4
);
//incr_n(offset / 4);
key
=
reinterpret_cast
<
const
uint2
&>
(
seed
);
ull2
*
tmp
=
reinterpret_cast
<
ull2
*>
(
&
counter
);
tmp
->
x
=
offset
/
4
;
tmp
->
y
=
subsequence
;
}
}
__device__
inline
uint4
operator
()()
{
__device__
inline
uint4
operator
()()
{
if
(
STATE
==
0
)
{
if
(
STATE
==
0
)
{
...
@@ -42,6 +47,10 @@ public:
...
@@ -42,6 +47,10 @@ public:
}
}
private:
private:
struct
ull2
{
uint64_t
x
;
uint64_t
y
;
};
uint4
counter
;
uint4
counter
;
uint4
output
;
uint4
output
;
uint2
key
;
uint2
key
;
...
...
apex/contrib/fmha/fmha.py
View file @
3c88451a
...
@@ -35,9 +35,10 @@ class FMHAFun(torch.autograd.Function):
...
@@ -35,9 +35,10 @@ class FMHAFun(torch.autograd.Function):
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
p_dropout
,
max_s
,
is_training
,
zero_tensors
):
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
p_dropout
,
max_s
,
is_training
,
zero_tensors
):
batch_size
=
cu_seqlens
.
numel
()
-
1
batch_size
=
cu_seqlens
.
numel
()
-
1
if
batch_size
<
4
:
if
batch_size
<
4
:
context
,
S_dmask
=
mha
.
fwd_nl
(
qkv
,
cu_seqlens
,
p_dropout
,
max_s
,
is_training
,
zero_tensors
,
None
)
max_s
=
512
context
,
S_dmask
=
mha
.
fwd_nl
(
qkv
,
cu_seqlens
,
p_dropout
,
max_s
,
is_training
,
True
,
zero_tensors
,
None
)
else
:
else
:
context
,
S_dmask
=
mha
.
fwd
(
qkv
,
cu_seqlens
,
p_dropout
,
max_s
,
is_training
,
zero_tensors
,
None
)
context
,
S_dmask
=
mha
.
fwd
(
qkv
,
cu_seqlens
,
p_dropout
,
max_s
,
is_training
,
False
,
zero_tensors
,
None
)
ctx
.
save_for_backward
(
qkv
,
S_dmask
)
ctx
.
save_for_backward
(
qkv
,
S_dmask
)
ctx
.
cu_seqlens
=
cu_seqlens
ctx
.
cu_seqlens
=
cu_seqlens
ctx
.
p_dropout
=
p_dropout
ctx
.
p_dropout
=
p_dropout
...
@@ -54,7 +55,7 @@ class FMHAFun(torch.autograd.Function):
...
@@ -54,7 +55,7 @@ class FMHAFun(torch.autograd.Function):
else
:
else
:
dqkv
,
dp
=
mha
.
bwd
(
dout
,
qkv
,
S_dmask
,
ctx
.
cu_seqlens
,
ctx
.
p_dropout
,
ctx
.
max_s
,
ctx
.
zero_tensors
)
dqkv
,
dp
=
mha
.
bwd
(
dout
,
qkv
,
S_dmask
,
ctx
.
cu_seqlens
,
ctx
.
p_dropout
,
ctx
.
max_s
,
ctx
.
zero_tensors
)
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
,
None
class
FMHA
(
torch
.
nn
.
Module
):
class
FMHA
(
torch
.
nn
.
Module
):
...
...
apex/contrib/test/fmha/test_fmha.py
View file @
3c88451a
...
@@ -25,7 +25,6 @@
...
@@ -25,7 +25,6 @@
#
#
###############################################################################
###############################################################################
import
sys
import
sys
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
...
@@ -77,9 +76,9 @@ class TestFMHA(unittest.TestCase):
...
@@ -77,9 +76,9 @@ class TestFMHA(unittest.TestCase):
qkv
.
requires_grad
=
True
qkv
.
requires_grad
=
True
if
b
<
4
:
if
b
<
4
:
ctx
,
S_
=
mha
.
fwd
_nl
(
qkv_vs
,
cu_seqlens
,
0.0
,
s
,
True
,
zero_tensors
,
None
)
ctx
,
S_
=
mha
.
fwd
(
qkv_vs
,
cu_seqlens
,
0.0
,
s
,
True
,
True
,
zero_tensors
,
None
)
else
:
else
:
ctx
,
S_
=
mha
.
fwd
(
qkv_vs
,
cu_seqlens
,
0.0
,
s
,
True
,
zero_tensors
,
None
)
ctx
,
S_
=
mha
.
fwd
(
qkv_vs
,
cu_seqlens
,
0.0
,
s
,
True
,
False
,
zero_tensors
,
None
)
ctx
=
ctx
.
view
(
b
,
s
,
h
,
d
)
ctx
=
ctx
.
view
(
b
,
s
,
h
,
d
)
ctx_ref
=
py_mha
(
qkv
,
amask
,
b
,
s
,
h
,
d
)
ctx_ref
=
py_mha
(
qkv
,
amask
,
b
,
s
,
h
,
d
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment