Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Uni-Core
Commits
6c18b2ed
Unverified
Commit
6c18b2ed
authored
Aug 04, 2022
by
Guolin Ke
Committed by
GitHub
Aug 04, 2022
Browse files
support softmax for large columns (#6)
* support softmax for large columns * more tests
parent
31fe887e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
257 additions
and
81 deletions
+257
-81
csrc/softmax_dropout/softmax_fast.h
csrc/softmax_dropout/softmax_fast.h
+245
-74
tests/test_softmax.py
tests/test_softmax.py
+3
-3
unicore/modules/softmax_dropout.py
unicore/modules/softmax_dropout.py
+9
-4
No files found.
csrc/softmax_dropout/softmax_fast.h
View file @
6c18b2ed
...
...
@@ -6,6 +6,7 @@
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <curand_kernel.h>
#include <cub/cub.cuh>
#include "util.h"
template
<
int
N
>
...
...
@@ -68,6 +69,141 @@ inline int softmax_rng_delta_offset(int elements)
return
warp_iterations
*
warp_batch
;
}
inline
cudaError_t
GetNumBlocks
(
int64_t
block_size
,
int64_t
max_blocks
,
int64_t
waves
,
int
*
num_blocks
)
{
int
dev
;
{
cudaError_t
err
=
cudaGetDevice
(
&
dev
);
if
(
err
!=
cudaSuccess
)
{
return
err
;
}
}
int
sm_count
;
{
cudaError_t
err
=
cudaDeviceGetAttribute
(
&
sm_count
,
cudaDevAttrMultiProcessorCount
,
dev
);
if
(
err
!=
cudaSuccess
)
{
return
err
;
}
}
int
tpm
;
{
cudaError_t
err
=
cudaDeviceGetAttribute
(
&
tpm
,
cudaDevAttrMaxThreadsPerMultiProcessor
,
dev
);
if
(
err
!=
cudaSuccess
)
{
return
err
;
}
}
*
num_blocks
=
std
::
max
<
int
>
(
1
,
std
::
min
<
int64_t
>
(
max_blocks
,
sm_count
*
tpm
/
block_size
*
waves
));
return
cudaSuccess
;
}
template
<
typename
T
>
struct
SumOp
{
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
MaxOp
{
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
max
(
a
,
b
);
}
};
template
<
template
<
typename
>
class
ReductionOp
,
typename
T
,
int
block_size
>
__inline__
__device__
T
BlockAllReduce
(
T
val
)
{
typedef
cub
::
BlockReduce
<
T
,
block_size
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
__shared__
T
result_broadcast
;
T
result
=
BlockReduce
(
temp_storage
).
Reduce
(
val
,
ReductionOp
<
T
>
());
if
(
threadIdx
.
x
==
0
)
{
result_broadcast
=
result
;
}
__syncthreads
();
return
result_broadcast
;
}
// modified from https://github.com/Oneflow-Inc/oneflow/blob/5d74efa4d07adfd0acbc8e0074778687f1006b86/oneflow/core/cuda/softmax.cuh#L480-L529
// Copyright 2020 The OneFlow Authors. All rights reserved.
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
block_size
,
bool
NeedBias
,
bool
NeedAttnMask
>
__global__
void
softmax_block_forward
(
const
input_t
*
input
,
output_t
*
output
,
const
input_t
*
attn_mask
,
const
input_t
*
bias
,
int64_t
rows
,
int
cols
,
int64_t
attn_inner_skip_batch
,
int64_t
bias_batch_size
)
{
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
shared_buf
[];
auto
*
buf
=
reinterpret_cast
<
acc_t
*>
(
shared_buf
);
const
int
tid
=
threadIdx
.
x
;
auto
element_count
=
cols
;
int64_t
bias_mod_size
=
bias_batch_size
*
cols
;
int64_t
attn_mask_div_size
=
element_count
;
if
IF_CONSTEXPR
(
NeedAttnMask
)
{
attn_mask_div_size
=
attn_inner_skip_batch
*
element_count
;
}
for
(
int64_t
row
=
blockIdx
.
x
;
row
<
rows
;
row
+=
gridDim
.
x
)
{
acc_t
thread_max
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
int64_t
idx_offset
=
row
*
cols
;
const
input_t
*
input_ptr
=
input
+
idx_offset
;
output_t
*
output_ptr
=
output
+
idx_offset
;
const
input_t
*
attn_mask_ptr
=
nullptr
;
if
IF_CONSTEXPR
(
NeedAttnMask
){
attn_mask_ptr
=
attn_mask
+
static_cast
<
int64_t
>
(
idx_offset
/
attn_mask_div_size
)
*
element_count
;
}
const
input_t
*
bias_ptr
=
nullptr
;
if
IF_CONSTEXPR
(
NeedBias
)
{
bias_ptr
=
bias
+
idx_offset
%
bias_mod_size
;
}
// TODO: enable pack as oneflow
for
(
int
col
=
tid
;
col
<
cols
;
col
+=
block_size
)
{
buf
[
col
]
=
static_cast
<
acc_t
>
(
input_ptr
[
col
]);
if
IF_CONSTEXPR
(
NeedAttnMask
)
{
buf
[
col
]
+=
attn_mask_ptr
[
col
];
}
if
IF_CONSTEXPR
(
NeedBias
)
{
buf
[
col
]
+=
bias_ptr
[
col
];
}
thread_max
=
max
(
thread_max
,
buf
[
col
]);
}
const
acc_t
row_max
=
BlockAllReduce
<
MaxOp
,
acc_t
,
block_size
>
(
thread_max
);
acc_t
thread_sum
=
0
;
for
(
int
col
=
tid
;
col
<
cols
;
col
+=
block_size
)
{
buf
[
col
]
=
std
::
exp
(
buf
[
col
]
-
row_max
);
thread_sum
+=
buf
[
col
];
}
const
acc_t
row_sum
=
BlockAllReduce
<
SumOp
,
acc_t
,
block_size
>
(
thread_sum
);
for
(
int
col
=
tid
;
col
<
cols
;
col
+=
block_size
)
{
output_ptr
[
col
]
=
static_cast
<
output_t
>
(
buf
[
col
]
/
row_sum
);
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
block_size
>
__global__
void
softmax_block_backward
(
output_t
*
store
,
const
input_t
*
dy
,
const
input_t
*
y
,
const
int64_t
rows
,
const
int64_t
cols
)
{
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
grad_shared_buf
[];
auto
*
dy_buf
=
reinterpret_cast
<
acc_t
*>
(
grad_shared_buf
);
auto
*
y_buf
=
reinterpret_cast
<
input_t
*>
(
dy_buf
+
cols
);
const
int
tid
=
threadIdx
.
x
;
for
(
int64_t
row
=
blockIdx
.
x
;
row
<
rows
;
row
+=
gridDim
.
x
)
{
acc_t
thread_sum
=
0
;
auto
dy_ptr
=
dy
+
row
*
cols
;
auto
y_ptr
=
y
+
row
*
cols
;
auto
store_ptr
=
store
+
row
*
cols
;
for
(
int
col
=
tid
;
col
<
cols
;
col
+=
block_size
)
{
y_buf
[
col
]
=
y_ptr
[
col
];
dy_buf
[
col
]
=
dy_ptr
[
col
]
*
(
acc_t
)
y_ptr
[
col
];
}
for
(
int
col
=
tid
;
col
<
cols
;
col
+=
block_size
)
{
thread_sum
+=
dy_buf
[
col
];
}
const
acc_t
row_sum
=
BlockAllReduce
<
SumOp
,
acc_t
,
block_size
>
(
thread_sum
);
for
(
int
col
=
tid
;
col
<
cols
;
col
+=
block_size
)
{
store_ptr
[
col
]
=
static_cast
<
output_t
>
(
dy_buf
[
col
]
-
y_buf
[
col
]
*
row_sum
);
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
typename
Parameters
,
bool
NeedMask
,
bool
NeedBias
,
bool
NeedAttnMask
>
...
...
@@ -113,6 +249,7 @@ __global__ void softmax_warp_forward(input_t *dst, input_t *dst_orig, const outp
for
(
int
i
=
0
;
i
<
Parameters
::
WarpBatch
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
auto
src_ptr
=
src
+
i
*
element_count
;
#pragma unroll
for
(
int
it
=
0
;
it
<
Parameters
::
WarpIterations
;
++
it
)
{
...
...
@@ -121,7 +258,7 @@ __global__ void softmax_warp_forward(input_t *dst, input_t *dst_orig, const outp
if
(
element_index
<
batch_element_count
)
{
elements_input
[
i
][
it
]
=
src
[
i
*
element_count
+
it
*
Parameters
::
WarpSize
];
elements_input
[
i
][
it
]
=
src
_ptr
[
it
*
Parameters
::
WarpSize
];
}
}
}
...
...
@@ -132,6 +269,15 @@ __global__ void softmax_warp_forward(input_t *dst, input_t *dst_orig, const outp
for
(
int
i
=
0
;
i
<
Parameters
::
WarpBatch
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
int64_t
idx_offset
=
(
first_batch
+
i
)
*
element_count
;
const
input_t
*
attn_mask_ptr
=
nullptr
;
if
IF_CONSTEXPR
(
NeedAttnMask
){
attn_mask_ptr
=
attn_mask
+
static_cast
<
int64_t
>
(
idx_offset
/
attn_mask_div_size
)
*
element_count
+
local_idx
;
}
const
input_t
*
bias_ptr
=
nullptr
;
if
IF_CONSTEXPR
(
NeedBias
){
bias_ptr
=
bias
+
idx_offset
%
bias_mod_size
+
local_idx
;
}
#pragma unroll
for
(
int
it
=
0
;
it
<
Parameters
::
WarpIterations
;
++
it
)
{
...
...
@@ -139,15 +285,13 @@ __global__ void softmax_warp_forward(input_t *dst, input_t *dst_orig, const outp
int
element_index
=
local_idx
+
it
*
Parameters
::
WarpSize
;
if
(
element_index
<
batch_element_count
)
{
int64_t
global_idx
=
thread_offset
+
i
*
element_count
+
it
*
Parameters
::
WarpSize
;
if
IF_CONSTEXPR
(
NeedAttnMask
)
{
auto
attn_mask_idx
=
static_cast
<
int64_t
>
(
global_idx
/
attn_mask_div_size
)
*
element_count
+
(
global_idx
%
element_count
);
elements
[
i
][
it
]
+=
attn_mask
[
attn_mask_idx
];
elements
[
i
][
it
]
+=
attn_mask_ptr
[
it
*
Parameters
::
WarpSize
];
}
if
IF_CONSTEXPR
(
NeedBias
)
{
elements
[
i
][
it
]
+=
bias
[
global_idx
%
bias_mod_s
ize
];
elements
[
i
][
it
]
+=
bias
_ptr
[
it
*
Parameters
::
WarpS
ize
];
}
}
}
...
...
@@ -245,6 +389,8 @@ __global__ void softmax_warp_forward(input_t *dst, input_t *dst_orig, const outp
}
}
mask
[
i
*
Parameters
::
MaskStride
+
local_idx
]
=
m
;
auto
dst_ptr
=
dst
+
i
*
element_count
;
auto
dst_orig_ptr
=
dst_orig
+
i
*
element_count
;
#pragma unroll
for
(
int
it
=
0
;
it
<
Parameters
::
WarpIterations
;
++
it
)
{
...
...
@@ -252,8 +398,8 @@ __global__ void softmax_warp_forward(input_t *dst, input_t *dst_orig, const outp
if
(
element_index
<
element_count
)
{
const
output_t
d
=
elements
[
i
][
it
]
/
sum
[
i
];
dst
[
i
*
element_count
+
it
*
Parameters
::
WarpSize
]
=
(
acc_t
)
d
*
((
acc_t
)((
m
>>
it
)
&
1
)
*
pinv
);
dst_orig
[
i
*
element_count
+
it
*
Parameters
::
WarpSize
]
=
d
;
dst
_ptr
[
it
*
Parameters
::
WarpSize
]
=
(
acc_t
)
d
*
((
acc_t
)((
m
>>
it
)
&
1
)
*
pinv
);
dst_orig
_ptr
[
it
*
Parameters
::
WarpSize
]
=
d
;
}
else
{
...
...
@@ -267,6 +413,7 @@ __global__ void softmax_warp_forward(input_t *dst, input_t *dst_orig, const outp
#pragma unroll
for
(
int
i
=
0
;
i
<
Parameters
::
WarpBatch
;
++
i
)
{
auto
dst_ptr
=
dst
+
i
*
element_count
;
if
(
i
>=
local_batches
)
break
;
#pragma unroll
...
...
@@ -275,7 +422,7 @@ __global__ void softmax_warp_forward(input_t *dst, input_t *dst_orig, const outp
int
element_index
=
local_idx
+
it
*
Parameters
::
WarpSize
;
if
(
element_index
<
element_count
)
{
dst
[
i
*
element_count
+
it
*
Parameters
::
WarpSize
]
=
elements
[
i
][
it
]
/
sum
[
i
];
dst
_ptr
[
it
*
Parameters
::
WarpSize
]
=
elements
[
i
][
it
]
/
sum
[
i
];
}
else
{
...
...
@@ -323,32 +470,42 @@ bool dispatch_softmax_forward(output_t *dst, output_t *dst_orig, const input_t *
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
LAUNCH_FORWARD_KERNEL
(
0
)
case
1
:
LAUNCH_FORWARD_KERNEL
(
1
)
case
2
:
LAUNCH_FORWARD_KERNEL
(
2
)
case
3
:
LAUNCH_FORWARD_KERNEL
(
3
)
case
4
:
LAUNCH_FORWARD_KERNEL
(
4
)
case
5
:
LAUNCH_FORWARD_KERNEL
(
5
)
case
6
:
LAUNCH_FORWARD_KERNEL
(
6
)
case
7
:
LAUNCH_FORWARD_KERNEL
(
7
)
case
8
:
LAUNCH_FORWARD_KERNEL
(
8
)
case
9
:
LAUNCH_FORWARD_KERNEL
(
9
)
case
10
:
LAUNCH_FORWARD_KERNEL
(
10
)
case
11
:
LAUNCH_FORWARD_KERNEL
(
11
)
default:
return
false
;
case
0
:
LAUNCH_FORWARD_KERNEL
(
0
)
case
1
:
LAUNCH_FORWARD_KERNEL
(
1
)
case
2
:
LAUNCH_FORWARD_KERNEL
(
2
)
case
3
:
LAUNCH_FORWARD_KERNEL
(
3
)
case
4
:
LAUNCH_FORWARD_KERNEL
(
4
)
case
5
:
LAUNCH_FORWARD_KERNEL
(
5
)
case
6
:
LAUNCH_FORWARD_KERNEL
(
6
)
case
7
:
LAUNCH_FORWARD_KERNEL
(
7
)
case
8
:
LAUNCH_FORWARD_KERNEL
(
8
)
case
9
:
LAUNCH_FORWARD_KERNEL
(
9
)
case
10
:
LAUNCH_FORWARD_KERNEL
(
10
)
default:
{
int
grid_dim
;
constexpr
int
block_size
=
128
;
constexpr
int
waves
=
32
;
auto
cols
=
softmax_elements
;
auto
rows
=
batch_count
;
GetNumBlocks
(
block_size
,
rows
,
waves
,
&
grid_dim
);
dim3
block
(
block_size
);
const
size_t
smem
=
cols
*
sizeof
(
acc_t
);
softmax_block_forward
<
input_t
,
output_t
,
acc_t
,
block_size
,
NeedAttnMask
,
NeedBias
><<<
grid_dim
,
block
,
smem
>>>
(
src
,
dst
,
attn_mask
,
bias
,
rows
,
cols
,
attn_inner_skip_batch
,
bias_batch_count
);
return
true
;
}
}
}
return
false
;
...
...
@@ -389,7 +546,7 @@ __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad,
// load data from global memory
acc_t
grad_reg
[
Parameters
::
WarpBatch
][
Parameters
::
WarpIterations
];
acc
_t
output_reg
[
Parameters
::
WarpBatch
][
Parameters
::
WarpIterations
];
input
_t
output_reg
[
Parameters
::
WarpBatch
][
Parameters
::
WarpIterations
];
if
IF_CONSTEXPR
(
NeedMask
)
{
MaskType
mask_reg
[
Parameters
::
WarpBatch
];
...
...
@@ -408,6 +565,8 @@ __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad,
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
MaskType
m
=
mask_reg
[
i
];
auto
output_ptr
=
output
+
i
*
element_count
;
auto
grad_ptr
=
grad
+
i
*
element_count
;
#pragma unroll
for
(
int
it
=
0
;
it
<
Parameters
::
WarpIterations
;
++
it
)
{
...
...
@@ -415,16 +574,16 @@ __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad,
if
(
element_index
<
batch_element_count
)
{
grad_reg
[
i
][
it
]
=
(
input_t
)(
(
acc_t
)((
m
>>
it
)
&
1
)
*
(
acc_t
)
grad
[
i
*
element_count
+
it
*
Parameters
::
WarpSize
]
*
pinv
)
*
output
[
i
*
element_count
+
it
*
Parameters
::
WarpSize
];
output_reg
[
i
][
it
]
=
output
[
i
*
element_count
+
it
*
Parameters
::
WarpSize
];
(
acc_t
)((
m
>>
it
)
&
1
)
*
(
acc_t
)
grad
_ptr
[
it
*
Parameters
::
WarpSize
]
*
pinv
*
output
_ptr
[
it
*
Parameters
::
WarpSize
];
output_reg
[
i
][
it
]
=
output
_ptr
[
it
*
Parameters
::
WarpSize
];
}
else
{
grad_reg
[
i
][
it
]
=
acc_t
(
0
);
output_reg
[
i
][
it
]
=
acc
_t
(
0
);
output_reg
[
i
][
it
]
=
input
_t
(
0
);
}
}
}
...
...
@@ -435,20 +594,22 @@ __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad,
for
(
int
i
=
0
;
i
<
Parameters
::
WarpBatch
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
auto
output_ptr
=
output
+
i
*
element_count
;
auto
grad_ptr
=
grad
+
i
*
element_count
;
#pragma unroll
for
(
int
it
=
0
;
it
<
Parameters
::
WarpIterations
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
Parameters
::
WarpSize
;
if
(
element_index
<
batch_element_count
)
{
grad
_reg
[
i
][
it
]
=
grad
[
i
*
element_count
+
it
*
Parameters
::
WarpSize
]
*
output
[
i
*
element_count
+
it
*
Parameters
::
WarpSize
]
;
output_reg
[
i
][
it
]
=
output
[
i
*
element_count
+
it
*
Parameters
::
WarpSize
];
output
_reg
[
i
][
it
]
=
output_ptr
[
it
*
Parameters
::
WarpSize
]
;
grad_reg
[
i
][
it
]
=
grad_ptr
[
it
*
Parameters
::
WarpSize
]
*
(
acc_t
)
output_ptr
[
it
*
Parameters
::
WarpSize
];
}
else
{
grad_reg
[
i
][
it
]
=
acc_t
(
0
);
output_reg
[
i
][
it
]
=
acc
_t
(
0
);
output_reg
[
i
][
it
]
=
output
_t
(
0
);
}
}
}
...
...
@@ -482,6 +643,7 @@ __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad,
{
if
(
i
>=
local_batches
)
break
;
auto
gradInput_ptr
=
gradInput
+
i
*
element_count
;
#pragma unroll
for
(
int
it
=
0
;
it
<
Parameters
::
WarpIterations
;
++
it
)
{
...
...
@@ -491,12 +653,12 @@ __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad,
// compute gradients
if
IF_CONSTEXPR
(
IsLogSoftmax
)
{
gradInput
[
i
*
element_count
+
it
*
Parameters
::
WarpSize
]
=
(
grad_reg
[
i
][
it
]
-
std
::
exp
(
output_reg
[
i
][
it
])
*
sum
[
i
]);
gradInput
_ptr
[
it
*
Parameters
::
WarpSize
]
=
(
grad_reg
[
i
][
it
]
-
std
::
exp
(
(
acc_t
)
output_reg
[
i
][
it
])
*
sum
[
i
]);
}
else
{
gradInput
[
i
*
element_count
+
it
*
Parameters
::
WarpSize
]
=
gradInput
_ptr
[
it
*
Parameters
::
WarpSize
]
=
(
grad_reg
[
i
][
it
]
-
output_reg
[
i
][
it
]
*
sum
[
i
]);
}
}
...
...
@@ -541,32 +703,41 @@ void dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
LAUNCH_BACKWARD_KERNEL
(
0
)
case
1
:
LAUNCH_BACKWARD_KERNEL
(
1
)
case
2
:
LAUNCH_BACKWARD_KERNEL
(
2
)
case
3
:
LAUNCH_BACKWARD_KERNEL
(
3
)
case
4
:
LAUNCH_BACKWARD_KERNEL
(
4
)
case
5
:
LAUNCH_BACKWARD_KERNEL
(
5
)
case
6
:
LAUNCH_BACKWARD_KERNEL
(
6
)
case
7
:
LAUNCH_BACKWARD_KERNEL
(
7
)
case
8
:
LAUNCH_BACKWARD_KERNEL
(
8
)
case
9
:
LAUNCH_BACKWARD_KERNEL
(
9
)
case
10
:
LAUNCH_BACKWARD_KERNEL
(
10
)
case
11
:
LAUNCH_BACKWARD_KERNEL
(
11
)
default:
break
;
case
0
:
LAUNCH_BACKWARD_KERNEL
(
0
)
case
1
:
LAUNCH_BACKWARD_KERNEL
(
1
)
case
2
:
LAUNCH_BACKWARD_KERNEL
(
2
)
case
3
:
LAUNCH_BACKWARD_KERNEL
(
3
)
case
4
:
LAUNCH_BACKWARD_KERNEL
(
4
)
case
5
:
LAUNCH_BACKWARD_KERNEL
(
5
)
case
6
:
LAUNCH_BACKWARD_KERNEL
(
6
)
case
7
:
LAUNCH_BACKWARD_KERNEL
(
7
)
case
8
:
LAUNCH_BACKWARD_KERNEL
(
8
)
case
9
:
LAUNCH_BACKWARD_KERNEL
(
9
)
case
10
:
LAUNCH_BACKWARD_KERNEL
(
10
)
default:
{
int
grid_dim
;
constexpr
int
block_size
=
128
;
constexpr
int
waves
=
32
;
auto
cols
=
softmax_elements
;
auto
rows
=
batch_count
;
GetNumBlocks
(
block_size
,
rows
,
waves
,
&
grid_dim
);
dim3
block
(
block_size
);
const
size_t
smem
=
cols
*
sizeof
(
acc_t
)
+
cols
*
sizeof
(
input_t
)
;
softmax_block_backward
<
input_t
,
output_t
,
acc_t
,
block_size
><<<
grid_dim
,
block
,
smem
>>>
(
grad_input
,
grad
,
output
,
rows
,
cols
);
}
}
}
}
tests/test_softmax.py
View file @
6c18b2ed
...
...
@@ -39,7 +39,7 @@ def test_softmax():
n_batch
=
4
n_heads
=
8
n_query
=
128
test_dims
=
[
64
,
128
,
256
,
512
,
1024
]
test_dims
=
[
64
,
128
,
256
,
512
,
1024
,
1536
,
2048
]
test_dtype
=
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
test_device
=
torch
.
device
(
"cuda"
)
for
last_dim
in
test_dims
:
...
...
@@ -83,7 +83,7 @@ def test_tri_softmax1():
n_groups
=
32
n_heads
=
8
n_query
=
128
test_dims
=
[
64
,
128
,
256
,
512
,
1024
]
test_dims
=
[
64
,
128
,
256
,
512
,
1024
,
1536
,
2048
]
test_dtype
=
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
test_device
=
torch
.
device
(
"cuda"
)
for
last_dim
in
test_dims
:
...
...
@@ -129,7 +129,7 @@ def test_tri_softmax2():
n_groups
=
32
n_heads
=
8
n_query
=
128
test_dims
=
[
64
,
128
,
256
,
512
,
1024
]
test_dims
=
[
64
,
128
,
256
,
512
,
1024
,
1536
,
2048
]
test_dtype
=
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
test_device
=
torch
.
device
(
"cuda"
)
for
last_dim
in
test_dims
:
...
...
unicore/modules/softmax_dropout.py
View file @
6c18b2ed
...
...
@@ -94,7 +94,7 @@ def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None)
torch.Tensor: the result after softmax
"""
input
=
input
.
contiguous
()
if
input
.
is_cuda
and
input
.
shape
[
-
1
]
<=
2048
:
if
input
.
is_cuda
:
input_size
=
input
.
size
()
if
mask
is
not
None
:
_check_mask
(
mask
,
input
)
...
...
@@ -103,9 +103,14 @@ def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None)
_check_bias
(
bias
,
input
)
bias
=
bias
.
contiguous
().
view
(
-
1
,
input_size
[
-
2
],
input_size
[
-
1
])
input
=
input
.
view
(
-
1
,
input_size
[
-
2
],
input_size
[
-
1
])
return
SoftmaxDropoutFast
.
apply
(
is_training
,
input
,
mask
,
bias
,
dropout_prob
).
view
(
*
input_size
)
if
dropout_prob
<=
0.0
or
input_size
[
-
1
]
<=
1024
:
return
SoftmaxDropoutFast
.
apply
(
is_training
,
input
,
mask
,
bias
,
dropout_prob
).
view
(
*
input_size
)
else
:
return
F
.
dropout
(
SoftmaxDropoutFast
.
apply
(
is_training
,
input
,
mask
,
bias
,
0.0
).
view
(
*
input_size
),
p
=
dropout_prob
,
training
=
is_training
)
else
:
if
mask
is
not
None
:
input
+=
mask
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment