Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
058addbe
Unverified
Commit
058addbe
authored
Jun 10, 2020
by
Cliff Woolley
Committed by
GitHub
Jun 10, 2020
Browse files
Merge pull request #880 from seryilmaz/seryilmaz/stream
add streaming support for softmax kernels
parents
3e474e85
1574c03d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
206 additions
and
1 deletion
+206
-1
apex/contrib/csrc/multihead_attn/softmax.h
apex/contrib/csrc/multihead_attn/softmax.h
+206
-1
No files found.
apex/contrib/csrc/multihead_attn/softmax.h
View file @
058addbe
...
...
@@ -471,6 +471,38 @@ bool dispatch_additive_masked_softmax(output_t *dst, const input_t *src, const i
return
false
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
dispatch_additive_masked_softmax_stream
(
output_t
*
dst
,
const
input_t
*
src
,
const
input_t
*
pad_mask
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
batch_count
,
int
pad_batch_stride
,
cudaStream_t
streamid
)
{
if
(
softmax_elements
==
0
)
{
return
true
;
}
else
if
(
softmax_elements
<=
1024
)
{
// compute function index. there's a function for each power of two size up to 1024.
int
log2_elements
=
0
;
while
((
1
<<
log2_elements
)
<
softmax_elements
)
++
log2_elements
;
additive_masked_softmax_forward_func
<
input_t
,
output_t
>
kernel
;
int
warp_size
,
batches_per_warp
;
if
(
!
warp_additive_masked_softmax_kernel
<
input_t
,
output_t
,
acc_t
>
(
log2_elements
,
warp_size
,
batches_per_warp
,
kernel
))
{
return
false
;
}
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
// compute warps per block.
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
// compute launch size
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
batch_count
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// launch
kernel
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
dst
,
src
,
pad_mask
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batch_stride
);
return
true
;
}
return
false
;
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
...
...
@@ -1110,7 +1142,80 @@ void dispatch_masked_scale_softmax_backward_masked_out(output_t *grad_input, con
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
bool
is_log_softmax
>
void
dispatch_masked_scale_softmax_backward_masked_out_stream
(
output_t
*
grad_input
,
const
input_t
*
grad
,
const
input_t
*
output
,
const
uint8_t
*
mask
,
const
uint8_t
*
pad_mask
,
acc_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
batch_count
,
int
heads
,
cudaStream_t
streamid
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
1024
);
if
(
softmax_elements
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil_native
(
softmax_elements
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
batch_count
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
0
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
1
:
// 2
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
1
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
2
:
// 4
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
2
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
3
:
// 8
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
3
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
4
:
// 16
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
4
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
5
:
// 32
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
5
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
6
:
// 64
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
6
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
7
:
// 128
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
7
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
8
:
// 256
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
8
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
9
:
// 512
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
9
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
10
:
// 1024
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
10
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
default:
break
;
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
,
bool
is_log_softmax
>
__global__
void
masked_scale_softmax_warp_backward
(
output_t
*
gradInput
,
const
input_t
*
grad
,
const
input_t
*
output
,
const
uint8_t
*
mask
,
acc_t
scale
,
int
batch_size
,
int
stride
,
int
element_count
)
{
...
...
@@ -1266,6 +1371,77 @@ void dispatch_masked_scale_softmax_backward(output_t *grad_input, const input_t
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
bool
is_log_softmax
>
void
dispatch_masked_scale_softmax_backward_stream
(
output_t
*
grad_input
,
const
input_t
*
grad
,
const
input_t
*
output
,
const
uint8_t
*
mask
,
acc_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
batch_count
,
cudaStream_t
streamid
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
1024
);
if
(
softmax_elements
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil_native
(
softmax_elements
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
batch_count
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
masked_scale_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
0
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
1
:
// 2
masked_scale_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
2
:
// 4
masked_scale_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
3
:
// 8
masked_scale_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
3
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
4
:
// 16
masked_scale_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
4
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
5
:
// 32
masked_scale_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
5
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
6
:
// 64
masked_scale_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
6
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
7
:
// 128
masked_scale_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
7
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
8
:
// 256
masked_scale_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
8
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
9
:
// 512
masked_scale_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
9
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
10
:
// 1024
masked_scale_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
10
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
default:
break
;
}
}
}
// elementwise multiplication called in at::softmax_backward_data is fused inside softmax dgrad kernel
// as a result of fusion, intermediate multiplication result is stored in fp32 in registers, instead of fp16
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
,
bool
is_log_softmax
>
...
...
@@ -1608,6 +1784,35 @@ bool dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const
return
false
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
dispatch_softmax_backward_stream
(
output_t
*
grad_input
,
const
input_t
*
grad
,
const
input_t
*
output
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
batch_count
,
cudaStream_t
streamid
)
{
if
(
softmax_elements
==
0
)
{
return
true
;
}
else
if
(
softmax_elements
<=
1024
)
{
// compute function index. there's a function for each power of two size up to 1024.
int
log2_elements
=
0
;
while
((
1
<<
log2_elements
)
<
softmax_elements
)
++
log2_elements
;
softmax_backward_func
<
input_t
,
output_t
>
kernel
;
int
warp_size
,
batches_per_warp
;
if
(
!
warp_softmax_backward_kernel
<
input_t
,
output_t
,
acc_t
>
(
log2_elements
,
warp_size
,
batches_per_warp
,
kernel
))
{
return
false
;
}
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
// compute warps per block.
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
// compute launch size
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
batch_count
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// launch
kernel
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
return
true
;
}
return
false
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_ITERATIONS
,
int
WARP_SIZE
=
32
,
int
ELEMENTS_PER_LDG_STG
=
1
>
__global__
void
masked_softmax_warp_backward
(
__half
*
gradInput
,
const
__half
*
grad
,
const
__half
*
output
,
const
uint8_t
*
pad_mask
,
int
batch_size
,
int
stride
,
int
element_count
,
int
pad_batch_stride
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment