Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
58edb19a
Commit
58edb19a
authored
Jan 14, 2021
by
Vijay Korthikanti
Browse files
Merge branch 'main' into vision_transformer
parents
adbba962
f34cc86b
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
471 additions
and
208 deletions
+471
-208
megatron/arguments.py
megatron/arguments.py
+33
-18
megatron/fused_kernels/scaled_masked_softmax.h
megatron/fused_kernels/scaled_masked_softmax.h
+43
-47
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
+13
-11
megatron/model/bert_model.py
megatron/model/bert_model.py
+2
-0
megatron/model/classification.py
megatron/model/classification.py
+2
-0
megatron/model/enums.py
megatron/model/enums.py
+28
-0
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+18
-15
megatron/model/gpt_model.py
megatron/model/gpt_model.py
+2
-0
megatron/model/language_model.py
megatron/model/language_model.py
+154
-54
megatron/model/multiple_choice.py
megatron/model/multiple_choice.py
+2
-0
megatron/model/realm_model.py
megatron/model/realm_model.py
+2
-0
megatron/model/transformer.py
megatron/model/transformer.py
+171
-61
pretrain_gpt.py
pretrain_gpt.py
+1
-2
No files found.
megatron/arguments.py
View file @
58edb19a
...
...
@@ -165,6 +165,20 @@ def parse_args(extra_args_provider=None, defaults={},
_check_arg_is_not_none
(
args
,
req_arg
)
# Checks.
if
args
.
ffn_hidden_size
is
None
:
args
.
ffn_hidden_size
=
4
*
args
.
hidden_size
if
args
.
kv_channels
is
None
:
assert
args
.
hidden_size
%
args
.
num_attention_heads
==
0
args
.
kv_channels
=
args
.
hidden_size
//
args
.
num_attention_heads
if
args
.
seq_length
is
not
None
:
assert
args
.
encoder_seq_length
is
None
args
.
encoder_seq_length
=
args
.
seq_length
else
:
assert
args
.
encoder_seq_length
is
not
None
args
.
seq_length
=
args
.
encoder_seq_length
assert
args
.
hidden_size
%
args
.
num_attention_heads
==
0
if
args
.
seq_length
is
not
None
:
assert
args
.
max_position_embeddings
>=
args
.
seq_length
...
...
@@ -184,15 +198,10 @@ def parse_args(extra_args_provider=None, defaults={},
'for distribute-checkpointed-activations to work you '
\
'need to enable checkpoint-activations'
if
args
.
scaled_masked_softmax_fusion
:
if
args
.
scaled_upper_triang_
masked_softmax_fusion
:
# Load
scaled_masked_softmax_fusion
_kernels
if
args
.
masked_softmax_fusion
:
fused_kernels
.
load_scaled_upper_triang_masked_softmax_fusion_kernel
()
else
:
fused_kernels
.
load_scaled_masked_softmax_fusion_kernel
()
else
:
# This argument will eventually go away, for now make sure it is off
# if scaled_masked_softmax_fusion is off.
args
.
scaled_upper_triang_masked_softmax_fusion
=
False
# Load mixed precision fused layer norm.
if
args
.
fp32_residual_connection
:
...
...
@@ -228,8 +237,14 @@ def _add_network_size_args(parser):
help
=
'Number of transformer layers.'
)
group
.
add_argument
(
'--hidden-size'
,
type
=
int
,
default
=
None
,
help
=
'Tansformer hidden size.'
)
group
.
add_argument
(
'--ffn-hidden-size'
,
type
=
int
,
default
=
None
,
help
=
'Transformer Feed-Forward Network hidden size. This is set to 4*hidden-size if not '
'provided'
)
group
.
add_argument
(
'--num-attention-heads'
,
type
=
int
,
default
=
None
,
help
=
'Number of transformer attention heads.'
)
group
.
add_argument
(
'--kv-channels'
,
type
=
int
,
default
=
None
,
help
=
'Projection weights dimension in multi-head attention. '
'This is set to args.hidden_size // args.num_attention_heads if not provided.'
)
group
.
add_argument
(
'--max-position-embeddings'
,
type
=
int
,
default
=
None
,
help
=
'Maximum number of position embeddings to use. '
'This is the size of position embedding.'
)
...
...
@@ -333,16 +348,11 @@ def _add_training_args(parser):
help
=
'Exit the program after this many minutes.'
)
group
.
add_argument
(
'--tensorboard-dir'
,
type
=
str
,
default
=
None
,
help
=
'Write TensorBoard logs to this directory.'
)
group
.
add_argument
(
'--no-
scaled-
masked-softmax-fusion'
,
group
.
add_argument
(
'--no-masked-softmax-fusion'
,
action
=
'store_false'
,
help
=
'Disable fusion of query_key_value scaling, '
'masking, and softmax.'
,
dest
=
'scaled_masked_softmax_fusion'
)
group
.
add_argument
(
'--scaled-upper-triang-masked-softmax-fusion'
,
type
=
bool
,
help
=
'Use upper triangular version of fused '
'scale, mask, softmax fusion kernel (default for GPT). '
'- DEPRECATED'
)
dest
=
'masked_softmax_fusion'
)
group
.
add_argument
(
'--no-bias-gelu-fusion'
,
action
=
'store_false'
,
help
=
'Disable bias and gelu fusion.'
,
dest
=
'bias_gelu_fusion'
)
...
...
@@ -537,7 +547,12 @@ def _add_data_args(parser):
group
.
add_argument
(
'--merge-file'
,
type
=
str
,
default
=
None
,
help
=
'Path to the BPE merge file.'
)
group
.
add_argument
(
'--seq-length'
,
type
=
int
,
default
=
None
,
help
=
"Maximum sequence length to process."
)
help
=
'Maximum sequence length to process.'
)
group
.
add_argument
(
'--encoder-seq-length'
,
type
=
int
,
default
=
None
,
help
=
'Maximum encoder sequence length to process.'
'This should be exclusive of --seq-length'
)
group
.
add_argument
(
'--decoder-seq-length'
,
type
=
int
,
default
=
None
,
help
=
"Maximum decoder sequence length to process."
)
group
.
add_argument
(
'--mask-prob'
,
type
=
float
,
default
=
0.15
,
help
=
'Probability of replacing a token with mask.'
)
group
.
add_argument
(
'--short-seq-prob'
,
type
=
float
,
default
=
0.1
,
...
...
megatron/fused_kernels/scaled_masked_softmax.h
View file @
58edb19a
...
...
@@ -81,7 +81,6 @@ __global__ void scaled_masked_softmax_warp_forward(
const
uint8_t
*
mask
,
const
acc_t
scale
,
int
micro_batch_size
,
int
stride
,
int
element_count
,
int
pad_batches
)
{
...
...
@@ -111,9 +110,9 @@ __global__ void scaled_masked_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
stride
+
local_idx
;
dst
+=
first_batch
*
stride
+
local_idx
;
mask
+=
pad_first_batch
*
stride
+
local_idx
;
src
+=
first_batch
*
element_count
+
local_idx
;
dst
+=
first_batch
*
element_count
+
local_idx
;
mask
+=
pad_first_batch
*
element_count
+
local_idx
;
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
...
...
@@ -185,7 +184,6 @@ __global__ void scaled_masked_softmax_warp_backward(
const
input_t
*
output
,
acc_t
scale
,
int
micro_batch_size
,
int
stride
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
...
...
@@ -209,7 +207,7 @@ __global__ void scaled_masked_softmax_warp_backward(
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
stride
+
local_idx
;
int
thread_offset
=
first_batch
*
element_count
+
local_idx
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
...
...
@@ -277,20 +275,19 @@ void dispatch_scaled_masked_softmax_forward(
const
input_t
*
src
,
const
uint8_t
*
mask
,
const
input_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
,
int
pad_batches
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
if
(
softmax_elements
==
0
)
{
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
2048
);
if
(
key_seq_len
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
softmax_elements
);
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
seq_len
=
softmax_elements
;
int
batch_count
=
batches
*
attn_heads
*
seq_len
;
int
batch_count
=
batches
*
attn_heads
*
query_seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
...
...
@@ -303,58 +300,58 @@ void dispatch_scaled_masked_softmax_forward(
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
seq_len
%
batches_per_block
==
0
);
dim3
blocks
(
seq_len
/
batches_per_block
,
attn_heads
,
batches
);
TORCH_INTERNAL_ASSERT
(
query_
seq_len
%
batches_per_block
==
0
);
dim3
blocks
(
query_
seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
1
:
// 2
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
2
:
// 4
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
3
:
// 8
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
4
:
// 16
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
5
:
// 32
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
6
:
// 64
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
7
:
// 128
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
8
:
// 256
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
9
:
// 512
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
10
:
// 1024
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
11
:
// 2048
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
default:
break
;
...
...
@@ -368,19 +365,18 @@ void dispatch_scaled_masked_softmax_backward(
input_t
*
grad
,
const
input_t
*
output
,
const
acc_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
if
(
softmax_elements
==
0
)
{
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
2048
);
if
(
key_seq_len
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
softmax_elements
);
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
seq_len
=
softmax_elements
;
int
batch_count
=
batches
*
attn_heads
*
seq_len
;
int
batch_count
=
batches
*
attn_heads
*
query_seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
...
...
@@ -399,51 +395,51 @@ void dispatch_scaled_masked_softmax_backward(
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
1
:
// 2
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
2
:
// 4
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
3
:
// 8
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
4
:
// 16
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
5
:
// 32
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
6
:
// 64
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
7
:
// 128
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
8
:
// 256
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
9
:
// 512
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
10
:
// 1024
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
11
:
// 2048
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
default:
break
;
...
...
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
View file @
58edb19a
...
...
@@ -37,17 +37,19 @@ torch::Tensor fwd_cuda(
const
int
batches
=
input
.
size
(
0
);
const
int
pad_batches
=
mask
.
size
(
0
);
const
int
attn_heads
=
input
.
size
(
1
);
const
int
seq_len
=
input
.
size
(
2
);
TORCH_INTERNAL_ASSERT
(
seq_len
<=
2048
);
const
int
query_seq_len
=
input
.
size
(
2
);
const
int
key_seq_len
=
input
.
size
(
3
);
TORCH_INTERNAL_ASSERT
(
key_seq_len
<=
2048
);
TORCH_INTERNAL_ASSERT
(
query_seq_len
>
1
);
TORCH_INTERNAL_ASSERT
(
pad_batches
==
1
||
pad_batches
==
batches
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
1
)
==
1
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
2
)
==
seq_len
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
3
)
==
seq_len
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
2
)
==
query_
seq_len
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
3
)
==
key_
seq_len
);
// Output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
batches
,
attn_heads
,
seq_len
,
seq_len
},
act_options
);
torch
::
empty
({
batches
,
attn_heads
,
query_
seq_len
,
key_
seq_len
},
act_options
);
// Softmax Intermediate Result Ptr
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
...
...
@@ -59,8 +61,8 @@ torch::Tensor fwd_cuda(
reinterpret_cast
<
const
half
*>
(
input_ptr
),
reinterpret_cast
<
const
uint8_t
*>
(
mask_ptr
),
scale_factor
,
seq_len
,
seq_len
,
query_
seq_len
,
key_
seq_len
,
batches
,
attn_heads
,
pad_batches
);
...
...
@@ -78,8 +80,8 @@ torch::Tensor bwd_cuda(
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
output_grads
.
size
(
0
);
const
int
attn_heads
=
output_grads
.
size
(
1
);
const
int
seq_len
=
output_grads
.
size
(
2
);
TORCH_INTERNAL_ASSERT
(
output_grads
.
size
(
2
)
=
=
output_grads
.
size
(
3
)
)
;
const
int
query_
seq_len
=
output_grads
.
size
(
2
);
const
int
key_seq_len
=
output_grads
.
size
(
3
);
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
...
...
@@ -89,8 +91,8 @@ torch::Tensor bwd_cuda(
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
seq_len
,
seq_len
,
query_
seq_len
,
key_
seq_len
,
batches
,
attn_heads
);
...
...
megatron/model/bert_model.py
View file @
58edb19a
...
...
@@ -19,6 +19,7 @@ import torch
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
get_language_model
from
megatron.model
import
import_layernorm
...
...
@@ -143,6 +144,7 @@ class BertModelBase(MegatronModule):
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
num_tokentypes
=
num_tokentypes
,
add_pooler
=
self
.
add_binary_head
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
)
...
...
megatron/model/classification.py
View file @
58edb19a
...
...
@@ -19,6 +19,7 @@ import torch
from
megatron
import
get_args
,
print_rank_last
from
megatron
import
mpu
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.bert_model
import
bert_extended_attention_mask
,
bert_position_ids
from
megatron.model.language_model
import
get_language_model
from
megatron.model.utils
import
get_linear_layer
...
...
@@ -39,6 +40,7 @@ class ClassificationBase(MegatronModule):
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
num_tokentypes
=
num_tokentypes
,
add_pooler
=
True
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
...
...
megatron/model/enums.py
0 → 100644
View file @
58edb19a
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
enum
class
LayerType
(
enum
.
Enum
):
encoder
=
1
decoder
=
2
class
AttnType
(
enum
.
Enum
):
self_attn
=
1
cross_attn
=
2
class
AttnMaskType
(
enum
.
Enum
):
padding
=
1
causal
=
2
megatron/model/fused_softmax.py
View file @
58edb19a
...
...
@@ -14,6 +14,7 @@
# limitations under the License.
import
torch
from
megatron.model.enums
import
AttnMaskType
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
):
...
...
@@ -85,8 +86,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
upper_triang_mask: if true, apply upper triangular masking.
(used in gpt family networks)
attn_mask_type: attention mask type (pad or causal)
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
...
...
@@ -96,16 +96,16 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
def
__init__
(
self
,
input_in_fp16
,
upper_triang_mask_fusion
,
general_mask
_fusion
,
attn_mask_type
,
scaled_masked_softmax
_fusion
,
mask_func
,
softmax_in_fp32
,
scale
,
):
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
upper_triang_mask_fusion
=
upper_triang_mask_fusion
self
.
general_mask_fusion
=
general_mask
_fusion
self
.
attn_mask_type
=
attn_mask_type
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax
_fusion
self
.
mask_func
=
mask_func
self
.
softmax_in_fp32
=
softmax_in_fp32
self
.
scale
=
scale
...
...
@@ -115,23 +115,26 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
),
"softmax should be in fp32 when scaled"
def
forward
(
self
,
input
,
mask
):
# [b, np, s, s]
# [b, np, s
q
, s
k
]
data_size
=
input
.
size
()
query_seq_len
=
data_size
[
-
2
]
key_seq_len
=
data_size
[
-
1
]
assert
input
.
dim
()
==
4
# invoke custom kernel
if
(
self
.
input_in_fp16
and
data_size
[
-
1
]
<=
2048
and
(
self
.
upper_triang_mask_fusion
or
self
.
general_mask_fusion
)
and
input
.
size
()[
2
]
==
input
.
size
()[
3
]
):
if
self
.
input_in_fp16
and
key_seq_len
<=
2048
and
\
query_seq_len
%
4
==
0
and
self
.
scaled_masked_softmax_fusion
:
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
upper_triang_mask_fusion
:
input
=
input
.
view
(
-
1
,
data_size
[
2
],
data_size
[
3
])
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
assert
query_seq_len
==
key_seq_len
,
\
"causal mask is only for self attention"
input
=
input
.
view
(
-
1
,
query_seq_len
,
key_seq_len
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
probs
=
probs
.
view
(
*
data_size
)
else
:
assert
self
.
attn_mask_type
==
AttnMaskType
.
padding
probs
=
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
else
:
if
self
.
input_in_fp16
and
self
.
softmax_in_fp32
:
...
...
megatron/model/gpt_model.py
View file @
58edb19a
...
...
@@ -21,6 +21,7 @@ from megatron import get_args
from
megatron
import
mpu
from
.module
import
MegatronModule
from
.enums
import
AttnMaskType
from
.language_model
import
parallel_lm_logits
from
.language_model
import
get_language_model
from
.utils
import
init_method_normal
...
...
@@ -69,6 +70,7 @@ class GPTModelBase(MegatronModule):
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
num_tokentypes
=
num_tokentypes
,
add_pooler
=
False
,
encoder_attn_mask_type
=
AttnMaskType
.
causal
,
init_method
=
init_method_normal
(
args
.
init_method_std
),
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
...
...
megatron/model/language_model.py
View file @
58edb19a
...
...
@@ -21,6 +21,7 @@ import torch.nn.functional as F
from
megatron
import
get_args
from
megatron
import
mpu
from
.module
import
MegatronModule
from
megatron.model.enums
import
LayerType
,
AttnMaskType
from
megatron.model.transformer
import
ParallelTransformer
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
,
scaled_init_method_normal
...
...
@@ -43,7 +44,9 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def
get_language_model
(
num_tokentypes
,
add_pooler
,
init_method
=
None
,
scaled_init_method
=
None
):
encoder_attn_mask_type
,
init_method
=
None
,
scaled_init_method
=
None
,
add_decoder
=
False
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
):
"""Build language model and return along with the key to save."""
args
=
get_args
()
...
...
@@ -51,15 +54,18 @@ def get_language_model(num_tokentypes, add_pooler,
init_method
=
init_method_normal
(
args
.
init_method_std
)
if
scaled_init_method
is
None
:
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
# Language model.
args
=
[
init_method
,
scaled_init_method
]
args
=
[
init_method
,
scaled_init_method
,
encoder_attn_mask_type
]
kwargs
=
{}
cls
=
None
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
is_pipeline_last_stage
():
cls
=
TransformerLanguageModel
kwargs
[
'num_tokentypes'
]
=
num_tokentypes
kwargs
[
'add_decoder'
]
=
add_decoder
kwargs
[
'decoder_attn_mask_type'
]
=
decoder_attn_mask_type
kwargs
[
'add_pooler'
]
=
add_pooler
elif
mpu
.
is_pipeline_first_stage
()
and
not
mpu
.
is_pipeline_last_stage
():
cls
=
TransformerLanguageModelFirstStage
...
...
@@ -273,7 +279,10 @@ class TransformerLanguageModelBase(MegatronModule):
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
0
,
add_decoder
=
False
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
add_pooler
=
False
):
super
(
TransformerLanguageModelBase
,
self
).
__init__
()
args
=
get_args
()
...
...
@@ -281,6 +290,9 @@ class TransformerLanguageModelBase(MegatronModule):
self
.
hidden_size
=
args
.
hidden_size
self
.
num_tokentypes
=
num_tokentypes
self
.
init_method
=
init_method
self
.
encoder_attn_mask_type
=
encoder_attn_mask_type
self
.
add_decoder
=
add_decoder
self
.
decoder_attn_mask_type
=
decoder_attn_mask_type
self
.
add_pooler
=
add_pooler
# Embeddings.
...
...
@@ -294,40 +306,83 @@ class TransformerLanguageModelBase(MegatronModule):
self
.
_embedding_key
=
'embedding'
# Transformer.
self
.
transformer
=
ParallelTransformer
(
self
.
init_method
,
output_layer_init_method
)
self
.
_transformer_key
=
'transformer'
self
.
encoder
=
ParallelTransformer
(
self
.
init_method
,
output_layer_init_method
,
self_attn_mask_type
=
self
.
encoder_attn_mask_type
)
self
.
_encoder_key
=
'encoder'
# Decoder
if
self
.
add_decoder
:
assert
args
.
pipeline_model_parallel_size
==
1
,
\
'pipeline parallelism is not supported in the presence of decoder'
self
.
decoder
=
ParallelTransformer
(
self
.
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
decoder
,
self_attn_mask_type
=
self
.
decoder_attn_mask_type
)
self
.
_decoder_key
=
'decoder'
if
mpu
.
is_pipeline_last_stage
():
# Pooler.
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_pooler
:
if
self
.
add_pooler
:
self
.
pooler
=
Pooler
(
self
.
hidden_size
,
self
.
init_method
)
self
.
_pooler_key
=
'pooler'
def
forward
(
self
,
language_model_input
,
attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
):
def
forward
(
self
,
enc_language_model_input
,
enc_attn_mask
,
dec_language_model_input
=
None
,
dec_attn_mask
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
# Embeddings.
if
mpu
.
is_pipeline_first_stage
():
(
input_ids
,
position_ids
)
=
language_model_input
(
input_ids
,
position_ids
)
=
enc_
language_model_input
embedding_output
=
self
.
embedding
(
input_ids
,
position_ids
,
tokentype_ids
=
tokentype_ids
)
transform
er_input
=
embedding_output
encod
er_input
=
embedding_output
else
:
transform
er_input
=
language_model_input
encod
er_input
=
enc_
language_model_input
# Transformer.
transformer_output
=
self
.
transformer
(
transformer_input
,
attention_mask
,
# encoder.
if
enc_hidden_states
is
None
:
encoder_output
=
self
.
encoder
(
encoder_input
,
enc_attn_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
else
:
encoder_output
=
enc_hidden_states
.
to
(
encoder_input
.
dtype
)
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_pooler
:
pooled_output
=
self
.
pooler
(
transformer_output
,
if
mpu
.
is_pipeline_last_stage
():
if
self
.
add_pooler
:
pooled_output
=
self
.
pooler
(
encoder_output
,
pooling_sequence_index
)
return
transformer_output
,
pooled_output
return
transformer_output
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if
not
self
.
add_decoder
or
output_enc_hidden
:
if
self
.
add_pooler
and
mpu
.
is_pipeline_last_stage
():
return
encoder_output
,
pooled_output
else
:
return
encoder_output
# Decoder Embedding
(
dec_input_ids
,
dec_position_ids
)
=
dec_language_model_input
dec_embedding_output
=
self
.
embedding
(
dec_input_ids
,
dec_position_ids
)
# decoder
decoder_output
=
self
.
decoder
(
dec_embedding_output
,
dec_attn_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
)
if
self
.
add_pooler
and
mpu
.
is_pipeline_last_stage
():
return
decoder_output
,
encoder_output
,
pooled_output
else
:
return
decoder_output
,
encoder_output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
...
...
@@ -338,13 +393,18 @@ class TransformerLanguageModelBase(MegatronModule):
state_dict_
[
self
.
_embedding_key
]
\
=
self
.
embedding
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
state_dict_
[
self
.
_
transform
er_key
]
\
=
self
.
transform
er
.
state_dict_for_save_checkpoint
(
state_dict_
[
self
.
_
encod
er_key
]
\
=
self
.
encod
er
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_pooler
:
if
mpu
.
is_pipeline_last_stage
():
if
self
.
add_pooler
:
state_dict_
[
self
.
_pooler_key
]
\
=
self
.
pooler
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
add_decoder
:
state_dict_
[
self
.
_decoder_key
]
\
=
self
.
decoder
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
...
...
@@ -363,23 +423,44 @@ class TransformerLanguageModelBase(MegatronModule):
state_dict_
[
key
]
=
state_dict
[
key
]
self
.
embedding
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
# Transformer.
if
self
.
_transformer_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_transformer_key
]
# Encoder.
if
self
.
_encoder_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_encoder_key
]
# for backward compatibility.
elif
'transformer'
in
state_dict
:
state_dict_
=
state_dict
[
'transformer'
]
else
:
# for backward compatibility.
state_dict_
=
{}
for
key
in
state_dict
.
keys
():
if
'transformer.'
in
key
:
state_dict_
[
key
.
split
(
'transformer.'
)[
1
]]
=
state_dict
[
key
]
self
.
transformer
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
# Pooler.
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_pooler
:
# for backward compatibility.
state_dict_self_attention
=
{}
for
key
in
state_dict_
.
keys
():
if
'.attention.'
in
key
:
state_dict_self_attention
[
key
.
replace
(
".attention."
,
".self_attention."
)]
=
state_dict_
[
key
]
else
:
state_dict_self_attention
[
key
]
=
state_dict_
[
key
]
state_dict_
=
state_dict_self_attention
self
.
encoder
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
if
mpu
.
is_pipeline_last_stage
():
# pooler
if
self
.
add_pooler
:
assert
'pooler'
in
state_dict
,
\
'could not find data for pooler in the checkpoint'
self
.
pooler
.
load_state_dict
(
state_dict
[
self
.
_pooler_key
],
strict
=
strict
)
# decoder
if
self
.
add_decoder
:
assert
'decoder'
in
state_dict
,
\
'could not find data for pooler in the checkpoint'
self
.
decoder
.
load_state_dict
(
state_dict
[
self
.
_decoder_key
],
strict
=
strict
)
class
TransformerLanguageModel
(
TransformerLanguageModelBase
):
...
...
@@ -390,24 +471,37 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
0
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
add_decoder
=
False
,
add_pooler
=
False
):
super
(
TransformerLanguageModel
,
self
).
__init__
(
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
num_tokentypes
,
add_decoder
=
add_decoder
,
decoder_attn_mask_type
=
decoder_attn_mask_type
,
add_pooler
=
add_pooler
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
):
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_attn_mask
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
return
super
(
TransformerLanguageModel
,
self
).
forward
(
(
input_ids
,
position_ids
),
attention_mask
,
(
enc_input_ids
,
enc_position_ids
),
enc_attn_mask
,
dec_language_model_input
=
(
dec_input_ids
,
dec_position_ids
),
dec_attn_mask
=
dec_attn_mask
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
tokentype_ids
=
tokentype_ids
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
pooling_sequence_index
=
pooling_sequence_index
pooling_sequence_index
=
pooling_sequence_index
,
enc_hidden_states
=
enc_hidden_states
,
output_enc_hidden
=
output_enc_hidden
)
...
...
@@ -419,10 +513,12 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
0
):
super
(
TransformerLanguageModelFirstStage
,
self
).
__init__
(
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
...
...
@@ -443,10 +539,12 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
def
__init__
(
self
,
init_method
,
output_layer_init_method
):
output_layer_init_method
,
encoder_attn_mask_type
):
super
(
TransformerLanguageModelIntermediateStage
,
self
).
__init__
(
init_method
,
output_layer_init_method
)
output_layer_init_method
,
encoder_attn_mask_type
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
...
...
@@ -466,10 +564,12 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
add_pooler
=
False
):
super
(
TransformerLanguageModelLastStage
,
self
).
__init__
(
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
add_pooler
=
add_pooler
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
...
...
@@ -480,5 +580,5 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
pooling_sequence_index
=
pooling_sequence_index
pooling_sequence_index
=
pooling_sequence_index
,
)
megatron/model/multiple_choice.py
View file @
58edb19a
...
...
@@ -19,6 +19,7 @@ import torch
from
megatron
import
get_args
,
print_rank_last
from
megatron
import
mpu
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.bert_model
import
bert_extended_attention_mask
,
bert_position_ids
from
megatron.model.language_model
import
get_language_model
from
megatron.model.utils
import
get_linear_layer
...
...
@@ -38,6 +39,7 @@ class MultipleChoiceBase(MegatronModule):
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
num_tokentypes
=
num_tokentypes
,
add_pooler
=
True
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
...
...
megatron/model/realm_model.py
View file @
58edb19a
...
...
@@ -6,6 +6,7 @@ from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoi
from
megatron.model
import
BertModel
from
.module
import
MegatronModule
from
megatron
import
mpu
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
from
megatron.model.language_model
import
get_language_model
...
...
@@ -158,6 +159,7 @@ class IREncoderBertModel(MegatronModule):
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
num_tokentypes
=
num_tokentypes
,
add_pooler
=
True
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
)
...
...
megatron/model/transformer.py
View file @
58edb19a
...
...
@@ -14,7 +14,6 @@
# limitations under the License.
"""Transformer."""
import
math
import
torch
import
torch.nn.functional
as
F
...
...
@@ -23,6 +22,7 @@ from megatron import get_args
from
megatron
import
mpu
from
.module
import
MegatronModule
from
megatron.checkpointing
import
get_checkpoint_version
from
megatron.model.enums
import
AttnMaskType
,
LayerType
,
AttnType
from
megatron.model
import
import_layernorm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
...
...
@@ -65,7 +65,7 @@ class ParallelMLP(MegatronModule):
# Project to 4h.
self
.
dense_h_to_4h
=
mpu
.
ColumnParallelLinear
(
args
.
hidden_size
,
4
*
args
.
hidden_size
,
args
.
ffn_
hidden_size
,
gather_output
=
False
,
init_method
=
init_method
,
skip_bias_add
=
True
)
...
...
@@ -79,7 +79,7 @@ class ParallelMLP(MegatronModule):
# Project back to h.
self
.
dense_4h_to_h
=
mpu
.
RowParallelLinear
(
4
*
args
.
hidden_size
,
args
.
ffn_
hidden_size
,
args
.
hidden_size
,
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
,
...
...
@@ -103,15 +103,18 @@ class ParallelMLP(MegatronModule):
return
output
,
output_bias
class
Parallel
Self
Attention
(
MegatronModule
):
class
ParallelAttention
(
MegatronModule
):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_number
):
super
(
ParallelSelfAttention
,
self
).
__init__
()
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
AttnMaskType
.
padding
):
super
(
ParallelAttention
,
self
).
__init__
()
args
=
get_args
()
self
.
fp16
=
args
.
fp16
...
...
@@ -120,20 +123,38 @@ class ParallelSelfAttention(MegatronModule):
if
self
.
apply_query_key_layer_scaling
:
self
.
attention_softmax_in_fp32
=
True
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
attention_type
=
attention_type
self
.
attn_mask_type
=
attn_mask_type
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
# Per attention head and per partition values.
world_size
=
mpu
.
get_tensor_model_parallel_world_size
()
self
.
hidden_size_per_partition
=
mpu
.
divide
(
args
.
hidde
n_size
,
self
.
hidden_size_per_partition
=
mpu
.
divide
(
projectio
n_size
,
world_size
)
self
.
hidden_size_per_attention_head
=
mpu
.
divide
(
args
.
hidde
n_size
,
args
.
num_attention_heads
)
projectio
n_size
,
args
.
num_attention_heads
)
self
.
num_attention_heads_per_partition
=
mpu
.
divide
(
args
.
num_attention_heads
,
world_size
)
# Strided linear layer.
if
attention_type
==
AttnType
.
self_attn
:
self
.
query_key_value
=
mpu
.
ColumnParallelLinear
(
args
.
hidden_size
,
3
*
args
.
hidden_size
,
3
*
projection_size
,
gather_output
=
False
,
init_method
=
init_method
)
else
:
assert
attention_type
==
AttnType
.
cross_attn
self
.
query
=
mpu
.
ColumnParallelLinear
(
args
.
hidden_size
,
projection_size
,
gather_output
=
False
,
init_method
=
init_method
)
self
.
key_value
=
mpu
.
ColumnParallelLinear
(
args
.
hidden_size
,
2
*
projection_size
,
gather_output
=
False
,
init_method
=
init_method
)
...
...
@@ -145,8 +166,8 @@ class ParallelSelfAttention(MegatronModule):
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
fp16
,
args
.
scaled_upper_triang_masked_softmax_fusion
,
args
.
scaled_
masked_softmax_fusion
,
self
.
attn_mask_type
,
args
.
masked_softmax_fusion
,
attention_mask_func
,
self
.
attention_softmax_in_fp32
,
coeff
)
...
...
@@ -158,7 +179,7 @@ class ParallelSelfAttention(MegatronModule):
# Output.
self
.
dense
=
mpu
.
RowParallelLinear
(
args
.
hidde
n_size
,
projectio
n_size
,
args
.
hidden_size
,
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
,
...
...
@@ -195,13 +216,14 @@ class ParallelSelfAttention(MegatronModule):
return
mixed_layer
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
get_key_value
=
False
,
encoder_output
=
None
):
# hidden_states: [sq, b, h]
# =====================
# Query, Key, and Value
# =====================
if
self
.
attention_type
==
AttnType
.
self_attn
:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
...
...
@@ -224,6 +246,36 @@ class ParallelSelfAttention(MegatronModule):
(
query_layer
,
key_layer
,
value_layer
)
=
mpu
.
split_tensor_along_last_dim
(
mixed_x_layer
,
3
)
else
:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer
,
_
=
self
.
key_value
(
encoder_output
)
checkpoint_version
=
get_checkpoint_version
()
if
checkpoint_version
is
not
None
:
if
checkpoint_version
==
0
:
# [s, b, (2 * np * hn)] --> [s, b, (np * 2 * hn)]
mixed_kv_layer
=
self
.
_transpose_last_dim
(
mixed_kv_layer
,
2
,
True
)
elif
checkpoint_version
==
1.0
:
# [s, b, (np * hn * 2)] --> [s, b, (np * 2 * hn)]
mixed_kv_layer
=
self
.
_transpose_last_dim
(
mixed_kv_layer
,
2
,
False
)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape
=
mixed_kv_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
2
*
self
.
hidden_size_per_attention_head
)
mixed_kv_layer
=
mixed_kv_layer
.
view
(
*
new_tensor_shape
)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(
key_layer
,
value_layer
)
=
mpu
.
split_tensor_along_last_dim
(
mixed_kv_layer
,
2
)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer
,
_
=
self
.
query
(
hidden_states
)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape
=
query_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
)
query_layer
=
query_layer
.
view
(
*
new_tensor_shape
)
# ==================================
# Adjust key and value for inference
...
...
@@ -251,6 +303,7 @@ class ParallelSelfAttention(MegatronModule):
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer
=
query_layer
.
view
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer
=
key_layer
.
view
(
output_size
[
3
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
...
...
@@ -377,15 +430,18 @@ def bias_dropout_add_fused_inference(x, bias, residual, prob):
class
ParallelTransformerLayer
(
MegatronModule
):
"""A single transformer layer.
Transform
or
e layer takes input with size [b, s, h] and returns an
Transforme
r
layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_number
):
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
):
args
=
get_args
()
super
(
ParallelTransformerLayer
,
self
).
__init__
()
self
.
layer_number
=
layer_number
self
.
layer_type
=
layer_type
self
.
apply_residual_connection_post_layernorm
\
=
args
.
apply_residual_connection_post_layernorm
...
...
@@ -397,30 +453,45 @@ class ParallelTransformerLayer(MegatronModule):
eps
=
args
.
layernorm_epsilon
)
# Self attention.
self
.
attention
=
ParallelSelfAttention
(
init_method
,
self
.
self_attention
=
ParallelAttention
(
init_method
,
output_layer_init_method
,
layer_number
)
layer_number
,
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
self_attn_mask_type
)
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
# Layernorm on the
input data.
# Layernorm on the
attention output
self
.
post_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
if
self
.
layer_type
==
LayerType
.
decoder
:
self
.
inter_attention
=
ParallelAttention
(
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
cross_attn
)
# Layernorm on the attention output.
self
.
post_inter_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
# MLP
self
.
mlp
=
ParallelMLP
(
init_method
,
output_layer_init_method
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
layer_past
=
None
,
get_key_value
=
False
):
# hidden_states: [b, s, h]
# Layer norm at the begining of the transformer layer.
# Layer norm at the begin
n
ing of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
attention_output
,
attention_bias
=
\
self
.
attention
(
layernorm_output
,
self
.
self_
attention
(
layernorm_output
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
...
...
@@ -457,6 +528,28 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
if
self
.
layer_type
==
LayerType
.
decoder
:
attention_output
,
attention_bias
=
\
self
.
inter_attention
(
layernorm_output
,
enc_dec_attn_mask
,
encoder_output
=
encoder_output
)
# residual connection
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
# Layer norm post the decoder attention
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
# MLP.
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
...
...
@@ -483,7 +576,9 @@ class ParallelTransformerLayer(MegatronModule):
class
ParallelTransformer
(
MegatronModule
):
"""Transformer class."""
def
__init__
(
self
,
init_method
,
output_layer_init_method
):
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
):
super
(
ParallelTransformer
,
self
).
__init__
()
args
=
get_args
()
...
...
@@ -501,7 +596,11 @@ class ParallelTransformer(MegatronModule):
# Transformer layers.
def
build_layer
(
layer_number
):
return
ParallelTransformerLayer
(
init_method
,
output_layer_init_method
,
layer_number
)
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
)
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
...
...
@@ -516,14 +615,18 @@ class ParallelTransformer(MegatronModule):
def
_get_layer
(
self
,
layer_number
):
return
self
.
layers
[
layer_number
]
def
_checkpointed_forward
(
self
,
hidden_states
,
attention_mask
):
def
_checkpointed_forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
):
"""Forward method with activation checkpointing."""
def
custom
(
start
,
end
):
def
custom_forward
(
*
inputs
):
x_
=
inputs
[
0
]
attention_mask
=
inputs
[
1
]
encoder_output
=
inputs
[
2
]
enc_dec_attn_mask
=
inputs
[
3
]
for
index
in
range
(
start
,
end
):
layer
=
self
.
_get_layer
(
index
)
x_
=
layer
(
x_
,
inputs
[
1
]
)
x_
=
layer
(
x_
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
return
x_
return
custom_forward
...
...
@@ -533,13 +636,13 @@ class ParallelTransformer(MegatronModule):
while
l
<
self
.
num_layers
:
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
self
.
checkpoint_num_layers
),
hidden_states
,
attention_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
l
+=
self
.
checkpoint_num_layers
return
hidden_states
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
get_key_value
=
False
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
):
# Checks.
if
layer_past
is
not
None
:
...
...
@@ -560,9 +663,14 @@ class ParallelTransformer(MegatronModule):
else
:
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
if
encoder_output
is
not
None
:
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
checkpoint_activations
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
attention_mask
)
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
if
get_key_value
:
presents
=
[]
...
...
@@ -573,6 +681,8 @@ class ParallelTransformer(MegatronModule):
past
=
layer_past
[
index
]
hidden_states
=
layer
(
hidden_states
,
attention_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
layer_past
=
past
,
get_key_value
=
get_key_value
)
if
get_key_value
:
...
...
pretrain_gpt.py
View file @
58edb19a
...
...
@@ -144,5 +144,4 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
,
'scaled_upper_triang_masked_softmax_fusion'
:
True
})
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment