Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
96850dfa
Unverified
Commit
96850dfa
authored
Aug 15, 2022
by
Jithun Nair
Committed by
GitHub
Aug 15, 2022
Browse files
Merge pull request #80 from ROCmSoftwarePlatform/IFU-master-2022-07-29
IFU-master-2022-07-29
parents
87fc4125
cc5f83b5
Changes
235
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1155 additions
and
1157 deletions
+1155
-1157
apex/contrib/csrc/fmha/src/fmha/mask.h
apex/contrib/csrc/fmha/src/fmha/mask.h
+5
-0
apex/contrib/csrc/fmha/src/fmha/smem_tile.h
apex/contrib/csrc/fmha/src/fmha/smem_tile.h
+0
-2
apex/contrib/csrc/fmha/src/fmha/softmax.h
apex/contrib/csrc/fmha/src/fmha/softmax.h
+144
-227
apex/contrib/csrc/fmha/src/fmha/utils.h
apex/contrib/csrc/fmha/src/fmha/utils.h
+85
-0
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu
+1
-1
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu
+1
-1
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu
+1
-1
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu
+1
-1
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h
+3
-3
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h
+5
-7
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu
+42
-16
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu
+42
-16
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu
+43
-16
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu
+86
-47
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h
+301
-106
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_nl.h
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_nl.h
+0
-343
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_reload_v.h
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_reload_v.h
+0
-322
apex/contrib/csrc/fmha/src/fmha_kernel.h
apex/contrib/csrc/fmha/src/fmha_kernel.h
+58
-48
apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp
apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp
+70
-0
apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu
apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu
+267
-0
No files found.
apex/contrib/csrc/fmha/src/fmha/mask.h
View file @
96850dfa
...
@@ -63,6 +63,11 @@ struct Mask {
...
@@ -63,6 +63,11 @@ struct Mask {
// return row_valid && col_valid;
// return row_valid && col_valid;
}
}
//BERT Mask: if upper left is invalid, none are valid
inline
__device__
bool
any_valid
(
int
mi
,
int
ni
)
const
{
return
is_valid
(
mi
,
ni
,
0
,
0
);
}
inline
__device__
void
load
(
int
it
)
{
inline
__device__
void
load
(
int
it
)
{
row_offset
=
it
*
Cta_tile
::
M
+
row
;
row_offset
=
it
*
Cta_tile
::
M
+
row
;
}
}
...
...
apex/contrib/csrc/fmha/src/fmha/smem_tile.h
View file @
96850dfa
...
@@ -1266,8 +1266,6 @@ struct Smem_tile_mma_epilogue : public Base {
...
@@ -1266,8 +1266,6 @@ struct Smem_tile_mma_epilogue : public Base {
}
}
}
}
template
<
int
M
,
int
N
>
template
<
int
M
,
int
N
>
inline
__device__
void
store
(
const
uint4
(
&
regs
)[
M
][
N
])
{
inline
__device__
void
store
(
const
uint4
(
&
regs
)[
M
][
N
])
{
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
...
...
apex/contrib/csrc/fmha/src/fmha/softmax.h
View file @
96850dfa
...
@@ -55,6 +55,88 @@ inline __device__ float apply_exp_(float x, float max) {
...
@@ -55,6 +55,88 @@ inline __device__ float apply_exp_(float x, float max) {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
COLS
>
struct
ReadType
{};
template
<
>
struct
ReadType
<
4
>
{
using
T
=
float
;};
template
<
>
struct
ReadType
<
8
>
{
using
T
=
float2
;};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
,
typename
Kernel_traits
>
struct
Smem_tile_reduce
{
// Helper class to distribute MMA tiles reduced over rows per warp over quads.
// The Mma tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The number of MMAs in M/N dimensions.
enum
{
MMAS_M
=
Mma_tile
::
MMAS_M
};
enum
{
MMAS_N
=
Mma_tile
::
MMAS_N
};
enum
{
WARPS_M
=
Cta_tile
::
WARPS_M
};
enum
{
WARPS_N
=
Cta_tile
::
WARPS_N
};
static
constexpr
int
ROWS
=
WARPS_M
*
MMAS_M
*
16
;
static
constexpr
int
COLS
=
WARPS_N
;
static_assert
(
COLS
==
4
||
COLS
==
8
);
static
constexpr
int
ROWS_PER_XOR_PATTERN
=
(
COLS
==
8
)
?
4
:
8
;
static
constexpr
int
BYTES_PER_TILE
=
ROWS
*
COLS
*
sizeof
(
float
);
static
constexpr
int
ELTS_PER_TILE
=
ROWS
*
COLS
;
static
constexpr
int
THREADS_PER_GROUP
=
Kernel_traits
::
Gmem_tile_o
::
THREADS_PER_ROW
;
static_assert
(
THREADS_PER_GROUP
==
16
);
// DEBUG
static
constexpr
int
ROWS_PER_WARP
=
32
/
THREADS_PER_GROUP
;
static
constexpr
int
LOOPS
=
Kernel_traits
::
Gmem_tile_o
::
LOOPS
;
static_assert
(
LOOPS
==
1
);
using
read_t
=
typename
ReadType
<
COLS
>::
T
;
__device__
inline
Smem_tile_reduce
(
float
*
smem_
,
const
int
tidx
)
{
int
lane
=
tidx
%
32
;
int
warp
=
tidx
/
32
;
int
warp_m
=
warp
%
WARPS_M
;
int
warp_n
=
warp
/
WARPS_M
;
qid_
=
lane
%
4
;
int
qp
=
lane
/
4
;
// Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps.
// This won't affect reading as we assume commutative reduction ops.
const
int
col
=
warp_n
^
(
qp
/
ROWS_PER_XOR_PATTERN
);
smem_write_
=
&
smem_
[
warp_m
*
16
*
MMAS_M
*
WARPS_N
+
qp
*
WARPS_N
+
col
];
smem_read_
=
&
reinterpret_cast
<
read_t
*>
(
smem_
)[
warp_m
*
16
*
MMAS_M
*
4
+
qp
*
4
+
qid_
];
}
__device__
inline
void
store
(
float
(
&
frag
)[
2
*
MMAS_M
])
{
if
(
qid_
==
0
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
mi
++
)
{
int
offset
=
mi
*
16
*
WARPS_N
;
smem_write_
[
offset
+
0
*
8
*
WARPS_N
]
=
frag
[
mi
*
2
+
0
];
smem_write_
[
offset
+
1
*
8
*
WARPS_N
]
=
frag
[
mi
*
2
+
1
];
}
}
}
__device__
inline
void
load
(
read_t
(
&
frag
)[
2
*
MMAS_M
])
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
mi
++
)
{
int
offset
=
mi
*
16
*
4
;
frag
[
mi
*
2
+
0
]
=
smem_read_
[
offset
+
0
*
8
*
4
];
frag
[
mi
*
2
+
1
]
=
smem_read_
[
offset
+
1
*
8
*
4
];
}
}
int
qid_
;
float
*
smem_write_
;
read_t
*
smem_read_
;
};
template
<
typename
Cta_tile
,
typename
Kernel_traits
>
template
<
typename
Cta_tile
,
typename
Kernel_traits
>
struct
Softmax_base
{
struct
Softmax_base
{
...
@@ -136,201 +218,6 @@ struct Softmax_base {
...
@@ -136,201 +218,6 @@ struct Softmax_base {
}
}
}
}
// Do a CTA-wide reduction.
template
<
typename
Functor
>
inline
__device__
void
reduce_1x4
(
float
(
&
dst
)[
MMAS_M
*
2
])
{
#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
if
(
Functor
::
IS_SUM
)
{
// Apply the summation inside the thread.
float
tmp
[
MMAS_M
*
2
][
2
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
tmp
[
mi
][
0
]
=
0.
f
;
tmp
[
mi
][
1
]
=
0.
f
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
++
ni
)
{
tmp
[
mi
][
0
]
+=
elt_
[
mi
][
4
*
ni
+
0
];
tmp
[
mi
][
0
]
+=
elt_
[
mi
][
4
*
ni
+
1
];
tmp
[
mi
][
1
]
+=
elt_
[
mi
][
4
*
ni
+
2
];
tmp
[
mi
][
1
]
+=
elt_
[
mi
][
4
*
ni
+
3
];
}
dst
[
mi
]
=
tmp
[
mi
][
0
]
+
tmp
[
mi
][
1
];
}
}
else
#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
{
// Apply the functor for each row inside a thread.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
dst
[
mi
]
=
elt_
[
mi
][
0
];
#pragma unroll
for
(
int
ni
=
1
;
ni
<
MMAS_N
*
4
;
++
ni
)
{
dst
[
mi
]
=
Functor
::
apply
(
dst
[
mi
],
elt_
[
mi
][
ni
]);
}
}
}
// Apply the functor for each row inside each group of 4 threads.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
dst
[
mi
]
=
Functor
::
apply
(
dst
[
mi
],
__shfl_xor_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
1
));
__syncwarp
();
dst
[
mi
]
=
Functor
::
apply
(
dst
[
mi
],
__shfl_xor_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
2
));
__syncwarp
();
}
// Store the different values.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
if
(
tidx_
%
4
==
0
)
{
smem_write_
[(
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
0
)
*
ELEMENTS_PER_ROW
]
=
dst
[
2
*
mi
+
0
];
smem_write_
[(
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
8
)
*
ELEMENTS_PER_ROW
]
=
dst
[
2
*
mi
+
1
];
}
}
// Make sure the values are in shared memory.
__syncthreads
();
// Load 8 values (one for each warp). The /8 corresponds to /(4*2) where 4 is from the
// float4.
float4
tmp
[
1
];
if
(
tidx_
<
Cta_tile
::
M
)
{
tmp
[
0
]
=
reinterpret_cast
<
const
float4
*>
(
&
smem_
[
0
*
ELEMENTS
/
2
])[
tidx_
];
}
// Compute the reduction of those 8 values in a binary-tree fashion.
tmp
[
0
].
x
=
Functor
::
apply
(
tmp
[
0
].
x
,
tmp
[
0
].
y
);
tmp
[
0
].
z
=
Functor
::
apply
(
tmp
[
0
].
z
,
tmp
[
0
].
w
);
tmp
[
0
].
x
=
Functor
::
apply
(
tmp
[
0
].
x
,
tmp
[
0
].
z
);
// Make sure we can write to shared memory.
__syncthreads
();
// Store the value back to shared memory.
if
(
tidx_
<
Cta_tile
::
M
)
{
smem_
[
tidx_
]
=
tmp
[
0
].
x
;
}
// Make sure the data is in shared memory.
__syncthreads
();
// Finally read the values.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
dst
[
2
*
mi
+
0
]
=
smem_read_
[
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
0
];
dst
[
2
*
mi
+
1
]
=
smem_read_
[
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
8
];
}
}
// Do a CTA-wide reduction.
template
<
typename
Functor
>
inline
__device__
void
reduce_1x8
(
float
(
&
dst
)[
MMAS_M
*
2
])
{
#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
if
(
Functor
::
IS_SUM
)
{
// Apply the summation inside the thread.
float
tmp
[
MMAS_M
*
2
][
2
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
tmp
[
mi
][
0
]
=
0.
f
;
tmp
[
mi
][
1
]
=
0.
f
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
++
ni
)
{
tmp
[
mi
][
0
]
+=
elt_
[
mi
][
4
*
ni
+
0
];
tmp
[
mi
][
0
]
+=
elt_
[
mi
][
4
*
ni
+
1
];
tmp
[
mi
][
1
]
+=
elt_
[
mi
][
4
*
ni
+
2
];
tmp
[
mi
][
1
]
+=
elt_
[
mi
][
4
*
ni
+
3
];
}
dst
[
mi
]
=
tmp
[
mi
][
0
]
+
tmp
[
mi
][
1
];
}
}
else
#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
{
// Apply the functor for each row inside a thread.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
dst
[
mi
]
=
elt_
[
mi
][
0
];
#pragma unroll
for
(
int
ni
=
1
;
ni
<
MMAS_N
*
4
;
++
ni
)
{
dst
[
mi
]
=
Functor
::
apply
(
dst
[
mi
],
elt_
[
mi
][
ni
]);
}
}
}
// Apply the functor for each row inside each group of 4 threads.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
dst
[
mi
]
=
Functor
::
apply
(
dst
[
mi
],
__shfl_xor_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
1
));
__syncwarp
();
dst
[
mi
]
=
Functor
::
apply
(
dst
[
mi
],
__shfl_xor_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
2
));
__syncwarp
();
}
// Store the different values.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
if
(
tidx_
%
4
==
0
)
{
smem_write_
[(
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
0
)
*
ELEMENTS_PER_ROW
]
=
dst
[
2
*
mi
+
0
];
smem_write_
[(
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
8
)
*
ELEMENTS_PER_ROW
]
=
dst
[
2
*
mi
+
1
];
}
}
// Make sure the values are in shared memory.
__syncthreads
();
// Load 8 values (one for each warp). The /8 corresponds to /(4*2) where 4 is from the
// float4.
float4
tmp
[
2
];
if
(
tidx_
<
Cta_tile
::
M
)
{
tmp
[
0
]
=
reinterpret_cast
<
const
float4
*>
(
&
smem_
[
0
*
ELEMENTS
/
2
])[
tidx_
];
tmp
[
1
]
=
reinterpret_cast
<
const
float4
*>
(
&
smem_
[
1
*
ELEMENTS
/
2
])[
tidx_
];
}
// Compute the reduction of those 8 values in a binary-tree fashion.
tmp
[
0
].
x
=
Functor
::
apply
(
tmp
[
0
].
x
,
tmp
[
0
].
y
);
tmp
[
0
].
z
=
Functor
::
apply
(
tmp
[
0
].
z
,
tmp
[
0
].
w
);
tmp
[
1
].
x
=
Functor
::
apply
(
tmp
[
1
].
x
,
tmp
[
1
].
y
);
tmp
[
1
].
z
=
Functor
::
apply
(
tmp
[
1
].
z
,
tmp
[
1
].
w
);
tmp
[
0
].
x
=
Functor
::
apply
(
tmp
[
0
].
x
,
tmp
[
0
].
z
);
tmp
[
1
].
x
=
Functor
::
apply
(
tmp
[
1
].
x
,
tmp
[
1
].
z
);
tmp
[
0
].
x
=
Functor
::
apply
(
tmp
[
0
].
x
,
tmp
[
1
].
x
);
// Make sure we can write to shared memory.
__syncthreads
();
// Store the value back to shared memory.
if
(
tidx_
<
Cta_tile
::
M
)
{
smem_
[
tidx_
]
=
tmp
[
0
].
x
;
}
// Make sure the data is in shared memory.
__syncthreads
();
// Finally read the values.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
dst
[
2
*
mi
+
0
]
=
smem_read_
[
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
0
];
dst
[
2
*
mi
+
1
]
=
smem_read_
[
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
8
];
}
}
// Do a CTA-wide reduction.
template
<
typename
Functor
>
inline
__device__
void
reduce
(
float
(
&
dst
)[
MMAS_M
*
2
])
{
static_assert
(
Cta_tile
::
WARPS_M
==
1
&&
(
Cta_tile
::
WARPS_N
==
4
||
Cta_tile
::
WARPS_N
==
8
));
if
(
Cta_tile
::
WARPS_M
==
1
&&
Cta_tile
::
WARPS_N
==
4
)
{
reduce_1x4
<
Functor
>
(
dst
);
}
else
if
(
Cta_tile
::
WARPS_M
==
1
&&
Cta_tile
::
WARPS_N
==
8
)
{
reduce_1x8
<
Functor
>
(
dst
);
}
else
{
assert
(
false
);
}
// Make sure we are done reading from shared memory.
__syncthreads
();
}
// Scale all the elements.
// Scale all the elements.
inline
__device__
void
scale
(
const
float
(
&
sum
)[
MMAS_M
*
2
])
{
inline
__device__
void
scale
(
const
float
(
&
sum
)[
MMAS_M
*
2
])
{
// Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.
// Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.
...
@@ -372,6 +259,8 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
...
@@ -372,6 +259,8 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
static_assert
(
Fragment_a
::
NUM_REGS
==
4
);
static_assert
(
Fragment_a
::
NUM_REGS
==
4
);
enum
{
WARPS_M
=
Cta_tile
::
WARPS_M
};
enum
{
WARPS_N
=
Cta_tile
::
WARPS_N
};
// The MMAs.
// The MMAs.
enum
{
MMAS_M
=
Base
::
MMAS_M
};
enum
{
MMAS_M
=
Base
::
MMAS_M
};
enum
{
MMAS_N
=
Base
::
MMAS_N
};
enum
{
MMAS_N
=
Base
::
MMAS_N
};
...
@@ -383,41 +272,15 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
...
@@ -383,41 +272,15 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
static_assert
(
std
::
is_same
<
Accumulator
::
Data_type
,
float
>::
value
);
static_assert
(
std
::
is_same
<
Accumulator
::
Data_type
,
float
>::
value
);
using
Smem_tile_red
=
Smem_tile_reduce
<
Cta_tile
,
Kernel_traits
>
;
static_assert
(
Smem_tile_red
::
ELTS_PER_TILE
==
Cta_tile
::
M
*
WARPS_N
);
// Ctor.
// Ctor.
template
<
typename
Params
>
template
<
typename
Params
>
inline
__device__
Softmax
(
const
Params
&
params
,
void
*
smem
,
int
bidb
,
int
tidx
)
inline
__device__
Softmax
(
const
Params
&
params
,
void
*
smem
,
int
bidb
,
int
tidx
)
:
Base
(
params
,
smem
,
bidb
,
tidx
),
params_scale_bmm1_
(
params
.
scale_bmm1
)
{
:
Base
(
params
,
smem
,
bidb
,
tidx
)
}
,
params_scale_bmm1_
(
params
.
scale_bmm1
)
,
smem_sum_
(
static_cast
<
float
*>
(
smem
),
tidx
)
// Store the tile after softmax.
,
smem_max_
(
static_cast
<
float
*>
(
smem
)
+
Smem_tile_red
::
ELTS_PER_TILE
,
tidx
)
{
template
<
typename
Gmem_tile
>
inline
__device__
void
store
(
Gmem_tile
&
gmem_tile
)
{
Accumulator_out
acc
[
MMAS_M
][
MMAS_N
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
++
ni
)
{
// The elements.
float
tmp_00
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
0
];
float
tmp_01
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
1
];
float
tmp_02
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
2
];
float
tmp_03
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
3
];
float
tmp_10
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
0
];
float
tmp_11
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
1
];
float
tmp_12
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
2
];
float
tmp_13
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
3
];
// Transform to accumulators.
acc
[
mi
][
ni
].
reg
(
0
)
=
fmha
::
float2_to_half2
(
tmp_00
,
tmp_01
);
acc
[
mi
][
ni
].
reg
(
1
)
=
fmha
::
float2_to_half2
(
tmp_10
,
tmp_11
);
acc
[
mi
][
ni
].
reg
(
2
)
=
fmha
::
float2_to_half2
(
tmp_02
,
tmp_03
);
acc
[
mi
][
ni
].
reg
(
3
)
=
fmha
::
float2_to_half2
(
tmp_12
,
tmp_13
);
}
}
// Delegate to the gmem tile to store.
gmem_tile
.
store
(
acc
);
}
}
// Pack the data to a fragment for the next GEMM.
// Pack the data to a fragment for the next GEMM.
...
@@ -470,7 +333,61 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
...
@@ -470,7 +333,61 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
}
}
}
}
}
}
// Scale FP32 fragments
inline
__device__
void
unpack_noscale
(
const
Accumulator
(
&
acc
)[
MMAS_M
][
MMAS_N
])
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
++
ni
)
{
// 1st row - 4 elements per row.
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
0
]
=
acc
[
mi
][
ni
].
elt
(
0
);
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
1
]
=
acc
[
mi
][
ni
].
elt
(
1
);
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
2
]
=
acc
[
mi
][
ni
].
elt
(
4
);
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
3
]
=
acc
[
mi
][
ni
].
elt
(
5
);
// 2nd row - 4 elements per row.
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
0
]
=
acc
[
mi
][
ni
].
elt
(
2
);
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
1
]
=
acc
[
mi
][
ni
].
elt
(
3
);
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
2
]
=
acc
[
mi
][
ni
].
elt
(
6
);
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
3
]
=
acc
[
mi
][
ni
].
elt
(
7
);
}
}
}
template
<
typename
Operator
>
__device__
inline
void
reduce_
(
float
(
&
frag
)[
2
*
MMAS_M
],
Operator
&
op
,
Smem_tile_red
&
smem_red
)
{
for
(
int
mi
=
0
;
mi
<
2
*
MMAS_M
;
mi
++
)
{
frag
[
mi
]
=
this
->
elt_
[
mi
][
0
];
for
(
int
ni
=
1
;
ni
<
4
*
MMAS_N
;
ni
++
)
{
frag
[
mi
]
=
op
(
frag
[
mi
],
this
->
elt_
[
mi
][
ni
]);
}
}
quad_reduce
(
frag
,
frag
,
op
);
smem_red
.
store
(
frag
);
__syncthreads
();
typename
Smem_tile_red
::
read_t
tmp
[
2
*
MMAS_M
];
smem_red
.
load
(
tmp
);
quad_allreduce
(
frag
,
tmp
,
op
);
}
__device__
inline
void
reduce_max
(
float
(
&
frag
)[
2
*
MMAS_M
]){
MaxOp
<
float
>
max
;
reduce_
(
frag
,
max
,
smem_max_
);
}
__device__
inline
void
reduce_sum
(
float
(
&
frag
)[
2
*
MMAS_M
]){
SumOp
<
float
>
sum
;
reduce_
(
frag
,
sum
,
smem_sum_
);
}
const
uint32_t
params_scale_bmm1_
;
const
uint32_t
params_scale_bmm1_
;
Smem_tile_red
smem_max_
;
Smem_tile_red
smem_sum_
;
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
apex/contrib/csrc/fmha/src/fmha/utils.h
View file @
96850dfa
...
@@ -950,4 +950,89 @@ inline __device__ void sts(uint32_t (&ptrs)[N], const uint4 (&data)[N]) {
...
@@ -950,4 +950,89 @@ inline __device__ void sts(uint32_t (&ptrs)[N], const uint4 (&data)[N]) {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
MaxOp
{
__device__
inline
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
>
y
?
x
:
y
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
SumOp
{
__device__
inline
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
+
y
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS
>
struct
Allreduce
{
static_assert
(
THREADS
==
32
||
THREADS
==
16
||
THREADS
==
8
||
THREADS
==
4
);
template
<
typename
T
,
typename
Operator
>
static
__device__
inline
T
run
(
T
x
,
Operator
&
op
)
{
constexpr
int
OFFSET
=
THREADS
/
2
;
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
OFFSET
));
return
Allreduce
<
OFFSET
>::
run
(
x
,
op
);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Allreduce
<
2
>
{
template
<
typename
T
,
typename
Operator
>
static
__device__
inline
T
run
(
T
x
,
Operator
&
op
)
{
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
1
));
return
x
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Operator
,
int
M
>
__device__
inline
void
quad_reduce
(
float
(
&
dst
)[
M
],
float
(
&
src
)[
M
],
Operator
&
op
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
){
dst
[
mi
]
=
src
[
mi
];
dst
[
mi
]
=
op
(
dst
[
mi
],
__shfl_down_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
2
));
dst
[
mi
]
=
op
(
dst
[
mi
],
__shfl_down_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
1
));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Operator
,
int
M
>
__device__
inline
void
quad_reduce
(
float
(
&
dst
)[
M
],
float2
(
&
src
)[
M
],
Operator
&
op
)
{
float
tmp
[
M
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
){
tmp
[
mi
]
=
op
(
src
[
mi
].
x
,
src
[
mi
].
y
);
}
quad_reduce
(
dst
,
tmp
,
op
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Operator
,
int
M
>
__device__
inline
void
quad_allreduce
(
float
(
&
dst
)[
M
],
float
(
&
src
)[
M
],
Operator
&
op
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
){
dst
[
mi
]
=
src
[
mi
];
dst
[
mi
]
=
Allreduce
<
4
>::
run
(
dst
[
mi
],
op
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Operator
,
int
M
>
__device__
inline
void
quad_allreduce
(
float
(
&
dst
)[
M
],
float2
(
&
src
)[
M
],
Operator
&
op
)
{
float
tmp
[
M
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
){
tmp
[
mi
]
=
op
(
src
[
mi
].
x
,
src
[
mi
].
y
);
}
quad_allreduce
(
dst
,
tmp
,
op
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu
View file @
96850dfa
...
@@ -28,7 +28,7 @@
...
@@ -28,7 +28,7 @@
#include "fmha.h"
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
extern
"C"
__global__
void
fmha_dgrad_fp16_128_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
extern
"C"
__global__
void
fmha_dgrad_fp16_128_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
...
...
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu
View file @
96850dfa
...
@@ -28,7 +28,7 @@
...
@@ -28,7 +28,7 @@
#include "fmha.h"
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
>
;
extern
"C"
__global__
void
fmha_dgrad_fp16_256_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
extern
"C"
__global__
void
fmha_dgrad_fp16_256_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
...
...
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu
View file @
96850dfa
...
@@ -28,7 +28,7 @@
...
@@ -28,7 +28,7 @@
#include "fmha.h"
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
384
,
64
,
16
,
1
,
8
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
384
,
64
,
16
,
1
,
8
,
0x08u
>
;
extern
"C"
__global__
void
fmha_dgrad_fp16_384_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
extern
"C"
__global__
void
fmha_dgrad_fp16_384_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
...
...
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu
View file @
96850dfa
...
@@ -29,7 +29,7 @@
...
@@ -29,7 +29,7 @@
#include "fmha_dgrad_kernel_1xN_reload.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
#include "fmha_dgrad_kernel_1xN_reload_nl.h"
#include "fmha_dgrad_kernel_1xN_reload_nl.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
512
,
64
,
16
,
1
,
8
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
512
,
64
,
16
,
1
,
8
,
0x08u
>
;
extern
"C"
__global__
void
fmha_dgrad_fp16_512_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
extern
"C"
__global__
void
fmha_dgrad_fp16_512_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
...
...
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h
View file @
96850dfa
...
@@ -141,7 +141,7 @@ inline __device__ void compute_dv_1xN(const Params ¶ms) {
...
@@ -141,7 +141,7 @@ inline __device__ void compute_dv_1xN(const Params ¶ms) {
enum
{
BITS_PER_ELT_S
=
sizeof
(
fmha
::
A_type
)
*
8
};
enum
{
BITS_PER_ELT_S
=
sizeof
(
fmha
::
A_type
)
*
8
};
Gmem_tile_s
gmem_s
(
params
.
s_ptr
,
params
,
tidx
);
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
// Create the object to do the softmax.
// Create the object to do the softmax.
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
...
@@ -231,7 +231,7 @@ inline __device__ void compute_dv_1xN(const Params ¶ms) {
...
@@ -231,7 +231,7 @@ inline __device__ void compute_dv_1xN(const Params ¶ms) {
}
}
float
p_sum
[
2
*
M
];
float
p_sum
[
2
*
M
];
softmax
.
template
reduce
<
fmha
::
Sum_
>
(
p_sum
);
softmax
.
reduce_sum
(
p_sum
);
const
float
scalef
=
reinterpret_cast
<
const
float
&>
(
params
.
scale_softmax
);
const
float
scalef
=
reinterpret_cast
<
const
float
&>
(
params
.
scale_softmax
);
#pragma unroll
#pragma unroll
...
@@ -406,7 +406,7 @@ inline __device__ void compute_dq_dk_1xN(const Params ¶ms) {
...
@@ -406,7 +406,7 @@ inline __device__ void compute_dq_dk_1xN(const Params ¶ms) {
// Trigger the loads for K.
// Trigger the loads for K.
gmem_k
.
load
(
smem_k
);
gmem_k
.
load
(
smem_k
);
Gmem_tile_s
gmem_s
(
params
.
s_ptr
,
params
,
tidx
);
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
// Load dP
// Load dP
uint4
s_regs
[
M
][
N
];
uint4
s_regs
[
M
][
N
];
gmem_s
.
load
(
s_regs
,
mask
);
gmem_s
.
load
(
s_regs
,
mask
);
...
...
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h
View file @
96850dfa
...
@@ -114,11 +114,11 @@ inline __device__ void compute_dv_1xN_nl(const Params ¶ms) {
...
@@ -114,11 +114,11 @@ inline __device__ void compute_dv_1xN_nl(const Params ¶ms) {
// Allocate the shared memory tile loader for K.
// Allocate the shared memory tile loader for K.
Smem_tile_k
smem_k
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
Smem_tile_k
smem_k
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
Gmem_tile_s
gmem_s
(
params
.
s_ptr
,
params
,
tidx
);
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
using
Noloop
=
Noloop_traits
<
CHUNKS
,
Cta_tile_p
>
;
using
Noloop
=
Noloop_traits
<
CHUNKS
,
Cta_tile_p
>
;
Noloop
nl_traits
(
bidc
);
Noloop
nl_traits
(
bidc
,
binfo
);
nl_traits
.
move_all
(
gmem_q
,
gmem_s
);
nl_traits
.
move_all
(
gmem_q
,
gmem_s
);
// Trigger the loads for Q.
// Trigger the loads for Q.
...
@@ -163,8 +163,6 @@ inline __device__ void compute_dv_1xN_nl(const Params ¶ms) {
...
@@ -163,8 +163,6 @@ inline __device__ void compute_dv_1xN_nl(const Params ¶ms) {
// Load over the entire sequence length.
// Load over the entire sequence length.
for
(
int
l
=
0
;
l
<
nl_traits
.
num_steps_
;
l
++
)
{
for
(
int
l
=
0
;
l
<
nl_traits
.
num_steps_
;
l
++
)
{
const
int
loop
=
nl_traits
.
offset_loop_count
(
l
);
if
(
loop
>=
binfo
.
actual_seqlen
)
break
;
uint4
s_regs
[
M
][
N
];
uint4
s_regs
[
M
][
N
];
gmem_s
.
load
(
s_regs
,
mask
);
gmem_s
.
load
(
s_regs
,
mask
);
...
@@ -230,7 +228,7 @@ inline __device__ void compute_dv_1xN_nl(const Params ¶ms) {
...
@@ -230,7 +228,7 @@ inline __device__ void compute_dv_1xN_nl(const Params ¶ms) {
}
}
float
p_sum
[
2
*
M
];
float
p_sum
[
2
*
M
];
softmax
.
template
reduce
<
fmha
::
Sum_
>
(
p_sum
);
softmax
.
reduce_sum
(
p_sum
);
const
float
scalef
=
reinterpret_cast
<
const
float
&>
(
params
.
scale_softmax
);
const
float
scalef
=
reinterpret_cast
<
const
float
&>
(
params
.
scale_softmax
);
#pragma unroll
#pragma unroll
...
@@ -400,7 +398,7 @@ inline __device__ void compute_dq_dk_1xN_nl(const Params ¶ms) {
...
@@ -400,7 +398,7 @@ inline __device__ void compute_dq_dk_1xN_nl(const Params ¶ms) {
// Allocate the shared memory tile loader for Q (as B).
// Allocate the shared memory tile loader for Q (as B).
Smem_tile_qt
smem_qt
(
&
smem_
[
0
],
tidx
);
Smem_tile_qt
smem_qt
(
&
smem_
[
0
],
tidx
);
// Allocate the global memory tile loader for dP.
// Allocate the global memory tile loader for dP.
Gmem_tile_s
gmem_s
(
params
.
s_ptr
,
params
,
tidx
);
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for dP.
// Allocate the shared memory tile loader for dP.
Smem_tile_st
smem_s
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
],
tidx
);
Smem_tile_st
smem_s
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
],
tidx
);
...
@@ -414,7 +412,7 @@ inline __device__ void compute_dq_dk_1xN_nl(const Params ¶ms) {
...
@@ -414,7 +412,7 @@ inline __device__ void compute_dq_dk_1xN_nl(const Params ¶ms) {
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o
smem_o
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
],
tidx
);
Smem_tile_o
smem_o
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
],
tidx
);
Noloop
nl_traits
(
bidc
);
Noloop
nl_traits
(
bidc
,
binfo
);
nl_traits
.
move_all
(
gmem_q
,
gmem_o
,
gmem_s
);
nl_traits
.
move_all
(
gmem_q
,
gmem_o
,
gmem_s
);
...
...
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu
View file @
96850dfa
...
@@ -28,31 +28,57 @@
...
@@ -28,31 +28,57 @@
#include "fmha.h"
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_fprop_kernel_1xN.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
extern
"C"
__global__
void
fmha_fprop_fp16_128_64_sm80_train_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
template
<
bool
Is_training
>
fmha
::
device_1xN
<
Kernel_traits
,
true
>
(
params
);
__global__
}
void
fmha_fprop_fp16_128_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
,
const
int
num_full_heads
,
const
int
num_main_groups
,
const
int
main_group_size
,
const
int
main_steps
,
const
int
rest_steps
)
{
extern
"C"
__global__
void
fmha_fprop_fp16_128_64_sm80_predict_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN
<
Kernel_traits
,
Is_training
>
(
fmha
::
device_1xN
<
Kernel_traits
,
false
>
(
param
s
);
params
,
num_full_heads
,
num_main_groups
,
main_group_size
,
main_steps
,
rest_step
s
);
}
}
void
run_fmha_fp16_128_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
)
{
void
run_fmha_fp16_128_64_sm80
(
Launch_params
<
Fused_multihead_attention_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
auto
kernel
=
is_training
?
&
fmha_fprop_fp16_128_64_sm80_train_kernel
:
&
fmha_fprop_fp16_128_64_sm80_predict_kernel
;
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
auto
kernel
=
launch_params
.
is_training
?
&
fmha_fprop_fp16_128_64_sm80_kernel
<
true
>
:
&
fmha_fprop_fp16_128_64_sm80_kernel
<
false
>
;
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
constexpr
int
smem_size
=
smem_size_q
+
std
::
max
(
smem_size_v
,
smem_size_o
+
smem_size_softmax
);
constexpr
int
smem_size
=
fmha
::
get_dynamic_smem_size
<
Kernel_traits
>
(
);
if
(
smem_size
>=
48
*
1024
)
{
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
}
dim3
grid
(
params
.
h
,
params
.
b
);
const
int
sm_count
=
launch_params
.
props
->
multiProcessorCount
;
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
int
ctas_per_sm
;
FMHA_CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS
,
smem_size
));
int
total_ctas
=
sm_count
*
ctas_per_sm
;
if
(
configure
)
{
const
int
heads_total
=
launch_params
.
params
.
b
*
launch_params
.
params
.
h
;
std
::
tie
(
launch_params
.
num_full_heads
,
launch_params
.
num_main_groups
,
launch_params
.
heads_last_wave
,
launch_params
.
main_steps
,
launch_params
.
rest_steps
,
launch_params
.
elts_per_thread
)
=
fmha
::
work_dist
<
Kernel_traits
>
(
total_ctas
,
heads_total
);
return
;
}
dim3
grid
(
total_ctas
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
launch_params
.
stream
>>>
(
launch_params
.
params
,
launch_params
.
num_full_heads
,
launch_params
.
num_main_groups
,
launch_params
.
heads_last_wave
,
launch_params
.
main_steps
,
launch_params
.
rest_steps
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
}
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu
View file @
96850dfa
...
@@ -28,31 +28,57 @@
...
@@ -28,31 +28,57 @@
#include "fmha.h"
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_fprop_kernel_1xN.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
>
;
extern
"C"
__global__
void
fmha_fprop_fp16_256_64_sm80_train_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
template
<
bool
Is_training
>
fmha
::
device_1xN
<
Kernel_traits
,
true
>
(
params
);
__global__
}
void
fmha_fprop_fp16_256_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
,
const
int
num_full_heads
,
const
int
num_main_groups
,
const
int
main_group_size
,
const
int
main_steps
,
const
int
rest_steps
)
{
extern
"C"
__global__
void
fmha_fprop_fp16_256_64_sm80_predict_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN
<
Kernel_traits
,
Is_training
>
(
fmha
::
device_1xN
<
Kernel_traits
,
false
>
(
param
s
);
params
,
num_full_heads
,
num_main_groups
,
main_group_size
,
main_steps
,
rest_step
s
);
}
}
void
run_fmha_fp16_256_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
)
{
void
run_fmha_fp16_256_64_sm80
(
Launch_params
<
Fused_multihead_attention_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
auto
kernel
=
is_training
?
&
fmha_fprop_fp16_256_64_sm80_train_kernel
:
&
fmha_fprop_fp16_256_64_sm80_predict_kernel
;
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
auto
kernel
=
launch_params
.
is_training
?
&
fmha_fprop_fp16_256_64_sm80_kernel
<
true
>
:
&
fmha_fprop_fp16_256_64_sm80_kernel
<
false
>
;
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
constexpr
int
smem_size
=
smem_size_q
+
std
::
max
(
smem_size_v
,
smem_size_o
+
smem_size_softmax
);
constexpr
int
smem_size
=
fmha
::
get_dynamic_smem_size
<
Kernel_traits
>
(
);
if
(
smem_size
>=
48
*
1024
)
{
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
}
dim3
grid
(
params
.
h
,
params
.
b
);
const
int
sm_count
=
launch_params
.
props
->
multiProcessorCount
;
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
int
ctas_per_sm
;
FMHA_CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS
,
smem_size
));
int
total_ctas
=
sm_count
*
ctas_per_sm
;
if
(
configure
)
{
const
int
heads_total
=
launch_params
.
params
.
b
*
launch_params
.
params
.
h
;
std
::
tie
(
launch_params
.
num_full_heads
,
launch_params
.
num_main_groups
,
launch_params
.
heads_last_wave
,
launch_params
.
main_steps
,
launch_params
.
rest_steps
,
launch_params
.
elts_per_thread
)
=
fmha
::
work_dist
<
Kernel_traits
>
(
total_ctas
,
heads_total
);
return
;
}
dim3
grid
(
total_ctas
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
launch_params
.
stream
>>>
(
launch_params
.
params
,
launch_params
.
num_full_heads
,
launch_params
.
num_main_groups
,
launch_params
.
heads_last_wave
,
launch_params
.
main_steps
,
launch_params
.
rest_steps
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
}
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu
View file @
96850dfa
...
@@ -26,32 +26,59 @@
...
@@ -26,32 +26,59 @@
******************************************************************************/
******************************************************************************/
#include "fmha.h"
#include "fmha.h"
#include "fmha_fprop_kernel_1xN
_reload_v
.h"
#include "fmha_fprop_kernel_1xN.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
384
,
64
,
16
,
1
,
4
,
0x
0
8u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
384
,
64
,
16
,
1
,
4
,
0x
1
8u
>
;
extern
"C"
__global__
void
fmha_fprop_fp16_384_64_sm80_train_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
template
<
bool
Is_training
>
fmha
::
device_1xN
<
Kernel_traits
,
true
>
(
params
);
__global__
}
void
fmha_fprop_fp16_384_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
,
const
int
num_full_heads
,
const
int
num_main_groups
,
const
int
main_group_size
,
const
int
main_steps
,
const
int
rest_steps
)
{
extern
"C"
__global__
void
fmha_fprop_fp16_384_64_sm80_predict_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN
<
Kernel_traits
,
Is_training
>
(
fmha
::
device_1xN
<
Kernel_traits
,
false
>
(
param
s
);
params
,
num_full_heads
,
num_main_groups
,
main_group_size
,
main_steps
,
rest_step
s
);
}
}
void
run_fmha_fp16_384_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
)
{
void
run_fmha_fp16_384_64_sm80
(
Launch_params
<
Fused_multihead_attention_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
auto
kernel
=
is_training
?
&
fmha_fprop_fp16_384_64_sm80_train_kernel
:
&
fmha_fprop_fp16_384_64_sm80_predict_kernel
;
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
auto
kernel
=
launch_params
.
is_training
?
&
fmha_fprop_fp16_384_64_sm80_kernel
<
true
>
:
&
fmha_fprop_fp16_384_64_sm80_kernel
<
false
>
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
constexpr
int
smem_size
=
smem_size_v
+
smem_size_o
+
smem_size_softmax
;
constexpr
int
smem_size
=
fmha
::
get_dynamic_smem_size
<
Kernel_traits
>
()
;
if
(
smem_size
>=
48
*
1024
)
{
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
}
dim3
grid
(
params
.
h
,
params
.
b
);
const
int
sm_count
=
launch_params
.
props
->
multiProcessorCount
;
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
int
ctas_per_sm
;
FMHA_CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS
,
smem_size
));
int
total_ctas
=
sm_count
*
ctas_per_sm
;
if
(
configure
)
{
const
int
heads_total
=
launch_params
.
params
.
b
*
launch_params
.
params
.
h
;
std
::
tie
(
launch_params
.
num_full_heads
,
launch_params
.
num_main_groups
,
launch_params
.
heads_last_wave
,
launch_params
.
main_steps
,
launch_params
.
rest_steps
,
launch_params
.
elts_per_thread
)
=
fmha
::
work_dist
<
Kernel_traits
>
(
total_ctas
,
heads_total
);
return
;
}
dim3
grid
(
total_ctas
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
launch_params
.
stream
>>>
(
launch_params
.
params
,
launch_params
.
num_full_heads
,
launch_params
.
num_main_groups
,
launch_params
.
heads_last_wave
,
launch_params
.
main_steps
,
launch_params
.
rest_steps
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
}
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu
View file @
96850dfa
...
@@ -27,72 +27,111 @@
...
@@ -27,72 +27,111 @@
#include "fmha.h"
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_fprop_kernel_1xN_nl.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
512
,
64
,
16
,
1
,
8
,
0x0
8
u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
512
,
64
,
16
,
1
,
8
,
0x0
0
u
>
;
extern
"C"
__global__
void
fmha_fprop_fp16_512_64_sm80_train_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
template
<
bool
Is_training
>
fmha
::
device_1xN
<
Kernel_traits
,
true
>
(
params
);
__global__
}
void
fmha_fprop_fp16_512_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
,
const
int
total_heads
)
{
extern
"C"
__global__
void
fmha_fprop_fp16_512_64_sm80_predict_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN
<
Kernel_traits
,
false
>
(
params
);
}
template
<
int
CHUNKS
>
fmha
::
device_1xN
<
Kernel_traits
,
Is_training
>
(
params
,
total_heads
);
__global__
void
fmha_fprop_fp16_512_64_sm80_train_nl_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN_nl
<
CHUNKS
,
Kernel_traits
,
true
>
(
params
);
}
}
template
<
int
CHUNKS
>
template
<
bool
Is_training
>
__global__
void
fmha_fprop_fp16_512_64_sm80_predict_nl_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
__global__
fmha
::
device_1xN_nl
<
CHUNKS
,
Kernel_traits
,
false
>
(
params
);
void
fmha_fprop_fp16_512_64_sm80_kernel_nl
(
Fused_multihead_attention_fprop_params
params
,
const
int
num_full_heads
,
const
int
num_main_groups
,
const
int
main_group_size
,
const
int
main_steps
,
const
int
rest_steps
)
{
fmha
::
device_1xN
<
Kernel_traits
,
Is_training
>
(
params
,
num_full_heads
,
num_main_groups
,
main_group_size
,
main_steps
,
rest_steps
);
}
}
void
run_fmha_fp16_512_64_sm80_
(
Launch_params
<
Fused_multihead_attention_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
void
run_
fmha_fp16_512_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
)
{
auto
kernel
=
launch_params
.
is_training
?
&
fmha_
fprop_
fp16_512_64_sm80
_kernel
<
true
>
:
&
fmha_fprop_fp16_512_64_sm80_kernel
<
false
>
;
auto
kernel
=
is_training
?
&
fmha_fprop_fp16_512_64_sm80_train_kernel
:
&
fmha_fprop_fp16_512_64_sm80_predict_kernel
;
constexpr
int
smem_size
=
fmha
::
get_dynamic_smem_size
<
Kernel_traits
>
()
;
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
constexpr
int
smem_size
=
smem_size_q
+
std
::
max
(
smem_size_v
,
smem_size_o
+
smem_size_softmax
);
if
(
smem_size
>=
48
*
1024
)
{
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
}
dim3
grid
(
params
.
h
,
params
.
b
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
}
void
run_fmha_fp16_512_64_sm80_nl
(
const
Fused_multihead_attention_fprop_params
&
params
,
const
bool
is_training
,
const
int
num_chunks
,
cudaStream_t
stream
)
{
const
int
sm_count
=
launch_params
.
props
->
multiProcessorCount
;
int
ctas_per_sm
;
auto
kernel
=
is_training
?
&
fmha_fprop_fp16_512_64_sm80_train_nl_kernel
<
2
>
:
&
fmha_fprop_fp16_512_64_sm80_predict_nl_kernel
<
2
>
;
FMHA_CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS
,
smem_size
));
if
(
num_chunks
==
2
)
{
int
total_ctas
=
sm_count
*
ctas_per_sm
;
kernel
=
is_training
?
&
fmha_fprop_fp16_512_64_sm80_train_nl_kernel
<
2
>
:
&
fmha_fprop_fp16_512_64_sm80_predict_nl_kernel
<
2
>
;
const
int
heads_total
=
launch_params
.
params
.
b
*
launch_params
.
params
.
h
;
}
else
if
(
num_chunks
==
3
)
{
if
(
configure
)
{
kernel
=
is_training
?
&
fmha_fprop_fp16_512_64_sm80_train_nl_kernel
<
3
>
:
&
fmha_fprop_fp16_512_64_sm80_predict_nl_kernel
<
3
>
;
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
typename
Kernel_traits
::
Cta_tile_p
>
;
}
else
if
(
num_chunks
==
4
)
{
constexpr
size_t
STEPS
=
Kernel_traits
::
Cta_tile_p
::
N
/
Kernel_traits
::
Cta_tile_p
::
M
;
kernel
=
is_training
?
&
fmha_fprop_fp16_512_64_sm80_train_nl_kernel
<
4
>
constexpr
size_t
MMAS_M
=
Mma_tile_p
::
MMAS_M
;
:
&
fmha_fprop_fp16_512_64_sm80_predict_nl_kernel
<
4
>
;
constexpr
size_t
MMAS_N
=
Mma_tile_p
::
MMAS_N
;
}
else
{
assert
(
false
&&
"Unsupported num_chunks"
);
size_t
heads_per_cta
=
((
heads_total
+
total_ctas
-
1
)
/
total_ctas
);
size_t
elts_per_head
=
STEPS
*
MMAS_M
*
MMAS_N
*
8
;
launch_params
.
elts_per_thread
=
heads_per_cta
*
elts_per_head
;
return
;
}
}
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
dim3
grid
(
total_ctas
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
launch_params
.
stream
>>>
(
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
launch_params
.
params
,
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
heads_total
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
void
run_fmha_fp16_512_64_sm80_nl_
(
Launch_params
<
Fused_multihead_attention_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
auto
kernel
=
launch_params
.
is_training
?
&
fmha_fprop_fp16_512_64_sm80_kernel_nl
<
true
>
:
&
fmha_fprop_fp16_512_64_sm80_kernel_nl
<
false
>
;
constexpr
int
smem_size
=
fmha
::
get_dynamic_smem_size
<
Kernel_traits
>
();
constexpr
int
smem_size
=
smem_size_q
+
std
::
max
(
smem_size_v
,
smem_size_o
+
smem_size_softmax
);
if
(
smem_size
>=
48
*
1024
)
{
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
}
dim3
grid
(
params
.
h
,
params
.
b
,
num_chunks
);
const
int
sm_count
=
launch_params
.
props
->
multiProcessorCount
;
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
int
ctas_per_sm
;
FMHA_CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS
,
smem_size
));
int
total_ctas
=
sm_count
*
ctas_per_sm
;
if
(
configure
)
{
const
int
heads_total
=
launch_params
.
params
.
b
*
launch_params
.
params
.
h
;
std
::
tie
(
launch_params
.
num_full_heads
,
launch_params
.
num_main_groups
,
launch_params
.
heads_last_wave
,
launch_params
.
main_steps
,
launch_params
.
rest_steps
,
launch_params
.
elts_per_thread
)
=
fmha
::
work_dist
<
Kernel_traits
>
(
total_ctas
,
heads_total
);
return
;
}
dim3
grid
(
total_ctas
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
launch_params
.
stream
>>>
(
launch_params
.
params
,
launch_params
.
num_full_heads
,
launch_params
.
num_main_groups
,
launch_params
.
heads_last_wave
,
launch_params
.
main_steps
,
launch_params
.
rest_steps
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
void
run_fmha_fp16_512_64_sm80
(
Launch_params
<
Fused_multihead_attention_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
if
(
launch_params
.
is_nl
)
{
run_fmha_fp16_512_64_sm80_nl_
(
launch_params
,
configure
);
}
else
{
run_fmha_fp16_512_64_sm80_
(
launch_params
,
configure
);
}
}
}
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h
View file @
96850dfa
/******************************************************************************
/******************************************************************************
*********************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
*
* Redistribution and use in source and binary forms, with or without
* Redistribution and use in source and binary forms, with or without
...
@@ -35,7 +35,159 @@ namespace fmha {
...
@@ -35,7 +35,159 @@ namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_training
,
typename
Params
>
inline
__device__
void
device_1xN
(
const
Params
&
params
)
{
template
<
typename
Kernel_traits
>
struct
Gemm_Q_K_base
{
using
Smem_tile_o
=
typename
Kernel_traits
::
Smem_tile_o
;
using
Smem_tile_q
=
typename
Kernel_traits
::
Smem_tile_q
;
using
Smem_tile_k
=
typename
Kernel_traits
::
Smem_tile_k
;
using
Fragment_q
=
typename
Smem_tile_q
::
Fragment
;
using
Fragment_k
=
typename
Smem_tile_k
::
Fragment
;
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
static
constexpr
int
SMEM_BYTES_SOFTMAX
=
Cta_tile_p
::
M
*
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
)
*
2
;
__device__
inline
Gemm_Q_K_base
(
char
*
smem_ptr_q
,
char
*
smem_ptr_k
,
const
int
tidx
)
:
smem_q
(
smem_ptr_q
,
tidx
)
,
smem_k
(
smem_ptr_k
,
tidx
)
{
}
__device__
inline
void
load_q
()
{
smem_q
.
load
(
frag_q
[
0
],
0
);
}
__device__
inline
void
reload_q
()
{
smem_q
.
load
(
frag_q
[
0
],
0
);
}
Fragment_q
frag_q
[
2
][
Mma_tile_p
::
MMAS_M
];
Smem_tile_q
smem_q
;
Smem_tile_k
smem_k
;
};
template
<
typename
Kernel_traits
,
bool
K_in_regs
>
struct
Gemm_Q_K
:
public
Gemm_Q_K_base
<
Kernel_traits
>
{
using
Base
=
Gemm_Q_K_base
<
Kernel_traits
>
;
using
Smem_tile_o
=
typename
Base
::
Smem_tile_o
;
using
Smem_tile_q
=
typename
Base
::
Smem_tile_q
;
using
Smem_tile_k
=
typename
Base
::
Smem_tile_k
;
using
Fragment_k
=
typename
Base
::
Fragment_k
;
using
Mma_tile_p
=
typename
Base
::
Mma_tile_p
;
enum
{
SHARE_SMEM_FOR_K_AND_V
=
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
};
enum
{
SMEM_OFFSET_O
=
Smem_tile_q
::
BYTES_PER_TILE
};
enum
{
SMEM_OFFSET_V
=
Smem_tile_q
::
BYTES_PER_TILE
+
(
SHARE_SMEM_FOR_K_AND_V
?
0
:
Smem_tile_k
::
BYTES_PER_TILE
)
};
// Q | K / V
// | O | SOFTMAX
static
constexpr
int
SMEM_BYTES
=
Smem_tile_q
::
BYTES_PER_TILE
+
std
::
max
((
SHARE_SMEM_FOR_K_AND_V
?
1
:
2
)
*
Smem_tile_k
::
BYTES_PER_TILE
,
Smem_tile_o
::
BYTES_PER_TILE
+
Base
::
SMEM_BYTES_SOFTMAX
);
__device__
inline
Gemm_Q_K
(
char
*
smem_
,
const
int
tidx
)
:
Base
(
smem_
,
smem_
+
Smem_tile_q
::
BYTES_PER_TILE
,
tidx
)
{
}
__device__
inline
void
load_k
(){
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
Base
::
smem_k
.
load
(
frag_k
[
ki
],
ki
);
}
}
template
<
typename
Acc
,
int
M
,
int
N
>
__device__
inline
void
operator
()(
Acc
(
&
acc_p
)[
M
][
N
]){
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
Base
::
smem_q
.
load
(
Base
::
frag_q
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
}
__device__
inline
void
reload_k
(){
// Noop.
}
Fragment_k
frag_k
[
Mma_tile_p
::
MMAS_K
][
Mma_tile_p
::
MMAS_N
];
};
template
<
typename
Kernel_traits
>
struct
Gemm_Q_K
<
Kernel_traits
,
false
>
:
public
Gemm_Q_K_base
<
Kernel_traits
>
{
using
Base
=
Gemm_Q_K_base
<
Kernel_traits
>
;
using
Smem_tile_o
=
typename
Base
::
Smem_tile_o
;
using
Smem_tile_q
=
typename
Base
::
Smem_tile_q
;
using
Smem_tile_k
=
typename
Base
::
Smem_tile_k
;
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_v
;
using
Fragment_k
=
typename
Base
::
Fragment_k
;
using
Mma_tile_p
=
typename
Base
::
Mma_tile_p
;
Fragment_k
frag_k
[
2
][
Mma_tile_p
::
MMAS_N
];
enum
{
SHARE_SMEM_FOR_K_AND_V
=
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
};
enum
{
SMEM_OFFSET_V
=
Smem_tile_q
::
BYTES_PER_TILE
+
(
SHARE_SMEM_FOR_K_AND_V
?
0
:
Smem_tile_k
::
BYTES_PER_TILE
)
};
static_assert
(
Smem_tile_v
::
BYTES_PER_TILE
==
(
int
)
Smem_tile_k
::
BYTES_PER_TILE
);
enum
{
SMEM_OFFSET_O
=
SMEM_OFFSET_V
+
Smem_tile_v
::
BYTES_PER_TILE
};
// Q | K/V + O + SOFTMAX
static
constexpr
int
SMEM_BYTES
=
Smem_tile_q
::
BYTES_PER_TILE
+
(
SHARE_SMEM_FOR_K_AND_V
?
1
:
2
)
*
Smem_tile_k
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
+
Base
::
SMEM_BYTES_SOFTMAX
;
__device__
inline
Gemm_Q_K
(
char
*
smem_
,
const
int
tidx
)
:
Base
(
smem_
,
smem_
+
Smem_tile_q
::
BYTES_PER_TILE
,
tidx
)
{
}
__device__
inline
void
load_k
(){
Base
::
smem_k
.
load
(
frag_k
[
0
],
0
);
}
template
<
typename
Acc
,
int
M
,
int
N
>
__device__
inline
void
operator
()(
Acc
(
&
acc_p
)[
M
][
N
]){
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
Base
::
smem_q
.
load
(
Base
::
frag_q
[
ki
&
1
],
ki
);
Base
::
smem_k
.
load
(
frag_k
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
}
}
__device__
inline
void
reload_k
(){
Base
::
smem_k
.
load
(
frag_k
[
0
],
0
);
}
};
template
<
typename
Kernel_traits
>
constexpr
size_t
get_dynamic_smem_size
(){
return
Gemm_Q_K
<
Kernel_traits
,
Kernel_traits
::
K_IN_REGS
>::
SMEM_BYTES
;
}
template
<
typename
Kernel_traits
,
bool
Is_training
,
typename
Params
,
typename
Prng
>
inline
__device__
void
device_1xN_
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
begin
,
const
int
steps
,
Prng
&
ph
)
{
// The description of the CTA tile for the 1st batched GEMM.
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
...
@@ -49,13 +201,9 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
...
@@ -49,13 +201,9 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
// The global memory tile to load Q.
// The global memory tile to load Q.
using
Gmem_tile_q
=
typename
Kernel_traits
::
Gmem_tile_q
;
using
Gmem_tile_q
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The shared memory tile to swizzle Q.
using
Smem_tile_q
=
typename
Kernel_traits
::
Smem_tile_q
;
// The global memory tile to load K.
// The global memory tile to load K.
using
Gmem_tile_k
=
typename
Kernel_traits
::
Gmem_tile_k
;
using
Gmem_tile_k
=
typename
Kernel_traits
::
Gmem_tile_k
;
// The shared memory tile to swizzle K.
using
Smem_tile_k
=
typename
Kernel_traits
::
Smem_tile_k
;
// The global memory tile to load V.
// The global memory tile to load V.
using
Gmem_tile_v
=
typename
Kernel_traits
::
Gmem_tile_v
;
using
Gmem_tile_v
=
typename
Kernel_traits
::
Gmem_tile_v
;
...
@@ -69,81 +217,88 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
...
@@ -69,81 +217,88 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
using
Gmem_tile_s
=
typename
Kernel_traits
::
Gmem_tile_s
;
using
Gmem_tile_s
=
typename
Kernel_traits
::
Gmem_tile_s
;
using
Gemm1
=
Gemm_Q_K
<
Kernel_traits
,
Kernel_traits
::
K_IN_REGS
>
;
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
// The number of threads per row.
enum
{
THREADS_PER_ROW
=
32
};
enum
{
BITS_PER_ELT_S
=
sizeof
(
fmha
::
A_type
)
*
8
};
// Shared memory.
// Shared memory.
extern
__shared__
char
smem_
[];
extern
__shared__
char
smem_
[];
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The thread index.
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
const
int
tidx
=
threadIdx
.
x
;
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
if
(
binfo
.
stop_early
()
)
if
(
binfo
.
stop_early
()
)
return
;
return
;
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
Philox
ph
(
std
::
get
<
0
>
(
seeds
),
binfo
.
tidx_global
,
std
::
get
<
1
>
(
seeds
));
Mask
<
Cta_tile_p
>
mask
(
params
,
binfo
,
tidx
);
Gemm1
gemm_q_k
(
smem_
,
tidx
);
// Allocate the global memory tile loader for Q.
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for Q.
// Allocate the global memory tile loader for O.
Smem_tile_q
smem_q
(
&
smem_
[
0
],
tidx
);
Gmem_tile_o
gmem_o
(
params
,
binfo
,
tidx
);
// Allocate the global memory tile loader for S.
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
// Wind gmem tiles to the correct position.
for
(
int
it
=
0
;
it
<
begin
;
it
++
)
{
gmem_q
.
move
();
gmem_s
.
move
();
gmem_o
.
move
();
}
fmha
::
Mask
<
Cta_tile_p
>
mask
(
params
,
binfo
,
tidx
);
// Allocate the global memory tile loader for K.
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for K.
Smem_tile_k
smem_k
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for V.
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
,
2
,
binfo
,
tidx
);
Gmem_tile_v
gmem_v
(
params
,
2
,
binfo
,
tidx
);
// The base pointer of smem_v;
// The base pointer of smem_v;
char
*
smem_v_
=
nullptr
;
char
*
smem_v_
=
&
smem_
[
Gemm1
::
SMEM_OFFSET_V
];
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
smem_v_
=
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
];
}
else
{
smem_v_
=
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
];
}
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v
smem_v
(
smem_v_
,
tidx
);
Smem_tile_v
smem_v
(
smem_v_
,
tidx
);
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o
smem_o
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
Smem_tile_o
smem_o
(
&
smem_
[
Gemm1
::
SMEM_OFFSET_O
],
tidx
);
// Trigger the loads for Q.
gmem_q
.
load
(
smem_q
);
// Trigger the loads for K.
gmem_k
.
load
(
smem_k
);
// Trigger the loads for K.
// Trigger the loads for K.
gmem_k
.
load
(
gemm_q_k
.
smem_k
);
// Trigger the loads for Q.
gmem_q
.
load
(
gemm_q_k
.
smem_q
);
// Trigger the loads for V.
gmem_v
.
load
(
smem_v
);
gmem_v
.
load
(
smem_v
);
// Commit the data for Q and K to shared memory.
const
uint32_t
scale_bmm1
=
reinterpret_cast
<
const
uint32_t
&>
(
params
.
scale_bmm1
);
gmem_q
.
commit
(
smem_q
);
#pragma unroll
gmem_k
.
commit
(
smem_k
);
for
(
int
it
=
0
;
it
<
Gmem_tile_k
::
LDGS
;
it
++
){
gmem_k
.
fetch_
[
it
]
=
fmha
::
hmul8
(
scale_bmm1
,
gmem_k
.
fetch_
[
it
]);
}
// Commit the data for V to shared memory.
// Commit the data for Q and V to shared memory.
gmem_q
.
commit
(
gemm_q_k
.
smem_q
);
gmem_v
.
commit
(
smem_v
);
// Commit the data for K to shared memory.
if
(
!
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
if
(
!
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
gmem_
v
.
commit
(
smem_
v
);
gmem_
k
.
commit
(
gemm_q_k
.
smem_
k
);
}
}
// Make sure the data is in shared memory.
__syncthreads
();
__syncthreads
();
// Load the fragments for Q.
// Load the fragments for Q.
typename
Smem_tile_q
::
Fragment
frag_q
[
2
][
Mma_tile_p
::
MMAS_M
];
gemm_q_k
.
load_q
();
smem_q
.
load
(
frag_q
[
0
],
0
);
// Load the fragments for
K
. We keep the data in registers during the entire kernel.
// Load the fragments for
V
. We keep the data in registers during the entire kernel.
typename
Smem_tile_
k
::
Fragment
frag_
k
[
Mma_tile_
p
::
MMAS_K
][
Mma_tile_
p
::
MMAS_N
];
typename
Smem_tile_
v
::
Fragment
frag_
v
[
Mma_tile_
o
::
MMAS_K
][
Mma_tile_
o
::
MMAS_N
];
#pragma unroll
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_
p
::
MMAS_K
;
++
ki
)
{
for
(
int
ki
=
0
;
ki
<
Mma_tile_
o
::
MMAS_K
;
++
ki
)
{
smem_
k
.
load
(
frag_
k
[
ki
],
ki
);
smem_
v
.
load
(
frag_
v
[
ki
],
ki
);
}
}
// Commit the data for V to shared memory if it has not been done already.
// Commit the data for V to shared memory if it has not been done already.
...
@@ -152,61 +307,41 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
...
@@ -152,61 +307,41 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
__syncthreads
();
__syncthreads
();
// Commit the data to shared memory for V.
// Commit the data to shared memory for V.
gmem_
v
.
commit
(
smem_
v
);
gmem_
k
.
commit
(
gemm_q_k
.
smem_
k
);
// Make sure the data is in shared memory.
// Make sure the data is in shared memory.
__syncthreads
();
__syncthreads
();
}
}
// Load the fragments for V. We keep the data in registers during the entire kernel.
// Load the fragments for K.
typename
Smem_tile_v
::
Fragment
frag_v
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_N
];
gemm_q_k
.
load_k
();
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
smem_v
.
load
(
frag_v
[
ki
],
ki
);
}
enum
{
BITS_PER_ELT_S
=
sizeof
(
fmha
::
A_type
)
*
8
};
Gmem_tile_s
gmem_s
(
params
.
s_ptr
,
params
,
tidx
);
// Create the object to do the softmax.
// Create the object to do the softmax.
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
Softmax
softmax
(
params
,
&
smem_
[
Gemm1
::
SMEM_OFFSET_O
+
Smem_tile_o
::
BYTES_PER_TILE
],
bidb
,
tidx
);
Softmax
softmax
(
params
,
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
],
bidb
,
tidx
);
enum
{
THREADS_PER_ROW
=
32
};
enum
{
STEPS
=
Cta_tile_p
::
N
/
Cta_tile_p
::
M
};
// Load over the entire sequence length.
// Load over the entire sequence length.
for
(
int
l
=
0
;
l
<
STEPS
;
l
++
)
{
for
(
int
l
=
0
;
l
<
steps
;
l
++
)
{
const
int
loop
=
l
*
Cta_tile_p
::
M
;
if
(
begin
+
l
*
Cta_tile_p
::
M
>=
binfo
.
actual_seqlen
)
break
;
if
(
loop
>=
binfo
.
actual_seqlen
)
break
;
// Declare the accumulators for the 1st gemm.
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_p
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Fragment_accumulator
acc_p
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_p
);
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_p
);
// Do this part of P^T = (Q * K^T)^T.
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
gemm_q_k
(
acc_p
);
for
(
int
ki
=
1
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
// Trigger the load for the next Q values.
smem_q
.
load
(
frag_q
[
ki
&
1
],
ki
);
if
(
l
<
steps
-
1
)
{
// Do the math for the values already in registers.
gemm_q_k
.
smem_q
.
move_to_next_write_buffer
();
fmha
::
gemm
(
acc_p
,
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
gmem_q
.
move
();
}
gmem_q
.
load
(
gemm_q_k
.
smem_q
);
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm
(
acc_p
,
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
}
// Load the mask for that iteration.
// Load the mask for that iteration.
mask
.
load
(
l
);
mask
.
load
(
begin
+
l
);
// Convert from the accumulator type to FP32 for Softmax.
// Convert from the accumulator type to FP32 for Softmax.
softmax
.
unpack
(
acc_p
);
softmax
.
unpack
_noscale
(
acc_p
);
// Apply the mask.
// Apply the mask.
softmax
.
apply_mask
(
mask
);
softmax
.
apply_mask
(
mask
);
...
@@ -217,21 +352,21 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
...
@@ -217,21 +352,21 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
}
}
// Compute the max.
// Compute the max.
float
p_max
[
Mma_tile_p
::
MMAS_M
*
2
];
float
p_max
[
Mma_tile_p
::
MMAS_M
*
2
];
softmax
.
template
reduce
<
fmha
::
Max_
>(
p_max
);
//softmax.template reduce<fmha::Max_>(p_max);
softmax
.
reduce_max
(
p_max
);
// Make sure we are done reading shared memory.
__syncthreads
();
// Compute the exponential value.
// Compute the exponential value.
softmax
.
apply_exp
(
p_max
);
softmax
.
apply_exp
(
p_max
);
// Compute the sum.
// Compute the sum.
float
p_sum
[
Mma_tile_p
::
MMAS_M
*
2
];
float
p_sum
[
Mma_tile_p
::
MMAS_M
*
2
];
softmax
.
template
reduce
<
fmha
::
Sum_
>
(
p_sum
);
softmax
.
reduce_sum
(
p_sum
);
// Finalize softmax on the accumulators of P^T.
// Finalize softmax on the accumulators of P^T.
softmax
.
scale
(
p_sum
);
softmax
.
scale
(
p_sum
);
using
Frag_p
=
fmha
::
Fragment_a
<
fmha
::
Row
>
;
Frag_p
frag_p
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_M
];
if
(
Is_training
)
{
if
(
Is_training
)
{
auto
encode_dropout
=
[](
bool
keep
,
float
val
)
{
return
keep
?
val
:
-
val
;
};
auto
encode_dropout
=
[](
bool
keep
,
float
val
)
{
return
keep
?
val
:
-
val
;
};
#pragma unroll
#pragma unroll
...
@@ -241,8 +376,7 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
...
@@ -241,8 +376,7 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
#pragma unroll
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
float4
tmp
=
uniform4
(
ph
());
float4
tmp
=
uniform4
(
ph
());
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from pre-existing zeros
// pre-existing zeros
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]
=
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]
=
encode_dropout
(
tmp
.
x
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]);
encode_dropout
(
tmp
.
x
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]
=
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]
=
...
@@ -254,20 +388,18 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
...
@@ -254,20 +388,18 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
}
}
}
}
}
}
gmem_s
.
store
(
softmax
.
elt_
,
mask
);
softmax
.
pack
(
frag_p
);
gmem_s
.
store
(
frag_p
,
mask
);
gmem_s
.
move
();
gmem_s
.
move
();
}
else
{
softmax
.
pack
(
frag_p
);
}
}
// Trigger the load for the next Q values.
// Commit the values for Q into shared memory.
if
(
l
<
STEPS
-
1
)
{
if
(
l
<
steps
-
1
)
{
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
commit
(
gemm_q_k
.
smem_q
);
gmem_q
.
move
();
gmem_q
.
load
(
smem_q
);
}
}
using
Frag_p
=
fmha
::
Fragment_a
<
fmha
::
Row
>
;
Frag_p
frag_p
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_M
];
softmax
.
pack
(
frag_p
);
#pragma unroll
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
ki
++
)
{
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
ki
++
)
{
#pragma unroll
#pragma unroll
...
@@ -316,21 +448,84 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
...
@@ -316,21 +448,84 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
// Move to the next part of the output.
// Move to the next part of the output.
gmem_o
.
move
();
gmem_o
.
move
();
gemm_q_k
.
reload_k
();
// Commit the values for Q into shared memory.
// Commit the values for Q into shared memory.
if
(
l
<
STEPS
-
1
)
{
if
(
l
<
steps
-
1
)
{
g
m
em_q
.
commit
(
smem
_q
);
ge
m
m_q
_k
.
reload
_q
(
);
}
}
// Make sure the data is in shared memory.
}
// Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_training
,
typename
Params
>
inline
__device__
void
device_1xN
(
const
Params
&
params
,
const
int
num_full_heads
,
const
int
num_main_groups
,
const
int
main_group_size
,
const
int
main_steps
,
const
int
rest_steps
)
{
constexpr
int
STEPS
=
Kernel_traits
::
Cta_tile_p
::
N
/
Kernel_traits
::
Cta_tile_p
::
M
;
const
int
tidx_global
=
blockIdx
.
x
*
gridDim
.
x
+
threadIdx
.
x
;
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
Philox
ph
(
std
::
get
<
0
>
(
seeds
),
tidx_global
,
std
::
get
<
1
>
(
seeds
));
for
(
int
it
=
0
;
it
<
num_full_heads
;
it
++
)
{
const
int
bidx
=
it
*
gridDim
.
x
+
blockIdx
.
x
;
const
int
bidh
=
bidx
%
params
.
h
;
const
int
bidb
=
bidx
/
params
.
h
;
fmha
::
device_1xN_
<
Kernel_traits
,
Is_training
>
(
params
,
bidb
,
bidh
,
0
,
STEPS
,
ph
);
__syncthreads
();
__syncthreads
();
}
if
(
main_group_size
==
0
)
return
;
const
int
head_offset
=
num_full_heads
*
gridDim
.
x
;
if
(
blockIdx
.
x
<
main_group_size
*
num_main_groups
)
{
// process within heads
const
int
group
=
blockIdx
.
x
%
num_main_groups
;
const
int
bidx
=
blockIdx
.
x
/
num_main_groups
;
const
int
bidh
=
(
head_offset
+
bidx
)
%
params
.
h
;
const
int
bidb
=
(
head_offset
+
bidx
)
/
params
.
h
;
const
int
offset
=
group
*
main_steps
;
fmha
::
device_1xN_
<
Kernel_traits
,
Is_training
>
(
params
,
bidb
,
bidh
,
offset
,
main_steps
,
ph
);
}
else
{
if
(
rest_steps
==
0
)
return
;
// process across heads
const
int
bidx
=
blockIdx
.
x
-
main_group_size
*
num_main_groups
;
const
int
offset
=
num_main_groups
*
main_steps
;
const
int
total_heads
=
params
.
b
*
params
.
h
;
const
int
rest_ctas
=
gridDim
.
x
-
main_group_size
*
num_main_groups
;
for
(
int
it
=
head_offset
+
bidx
;
it
<
total_heads
;
it
+=
rest_ctas
)
{
const
int
bidh
=
it
%
params
.
h
;
const
int
bidb
=
it
/
params
.
h
;
fmha
::
device_1xN_
<
Kernel_traits
,
Is_training
>
(
params
,
bidb
,
bidh
,
offset
,
rest_steps
,
ph
);
__syncthreads
();
}
}
}
// Trigger the loads for the values of Q for the next iteration.
////////////////////////////////////////////////////////////////////////////////////////////////////
smem_q
.
load
(
frag_q
[
0
],
0
);
}
// Outer loop over the sequence length.
template
<
typename
Kernel_traits
,
bool
Is_training
,
typename
Params
>
inline
__device__
void
device_1xN
(
const
Params
&
params
,
const
int
total_heads
)
{
const
int
tidx_global
=
blockIdx
.
x
*
gridDim
.
x
+
threadIdx
.
x
;
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
Philox
ph
(
std
::
get
<
0
>
(
seeds
),
tidx_global
,
std
::
get
<
1
>
(
seeds
));
constexpr
int
STEPS
=
Kernel_traits
::
Cta_tile_p
::
N
/
Kernel_traits
::
Cta_tile_p
::
M
;
for
(
int
bidx
=
blockIdx
.
x
;
bidx
<
total_heads
;
bidx
+=
gridDim
.
x
){
const
int
bidh
=
bidx
%
params
.
h
;
const
int
bidb
=
bidx
/
params
.
h
;
fmha
::
device_1xN_
<
Kernel_traits
,
Is_training
>
(
params
,
bidb
,
bidh
,
0
,
STEPS
,
ph
);
__syncthreads
();
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_nl.h
deleted
100644 → 0
View file @
87fc4125
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
CHUNKS
,
typename
Kernel_traits
,
bool
Is_training
,
typename
Params
>
inline
__device__
void
device_1xN_nl
(
const
Params
&
params
)
{
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
// The description of the CTA tile for the 2nd batched GEMM.
using
Cta_tile_o
=
typename
Kernel_traits
::
Cta_tile_o
;
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
// The MMA tile for the 2nd GEMM.
using
Mma_tile_o
=
fmha
::
Hmma_tile
<
Cta_tile_o
>
;
// The global memory tile to load Q.
using
Gmem_tile_q
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The shared memory tile to swizzle Q.
using
Smem_tile_q
=
typename
Kernel_traits
::
Smem_tile_q
;
// The global memory tile to load K.
using
Gmem_tile_k
=
typename
Kernel_traits
::
Gmem_tile_k
;
// The shared memory tile to swizzle K.
using
Smem_tile_k
=
typename
Kernel_traits
::
Smem_tile_k
;
// The global memory tile to load V.
using
Gmem_tile_v
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_v
;
// The global memory tile to store O.
using
Gmem_tile_o
=
typename
Kernel_traits
::
Gmem_tile_o
;
// The shared memory tile to swizzle O.
using
Smem_tile_o
=
typename
Kernel_traits
::
Smem_tile_o
;
// The global memory tile to store S/D.
using
Gmem_tile_s
=
typename
Kernel_traits
::
Gmem_tile_s
;
using
Noloop
=
Noloop_traits
<
CHUNKS
,
Cta_tile_p
>
;
// Shared memory.
extern
__shared__
char
smem_
[];
const
int
bidc
=
blockIdx
.
z
;
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
Noloop
nl_traits
(
bidc
);
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
if
(
binfo
.
stop_early
()
)
return
;
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
Philox
ph
(
std
::
get
<
0
>
(
seeds
),
binfo
.
tidx_global
,
std
::
get
<
1
>
(
seeds
));
fmha
::
Mask
<
Cta_tile_p
>
mask
(
params
,
binfo
,
tidx
);
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for Q.
Smem_tile_q
smem_q
(
&
smem_
[
0
],
tidx
);
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for K.
Smem_tile_k
smem_k
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
,
2
,
binfo
,
tidx
);
// The base pointer of smem_v;
char
*
smem_v_
=
nullptr
;
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
smem_v_
=
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
];
}
else
{
smem_v_
=
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
];
}
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v
smem_v
(
smem_v_
,
tidx
);
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o
smem_o
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
Gmem_tile_s
gmem_s
(
params
.
s_ptr
,
params
,
tidx
);
nl_traits
.
move_all
(
gmem_q
,
gmem_o
,
gmem_s
);
// Trigger the loads for Q.
gmem_q
.
load
(
smem_q
);
// Trigger the loads for K.
gmem_k
.
load
(
smem_k
);
// Trigger the loads for K.
gmem_v
.
load
(
smem_v
);
// Commit the data for Q and K to shared memory.
gmem_q
.
commit
(
smem_q
);
gmem_k
.
commit
(
smem_k
);
// Commit the data for V to shared memory.
if
(
!
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
gmem_v
.
commit
(
smem_v
);
}
// Make sure the data is in shared memory.
__syncthreads
();
// Load the fragments for Q.
typename
Smem_tile_q
::
Fragment
frag_q
[
2
][
Mma_tile_p
::
MMAS_M
];
smem_q
.
load
(
frag_q
[
0
],
0
);
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename
Smem_tile_k
::
Fragment
frag_k
[
Mma_tile_p
::
MMAS_K
][
Mma_tile_p
::
MMAS_N
];
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
smem_k
.
load
(
frag_k
[
ki
],
ki
);
}
// Commit the data for V to shared memory if it has not been done already.
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
// Make sure we are done loading the fragments for K.
__syncthreads
();
// Commit the data to shared memory for V.
gmem_v
.
commit
(
smem_v
);
// Make sure the data is in shared memory.
__syncthreads
();
}
// Load the fragments for V. We keep the data in registers during the entire kernel.
typename
Smem_tile_v
::
Fragment
frag_v
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_N
];
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
smem_v
.
load
(
frag_v
[
ki
],
ki
);
}
enum
{
BITS_PER_ELT_S
=
sizeof
(
fmha
::
A_type
)
*
8
};
// Create the object to do the softmax.
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
Softmax
softmax
(
params
,
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
],
bidb
,
tidx
);
// The number of threads per row.
enum
{
THREADS_PER_ROW
=
32
};
// Load over the entire sequence length.
for
(
int
l
=
0
;
l
<
nl_traits
.
num_steps_
;
l
++
)
{
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_p
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_p
);
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_q
.
load
(
frag_q
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_p
,
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm
(
acc_p
,
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
// Trigger the load for the next Q values.
if
(
l
<
nl_traits
.
num_steps_
-
1
)
{
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
load
(
smem_q
);
}
// Load the mask for that iteration.
mask
.
load
(
nl_traits
.
loop_offset_
+
l
);
// Convert from the accumulator type to FP32 for Softmax.
softmax
.
unpack
(
acc_p
);
// Apply the mask.
softmax
.
apply_mask
(
mask
);
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
&&
l
==
0
)
{
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads
();
}
// Compute the max.
float
p_max
[
Mma_tile_p
::
MMAS_M
*
2
];
softmax
.
template
reduce
<
fmha
::
Max_
>(
p_max
);
// Make sure we are done reading shared memory.
__syncthreads
();
// Compute the exponential value.
softmax
.
apply_exp
(
p_max
);
// Compute the sum.
float
p_sum
[
Mma_tile_p
::
MMAS_M
*
2
];
softmax
.
template
reduce
<
fmha
::
Sum_
>(
p_sum
);
// Finalize softmax on the accumulators of P^T.
softmax
.
scale
(
p_sum
);
if
(
Is_training
)
{
auto
encode_dropout
=
[](
bool
keep
,
float
val
)
{
return
keep
?
val
:
-
val
;
};
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
2
;
ii
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
float4
tmp
=
uniform4
(
ph
());
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from pre-existing zeros
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]
=
encode_dropout
(
tmp
.
x
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]
=
encode_dropout
(
tmp
.
y
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
2
]
=
encode_dropout
(
tmp
.
z
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
2
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]
=
encode_dropout
(
tmp
.
w
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]);
}
}
}
gmem_s
.
store
(
softmax
.
elt_
,
mask
);
gmem_s
.
move
();
}
using
Frag_p
=
fmha
::
Fragment_a
<
fmha
::
Row
>
;
Frag_p
frag_p
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_M
];
softmax
.
pack
(
frag_p
);
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
ki
++
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_o
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Frag_p
::
NUM_REGS
;
ii
++
)
{
//"Apply" the dropout.
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hmul2
(
frag_p
[
ki
][
mi
].
reg
(
ii
),
params
.
scale_dropout
);
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hrelu2
(
frag_p
[
ki
][
mi
].
reg
(
ii
));
}
}
}
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_o
[
Mma_tile_o
::
MMAS_M
][
Mma_tile_o
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_o
::
WARPS_K
>::
apply
(
acc_o
);
// Do this part of O = P^T * V^T.
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
fmha
::
gemm
(
acc_o
,
frag_p
[
ki
],
frag_v
[
ki
]);
}
// Loop over MMAS_M.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Gmem_tile_o
::
LOOPS
;
++
ii
)
{
// Swizzle the elements and do the final reduction.
smem_o
.
store
(
acc_o
,
ii
);
// Make sure the data is in shared memory.
__syncthreads
();
// Load from shared memory.
uint4
out
[
Gmem_tile_o
::
STGS_PER_LOOP
];
smem_o
.
load
(
out
);
// Make sure the data was read from shared memory.
if
(
ii
<
Gmem_tile_o
::
LOOPS
-
1
)
{
__syncthreads
();
}
// Output the values.
gmem_o
.
store
(
out
,
ii
);
}
// Move to the next part of the output.
gmem_o
.
move
();
// Commit the values for Q into shared memory.
if
(
l
<
nl_traits
.
num_steps_
-
1
)
{
gmem_q
.
commit
(
smem_q
);
__syncthreads
();
smem_q
.
load
(
frag_q
[
0
],
0
);
}
}
// Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_reload_v.h
deleted
100644 → 0
View file @
87fc4125
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_training
,
typename
Params
>
inline
__device__
void
device_1xN
(
const
Params
&
params
)
{
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
// The description of the CTA tile for the 2nd batched GEMM.
using
Cta_tile_o
=
typename
Kernel_traits
::
Cta_tile_o
;
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
// The MMA tile for the 2nd GEMM.
using
Mma_tile_o
=
fmha
::
Hmma_tile
<
Cta_tile_o
>
;
// The global memory tile to load Q.
using
Gmem_tile_q
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The shared memory tile to swizzle Q.
using
Smem_tile_q
=
typename
Kernel_traits
::
Smem_tile_q
;
// The global memory tile to load K.
using
Gmem_tile_k
=
typename
Kernel_traits
::
Gmem_tile_k
;
// The shared memory tile to swizzle K.
using
Smem_tile_k
=
typename
Kernel_traits
::
Smem_tile_k
;
// The global memory tile to load V.
using
Gmem_tile_v
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_v
;
// The global memory tile to store O.
using
Gmem_tile_o
=
typename
Kernel_traits
::
Gmem_tile_o
;
// The shared memory tile to swizzle O.
using
Smem_tile_o
=
typename
Kernel_traits
::
Smem_tile_o
;
using
Gmem_tile_s
=
typename
Kernel_traits
::
Gmem_tile_s
;
// Shared memory.
extern
__shared__
char
smem_
[];
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
if
(
binfo
.
stop_early
()
)
return
;
Mask
<
Cta_tile_p
>
mask
(
params
,
binfo
,
tidx
);
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
Philox
ph
(
std
::
get
<
0
>
(
seeds
),
binfo
.
tidx_global
,
std
::
get
<
1
>
(
seeds
));
static_assert
(
2
*
Mma_tile_p
::
MMAS_M
*
4
*
Mma_tile_p
::
MMAS_N
<=
64
);
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for K.
Smem_tile_k
smem_k
(
&
smem_
[
0
],
tidx
);
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
,
2
,
binfo
,
tidx
);
// The base pointer of smem_v;
char
*
smem_v_
=
nullptr
;
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
smem_v_
=
&
smem_
[
0
];
}
else
{
smem_v_
=
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
];
}
static_assert
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
);
static_assert
(
Smem_tile_k
::
BYTES_PER_TILE
==
Smem_tile_v
::
BYTES_PER_TILE
);
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v
smem_v
(
smem_v_
,
tidx
);
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for Q.
Smem_tile_q
smem_q
(
&
smem_
[
Smem_tile_v
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o
smem_o
(
&
smem_
[
Smem_tile_v
::
BYTES_PER_TILE
],
tidx
);
// Trigger the loads for Q.
gmem_q
.
load
(
smem_q
);
// Trigger the loads for K.
gmem_k
.
load
(
smem_k
);
// Trigger the loads for K.
gmem_v
.
load
(
smem_v
);
// Commit the data for Q and K to shared memory.
gmem_q
.
commit
(
smem_q
);
gmem_k
.
commit
(
smem_k
);
// Commit the data for V to shared memory.
if
(
!
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
gmem_v
.
commit
(
smem_v
);
}
// Make sure the data is in shared memory.
__syncthreads
();
// Load the fragments for Q.
typename
Smem_tile_q
::
Fragment
frag_q
[
1
][
Mma_tile_p
::
MMAS_M
];
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename
Smem_tile_k
::
Fragment
frag_k
[
Mma_tile_p
::
MMAS_K
][
Mma_tile_p
::
MMAS_N
];
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
smem_k
.
load
(
frag_k
[
ki
],
ki
);
}
// Commit the data for V to shared memory if it has not been done already.
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
// Make sure we are done loading the fragments for K.
__syncthreads
();
// Commit the data to shared memory for V.
gmem_v
.
commit
(
smem_v
);
}
enum
{
BITS_PER_ELT_S
=
sizeof
(
typename
fmha
::
A_type
)
*
8
};
Gmem_tile_s
gmem_s
(
params
.
s_ptr
,
params
,
tidx
);
// Create the object to do the softmax.
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
Softmax
softmax
(
params
,
&
smem_
[
Smem_tile_v
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
],
bidb
,
tidx
);
constexpr
int
SMEM_BYTES_SOFTMAX
=
Softmax
::
ELEMENTS
*
sizeof
(
float
);
static_assert
(
SMEM_BYTES_SOFTMAX
==
Cta_tile_p
::
M
*
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
));
enum
{
THREADS_PER_ROW
=
32
};
const
float
pinv
=
1.
f
/
params
.
p_dropout
;
// Load over the entire sequence length.
for
(
int
loop
=
0
,
outer
=
0
;
loop
<
Cta_tile_p
::
N
;
loop
+=
Cta_tile_p
::
M
,
outer
++
)
{
if
(
loop
>=
binfo
.
actual_seqlen
)
break
;
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_p
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_p
);
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_q
.
load
(
frag_q
[
0
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_p
,
frag_q
[
0
],
frag_k
[
ki
]);
}
// Load the mask for that iteration.
mask
.
load
(
outer
);
// Convert from the accumulator typ e to FP32 for Softmax.
softmax
.
unpack
(
acc_p
);
// Apply the mask.
softmax
.
apply_mask
(
mask
);
static_assert
(
2
*
Mma_tile_p
::
MMAS_M
*
4
*
Mma_tile_p
::
MMAS_N
<=
64
);
// Compute the max.
float
p_max
[
Mma_tile_p
::
MMAS_M
*
2
];
softmax
.
template
reduce
<
fmha
::
Max_
>(
p_max
);
// Make sure we are done reading shared memory.
__syncthreads
();
// Compute the exponential value.
softmax
.
apply_exp
(
p_max
);
// Compute the sum.
float
p_sum
[
Mma_tile_p
::
MMAS_M
*
2
];
softmax
.
template
reduce
<
fmha
::
Sum_
>(
p_sum
);
// Finalize softmax on the accumulators of P^T.
softmax
.
scale
(
p_sum
);
__syncthreads
();
if
(
Is_training
)
{
auto
encode_dropout
=
[](
bool
keep
,
float
val
)
{
return
keep
?
val
:
-
val
;
};
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
2
;
ii
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
float4
tmp
=
uniform4
(
ph
());
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from
// pre-existing zeros
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]
=
encode_dropout
(
tmp
.
x
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]
=
encode_dropout
(
tmp
.
y
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
2
]
=
encode_dropout
(
tmp
.
z
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
2
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]
=
encode_dropout
(
tmp
.
w
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]);
}
}
}
gmem_s
.
store
(
softmax
.
elt_
,
mask
);
gmem_s
.
move
();
}
// Trigger the load for the next Q values.
if
(
loop
+
Cta_tile_p
::
M
<
Cta_tile_p
::
N
)
{
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
load
(
smem_q
);
}
typename
Smem_tile_v
::
Fragment
frag_v
[
1
][
Mma_tile_o
::
MMAS_N
];
using
Frag_p
=
fmha
::
Fragment_a
<
fmha
::
Row
>
;
Frag_p
frag_p
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_M
];
softmax
.
pack
(
frag_p
);
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
ki
++
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_o
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Frag_p
::
NUM_REGS
;
ii
++
)
{
//"Apply" the dropout.
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hmul2
(
frag_p
[
ki
][
mi
].
reg
(
ii
),
params
.
scale_dropout
);
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hrelu2
(
frag_p
[
ki
][
mi
].
reg
(
ii
));
}
}
}
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_o
[
Mma_tile_o
::
MMAS_M
][
Mma_tile_o
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_o
::
WARPS_K
>::
apply
(
acc_o
);
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of V values.
smem_v
.
load
(
frag_v
[
0
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_o
,
frag_p
[
ki
],
frag_v
[
0
]);
}
// Loop over MMAS_M.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Gmem_tile_o
::
LOOPS
;
++
ii
)
{
// Swizzle the elements and do the final reduction.
smem_o
.
store
(
acc_o
,
ii
);
// Make sure the data is in shared memory.
__syncthreads
();
// Load from shared memory.
uint4
out
[
Gmem_tile_o
::
STGS_PER_LOOP
];
smem_o
.
load
(
out
);
// Always sync after last iter: shared smem_q and smem_o!
__syncthreads
();
// Output the values.
gmem_o
.
store
(
out
,
ii
);
}
// same smem as o
// Move to the next part of the output.
gmem_o
.
move
();
// Commit the values for Q into shared memory.
if
(
loop
+
Cta_tile_p
::
M
<
Cta_tile_p
::
N
)
{
gmem_q
.
commit
(
smem_q
);
}
// Make sure the data is in shared memory.
__syncthreads
();
}
// Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha_kernel.h
View file @
96850dfa
...
@@ -79,17 +79,19 @@ struct Noloop_traits{
...
@@ -79,17 +79,19 @@ struct Noloop_traits{
enum
{
STEP
=
Cta_tile
::
M
};
enum
{
STEP
=
Cta_tile
::
M
};
enum
{
SEQLEN
=
Cta_tile
::
N
};
enum
{
SEQLEN
=
Cta_tile
::
N
};
// The size of the subsequence this CTA is processing
template
<
typename
Block_info
>
enum
{
SUBSEQ
=
SEQLEN
/
CHUNKS
};
inline
__device__
Noloop_traits
(
const
int
bidc
,
const
Block_info
&
binfo
)
static_assert
(
SUBSEQ
*
CHUNKS
==
SEQLEN
);
:
bidc_
(
bidc
)
{
const
int
seqlen
=
binfo
.
actual_seqlen
;
const
int
steps
=
(
seqlen
+
STEP
-
1
)
/
STEP
;
const
int
steps_per_chunk
=
(
steps
+
CHUNKS
-
1
)
/
CHUNKS
;
const
int
step_begin
=
bidc_
*
steps_per_chunk
;
const
int
step_end
=
min
(
steps
,
(
bidc_
+
1
)
*
steps_per_chunk
);
const
int
actual_steps
=
max
(
0
,
step_end
-
step_begin
);
loop_offset_
=
step_begin
;
num_steps_
=
actual_steps
;
// The number of steps to process the subsequence
enum
{
NUM_STEPS
=
SUBSEQ
/
STEP
};
static_assert
(
NUM_STEPS
*
Cta_tile
::
M
==
SUBSEQ
);
inline
__device__
Noloop_traits
(
const
int
bidc
)
:
loop_offset_
(
NUM_STEPS
*
bidc
)
,
bidc_
(
bidc
)
{
}
}
template
<
typename
...
Tiles
>
template
<
typename
...
Tiles
>
...
@@ -115,54 +117,62 @@ struct Noloop_traits{
...
@@ -115,54 +117,62 @@ struct Noloop_traits{
return
(
loop_offset_
+
l
)
*
STEP
;
return
(
loop_offset_
+
l
)
*
STEP
;
}
}
const
int
loop_offset_
;
const
uint32_t
bidc_
;
const
uint32_t
bidc_
;
const
int
num_steps_
=
NUM_STEPS
;
int
loop_offset_
;
int
num_steps_
;
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
>
template
<
typename
Kernel_traits
>
struct
Noloop_traits
<
3
,
Cta_tile
>
{
std
::
tuple
<
int
,
int
,
int
,
int
,
int
,
int
>
work_dist
(
const
int
total_ctas
,
const
int
heads_total
)
{
// Interpretation of Cta_tile dims, i.e. Cta_tile_p:
enum
{
STEP
=
Cta_tile
::
M
};
constexpr
int
STEPS_PER_HEAD
=
Kernel_traits
::
Cta_tile_p
::
N
/
Kernel_traits
::
Cta_tile_p
::
M
;
enum
{
SEQLEN
=
Cta_tile
::
N
};
const
int
num_full_heads
=
heads_total
/
total_ctas
;
static_assert
(
STEP
==
16
&&
SEQLEN
==
512
);
const
int
heads_last_wave
=
heads_total
%
total_ctas
;
inline
__device__
Noloop_traits
(
const
int
bidc
)
int
num_main_groups
=
0
;
:
bidc_
(
bidc
)
int
main_steps
=
0
;
,
num_steps_
(
bidc
<
2
?
11
:
10
)
int
rest_steps
=
0
;
,
loop_offset_
(
bidc
*
11
)
{
if
(
heads_last_wave
>
0
)
{
}
// Number of CTA groups that process within heads.
num_main_groups
=
total_ctas
/
heads_last_wave
;
template
<
typename
...
Tiles
>
// Remaining CTAs that process between heads.
inline
__device__
void
move_all
(
Tiles
&
...
tiles
)
const
{
const
int
rest_ctas
=
total_ctas
-
(
heads_last_wave
*
num_main_groups
);
using
expand_type
=
int
[];
if
(
rest_ctas
==
0
)
{
for
(
int
s
=
0
;
s
<
loop_offset_
;
s
++
)
{
// We have exactly "num_main_groups" CTAs to process each of the remaining heads.
expand_type
{
(
tiles
.
move
(),
0
)...
};
main_steps
=
(
STEPS_PER_HEAD
+
num_main_groups
-
1
)
/
num_main_groups
;
num_main_groups
=
STEPS_PER_HEAD
/
main_steps
;
// Here: main_step > 0
rest_steps
=
STEPS_PER_HEAD
%
main_steps
;
}
else
{
// Ideal number of steps if we could load-balance as evenly as possible.
const
int
steps_ideal
=
(
heads_last_wave
*
STEPS_PER_HEAD
+
total_ctas
-
1
)
/
total_ctas
;
// Iterations that a "rest" CTA has to do at most.
const
int
max_rest_iters
=
(
heads_last_wave
+
rest_ctas
-
1
)
/
rest_ctas
;
// Find the first step distribution, s.t. the maximum work of the "rest" CTAs is less than the work of the main CTAs.
main_steps
=
steps_ideal
;
rest_steps
=
STEPS_PER_HEAD
-
main_steps
*
num_main_groups
;
for
(
;
main_steps
*
num_main_groups
<
STEPS_PER_HEAD
;
main_steps
++
)
{
rest_steps
=
STEPS_PER_HEAD
-
main_steps
*
num_main_groups
;
const
int
max_rest_total_steps
=
rest_steps
*
max_rest_iters
;
if
(
max_rest_total_steps
<
main_steps
)
break
;
}
rest_steps
=
STEPS_PER_HEAD
-
main_steps
*
num_main_groups
;
}
}
}
}
inline
__device__
int
get_idx_dk
()
const
{
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
//return bidc_;
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
return
bidc_
*
2
+
0
;
}
inline
__device__
int
get_idx_dv
()
const
{
//return CHUNKS + bidc_;
return
bidc_
*
2
+
1
;
}
inline
__device__
int
offset_loop_count
(
const
int
l
)
{
const
int
max_steps
=
STEPS_PER_HEAD
*
num_full_heads
+
std
::
max
(
main_steps
,
rest_steps
);
// convert loop counter to position in the outer sequence
const
int
elts_per_thread_per_step
=
Mma_tile_p
::
MMAS_M
*
Mma_tile_p
::
MMAS_N
*
8
;
return
(
loop_offset_
+
l
)
*
STEP
;
const
int
elts_per_thread
=
max_steps
*
elts_per_thread_per_step
;
}
const
int
loop_offset_
;
return
{
num_full_heads
,
num_main_groups
,
heads_last_wave
,
main_steps
,
rest_steps
,
elts_per_thread
};
const
uint32_t
bidc_
;
}
const
int
num_steps_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp
0 → 100644
View file @
96850dfa
#include <torch/torch.h>
#include <vector>
#include <cstdint>
// CUDA forward declarations
std
::
vector
<
at
::
Tensor
>
focal_loss_forward_cuda
(
const
at
::
Tensor
&
cls_output
,
const
at
::
Tensor
&
cls_targets_at_level
,
const
at
::
Tensor
&
num_positives_sum
,
const
int64_t
num_real_classes
,
const
float
alpha
,
const
float
gamma
,
const
float
smoothing_factor
);
at
::
Tensor
focal_loss_backward_cuda
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
partial_grad
,
const
at
::
Tensor
&
num_positives_sum
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std
::
vector
<
at
::
Tensor
>
focal_loss_forward
(
const
at
::
Tensor
&
cls_output
,
const
at
::
Tensor
&
cls_targets_at_level
,
const
at
::
Tensor
&
num_positives_sum
,
const
int64_t
num_real_classes
,
const
float
alpha
,
const
float
gamma
,
const
float
smoothing_factor
)
{
CHECK_INPUT
(
cls_output
);
CHECK_INPUT
(
cls_targets_at_level
);
CHECK_INPUT
(
num_positives_sum
);
return
focal_loss_forward_cuda
(
cls_output
,
cls_targets_at_level
,
num_positives_sum
,
num_real_classes
,
alpha
,
gamma
,
smoothing_factor
);
}
at
::
Tensor
focal_loss_backward
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
partial_grad
,
const
at
::
Tensor
&
num_positives_sum
)
{
CHECK_INPUT
(
grad_output
);
CHECK_INPUT
(
partial_grad
);
return
focal_loss_backward_cuda
(
grad_output
,
partial_grad
,
num_positives_sum
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
focal_loss_forward
,
"Focal loss calculation forward (CUDA)"
);
m
.
def
(
"backward"
,
&
focal_loss_backward
,
"Focal loss calculation backward (CUDA)"
);
}
apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu
0 → 100644
View file @
96850dfa
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#define ASSERT_UINT4_ALIGNED(PTR) \
TORCH_INTERNAL_ASSERT(is_aligned<uint4>(PTR), "Tensor " #PTR " is not uint4 aligned")
template
<
class
T
>
bool
is_aligned
(
const
void
*
ptr
)
noexcept
{
auto
iptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
ptr
);
return
!
(
iptr
%
alignof
(
T
));
}
template
<
bool
SMOOTHING
,
int
ILP
,
typename
scalar_t
,
typename
labelscalar_t
,
typename
accscalar_t
,
typename
outscalar_t
>
__global__
void
focal_loss_forward_cuda_kernel
(
outscalar_t
*
loss
,
scalar_t
*
partial_grad
,
const
scalar_t
*
__restrict__
cls_output
,
const
labelscalar_t
*
__restrict__
cls_targets_at_level
,
const
float
*
__restrict__
num_positives_sum
,
const
int64_t
num_examples
,
const
int64_t
num_classes
,
const
int64_t
num_real_classes
,
const
float
alpha
,
const
float
gamma
,
const
float
smoothing_factor
)
{
extern
__shared__
unsigned
char
shm
[];
accscalar_t
*
loss_shm
=
reinterpret_cast
<
accscalar_t
*>
(
shm
);
loss_shm
[
threadIdx
.
x
]
=
0
;
accscalar_t
loss_acc
=
0
;
accscalar_t
one
=
accscalar_t
(
1.0
);
accscalar_t
K
=
accscalar_t
(
2.0
);
accscalar_t
normalizer
=
one
/
static_cast
<
accscalar_t
>
(
num_positives_sum
[
0
]);
accscalar_t
nn_norm
,
np_norm
,
pn_norm
,
pp_norm
;
// *_norm is used for label smoothing only
if
(
SMOOTHING
)
{
nn_norm
=
one
-
smoothing_factor
/
K
;
np_norm
=
smoothing_factor
/
K
;
pn_norm
=
smoothing_factor
-
smoothing_factor
/
K
;
pp_norm
=
one
-
smoothing_factor
+
smoothing_factor
/
K
;
}
uint4
p_vec
,
grad_vec
;
// Accumulate loss on each thread
for
(
int64_t
i
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
)
*
ILP
;
i
<
num_examples
*
num_classes
;
i
+=
gridDim
.
x
*
blockDim
.
x
*
ILP
)
{
int64_t
idy
=
i
/
num_classes
;
labelscalar_t
y
=
cls_targets_at_level
[
idy
];
int64_t
base_yid
=
i
%
num_classes
;
int64_t
pos_idx
=
idy
*
num_classes
+
y
;
p_vec
=
*
(
uint4
*
)
&
cls_output
[
i
];
// Skip ignored matches
if
(
y
==
-
2
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
ILP
;
j
++
)
{
*
((
scalar_t
*
)(
&
grad_vec
)
+
j
)
=
0
;
}
*
(
uint4
*
)
&
partial_grad
[
i
]
=
grad_vec
;
continue
;
}
#pragma unroll
for
(
int
j
=
0
;
j
<
ILP
;
j
++
)
{
// Skip the pad classes
if
(
base_yid
+
j
>=
num_real_classes
)
{
*
((
scalar_t
*
)(
&
grad_vec
)
+
j
)
=
0
;
continue
;
}
accscalar_t
p
=
static_cast
<
accscalar_t
>
(
*
((
scalar_t
*
)(
&
p_vec
)
+
j
));
accscalar_t
exp_np
=
::
exp
(
-
p
);
accscalar_t
exp_pp
=
::
exp
(
p
);
accscalar_t
sigma
=
one
/
(
one
+
exp_np
);
accscalar_t
logee
=
(
p
>=
0
)
?
exp_np
:
exp_pp
;
accscalar_t
addee
=
(
p
>=
0
)
?
0
:
-
p
;
accscalar_t
off_a
=
addee
+
::
log
(
one
+
logee
);
// Negative matches
accscalar_t
base
=
SMOOTHING
?
nn_norm
*
p
:
p
;
accscalar_t
off_b
=
(
SMOOTHING
?
np_norm
:
0
)
-
sigma
;
accscalar_t
coeff_f1
=
one
-
alpha
;
accscalar_t
coeff_f2
=
sigma
;
accscalar_t
coeff_b1
=
gamma
;
accscalar_t
coeff_b2
=
one
-
sigma
;
// Positive matches
if
(
y
>=
0
&&
(
i
+
j
==
pos_idx
))
{
base
=
SMOOTHING
?
pn_norm
*
p
:
0
;
off_b
=
(
SMOOTHING
?
pp_norm
:
one
)
-
sigma
;
coeff_f1
=
alpha
;
coeff_f2
=
one
-
sigma
;
coeff_b1
=
-
gamma
;
coeff_b2
=
sigma
;
}
accscalar_t
coeff_f
=
coeff_f1
*
::
pow
(
coeff_f2
,
gamma
);
accscalar_t
coeff_b
=
coeff_b1
*
coeff_b2
;
accscalar_t
loss_t
=
coeff_f
*
(
base
+
off_a
);
accscalar_t
grad
=
coeff_f
*
(
coeff_b
*
(
base
+
off_a
)
-
off_b
);
// Delay the normalize of partial gradient by num_positives_sum to back
// propagation because scalar_t reduces precision. Focal loss is very
// sensitive to the small gradient. No worry on overflow here since
// gradient has relative smaller range than input.
loss_acc
+=
loss_t
;
*
((
scalar_t
*
)(
&
grad_vec
)
+
j
)
=
static_cast
<
scalar_t
>
(
grad
);
}
// This can't ensure to generate stg.128 and may be two stg.64.
*
(
uint4
*
)
&
partial_grad
[
i
]
=
grad_vec
;
}
loss_shm
[
threadIdx
.
x
]
=
loss_acc
;
// Intra-CTA reduction
__syncthreads
();
for
(
unsigned
int
s
=
blockDim
.
x
/
2
;
s
>
0
;
s
>>=
1
)
{
if
(
threadIdx
.
x
<
s
)
{
loss_shm
[
threadIdx
.
x
]
+=
loss_shm
[
threadIdx
.
x
+
s
];
}
__syncthreads
();
}
// Inter-CTA reduction
if
(
threadIdx
.
x
==
0
)
{
loss_acc
=
loss_shm
[
0
]
*
normalizer
;
atomicAdd
(
loss
,
loss_acc
);
}
}
template
<
int
ILP
,
typename
scalar_t
,
typename
accscalar_t
,
typename
outscalar_t
>
__global__
void
focal_loss_backward_cuda_kernel
(
scalar_t
*
partial_grad
,
const
outscalar_t
*
__restrict__
grad_output
,
const
float
*
__restrict__
num_positives_sum
,
const
uint64_t
numel
)
{
int64_t
idx
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
)
*
ILP
;
accscalar_t
normalizer
=
static_cast
<
accscalar_t
>
(
grad_output
[
0
])
/
static_cast
<
accscalar_t
>
(
num_positives_sum
[
0
]);
// The input is enforced to pad to use vector load, thus there's no need to
// check whether the last element of ILP can out of bound.
if
(
idx
>=
numel
)
return
;
uint4
grad_vec
;
grad_vec
=
*
(
uint4
*
)
&
partial_grad
[
idx
];
#pragma unroll(ILP)
for
(
int
i
=
0
;
i
<
ILP
;
i
++
)
{
auto
grad
=
static_cast
<
accscalar_t
>
(
*
((
scalar_t
*
)(
&
grad_vec
)
+
i
));
grad
*=
normalizer
;
*
((
scalar_t
*
)(
&
grad_vec
)
+
i
)
=
static_cast
<
scalar_t
>
(
grad
);
}
*
(
uint4
*
)
&
partial_grad
[
idx
]
=
grad_vec
;
}
std
::
vector
<
at
::
Tensor
>
focal_loss_forward_cuda
(
const
at
::
Tensor
&
cls_output
,
const
at
::
Tensor
&
cls_targets_at_level
,
const
at
::
Tensor
&
num_positives_sum
,
const
int64_t
num_real_classes
,
const
float
alpha
,
const
float
gamma
,
const
float
smoothing_factor
)
{
// Checks required for correctness
TORCH_INTERNAL_ASSERT
(
cls_output
.
size
(
-
1
)
>=
num_real_classes
,
"Incorrect number of real classes."
);
TORCH_INTERNAL_ASSERT
(
cls_targets_at_level
.
scalar_type
()
==
at
::
kLong
,
"Invalid label type."
);
TORCH_INTERNAL_ASSERT
(
(
num_positives_sum
.
numel
()
==
1
)
&&
(
num_positives_sum
.
scalar_type
()
==
at
::
kFloat
),
"Expect num_positives_sum to be a float32 tensor with only one element."
);
TORCH_INTERNAL_ASSERT
(
cls_output
.
dim
()
==
cls_targets_at_level
.
dim
()
+
1
,
"Mis-matched dimensions between class output and label."
);
for
(
int64_t
i
=
0
;
i
<
cls_targets_at_level
.
dim
();
i
++
)
TORCH_INTERNAL_ASSERT
(
cls_output
.
size
(
i
)
==
cls_targets_at_level
.
size
(
i
),
"Mis-matched shape between class output and label."
);
// Checks required for better performance
const
int
ILP
=
sizeof
(
uint4
)
/
cls_output
.
element_size
();
ASSERT_UINT4_ALIGNED
(
cls_output
.
data_ptr
());
TORCH_INTERNAL_ASSERT
(
cls_output
.
size
(
-
1
)
%
ILP
==
0
,
"Pad number of classes first to take advantage of 128 bit load."
);
TORCH_INTERNAL_ASSERT
(
num_real_classes
>=
ILP
,
"Too few classes."
);
int64_t
num_classes
=
cls_output
.
size
(
-
1
);
int64_t
num_examples
=
cls_output
.
numel
()
/
num_classes
;
at
::
Tensor
loss
=
at
::
zeros
({},
cls_output
.
options
().
dtype
(
at
::
kFloat
));
// Compute the incompelete gradient during fprop since most of the heavy
// functions of bprop are the same as fprop, thus trade memory for compute
// helps with focal loss.
at
::
Tensor
partial_grad
=
at
::
empty_like
(
cls_output
);
// The grid contains 2 CTA per SM, each CTA loop on input with stride till the
// last item.
cudaDeviceProp
props
;
cudaGetDeviceProperties
(
&
props
,
at
::
cuda
::
current_device
());
dim3
block
(
512
);
dim3
grid
(
2
*
props
.
multiProcessorCount
);
// Specialize on label smoothing or not to reduce redundant operations
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
smoothing_factor
==
0.0
f
)
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
cls_output
.
scalar_type
(),
"focal_loss_fprop"
,
[
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
labelscalar_t
=
int64_t
;
using
outscalar_t
=
float
;
const
int
ILP
=
sizeof
(
uint4
)
/
sizeof
(
scalar_t
);
focal_loss_forward_cuda_kernel
<
false
,
ILP
,
scalar_t
,
labelscalar_t
,
accscalar_t
,
outscalar_t
>
<<<
grid
,
block
,
block
.
x
*
sizeof
(
accscalar_t
),
stream
>>>
(
loss
.
data_ptr
<
outscalar_t
>
(),
partial_grad
.
data_ptr
<
scalar_t
>
(),
cls_output
.
data_ptr
<
scalar_t
>
(),
cls_targets_at_level
.
data_ptr
<
labelscalar_t
>
(),
num_positives_sum
.
data_ptr
<
float
>
(),
num_examples
,
num_classes
,
num_real_classes
,
alpha
,
gamma
,
smoothing_factor
);
});
}
else
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
cls_output
.
scalar_type
(),
"focal_loss_fprop"
,
[
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
labelscalar_t
=
int64_t
;
using
outscalar_t
=
float
;
const
int
ILP
=
sizeof
(
uint4
)
/
sizeof
(
scalar_t
);
focal_loss_forward_cuda_kernel
<
true
,
ILP
,
scalar_t
,
labelscalar_t
,
accscalar_t
,
outscalar_t
>
<<<
grid
,
block
,
block
.
x
*
sizeof
(
accscalar_t
),
stream
>>>
(
loss
.
data_ptr
<
outscalar_t
>
(),
partial_grad
.
data_ptr
<
scalar_t
>
(),
cls_output
.
data_ptr
<
scalar_t
>
(),
cls_targets_at_level
.
data_ptr
<
labelscalar_t
>
(),
num_positives_sum
.
data_ptr
<
float
>
(),
num_examples
,
num_classes
,
num_real_classes
,
alpha
,
gamma
,
smoothing_factor
);
});
}
AT_CUDA_CHECK
(
cudaGetLastError
());
return
{
loss
,
partial_grad
};
}
at
::
Tensor
focal_loss_backward_cuda
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
partial_grad
,
const
at
::
Tensor
&
num_positives_sum
)
{
// Each thread process ILP elements
const
int
ILP
=
sizeof
(
uint4
)
/
partial_grad
.
element_size
();
dim3
block
(
512
);
dim3
grid
((
partial_grad
.
numel
()
+
block
.
x
*
ILP
-
1
)
/
(
block
.
x
*
ILP
));
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
partial_grad
.
scalar_type
(),
"focal_loss_bprop"
,
[
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
outscalar_t
=
float
;
const
int
ILP
=
sizeof
(
uint4
)
/
sizeof
(
scalar_t
);
focal_loss_backward_cuda_kernel
<
ILP
,
scalar_t
,
accscalar_t
,
outscalar_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
partial_grad
.
data_ptr
<
scalar_t
>
(),
grad_output
.
data_ptr
<
outscalar_t
>
(),
num_positives_sum
.
data_ptr
<
float
>
(),
partial_grad
.
numel
());
});
AT_CUDA_CHECK
(
cudaGetLastError
());
return
partial_grad
;
}
Prev
1
2
3
4
5
6
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment