Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
ed553e92
Commit
ed553e92
authored
Oct 23, 2022
by
Tri Dao
Browse files
Add Megatron attention implementation for benchmarking
parent
50ca2348
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1730 additions
and
3 deletions
+1730
-3
benchmarks/benchmark_causal.py
benchmarks/benchmark_causal.py
+32
-3
csrc/fused_softmax/fused_softmax.cpp
csrc/fused_softmax/fused_softmax.cpp
+148
-0
csrc/fused_softmax/scaled_masked_softmax.h
csrc/fused_softmax/scaled_masked_softmax.h
+528
-0
csrc/fused_softmax/scaled_masked_softmax_cuda.cu
csrc/fused_softmax/scaled_masked_softmax_cuda.cu
+121
-0
csrc/fused_softmax/scaled_upper_triang_masked_softmax.h
csrc/fused_softmax/scaled_upper_triang_masked_softmax.h
+529
-0
csrc/fused_softmax/scaled_upper_triang_masked_softmax_cuda.cu
.../fused_softmax/scaled_upper_triang_masked_softmax_cuda.cu
+98
-0
csrc/fused_softmax/setup.py
csrc/fused_softmax/setup.py
+49
-0
csrc/fused_softmax/type_shim.h
csrc/fused_softmax/type_shim.h
+20
-0
flash_attn/fused_softmax.py
flash_attn/fused_softmax.py
+205
-0
No files found.
benchmarks/benchmark_causal.py
View file @
ed553e92
...
...
@@ -6,12 +6,17 @@ import torch.nn.functional as F
from
einops
import
rearrange
,
repeat
from
flash_attn.utils.benchmark
import
benchmark_all
,
benchmark_forward
,
benchmark_backward
,
benchmark_combined
from
flash_attn.utils.benchmark
import
benchmark_all
,
pytorch_profiler
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_qkvpacked_func
from
flash_attn.triton.fused_attention
import
attention
as
attention
try
:
from
flash_attn.fused_softmax
import
scaled_upper_triang_masked_softmax
except
ImportError
:
scaled_upper_triang_masked_softmax
=
None
def
attention_pytorch
(
qkv
,
dropout_p
=
0.0
,
causal
=
False
):
def
attention_pytorch
(
qkv
,
dropout_p
=
0.0
,
causal
=
True
):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, head_dim)
...
...
@@ -53,10 +58,31 @@ def attention_triton(q, k, v):
return
attention
(
q
,
k
,
v
,
softmax_scale
)
def
attention_megatron
(
qkv
):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, head_dim)
Output:
output: (batch_size, seqlen, nheads, head_dim)
"""
batch_size
,
seqlen
,
_
,
nheads
,
d
=
qkv
.
shape
q
,
k
,
v
=
qkv
.
unbind
(
dim
=
2
)
q
=
rearrange
(
q
,
'b t h d -> (b h) t d'
)
k
=
rearrange
(
k
,
'b s h d -> (b h) d s'
)
softmax_scale
=
1.0
/
math
.
sqrt
(
d
)
# Preallocate attn_weights for `baddbmm`
scores
=
torch
.
empty
(
batch_size
*
nheads
,
seqlen
,
seqlen
,
dtype
=
qkv
.
dtype
,
device
=
qkv
.
device
)
scores
=
rearrange
(
torch
.
baddbmm
(
scores
,
q
,
k
,
beta
=
0
,
alpha
=
softmax_scale
),
'(b h) t s -> b h t s'
,
h
=
nheads
)
attention
=
scaled_upper_triang_masked_softmax
(
scores
,
None
,
scale
=
1.0
)
output
=
torch
.
einsum
(
'bhts,bshd->bthd'
,
attention
,
v
)
return
output
.
to
(
dtype
=
qkv
.
dtype
)
torch
.
manual_seed
(
0
)
repeats
=
30
batch_size
=
2
seqlen
=
2048
seqlen
=
4096
nheads
=
12
headdim
=
128
dropout_p
=
0.0
...
...
@@ -77,3 +103,6 @@ benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal,
q
,
k
,
v
=
[
torch
.
randn
(
batch_size
,
nheads
,
seqlen
,
headdim
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
for
_
in
range
(
3
)]
benchmark_all
(
attention_triton
,
q
,
k
,
v
,
repeats
=
repeats
,
desc
=
'FlashAttention Triton'
)
if
scaled_upper_triang_masked_softmax
is
not
None
:
benchmark_all
(
attention_megatron
,
qkv
,
repeats
=
repeats
,
desc
=
'Megatron Attention'
)
csrc/fused_softmax/fused_softmax.cpp
0 → 100644
View file @
ed553e92
/* coding=utf-8
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_masked_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
float
scale_factor
);
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
);
int
get_batch_per_block_cuda
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
);
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
4
,
"expected 4D tensor"
);
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
(
mask
.
dim
()
==
4
,
"expected 4D tensor"
);
return
fwd_cuda
(
input
,
mask
,
scale_factor
);
}
torch
::
Tensor
bwd
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
((
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
((
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
int
get_batch_per_block
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
)
{
return
get_batch_per_block_cuda
(
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
}
}
// end namespace scaled_masked_softmax
}
// end namespace fused_softmax
}
// end namespace multihead_attn
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_upper_triang_masked_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
);
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
);
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
fwd_cuda
(
input
,
scale_factor
);
}
torch
::
Tensor
bwd
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
((
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
((
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
}
// end namespace scaled_upper_triang_masked_softmax
}
// end namespace fused_softmax
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"scaled_masked_softmax_forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
fwd
,
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
m
.
def
(
"scaled_masked_softmax_backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
bwd
,
"Self Multihead Attention scaled, time masked softmax -- Backward."
);
m
.
def
(
"scaled_masked_softmax_get_batch_per_block"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
get_batch_per_block
,
"Return Batch per block size."
);
m
.
def
(
"scaled_upper_triang_masked_softmax_forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
fwd
,
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
m
.
def
(
"scaled_upper_triang_masked_softmax_backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
bwd
,
"Self Multihead Attention scaled, time masked softmax -- Backward."
);
}
csrc/fused_softmax/scaled_masked_softmax.h
0 → 100644
View file @
ed553e92
/* coding=utf-8
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h>
namespace
{
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_vector
(
Datatype
*
dst
,
const
Datatype
*
src
);
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
1
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
4
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
1
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
4
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
1
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
4
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
((
half2
*
)
dst
)
=
*
((
half2
*
)
src
);
}
int
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
return
log2_value
;
}
template
<
typename
T
>
struct
Add
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
Max
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
<
b
?
b
:
a
;
}
};
template
<
typename
T
>
__device__
__forceinline__
T
WARP_SHFL_XOR_NATIVE
(
T
value
,
int
laneMask
,
int
width
=
warpSize
,
unsigned
int
mask
=
0xffffffff
)
{
#if CUDA_VERSION >= 9000
return
__shfl_xor_sync
(
mask
,
value
,
laneMask
,
width
);
#else
return
__shfl_xor
(
value
,
laneMask
,
width
);
#endif
}
template
<
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_SIZE
,
template
<
typename
>
class
ReduceOp
>
__device__
__forceinline__
void
warp_reduce
(
acc_t
*
sum
)
{
ReduceOp
<
acc_t
>
r
;
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
acc_t
b
=
WARP_SHFL_XOR_NATIVE
(
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
=
r
(
sum
[
i
],
b
);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Explicit masking
*/
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_masked_softmax_warp_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
uint8_t
*
mask
,
const
acc_t
scale
,
int
micro_batch_size
,
int
element_count
,
int
pad_batches
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int
first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
(
blockIdx
.
y
+
gridDim
.
y
*
blockIdx
.
z
))
+
threadIdx
.
y
)
*
WARP_BATCH
;
int
pad_first_batch
=
0
;
if
(
pad_batches
!=
1
)
{
// bert style
pad_first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
blockIdx
.
z
)
+
threadIdx
.
y
)
*
WARP_BATCH
;
}
else
{
// gpt2 style
pad_first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
}
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
micro_batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
dst
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
mask
+=
pad_first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
input_t
temp_data
[
ELEMENTS_PER_LDG_STG
];
uint8_t
temp_mask
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
int
itr_idx
=
i
*
element_count
+
it
*
WARP_SIZE
;
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_data
,
src
+
itr_idx
);
copy_vector
<
uint8_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_mask
,
mask
+
itr_idx
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
temp_mask
[
element
]
!=
1
)
{
elements
[
i
][
it
+
element
]
=
(
acc_t
)
temp_data
[
element
]
*
scale
;
}
else
{
elements
[
i
][
it
+
element
]
=
-
10000.0
;
}
}
}
else
{
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
}
// compute max_value
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Max
>
(
max_value
);
// compute scale value to account for full mask
acc_t
scale_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
scale_value
[
i
]
=
(
max_value
[
i
]
==
-
10000.0
)
?
0.0
:
1.0
;
}
acc_t
sum
[
WARP_BATCH
]
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
elements
[
i
][
it
]
=
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
]));
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
*
scale_value
[
i
]
/
sum
[
i
];
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
else
{
break
;
}
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_masked_softmax_warp_backward
(
output_t
*
gradInput
,
input_t
*
grad
,
const
input_t
*
output
,
acc_t
scale
,
int
micro_batch_size
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
micro_batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
input_t
temp_grad
[
ELEMENTS_PER_LDG_STG
];
input_t
temp_output
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_grad
,
grad
+
i
*
element_count
+
it
*
WARP_SIZE
);
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_output
,
output
+
i
*
element_count
+
it
*
WARP_SIZE
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
output_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_output
[
element
];
}
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
grad_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_grad
[
element
]
*
output_reg
[
i
][
it
+
element
];
}
}
}
}
acc_t
sum
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
=
grad_reg
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
sum
[
i
]
+=
grad_reg
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// compute gradients
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
+
element
]
-
output_reg
[
i
][
it
+
element
]
*
sum
[
i
]));
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
gradInput
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
}
}
}
}
// end of anonymous namespace
int
get_batch_per_block
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
){
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
return
batches_per_block
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_masked_softmax_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
uint8_t
*
mask
,
const
input_t
scale
,
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
,
int
pad_batches
)
{
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
8192
);
if
(
key_seq_len
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
batch_count
=
batches
*
attn_heads
*
query_seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
query_seq_len
%
batches_per_block
==
0
);
dim3
blocks
(
query_seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
1
:
// 2
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
2
:
// 4
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
3
:
// 8
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
4
:
// 16
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
5
:
// 32
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
6
:
// 64
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
7
:
// 128
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
8
:
// 256
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
9
:
// 512
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
10
:
// 1024
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
11
:
// 2048
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
12
:
// 4096
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
12
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
13
:
// 8192
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
13
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
default:
break
;
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_masked_softmax_backward
(
output_t
*
grad_input
,
input_t
*
grad
,
const
input_t
*
output
,
const
acc_t
scale
,
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
)
{
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
8192
);
if
(
key_seq_len
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
batch_count
=
batches
*
attn_heads
*
query_seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
batch_count
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
1
:
// 2
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
2
:
// 4
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
3
:
// 8
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
4
:
// 16
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
5
:
// 32
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
6
:
// 64
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
7
:
// 128
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
8
:
// 256
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
9
:
// 512
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
10
:
// 1024
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
11
:
// 2048
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
12
:
// 4096
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
12
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
13
:
// 8192
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
13
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
default:
break
;
}
}
}
csrc/fused_softmax/scaled_masked_softmax_cuda.cu
0 → 100644
View file @
ed553e92
/* coding=utf-8
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_masked_softmax
{
int
get_batch_per_block_cuda
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
){
return
get_batch_per_block
(
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
}
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
float
scale_factor
)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
input
.
size
(
0
);
const
int
pad_batches
=
mask
.
size
(
0
);
const
int
attn_heads
=
input
.
size
(
1
);
const
int
query_seq_len
=
input
.
size
(
2
);
const
int
key_seq_len
=
input
.
size
(
3
);
TORCH_INTERNAL_ASSERT
(
key_seq_len
<=
8192
);
TORCH_INTERNAL_ASSERT
(
query_seq_len
>
1
);
TORCH_INTERNAL_ASSERT
(
pad_batches
==
1
||
pad_batches
==
batches
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
1
)
==
1
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
2
)
==
query_seq_len
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
3
)
==
key_seq_len
);
// Output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
batches
,
attn_heads
,
query_seq_len
,
key_seq_len
},
act_options
);
// Softmax Intermediate Result Ptr
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
mask_ptr
=
static_cast
<
void
*>
(
mask
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
DISPATCH_HALF_AND_BFLOAT
(
input
.
scalar_type
(),
"dispatch_scaled_masked_softmax_forward"
,
dispatch_scaled_masked_softmax_forward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
reinterpret_cast
<
const
uint8_t
*>
(
mask_ptr
),
scale_factor
,
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
,
pad_batches
);
);
return
softmax_results
;
}
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads_
,
torch
::
Tensor
const
&
softmax_results_
,
float
scale_factor
)
{
auto
output_grads
=
output_grads_
.
contiguous
();
auto
softmax_results
=
softmax_results_
.
contiguous
();
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
output_grads
.
size
(
0
);
const
int
attn_heads
=
output_grads
.
size
(
1
);
const
int
query_seq_len
=
output_grads
.
size
(
2
);
const
int
key_seq_len
=
output_grads
.
size
(
3
);
auto
act_options
=
output_grads
.
options
().
requires_grad
(
false
);
torch
::
Tensor
input_grads
=
torch
::
empty
({
batches
,
attn_heads
,
query_seq_len
,
key_seq_len
},
act_options
);
void
*
input_grads_ptr
=
static_cast
<
void
*>
(
input_grads
.
data_ptr
());
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT
(
output_grads_
.
scalar_type
(),
"dispatch_scaled_masked_softmax_backward"
,
dispatch_scaled_masked_softmax_backward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
input_grads_ptr
),
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
);
return
input_grads
;
}
}
}
}
csrc/fused_softmax/scaled_upper_triang_masked_softmax.h
0 → 100644
View file @
ed553e92
/* coding=utf-8
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <c10/macros/Macros.h>
namespace
{
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_vector
(
Datatype
*
dst
,
const
Datatype
*
src
);
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
1
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
4
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
1
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
4
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
1
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
4
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
((
half2
*
)
dst
)
=
*
((
half2
*
)
src
);
}
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_zero_vector
(
Datatype
*
dst
);
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
BFloat16
,
1
>
(
c10
::
BFloat16
*
dst
)
{
*
dst
=
0.0
;
}
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
BFloat16
,
4
>
(
c10
::
BFloat16
*
dst
)
{
*
((
float2
*
)
dst
)
=
make_float2
(
0.0
f
,
0.0
f
);
}
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
Half
,
1
>
(
c10
::
Half
*
dst
)
{
*
dst
=
0.0
;
}
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
Half
,
4
>
(
c10
::
Half
*
dst
)
{
*
((
float2
*
)
dst
)
=
make_float2
(
0.0
f
,
0.0
f
);
}
int
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
return
log2_value
;
}
template
<
typename
T
>
struct
Add
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
Max
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
<
b
?
b
:
a
;
}
};
template
<
typename
T
>
__device__
__forceinline__
T
WARP_SHFL_XOR_NATIVE
(
T
value
,
int
laneMask
,
int
width
=
warpSize
,
unsigned
int
mask
=
0xffffffff
)
{
#if CUDA_VERSION >= 9000
return
__shfl_xor_sync
(
mask
,
value
,
laneMask
,
width
);
#else
return
__shfl_xor
(
value
,
laneMask
,
width
);
#endif
}
template
<
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_SIZE
,
template
<
typename
>
class
ReduceOp
>
__device__
__forceinline__
void
warp_reduce
(
acc_t
*
sum
)
{
ReduceOp
<
acc_t
>
r
;
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
acc_t
b
=
WARP_SHFL_XOR_NATIVE
(
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
=
r
(
sum
[
i
],
b
);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Implicit time (diagonal masking)
*/
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_upper_triang_masked_softmax_warp_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
acc_t
scale
,
int
micro_batch_size
,
int
stride
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
int
warp_iteration_limit
=
(
local_seq
+
ELEMENTS_PER_LDG_STG
*
WARP_SIZE
-
1
)
/
WARP_SIZE
;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
micro_batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
dst
+=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
input_t
temp_data
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_data
,
src
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
((
element_index
+
element
)
<
batch_element_count
)
{
elements
[
i
][
it
+
element
]
=
(
acc_t
)
temp_data
[
element
]
*
scale
;
}
else
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
else
{
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
}
// compute max_value
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Max
>
(
max_value
);
acc_t
sum
[
WARP_BATCH
]
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
if
(
it
<
warp_iteration_limit
)
{
elements
[
i
][
it
]
=
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
]));
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
local_seq
)
{
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
element_index
+
element
<
local_seq
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
}
else
{
out
[
element
]
=
0
;
}
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
,
out
);
}
else
if
(
element_index
<
element_count
)
{
copy_zero_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
}
else
{
break
;
}
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_upper_triang_masked_softmax_warp_backward
(
output_t
*
gradInput
,
input_t
*
grad
,
const
input_t
*
output
,
acc_t
scale
,
int
micro_batch_size
,
int
stride
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
micro_batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
input_t
temp_grad
[
ELEMENTS_PER_LDG_STG
];
input_t
temp_output
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_grad
,
grad
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_output
,
output
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
element_index
+
element
<
batch_element_count
)
{
output_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_output
[
element
];
}
}
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
element_index
+
element
<
batch_element_count
)
{
grad_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_grad
[
element
]
*
output_reg
[
i
][
it
+
element
];
}
}
}
}
}
acc_t
sum
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
=
grad_reg
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
sum
[
i
]
+=
grad_reg
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// compute gradients
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
+
element
]
-
output_reg
[
i
][
it
+
element
]
*
sum
[
i
]));
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
gradInput
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
,
out
);
}
}
}
}
}
// end of anonymous namespace
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_upper_triang_masked_softmax_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
input_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
attn_batches
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
8192
);
if
(
softmax_elements
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
softmax_elements
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
seq_len
=
softmax_elements
;
int
batch_count
=
attn_batches
*
seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
1
:
// 2
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
2
:
// 4
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
3
:
// 8
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
4
:
// 16
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
5
:
// 32
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
6
:
// 64
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
7
:
// 128
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
8
:
// 256
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
9
:
// 512
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
10
:
// 1024
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
11
:
// 2048
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
12
:
// 4096
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
12
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
13
:
// 8192
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
13
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
default:
break
;
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_upper_triang_masked_softmax_backward
(
output_t
*
grad_input
,
input_t
*
grad
,
const
input_t
*
output
,
const
acc_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
attn_batches
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
8192
);
if
(
softmax_elements
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
softmax_elements
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
seq_len
=
softmax_elements
;
int
batch_count
=
attn_batches
*
seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
1
:
// 2
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
2
:
// 4
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
3
:
// 8
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
4
:
// 16
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
5
:
// 32
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
6
:
// 64
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
7
:
// 128
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
8
:
// 256
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
9
:
// 512
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
10
:
// 1024
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
11
:
// 2048
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
12
:
// 4096
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
12
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
13
:
// 8192
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
13
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
default:
break
;
}
}
}
csrc/fused_softmax/scaled_upper_triang_masked_softmax_cuda.cu
0 → 100644
View file @
ed553e92
/* coding=utf-8
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h"
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_upper_triang_masked_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const
int
attn_batches
=
input
.
size
(
0
);
const
int
seq_len
=
input
.
size
(
1
);
TORCH_INTERNAL_ASSERT
(
seq_len
<=
8192
);
// Output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
attn_batches
,
seq_len
,
seq_len
},
act_options
);
// Softmax Intermediate Result Ptr
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
DISPATCH_HALF_AND_BFLOAT
(
input
.
scalar_type
(),
"dispatch_scaled_upper_triang_masked_softmax_forward"
,
dispatch_scaled_upper_triang_masked_softmax_forward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
scale_factor
,
seq_len
,
seq_len
,
attn_batches
);
);
return
softmax_results
;
}
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads_
,
torch
::
Tensor
const
&
softmax_results_
,
float
scale_factor
)
{
auto
output_grads
=
output_grads_
.
contiguous
();
auto
softmax_results
=
softmax_results_
.
contiguous
();
//output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const
int
attn_batches
=
output_grads
.
size
(
0
);
const
int
seq_len
=
output_grads
.
size
(
1
);
TORCH_INTERNAL_ASSERT
(
output_grads
.
size
(
1
)
==
output_grads
.
size
(
2
));
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT
(
output_grads_
.
scalar_type
(),
"dispatch_scaled_upper_triang_masked_softmax_backward"
,
dispatch_scaled_upper_triang_masked_softmax_backward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
seq_len
,
seq_len
,
attn_batches
);
);
//backward pass is completely in-place
return
output_grads
;
}
}
}
}
csrc/fused_softmax/setup.py
0 → 100644
View file @
ed553e92
# Copied from https://github.com/NVIDIA/apex/tree/master/csrc/megatron
# We add the case where seqlen = 4k and seqlen = 8k
import
os
import
subprocess
import
torch
from
setuptools
import
setup
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
,
CUDA_HOME
def
get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
return
raw_output
,
bare_metal_major
,
bare_metal_minor
def
append_nvcc_threads
(
nvcc_extra_args
):
_
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
and
int
(
bare_metal_minor
)
>=
2
:
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
return
nvcc_extra_args
cc_flag
=
[]
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_70,code=sm_70"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_80,code=sm_80"
)
setup
(
name
=
'fused_softmax_lib'
,
ext_modules
=
[
CUDAExtension
(
name
=
'fused_softmax_lib'
,
sources
=
[
'fused_softmax.cpp'
,
'scaled_masked_softmax_cuda.cu'
,
'scaled_upper_triang_masked_softmax_cuda.cu'
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,],
'nvcc'
:
append_nvcc_threads
([
'-O3'
,
'--use_fast_math'
]
+
cc_flag
)
}
)
],
cmdclass
=
{
'build_ext'
:
BuildExtension
})
csrc/fused_softmax/type_shim.h
0 → 100644
View file @
ed553e92
#include <ATen/ATen.h>
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
flash_attn/fused_softmax.py
0 → 100644
View file @
ed553e92
# [2022-10-23] Copied from https://github.com/NVIDIA/apex/blob/master/apex/transformer/functional/fused_softmax.py
# for benchmarking.
# We added support for seqlen=2k and seqlen=4k
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
apex._autocast_utils
import
_cast_if_autocast_enabled
from
apex.transformer.enums
import
AttnMaskType
from
fused_softmax_lib
import
scaled_masked_softmax_forward
,
scaled_masked_softmax_backward
from
fused_softmax_lib
import
scaled_masked_softmax_get_batch_per_block
from
fused_softmax_lib
import
scaled_upper_triang_masked_softmax_forward
,
scaled_upper_triang_masked_softmax_backward
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
scale
):
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_upper_triang_masked_softmax_forward
(
inputs
,
scale_t
[
0
]
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_upper_triang_masked_softmax_backward
(
output_grads
,
softmax_results
,
scale_t
[
0
]
)
return
input_grads
,
None
def
scaled_upper_triang_masked_softmax
(
inputs
,
_
,
scale
):
b
,
np
,
sq
,
sk
=
inputs
.
size
()
assert
sq
==
sk
,
"causal mask is only for self attention"
# Reshaping input to 3D tensor (attn_batches, sq, sk)
inputs
=
inputs
.
view
(
-
1
,
sq
,
sk
)
args
=
_cast_if_autocast_enabled
(
inputs
,
scale
)
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
*
args
)
return
probs
.
view
(
b
,
np
,
sq
,
sk
)
# NOTE (mkozuki): `ScaledMaskedSoftmax` somehow doesn't work well with `torch.cuda.amp.custom_fwd`.
# Without `cast_inputs` kwarg, somehow inputs are not cast to dtype used in the autocast context.
# So I needed to manually write two `torch.autograd.Function` inheritances.
# Fused operation which performs following three operations in sequence
# 1. Scale the tensor.
# 2. Apply the mask.
# 3. Perform softmax.
class
ScaledMaskedSoftmax
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
inputs
,
mask
,
scale
):
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_masked_softmax_forward
(
inputs
,
mask
,
scale_t
[
0
])
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_masked_softmax_backward
(
output_grads
,
softmax_results
,
scale_t
[
0
]
)
return
input_grads
,
None
,
None
def
scaled_masked_softmax
(
inputs
,
mask
,
scale
):
# input is 4D tensor (b, np, sq, sk)
args
=
_cast_if_autocast_enabled
(
inputs
,
mask
,
scale
)
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
return
ScaledMaskedSoftmax
.
apply
(
*
args
)
class
FusedScaleMaskSoftmax
(
torch
.
nn
.
Module
):
"""
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal)
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def
__init__
(
self
,
input_in_fp16
,
input_in_bf16
,
attn_mask_type
,
scaled_masked_softmax_fusion
,
mask_func
,
softmax_in_fp32
,
scale
,
):
super
().
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_bf16
=
input_in_bf16
if
self
.
input_in_fp16
and
self
.
input_in_bf16
:
raise
RuntimeError
(
"both fp16 and bf16 flags cannot be active at the same time."
)
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
attn_mask_type
=
attn_mask_type
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
self
.
mask_func
=
mask_func
self
.
softmax_in_fp32
=
softmax_in_fp32
self
.
scale
=
scale
if
not
(
self
.
scale
is
None
or
softmax_in_fp32
):
raise
RuntimeError
(
"softmax should be in fp32 when scaled"
)
if
self
.
scaled_masked_softmax_fusion
:
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
self
.
fused_softmax_func
=
scaled_upper_triang_masked_softmax
elif
self
.
attn_mask_type
==
AttnMaskType
.
padding
:
self
.
fused_softmax_func
=
scaled_masked_softmax
else
:
raise
ValueError
(
"Invalid attn_mask_type."
)
def
forward
(
self
,
input
,
mask
):
# [b, np, sq, sk]
assert
input
.
dim
()
==
4
if
self
.
is_kernel_available
(
mask
,
*
input
.
size
()):
return
self
.
forward_fused_softmax
(
input
,
mask
)
else
:
return
self
.
forward_torch_softmax
(
input
,
mask
)
def
is_kernel_available
(
self
,
mask
,
b
,
np
,
sq
,
sk
):
attn_batches
=
b
*
np
if
(
self
.
scaled_masked_softmax_fusion
# user want to fuse
and
self
.
input_in_float16
# input must be fp16
and
(
self
.
attn_mask_type
==
AttnMaskType
.
causal
or
(
self
.
attn_mask_type
==
AttnMaskType
.
padding
and
mask
is
not
None
)
)
and
16
<
sk
<=
8192
# sk must be 16 ~ 8192
and
sq
%
4
==
0
# sq must be divisor of 4
and
sk
%
4
==
0
# sk must be divisor of 4
and
attn_batches
%
4
==
0
# np * b must be divisor of 4
):
if
0
<=
sk
<=
8192
:
batch_per_block
=
self
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
if
attn_batches
%
batch_per_block
==
0
:
return
True
else
:
if
sq
%
batch_per_block
==
0
:
return
True
return
False
def
forward_fused_softmax
(
self
,
input
,
mask
):
# input.shape = [b, np, sq, sk]
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
return
self
.
fused_softmax_func
(
input
,
mask
,
scale
)
def
forward_torch_softmax
(
self
,
input
,
mask
):
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
input
=
input
.
float
()
if
self
.
scale
is
not
None
:
input
=
input
*
self
.
scale
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_fp16
:
probs
=
probs
.
half
()
else
:
probs
=
probs
.
bfloat16
()
return
probs
@
staticmethod
def
get_batch_per_block
(
sq
,
sk
,
b
,
np
):
return
scaled_masked_softmax_get_batch_per_block
(
sq
,
sk
,
b
,
np
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment