Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
c6b77639
Commit
c6b77639
authored
Aug 18, 2021
by
hyunwoongko
Committed by
mshoeybi
Aug 22, 2021
Browse files
chagne PR by reviews
parent
0d350c8d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
2 deletions
+11
-2
megatron/fused_kernels/scaled_masked_softmax.h
megatron/fused_kernels/scaled_masked_softmax.h
+3
-0
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
+6
-0
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+2
-2
No files found.
megatron/fused_kernels/scaled_masked_softmax.h
View file @
c6b77639
...
@@ -339,6 +339,7 @@ void dispatch_scaled_masked_softmax_forward(
...
@@ -339,6 +339,7 @@ void dispatch_scaled_masked_softmax_forward(
int
attn_heads
,
int
attn_heads
,
int
pad_batches
)
int
pad_batches
)
{
{
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
2048
);
if
(
key_seq_len
==
0
)
{
if
(
key_seq_len
==
0
)
{
return
;
return
;
}
else
{
}
else
{
...
@@ -357,6 +358,7 @@ void dispatch_scaled_masked_softmax_forward(
...
@@ -357,6 +358,7 @@ void dispatch_scaled_masked_softmax_forward(
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
query_seq_len
%
batches_per_block
==
0
);
dim3
blocks
(
query_seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
blocks
(
query_seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
...
@@ -426,6 +428,7 @@ void dispatch_scaled_masked_softmax_backward(
...
@@ -426,6 +428,7 @@ void dispatch_scaled_masked_softmax_backward(
int
batches
,
int
batches
,
int
attn_heads
)
int
attn_heads
)
{
{
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
2048
);
if
(
key_seq_len
==
0
)
{
if
(
key_seq_len
==
0
)
{
return
;
return
;
}
else
{
}
else
{
...
...
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
View file @
c6b77639
...
@@ -340,6 +340,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
...
@@ -340,6 +340,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
int
softmax_elements_stride
,
int
softmax_elements_stride
,
int
attn_batches
)
int
attn_batches
)
{
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
if
(
softmax_elements
==
0
)
{
if
(
softmax_elements
==
0
)
{
return
;
return
;
}
else
{
}
else
{
...
@@ -359,6 +360,8 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
...
@@ -359,6 +360,8 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
...
@@ -428,6 +431,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
...
@@ -428,6 +431,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
int
softmax_elements_stride
,
int
softmax_elements_stride
,
int
attn_batches
)
int
attn_batches
)
{
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
if
(
softmax_elements
==
0
)
{
if
(
softmax_elements
==
0
)
{
return
;
return
;
}
else
{
}
else
{
...
@@ -447,6 +451,8 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
...
@@ -447,6 +451,8 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
...
...
megatron/model/fused_softmax.py
View file @
c6b77639
...
@@ -138,8 +138,8 @@ class FusedScaleMaskSoftmax(nn.Module):
...
@@ -138,8 +138,8 @@ class FusedScaleMaskSoftmax(nn.Module):
self
.
scaled_masked_softmax_fusion
# user want to fuse
self
.
scaled_masked_softmax_fusion
# user want to fuse
and
self
.
input_in_float16
# input must be fp16
and
self
.
input_in_float16
# input must be fp16
and
mask
is
not
None
# mask tensor must not be None
and
mask
is
not
None
# mask tensor must not be None
and
16
<
s
q
<=
2048
# sq must be 16 ~ 2048
and
16
<
s
k
<=
2048
# sq must be 16 ~ 2048
and
s
k
%
4
==
0
# sk must be divisor of 4
and
s
q
%
4
==
0
# sk must be divisor of 4
and
attn_batches
%
4
==
0
# np * b must be divisor of 4
and
attn_batches
%
4
==
0
# np * b must be divisor of 4
):
):
if
0
<=
sk
<=
2048
:
if
0
<=
sk
<=
2048
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment