Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
f63b27e8
Unverified
Commit
f63b27e8
authored
Jan 27, 2023
by
Przemyslaw Tredak
Committed by
GitHub
Jan 27, 2023
Browse files
Fix the integer overflow in fused softmax (#60)
Signed-off-by:
Przemek Tredak
<
ptredak@nvidia.com
>
parent
b67fe451
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
17 deletions
+29
-17
transformer_engine/common/fused_softmax/scaled_masked_softmax.cu
...rmer_engine/common/fused_softmax/scaled_masked_softmax.cu
+17
-12
transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu
...ommon/fused_softmax/scaled_upper_triang_masked_softmax.cu
+8
-5
transformer_engine/pytorch/csrc/extensions.cu
transformer_engine/pytorch/csrc/extensions.cu
+4
-0
No files found.
transformer_engine/common/fused_softmax/scaled_masked_softmax.cu
View file @
f63b27e8
...
@@ -121,7 +121,8 @@ __global__ void scaled_softmax_warp_forward(
...
@@ -121,7 +121,8 @@ __global__ void scaled_softmax_warp_forward(
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int
first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
(
blockIdx
.
y
+
gridDim
.
y
*
blockIdx
.
z
))
size_t
first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
(
blockIdx
.
y
+
gridDim
.
y
*
blockIdx
.
z
))
+
threadIdx
.
y
)
*
WARP_BATCH
;
+
threadIdx
.
y
)
*
WARP_BATCH
;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
...
@@ -133,8 +134,9 @@ __global__ void scaled_softmax_warp_forward(
...
@@ -133,8 +134,9 @@ __global__ void scaled_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
size_t
thread_offset
=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
dst
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
src
+=
thread_offset
;
dst
+=
thread_offset
;
// load data from global memory
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
...
@@ -236,9 +238,10 @@ __global__ void scaled_masked_softmax_warp_forward(
...
@@ -236,9 +238,10 @@ __global__ void scaled_masked_softmax_warp_forward(
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int
first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
(
blockIdx
.
y
+
gridDim
.
y
*
blockIdx
.
z
))
size_t
first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
(
blockIdx
.
y
+
gridDim
.
y
*
blockIdx
.
z
))
+
threadIdx
.
y
)
*
WARP_BATCH
;
+
threadIdx
.
y
)
*
WARP_BATCH
;
in
t
pad_first_batch
=
0
;
size_
t
pad_first_batch
=
0
;
if
(
pad_batches
!=
1
)
{
// bert style
if
(
pad_batches
!=
1
)
{
// bert style
pad_first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
blockIdx
.
z
)
+
threadIdx
.
y
)
pad_first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
blockIdx
.
z
)
+
threadIdx
.
y
)
*
WARP_BATCH
;
*
WARP_BATCH
;
...
@@ -255,9 +258,11 @@ __global__ void scaled_masked_softmax_warp_forward(
...
@@ -255,9 +258,11 @@ __global__ void scaled_masked_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
size_t
thread_offset_src_dst
=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
dst
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
size_t
thread_offset_mask
=
pad_first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
mask
+=
pad_first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
src
+=
thread_offset_src_dst
;
dst
+=
thread_offset_src_dst
;
mask
+=
thread_offset_mask
;
// load data from global memory
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
...
@@ -365,7 +370,7 @@ __global__ void scaled_masked_softmax_warp_backward(
...
@@ -365,7 +370,7 @@ __global__ void scaled_masked_softmax_warp_backward(
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
// gridDim/blockIdx = (seq_len, attn_heads, batches)
in
t
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
size_
t
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
// many batches have to computed within this WARP.
...
@@ -377,7 +382,7 @@ __global__ void scaled_masked_softmax_warp_backward(
...
@@ -377,7 +382,7 @@ __global__ void scaled_masked_softmax_warp_backward(
int
local_idx
=
threadIdx
.
x
;
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
// the first element to process by the current thread
in
t
thread_offset
=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
size_
t
thread_offset
=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
grad
+=
thread_offset
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
gradInput
+=
thread_offset
;
...
...
transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu
View file @
f63b27e8
...
@@ -139,7 +139,8 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
...
@@ -139,7 +139,8 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
size_t
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
int
local_seq
=
blockIdx
.
x
+
1
;
int
warp_iteration_limit
=
(
local_seq
+
ELEMENTS_PER_LDG_STG
*
WARP_SIZE
-
1
)
/
WARP_SIZE
;
int
warp_iteration_limit
=
(
local_seq
+
ELEMENTS_PER_LDG_STG
*
WARP_SIZE
-
1
)
/
WARP_SIZE
;
...
@@ -152,8 +153,9 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
...
@@ -152,8 +153,9 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
size_t
thread_offset
=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
dst
+=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
src
+=
thread_offset
;
dst
+=
thread_offset
;
// load data from global memory
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
...
@@ -263,7 +265,8 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
...
@@ -263,7 +265,8 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
size_t
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
int
local_seq
=
blockIdx
.
x
+
1
;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
...
@@ -276,7 +279,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
...
@@ -276,7 +279,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
int
local_idx
=
threadIdx
.
x
;
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
// the first element to process by the current thread
in
t
thread_offset
=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
size_
t
thread_offset
=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
grad
+=
thread_offset
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
gradInput
+=
thread_offset
;
...
...
transformer_engine/pytorch/csrc/extensions.cu
View file @
f63b27e8
...
@@ -687,6 +687,10 @@ at::Tensor scaled_masked_softmax_forward(at::Tensor input,
...
@@ -687,6 +687,10 @@ at::Tensor scaled_masked_softmax_forward(at::Tensor input,
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
(
mask
.
dim
()
==
4
,
"expected 4D tensor"
);
AT_ASSERTM
(
mask
.
dim
()
==
4
,
"expected 4D tensor"
);
if
(
!
input
.
is_contiguous
())
input
=
input
.
contiguous
();
if
(
!
mask
.
is_contiguous
())
mask
=
mask
.
contiguous
();
const
int
batches
=
input
.
size
(
0
);
const
int
batches
=
input
.
size
(
0
);
const
int
pad_batches
=
mask
.
size
(
0
);
const
int
pad_batches
=
mask
.
size
(
0
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment