Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
4b506832
Commit
4b506832
authored
Jan 08, 2021
by
Vijay Korthikanti
Browse files
decoder support in transformers
parent
f5eac3d1
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
520 additions
and
261 deletions
+520
-261
megatron/arguments.py
megatron/arguments.py
+31
-17
megatron/fused_kernels/scaled_masked_softmax.h
megatron/fused_kernels/scaled_masked_softmax.h
+43
-47
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
+13
-11
megatron/model/enums.py
megatron/model/enums.py
+28
-0
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+31
-21
megatron/model/gpt2_model.py
megatron/model/gpt2_model.py
+2
-0
megatron/model/language_model.py
megatron/model/language_model.py
+172
-66
megatron/model/transformer.py
megatron/model/transformer.py
+199
-97
pretrain_gpt2.py
pretrain_gpt2.py
+1
-2
No files found.
megatron/arguments.py
View file @
4b506832
...
@@ -164,6 +164,20 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -164,6 +164,20 @@ def parse_args(extra_args_provider=None, defaults={},
_check_arg_is_not_none
(
args
,
req_arg
)
_check_arg_is_not_none
(
args
,
req_arg
)
# Checks.
# Checks.
if
args
.
ffn_hidden_size
is
None
:
args
.
ffn_hidden_size
=
4
*
args
.
hidden_size
if
args
.
kv_channels
is
None
:
assert
args
.
hidden_size
%
args
.
num_attention_heads
==
0
args
.
kv_channels
=
args
.
hidden_size
//
args
.
num_attention_heads
if
args
.
seq_length
is
not
None
:
assert
args
.
encoder_seq_length
is
None
args
.
encoder_seq_length
=
args
.
seq_length
else
:
assert
args
.
encoder_seq_length
is
not
None
args
.
seq_length
=
args
.
encoder_seq_length
assert
args
.
hidden_size
%
args
.
num_attention_heads
==
0
assert
args
.
hidden_size
%
args
.
num_attention_heads
==
0
if
args
.
seq_length
is
not
None
:
if
args
.
seq_length
is
not
None
:
assert
args
.
max_position_embeddings
>=
args
.
seq_length
assert
args
.
max_position_embeddings
>=
args
.
seq_length
...
@@ -183,15 +197,10 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -183,15 +197,10 @@ def parse_args(extra_args_provider=None, defaults={},
'for distribute-checkpointed-activations to work you '
\
'for distribute-checkpointed-activations to work you '
\
'need to enable checkpoint-activations'
'need to enable checkpoint-activations'
if
args
.
scaled_masked_softmax_fusion
:
# Load
scaled_masked_softmax_fusion
_kernels
if
args
.
scaled_upper_triang_
masked_softmax_fusion
:
if
args
.
masked_softmax_fusion
:
fused_kernels
.
load_scaled_upper_triang_masked_softmax_fusion_kernel
()
fused_kernels
.
load_scaled_upper_triang_masked_softmax_fusion_kernel
()
else
:
fused_kernels
.
load_scaled_masked_softmax_fusion_kernel
()
fused_kernels
.
load_scaled_masked_softmax_fusion_kernel
()
else
:
# This argument will eventually go away, for now make sure it is off
# if scaled_masked_softmax_fusion is off.
args
.
scaled_upper_triang_masked_softmax_fusion
=
False
# Load mixed precision fused layer norm.
# Load mixed precision fused layer norm.
if
args
.
fp32_residual_connection
:
if
args
.
fp32_residual_connection
:
...
@@ -227,8 +236,14 @@ def _add_network_size_args(parser):
...
@@ -227,8 +236,14 @@ def _add_network_size_args(parser):
help
=
'Number of transformer layers.'
)
help
=
'Number of transformer layers.'
)
group
.
add_argument
(
'--hidden-size'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--hidden-size'
,
type
=
int
,
default
=
None
,
help
=
'Tansformer hidden size.'
)
help
=
'Tansformer hidden size.'
)
group
.
add_argument
(
'--ffn-hidden-size'
,
type
=
int
,
default
=
None
,
help
=
'Transformer Feed-Forward Network hidden size. This is set to 4*hidden-size if not '
'provided'
)
group
.
add_argument
(
'--num-attention-heads'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--num-attention-heads'
,
type
=
int
,
default
=
None
,
help
=
'Number of transformer attention heads.'
)
help
=
'Number of transformer attention heads.'
)
group
.
add_argument
(
'--kv-channels'
,
type
=
int
,
default
=
None
,
help
=
'Projection weights dimension in multi-head attention. '
'This is set to args.hidden_size // args.num_attention_heads if not provided.'
)
group
.
add_argument
(
'--max-position-embeddings'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--max-position-embeddings'
,
type
=
int
,
default
=
None
,
help
=
'Maximum number of position embeddings to use. '
help
=
'Maximum number of position embeddings to use. '
'This is the size of position embedding.'
)
'This is the size of position embedding.'
)
...
@@ -330,16 +345,11 @@ def _add_training_args(parser):
...
@@ -330,16 +345,11 @@ def _add_training_args(parser):
help
=
'Exit the program after this many minutes.'
)
help
=
'Exit the program after this many minutes.'
)
group
.
add_argument
(
'--tensorboard-dir'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--tensorboard-dir'
,
type
=
str
,
default
=
None
,
help
=
'Write TensorBoard logs to this directory.'
)
help
=
'Write TensorBoard logs to this directory.'
)
group
.
add_argument
(
'--no-
scaled-
masked-softmax-fusion'
,
group
.
add_argument
(
'--no-masked-softmax-fusion'
,
action
=
'store_false'
,
action
=
'store_false'
,
help
=
'Disable fusion of query_key_value scaling, '
help
=
'Disable fusion of query_key_value scaling, '
'masking, and softmax.'
,
'masking, and softmax.'
,
dest
=
'scaled_masked_softmax_fusion'
)
dest
=
'masked_softmax_fusion'
)
group
.
add_argument
(
'--scaled-upper-triang-masked-softmax-fusion'
,
type
=
bool
,
help
=
'Use upper triangular version of fused '
'scale, mask, softmax fusion kernel (default for GPT). '
'- DEPRECATED'
)
group
.
add_argument
(
'--no-bias-gelu-fusion'
,
action
=
'store_false'
,
group
.
add_argument
(
'--no-bias-gelu-fusion'
,
action
=
'store_false'
,
help
=
'Disable bias and gelu fusion.'
,
help
=
'Disable bias and gelu fusion.'
,
dest
=
'bias_gelu_fusion'
)
dest
=
'bias_gelu_fusion'
)
...
@@ -530,6 +540,10 @@ def _add_data_args(parser):
...
@@ -530,6 +540,10 @@ def _add_data_args(parser):
help
=
'Path to the BPE merge file.'
)
help
=
'Path to the BPE merge file.'
)
group
.
add_argument
(
'--seq-length'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--seq-length'
,
type
=
int
,
default
=
None
,
help
=
"Maximum sequence length to process."
)
help
=
"Maximum sequence length to process."
)
group
.
add_argument
(
'--encoder-seq-length'
,
type
=
int
,
default
=
None
,
help
=
"Maximum encoder sequence length to process."
)
group
.
add_argument
(
'--decoder-seq-length'
,
type
=
int
,
default
=
None
,
help
=
"Maximum decoder sequence length to process."
)
group
.
add_argument
(
'--mask-prob'
,
type
=
float
,
default
=
0.15
,
group
.
add_argument
(
'--mask-prob'
,
type
=
float
,
default
=
0.15
,
help
=
'Probability of replacing a token with mask.'
)
help
=
'Probability of replacing a token with mask.'
)
group
.
add_argument
(
'--short-seq-prob'
,
type
=
float
,
default
=
0.1
,
group
.
add_argument
(
'--short-seq-prob'
,
type
=
float
,
default
=
0.1
,
...
...
megatron/fused_kernels/scaled_masked_softmax.h
View file @
4b506832
...
@@ -81,7 +81,6 @@ __global__ void scaled_masked_softmax_warp_forward(
...
@@ -81,7 +81,6 @@ __global__ void scaled_masked_softmax_warp_forward(
const
uint8_t
*
mask
,
const
uint8_t
*
mask
,
const
acc_t
scale
,
const
acc_t
scale
,
int
micro_batch_size
,
int
micro_batch_size
,
int
stride
,
int
element_count
,
int
element_count
,
int
pad_batches
)
int
pad_batches
)
{
{
...
@@ -111,9 +110,9 @@ __global__ void scaled_masked_softmax_warp_forward(
...
@@ -111,9 +110,9 @@ __global__ void scaled_masked_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
stride
+
local_idx
;
src
+=
first_batch
*
element_count
+
local_idx
;
dst
+=
first_batch
*
stride
+
local_idx
;
dst
+=
first_batch
*
element_count
+
local_idx
;
mask
+=
pad_first_batch
*
stride
+
local_idx
;
mask
+=
pad_first_batch
*
element_count
+
local_idx
;
// load data from global memory
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
...
@@ -185,7 +184,6 @@ __global__ void scaled_masked_softmax_warp_backward(
...
@@ -185,7 +184,6 @@ __global__ void scaled_masked_softmax_warp_backward(
const
input_t
*
output
,
const
input_t
*
output
,
acc_t
scale
,
acc_t
scale
,
int
micro_batch_size
,
int
micro_batch_size
,
int
stride
,
int
element_count
)
int
element_count
)
{
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
...
@@ -209,7 +207,7 @@ __global__ void scaled_masked_softmax_warp_backward(
...
@@ -209,7 +207,7 @@ __global__ void scaled_masked_softmax_warp_backward(
int
local_idx
=
threadIdx
.
x
;
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
stride
+
local_idx
;
int
thread_offset
=
first_batch
*
element_count
+
local_idx
;
grad
+=
thread_offset
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
gradInput
+=
thread_offset
;
...
@@ -277,20 +275,19 @@ void dispatch_scaled_masked_softmax_forward(
...
@@ -277,20 +275,19 @@ void dispatch_scaled_masked_softmax_forward(
const
input_t
*
src
,
const
input_t
*
src
,
const
uint8_t
*
mask
,
const
uint8_t
*
mask
,
const
input_t
scale
,
const
input_t
scale
,
int
softmax_elements
,
int
query_seq_len
,
int
softmax_elements_stride
,
int
key_seq_len
,
int
batches
,
int
batches
,
int
attn_heads
,
int
attn_heads
,
int
pad_batches
)
int
pad_batches
)
{
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
2048
);
if
(
softmax_elements
==
0
)
{
if
(
key_seq_len
==
0
)
{
return
;
return
;
}
else
{
}
else
{
int
log2_elements
=
log2_ceil
(
softmax_elements
);
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
seq_len
=
softmax_elements
;
int
batch_count
=
batches
*
attn_heads
*
query_seq_len
;
int
batch_count
=
batches
*
attn_heads
*
seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
...
@@ -303,58 +300,58 @@ void dispatch_scaled_masked_softmax_forward(
...
@@ -303,58 +300,58 @@ void dispatch_scaled_masked_softmax_forward(
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
seq_len
%
batches_per_block
==
0
);
TORCH_INTERNAL_ASSERT
(
query_
seq_len
%
batches_per_block
==
0
);
dim3
blocks
(
seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
blocks
(
query_
seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
switch
(
log2_elements
)
{
case
0
:
// 1
case
0
:
// 1
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
0
>
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
break
;
case
1
:
// 2
case
1
:
// 2
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
>
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
break
;
case
2
:
// 4
case
2
:
// 4
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
>
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
break
;
case
3
:
// 8
case
3
:
// 8
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
3
>
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
break
;
case
4
:
// 16
case
4
:
// 16
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
4
>
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
break
;
case
5
:
// 32
case
5
:
// 32
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
5
>
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
break
;
case
6
:
// 64
case
6
:
// 64
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
6
>
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
break
;
case
7
:
// 128
case
7
:
// 128
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
7
>
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
break
;
case
8
:
// 256
case
8
:
// 256
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
8
>
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
break
;
case
9
:
// 512
case
9
:
// 512
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
9
>
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
break
;
case
10
:
// 1024
case
10
:
// 1024
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
10
>
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
break
;
case
11
:
// 2048
case
11
:
// 2048
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
11
>
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
break
;
default:
default:
break
;
break
;
...
@@ -368,19 +365,18 @@ void dispatch_scaled_masked_softmax_backward(
...
@@ -368,19 +365,18 @@ void dispatch_scaled_masked_softmax_backward(
input_t
*
grad
,
input_t
*
grad
,
const
input_t
*
output
,
const
input_t
*
output
,
const
acc_t
scale
,
const
acc_t
scale
,
int
softmax_elements
,
int
query_seq_len
,
int
softmax_elements_stride
,
int
key_seq_len
,
int
batches
,
int
batches
,
int
attn_heads
)
int
attn_heads
)
{
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
2048
);
if
(
softmax_elements
==
0
)
{
if
(
key_seq_len
==
0
)
{
return
;
return
;
}
else
{
}
else
{
int
log2_elements
=
log2_ceil
(
softmax_elements
);
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
seq_len
=
softmax_elements
;
int
batch_count
=
batches
*
attn_heads
*
query_seq_len
;
int
batch_count
=
batches
*
attn_heads
*
seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
...
@@ -399,51 +395,51 @@ void dispatch_scaled_masked_softmax_backward(
...
@@ -399,51 +395,51 @@ void dispatch_scaled_masked_softmax_backward(
switch
(
log2_elements
)
{
switch
(
log2_elements
)
{
case
0
:
// 1
case
0
:
// 1
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
0
>
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
break
;
case
1
:
// 2
case
1
:
// 2
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
>
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
break
;
case
2
:
// 4
case
2
:
// 4
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
>
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
break
;
case
3
:
// 8
case
3
:
// 8
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
3
>
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
break
;
case
4
:
// 16
case
4
:
// 16
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
4
>
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
break
;
case
5
:
// 32
case
5
:
// 32
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
5
>
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
break
;
case
6
:
// 64
case
6
:
// 64
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
6
>
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
break
;
case
7
:
// 128
case
7
:
// 128
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
7
>
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
break
;
case
8
:
// 256
case
8
:
// 256
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
8
>
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
break
;
case
9
:
// 512
case
9
:
// 512
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
9
>
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
break
;
case
10
:
// 1024
case
10
:
// 1024
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
10
>
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
break
;
case
11
:
// 2048
case
11
:
// 2048
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
11
>
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
break
;
default:
default:
break
;
break
;
...
...
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
View file @
4b506832
...
@@ -37,17 +37,19 @@ torch::Tensor fwd_cuda(
...
@@ -37,17 +37,19 @@ torch::Tensor fwd_cuda(
const
int
batches
=
input
.
size
(
0
);
const
int
batches
=
input
.
size
(
0
);
const
int
pad_batches
=
mask
.
size
(
0
);
const
int
pad_batches
=
mask
.
size
(
0
);
const
int
attn_heads
=
input
.
size
(
1
);
const
int
attn_heads
=
input
.
size
(
1
);
const
int
seq_len
=
input
.
size
(
2
);
const
int
query_seq_len
=
input
.
size
(
2
);
TORCH_INTERNAL_ASSERT
(
seq_len
<=
2048
);
const
int
key_seq_len
=
input
.
size
(
3
);
TORCH_INTERNAL_ASSERT
(
key_seq_len
<=
2048
);
TORCH_INTERNAL_ASSERT
(
query_seq_len
>
1
);
TORCH_INTERNAL_ASSERT
(
pad_batches
==
1
||
pad_batches
==
batches
);
TORCH_INTERNAL_ASSERT
(
pad_batches
==
1
||
pad_batches
==
batches
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
1
)
==
1
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
1
)
==
1
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
2
)
==
seq_len
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
2
)
==
query_
seq_len
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
3
)
==
seq_len
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
3
)
==
key_
seq_len
);
// Output
// Output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
torch
::
Tensor
softmax_results
=
torch
::
Tensor
softmax_results
=
torch
::
empty
({
batches
,
attn_heads
,
seq_len
,
seq_len
},
act_options
);
torch
::
empty
({
batches
,
attn_heads
,
query_
seq_len
,
key_
seq_len
},
act_options
);
// Softmax Intermediate Result Ptr
// Softmax Intermediate Result Ptr
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
...
@@ -59,8 +61,8 @@ torch::Tensor fwd_cuda(
...
@@ -59,8 +61,8 @@ torch::Tensor fwd_cuda(
reinterpret_cast
<
const
half
*>
(
input_ptr
),
reinterpret_cast
<
const
half
*>
(
input_ptr
),
reinterpret_cast
<
const
uint8_t
*>
(
mask_ptr
),
reinterpret_cast
<
const
uint8_t
*>
(
mask_ptr
),
scale_factor
,
scale_factor
,
seq_len
,
query_
seq_len
,
seq_len
,
key_
seq_len
,
batches
,
batches
,
attn_heads
,
attn_heads
,
pad_batches
);
pad_batches
);
...
@@ -78,8 +80,8 @@ torch::Tensor bwd_cuda(
...
@@ -78,8 +80,8 @@ torch::Tensor bwd_cuda(
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
output_grads
.
size
(
0
);
const
int
batches
=
output_grads
.
size
(
0
);
const
int
attn_heads
=
output_grads
.
size
(
1
);
const
int
attn_heads
=
output_grads
.
size
(
1
);
const
int
seq_len
=
output_grads
.
size
(
2
);
const
int
query_
seq_len
=
output_grads
.
size
(
2
);
TORCH_INTERNAL_ASSERT
(
output_grads
.
size
(
2
)
=
=
output_grads
.
size
(
3
)
)
;
const
int
key_seq_len
=
output_grads
.
size
(
3
);
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
...
@@ -89,8 +91,8 @@ torch::Tensor bwd_cuda(
...
@@ -89,8 +91,8 @@ torch::Tensor bwd_cuda(
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
scale_factor
,
seq_len
,
query_
seq_len
,
seq_len
,
key_
seq_len
,
batches
,
batches
,
attn_heads
);
attn_heads
);
...
...
megatron/model/enums.py
0 → 100644
View file @
4b506832
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
enum
class
LayerType
(
enum
.
Enum
):
encoder
=
1
decoder
=
2
class
AttnType
(
enum
.
Enum
):
self_attn
=
1
cross_attn
=
2
class
AttnMaskType
(
enum
.
Enum
):
padding
=
1
causal
=
2
megatron/model/fused_softmax.py
View file @
4b506832
...
@@ -14,8 +14,10 @@
...
@@ -14,8 +14,10 @@
# limitations under the License.
# limitations under the License.
import
torch
import
torch
from
megatron.model.enums
import
AttnMaskType
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
)
:
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
):
"""
"""
Fused operation which performs following three operations in sequence
Fused operation which performs following three operations in sequence
1. Scale the tensor.
1. Scale the tensor.
...
@@ -43,7 +45,8 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function) :
...
@@ -43,7 +45,8 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function) :
scale_t
[
0
])
scale_t
[
0
])
return
input_grads
,
None
return
input_grads
,
None
class
ScaledMaskedSoftmax
(
torch
.
autograd
.
Function
)
:
class
ScaledMaskedSoftmax
(
torch
.
autograd
.
Function
):
"""
"""
Fused operation which performs following three operations in sequence
Fused operation which performs following three operations in sequence
1. Scale the tensor.
1. Scale the tensor.
...
@@ -71,24 +74,25 @@ class ScaledMaskedSoftmax(torch.autograd.Function) :
...
@@ -71,24 +74,25 @@ class ScaledMaskedSoftmax(torch.autograd.Function) :
scale_t
[
0
])
scale_t
[
0
])
return
input_grads
,
None
,
None
return
input_grads
,
None
,
None
class
FusedScaleMaskSoftmax
(
torch
.
nn
.
Module
):
class
FusedScaleMaskSoftmax
(
torch
.
nn
.
Module
):
"""
"""
fused operation: scaling + mask + softmax
fused operation: scaling + mask + softmax
Arguments:
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_fp16: flag to indicate if input in fp16 data format.
upper_triang_mask: if true, apply upper triangular masking.
attn_mask_type: attention mask type (pad or causal)
(used in gpt family networks)
mask_func: mask function to be applied.
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
scale: scaling factor used in input tensor scaling.
"""
"""
def
__init__
(
self
,
input_in_fp16
,
upper_triang_mask_fusion
,
def
__init__
(
self
,
input_in_fp16
,
attn_mask_type
,
general_mask_fusion
,
mask_func
,
softmax_in_fp32
,
scale
):
scaled_masked_softmax_fusion
,
mask_func
,
softmax_in_fp32
,
scale
):
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_fp16
=
input_in_fp16
self
.
upper_triang_mask_fusion
=
upper_triang_mask_fusion
self
.
attn_mask_type
=
attn_mask_type
self
.
general_mask_fusion
=
general_mask
_fusion
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax
_fusion
self
.
mask_func
=
mask_func
self
.
mask_func
=
mask_func
self
.
softmax_in_fp32
=
softmax_in_fp32
self
.
softmax_in_fp32
=
softmax_in_fp32
self
.
scale
=
scale
self
.
scale
=
scale
...
@@ -97,20 +101,26 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -97,20 +101,26 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
'softmax should be in fp32 when scaled'
'softmax should be in fp32 when scaled'
def
forward
(
self
,
input
,
mask
):
def
forward
(
self
,
input
,
mask
):
# [b, np, s, s]
# [b, np, s
q
, s
k
]
data_size
=
input
.
size
()
data_size
=
input
.
size
()
query_seq_len
=
data_size
[
-
2
]
key_seq_len
=
data_size
[
-
1
]
assert
input
.
dim
()
==
4
assert
input
.
dim
()
==
4
# invoke custom kernel
# invoke custom kernel
if
self
.
input_in_fp16
and
data_size
[
-
1
]
<=
2048
and
\
if
self
.
input_in_fp16
and
key_seq_len
<=
2048
and
\
(
self
.
upper_triang_mask_fusion
or
self
.
general_mask_fusion
)
and
\
query_seq_len
%
4
==
0
and
self
.
scaled_masked_softmax_fusion
:
input
.
size
()[
2
]
==
input
.
size
()[
3
]:
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
upper_triang_mask_fusion
:
input
=
input
.
view
(
-
1
,
data_size
[
2
],
data_size
[
3
])
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
assert
query_seq_len
==
key_seq_len
,
\
"causal mask is only for self attention"
input
=
input
.
view
(
-
1
,
query_seq_len
,
key_seq_len
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
probs
=
probs
.
view
(
*
data_size
)
probs
=
probs
.
view
(
*
data_size
)
else
:
else
:
assert
self
.
attn_mask_type
==
AttnMaskType
.
padding
probs
=
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
probs
=
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
else
:
else
:
if
self
.
input_in_fp16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_fp16
and
self
.
softmax_in_fp32
:
...
...
megatron/model/gpt2_model.py
View file @
4b506832
...
@@ -21,6 +21,7 @@ from megatron import get_args
...
@@ -21,6 +21,7 @@ from megatron import get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
.module
import
MegatronModule
from
.module
import
MegatronModule
from
.enums
import
AttnMaskType
from
.language_model
import
parallel_lm_logits
from
.language_model
import
parallel_lm_logits
from
.language_model
import
get_language_model
from
.language_model
import
get_language_model
from
.utils
import
init_method_normal
from
.utils
import
init_method_normal
...
@@ -75,6 +76,7 @@ class GPT2ModelBase(MegatronModule):
...
@@ -75,6 +76,7 @@ class GPT2ModelBase(MegatronModule):
attention_mask_func
=
gpt2_attention_mask_func
,
attention_mask_func
=
gpt2_attention_mask_func
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
False
,
add_pooler
=
False
,
self_attn_mask_type
=
AttnMaskType
.
causal
,
init_method
=
init_method_normal
(
args
.
init_method_std
),
init_method
=
init_method_normal
(
args
.
init_method_std
),
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
args
.
num_layers
))
...
...
megatron/model/language_model.py
View file @
4b506832
...
@@ -21,6 +21,7 @@ import torch.nn.functional as F
...
@@ -21,6 +21,7 @@ import torch.nn.functional as F
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
.module
import
MegatronModule
from
.module
import
MegatronModule
from
megatron.model.enums
import
LayerType
,
AttnMaskType
from
megatron.model.transformer
import
ParallelTransformer
from
megatron.model.transformer
import
ParallelTransformer
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
,
scaled_init_method_normal
from
megatron.model.utils
import
init_method_normal
,
scaled_init_method_normal
...
@@ -43,7 +44,9 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
...
@@ -43,7 +44,9 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def
get_language_model
(
attention_mask_func
,
num_tokentypes
,
add_pooler
,
def
get_language_model
(
attention_mask_func
,
num_tokentypes
,
add_pooler
,
init_method
=
None
,
scaled_init_method
=
None
):
add_decoder
=
False
,
init_method
=
None
,
scaled_init_method
=
None
,
self_attn_mask_type
=
AttnMaskType
.
padding
):
"""Build language model and return along with the key to save."""
"""Build language model and return along with the key to save."""
args
=
get_args
()
args
=
get_args
()
...
@@ -51,7 +54,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
...
@@ -51,7 +54,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method
=
init_method_normal
(
args
.
init_method_std
)
init_method
=
init_method_normal
(
args
.
init_method_std
)
if
scaled_init_method
is
None
:
if
scaled_init_method
is
None
:
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
# Language model.
# Language model.
args
=
[
attention_mask_func
,
init_method
,
scaled_init_method
]
args
=
[
attention_mask_func
,
init_method
,
scaled_init_method
]
...
@@ -60,6 +64,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
...
@@ -60,6 +64,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
is_pipeline_last_stage
():
cls
=
TransformerLanguageModel
cls
=
TransformerLanguageModel
kwargs
[
'num_tokentypes'
]
=
num_tokentypes
kwargs
[
'num_tokentypes'
]
=
num_tokentypes
kwargs
[
'self_attn_mask_type'
]
=
self_attn_mask_type
kwargs
[
'add_decoder'
]
=
add_decoder
kwargs
[
'add_pooler'
]
=
add_pooler
kwargs
[
'add_pooler'
]
=
add_pooler
elif
mpu
.
is_pipeline_first_stage
()
and
not
mpu
.
is_pipeline_last_stage
():
elif
mpu
.
is_pipeline_first_stage
()
and
not
mpu
.
is_pipeline_last_stage
():
cls
=
TransformerLanguageModelFirstStage
cls
=
TransformerLanguageModelFirstStage
...
@@ -186,8 +192,6 @@ class Embedding(MegatronModule):
...
@@ -186,8 +192,6 @@ class Embedding(MegatronModule):
if
tokentype_ids
is
not
None
:
if
tokentype_ids
is
not
None
:
assert
self
.
tokentype_embeddings
is
not
None
assert
self
.
tokentype_embeddings
is
not
None
embeddings
=
embeddings
+
self
.
tokentype_embeddings
(
tokentype_ids
)
embeddings
=
embeddings
+
self
.
tokentype_embeddings
(
tokentype_ids
)
else
:
assert
self
.
tokentype_embeddings
is
None
# Dropout.
# Dropout.
embeddings
=
self
.
embedding_dropout
(
embeddings
)
embeddings
=
self
.
embedding_dropout
(
embeddings
)
...
@@ -281,6 +285,8 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -281,6 +285,8 @@ class TransformerLanguageModelBase(MegatronModule):
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
num_tokentypes
=
0
,
num_tokentypes
=
0
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
add_decoder
=
False
,
add_pooler
=
False
):
add_pooler
=
False
):
super
(
TransformerLanguageModelBase
,
self
).
__init__
()
super
(
TransformerLanguageModelBase
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
...
@@ -288,6 +294,8 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -288,6 +294,8 @@ class TransformerLanguageModelBase(MegatronModule):
self
.
hidden_size
=
args
.
hidden_size
self
.
hidden_size
=
args
.
hidden_size
self
.
num_tokentypes
=
num_tokentypes
self
.
num_tokentypes
=
num_tokentypes
self
.
init_method
=
init_method
self
.
init_method
=
init_method
self
.
self_attn_mask_type
=
self_attn_mask_type
self
.
add_decoder
=
add_decoder
self
.
add_pooler
=
add_pooler
self
.
add_pooler
=
add_pooler
# Embeddings.
# Embeddings.
...
@@ -301,41 +309,87 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -301,41 +309,87 @@ class TransformerLanguageModelBase(MegatronModule):
self
.
_embedding_key
=
'embedding'
self
.
_embedding_key
=
'embedding'
# Transformer.
# Transformer.
self
.
transformer
=
ParallelTransformer
(
self
.
encoder
=
ParallelTransformer
(
attention_mask_func
,
self
.
init_method
,
attention_mask_func
,
output_layer_init_method
)
self
.
init_method
,
self
.
_transformer_key
=
'transformer'
output_layer_init_method
,
self_attn_mask_type
=
self_attn_mask_type
)
self
.
_encoder_key
=
'encoder'
# assuming pooler and decoder are in the last stage
# of the pipeline(to be revised)
if
mpu
.
is_pipeline_last_stage
():
# decoder
if
self
.
add_decoder
:
self
.
decoder
=
ParallelTransformer
(
attention_mask_func
,
self
.
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
decoder
,
self_attn_mask_type
=
AttnMaskType
.
causal
)
self
.
_decoder_key
=
'decoder'
# Pooler.
# Pooler.
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_pooler
:
if
self
.
add_pooler
:
self
.
pooler
=
Pooler
(
self
.
hidden_size
,
self
.
init_method
)
self
.
pooler
=
Pooler
(
self
.
hidden_size
,
self
.
init_method
)
self
.
_pooler_key
=
'pooler'
self
.
_pooler_key
=
'pooler'
def
forward
(
self
,
language_model_input
,
attention_mask
,
def
forward
(
self
,
enc_language_model_input
,
enc_attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
dec_language_model_input
=
None
,
dec_attn_mask
=
None
,
pooling_sequence_index
=
0
):
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
# Embeddings.
# Embeddings.
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
(
input_ids
,
position_ids
)
=
language_model_input
(
input_ids
,
position_ids
)
=
enc_
language_model_input
embedding_output
=
self
.
embedding
(
input_ids
,
position_ids
,
embedding_output
=
self
.
embedding
(
input_ids
,
position_ids
,
tokentype_ids
=
tokentype_ids
)
tokentype_ids
=
tokentype_ids
)
transform
er_input
=
embedding_output
encod
er_input
=
embedding_output
else
:
else
:
transform
er_input
=
language_model_input
encod
er_input
=
enc_
language_model_input
# Transformer.
# encoder.
transformer_output
=
self
.
transformer
(
transformer_input
,
if
enc_hidden_states
is
None
:
attention_mask
,
encoder_output
=
self
.
encoder
(
encoder_input
,
enc_attention_mask
,
layer_past
=
layer_past
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
get_key_value
=
get_key_value
)
else
:
encoder_output
=
enc_hidden_states
.
to
(
encoder_input
.
dtype
)
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_pooler
:
if
mpu
.
is_pipeline_last_stage
():
pooled_output
=
self
.
pooler
(
transformer_output
,
if
self
.
add_pooler
:
pooled_output
=
self
.
pooler
(
encoder_output
,
pooling_sequence_index
)
pooling_sequence_index
)
return
transformer_output
,
pooled_output
return
transformer_output
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if
not
self
.
add_decoder
or
output_enc_hidden
:
if
self
.
add_pooler
:
return
encoder_output
,
pooled_output
else
:
return
encoder_output
# Decoder Embedding
(
dec_input_ids
,
dec_position_ids
)
=
dec_language_model_input
dec_embedding_output
=
self
.
embedding
(
dec_input_ids
,
dec_position_ids
)
# decoder
decoder_output
=
self
.
decoder
(
dec_embedding_output
,
dec_attn_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
)
if
self
.
add_pooler
:
return
decoder_output
,
encoder_output
,
pooled_output
else
:
return
decoder_output
,
encoder_output
return
encoder_output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
...
@@ -346,13 +400,18 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -346,13 +400,18 @@ class TransformerLanguageModelBase(MegatronModule):
state_dict_
[
self
.
_embedding_key
]
\
state_dict_
[
self
.
_embedding_key
]
\
=
self
.
embedding
.
state_dict_for_save_checkpoint
(
=
self
.
embedding
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
destination
,
prefix
,
keep_vars
)
state_dict_
[
self
.
_
transform
er_key
]
\
state_dict_
[
self
.
_
encod
er_key
]
\
=
self
.
transform
er
.
state_dict_for_save_checkpoint
(
=
self
.
encod
er
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
destination
,
prefix
,
keep_vars
)
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_pooler
:
if
mpu
.
is_pipeline_last_stage
():
if
self
.
add_pooler
:
state_dict_
[
self
.
_pooler_key
]
\
state_dict_
[
self
.
_pooler_key
]
\
=
self
.
pooler
.
state_dict_for_save_checkpoint
(
=
self
.
pooler
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
destination
,
prefix
,
keep_vars
)
if
self
.
add_decoder
:
state_dict_
[
self
.
_decoder_key
]
\
=
self
.
decoder
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
return
state_dict_
...
@@ -371,23 +430,44 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -371,23 +430,44 @@ class TransformerLanguageModelBase(MegatronModule):
state_dict_
[
key
]
=
state_dict
[
key
]
state_dict_
[
key
]
=
state_dict
[
key
]
self
.
embedding
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
self
.
embedding
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
# Transformer.
# Encoder.
if
self
.
_transformer_key
in
state_dict
:
if
self
.
_encoder_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_transformer_key
]
state_dict_
=
state_dict
[
self
.
_encoder_key
]
# for backward compatibility.
elif
'transformer'
in
state_dict
:
state_dict_
=
state_dict
[
'transformer'
]
else
:
else
:
# for backward compatibility.
# for backward compatibility.
state_dict_
=
{}
state_dict_
=
{}
for
key
in
state_dict
.
keys
():
for
key
in
state_dict
.
keys
():
if
'transformer.'
in
key
:
if
'encoder.'
in
key
:
state_dict_
[
key
.
split
(
'transformer.'
)[
1
]]
=
state_dict
[
key
]
state_dict_
[
key
.
split
(
'encoder.'
)[
1
]]
=
state_dict
[
key
]
self
.
transformer
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
# Pooler.
# for backward compatibility.
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_pooler
:
state_dict_self_attention
=
{}
for
key
in
state_dict_
.
keys
():
if
'.attention.'
in
key
:
state_dict_self_attention
[
key
.
replace
(
".attention."
,
".self_attention."
)]
=
state_dict_
[
key
]
else
:
state_dict_self_attention
[
key
]
=
state_dict_
[
key
]
state_dict_
=
state_dict_self_attention
self
.
encoder
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
if
mpu
.
is_pipeline_last_stage
():
# pooler
if
self
.
add_pooler
:
assert
'pooler'
in
state_dict
,
\
assert
'pooler'
in
state_dict
,
\
'could not find data for pooler in the checkpoint'
'could not find data for pooler in the checkpoint'
self
.
pooler
.
load_state_dict
(
state_dict
[
self
.
_pooler_key
],
self
.
pooler
.
load_state_dict
(
state_dict
[
self
.
_pooler_key
],
strict
=
strict
)
strict
=
strict
)
# decoder
if
self
.
add_decoder
:
assert
'decoder'
in
state_dict
,
\
'could not find data for pooler in the checkpoint'
self
.
decoder
.
load_state_dict
(
state_dict
[
self
.
_decoder_key
],
strict
=
strict
)
class
TransformerLanguageModel
(
TransformerLanguageModelBase
):
class
TransformerLanguageModel
(
TransformerLanguageModelBase
):
...
@@ -400,24 +480,35 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
...
@@ -400,24 +480,35 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
num_tokentypes
=
0
,
num_tokentypes
=
0
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
add_decoder
=
False
,
add_pooler
=
False
):
add_pooler
=
False
):
super
(
TransformerLanguageModel
,
self
).
__init__
(
super
(
TransformerLanguageModel
,
self
).
__init__
(
attention_mask_func
,
attention_mask_func
,
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
self_attn_mask_type
=
self_attn_mask_type
,
add_decoder
=
add_decoder
,
add_pooler
=
add_pooler
)
add_pooler
=
add_pooler
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
pooling_sequence_index
=
0
):
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
return
super
(
TransformerLanguageModel
,
self
).
forward
(
return
super
(
TransformerLanguageModel
,
self
).
forward
(
(
input_ids
,
position_ids
),
(
enc_input_ids
,
enc_position_ids
),
attention_mask
,
enc_attention_mask
,
dec_language_model_input
=
(
dec_input_ids
,
dec_position_ids
),
dec_attn_mask
=
dec_attn_mask
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
tokentype_ids
=
tokentype_ids
,
tokentype_ids
=
tokentype_ids
,
layer_past
=
layer_past
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
get_key_value
=
get_key_value
,
pooling_sequence_index
=
pooling_sequence_index
pooling_sequence_index
=
pooling_sequence_index
,
enc_hidden_states
=
enc_hidden_states
,
output_enc_hidden
=
output_enc_hidden
)
)
...
@@ -430,12 +521,14 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
...
@@ -430,12 +521,14 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
attention_mask_func
,
attention_mask_func
,
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
num_tokentypes
=
0
):
num_tokentypes
=
0
,
self_attn_mask_type
=
AttnMaskType
.
padding
):
super
(
TransformerLanguageModelFirstStage
,
self
).
__init__
(
super
(
TransformerLanguageModelFirstStage
,
self
).
__init__
(
attention_mask_func
,
attention_mask_func
,
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
num_tokentypes
=
num_tokentypes
)
num_tokentypes
=
num_tokentypes
,
self_attn_mask_type
=
self_attn_mask_type
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
):
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
):
...
@@ -456,11 +549,13 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
...
@@ -456,11 +549,13 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
def
__init__
(
self
,
def
__init__
(
self
,
attention_mask_func
,
attention_mask_func
,
init_method
,
init_method
,
output_layer_init_method
):
output_layer_init_method
,
self_attn_mask_type
=
AttnMaskType
.
padding
):
super
(
TransformerLanguageModelIntermediateStage
,
self
).
__init__
(
super
(
TransformerLanguageModelIntermediateStage
,
self
).
__init__
(
attention_mask_func
,
attention_mask_func
,
init_method
,
init_method
,
output_layer_init_method
)
output_layer_init_method
,
self_attn_mask_type
=
self_attn_mask_type
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
layer_past
=
None
,
get_key_value
=
False
):
...
@@ -481,20 +576,31 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
...
@@ -481,20 +576,31 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
attention_mask_func
,
attention_mask_func
,
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
add_decoder
=
False
,
add_pooler
=
False
):
add_pooler
=
False
):
super
(
TransformerLanguageModelLastStage
,
self
).
__init__
(
super
(
TransformerLanguageModelLastStage
,
self
).
__init__
(
attention_mask_func
,
attention_mask_func
,
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
add_decoder
=
add_decoder
,
add_pooler
=
add_pooler
)
add_pooler
=
add_pooler
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
def
forward
(
self
,
hidden_states
,
enc_attention_mask
,
layer_past
=
None
,
get_key_value
=
False
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
pooling_sequence_index
=
0
):
enc_dec_attn_mask
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
return
super
(
TransformerLanguageModelLastStage
,
self
).
forward
(
return
super
(
TransformerLanguageModelLastStage
,
self
).
forward
(
hidden_states
,
hidden_states
,
attention_mask
,
enc_attention_mask
,
dec_language_input
=
(
dec_input_ids
,
dec_position_ids
),
dec_attn_mask
=
dec_attn_mask
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
layer_past
=
layer_past
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
get_key_value
=
get_key_value
,
pooling_sequence_index
=
pooling_sequence_index
pooling_sequence_index
=
pooling_sequence_index
,
enc_hidden_states
=
enc_hidden_states
,
ouput_enc_hidden
=
output_enc_hidden
)
)
megatron/model/transformer.py
View file @
4b506832
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
# limitations under the License.
# limitations under the License.
"""Transformer."""
"""Transformer."""
import
enum
import
math
import
math
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -23,6 +23,7 @@ from megatron import get_args
...
@@ -23,6 +23,7 @@ from megatron import get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
.module
import
MegatronModule
from
.module
import
MegatronModule
from
megatron.checkpointing
import
get_checkpoint_version
from
megatron.checkpointing
import
get_checkpoint_version
from
megatron.model.enums
import
AttnMaskType
,
LayerType
,
AttnType
from
megatron.model
import
import_layernorm
from
megatron.model
import
import_layernorm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
...
@@ -71,7 +72,7 @@ class ParallelMLP(MegatronModule):
...
@@ -71,7 +72,7 @@ class ParallelMLP(MegatronModule):
# Project to 4h.
# Project to 4h.
self
.
dense_h_to_4h
=
mpu
.
ColumnParallelLinear
(
self
.
dense_h_to_4h
=
mpu
.
ColumnParallelLinear
(
args
.
hidden_size
,
args
.
hidden_size
,
4
*
args
.
hidden_size
,
args
.
ffn_
hidden_size
,
gather_output
=
False
,
gather_output
=
False
,
init_method
=
init_method
,
init_method
=
init_method
,
skip_bias_add
=
True
)
skip_bias_add
=
True
)
...
@@ -85,13 +86,12 @@ class ParallelMLP(MegatronModule):
...
@@ -85,13 +86,12 @@ class ParallelMLP(MegatronModule):
# Project back to h.
# Project back to h.
self
.
dense_4h_to_h
=
mpu
.
RowParallelLinear
(
self
.
dense_4h_to_h
=
mpu
.
RowParallelLinear
(
4
*
args
.
hidden_size
,
args
.
ffn_
hidden_size
,
args
.
hidden_size
,
args
.
hidden_size
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
,
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
skip_bias_add
=
True
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
# [s, b, 4hp]
# [s, b, 4hp]
...
@@ -109,7 +109,7 @@ class ParallelMLP(MegatronModule):
...
@@ -109,7 +109,7 @@ class ParallelMLP(MegatronModule):
return
output
,
output_bias
return
output
,
output_bias
class
Parallel
Self
Attention
(
MegatronModule
):
class
ParallelAttention
(
MegatronModule
):
"""Parallel self-attention layer abstract class.
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
Self-attention layer takes input with size [b, s, h]
...
@@ -117,8 +117,10 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -117,8 +117,10 @@ class ParallelSelfAttention(MegatronModule):
"""
"""
def
__init__
(
self
,
attention_mask_func
,
init_method
,
def
__init__
(
self
,
attention_mask_func
,
init_method
,
output_layer_init_method
,
layer_number
):
output_layer_init_method
,
layer_number
,
super
(
ParallelSelfAttention
,
self
).
__init__
()
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
AttnMaskType
.
padding
):
super
(
ParallelAttention
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
self
.
fp16
=
args
.
fp16
self
.
fp16
=
args
.
fp16
...
@@ -128,20 +130,38 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -128,20 +130,38 @@ class ParallelSelfAttention(MegatronModule):
if
self
.
apply_query_key_layer_scaling
:
if
self
.
apply_query_key_layer_scaling
:
self
.
attention_softmax_in_fp32
=
True
self
.
attention_softmax_in_fp32
=
True
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
attention_type
=
attention_type
self
.
attn_mask_type
=
attn_mask_type
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
# Per attention head and per partition values.
# Per attention head and per partition values.
world_size
=
mpu
.
get_tensor_model_parallel_world_size
()
world_size
=
mpu
.
get_tensor_model_parallel_world_size
()
self
.
hidden_size_per_partition
=
mpu
.
divide
(
args
.
hidde
n_size
,
self
.
hidden_size_per_partition
=
mpu
.
divide
(
projectio
n_size
,
world_size
)
world_size
)
self
.
hidden_size_per_attention_head
=
mpu
.
divide
(
self
.
hidden_size_per_attention_head
=
mpu
.
divide
(
args
.
hidde
n_size
,
args
.
num_attention_heads
)
projectio
n_size
,
args
.
num_attention_heads
)
self
.
num_attention_heads_per_partition
=
mpu
.
divide
(
self
.
num_attention_heads_per_partition
=
mpu
.
divide
(
args
.
num_attention_heads
,
world_size
)
args
.
num_attention_heads
,
world_size
)
# Strided linear layer.
# Strided linear layer.
if
attention_type
==
AttnType
.
self_attn
:
self
.
query_key_value
=
mpu
.
ColumnParallelLinear
(
self
.
query_key_value
=
mpu
.
ColumnParallelLinear
(
args
.
hidden_size
,
args
.
hidden_size
,
3
*
args
.
hidden_size
,
3
*
projection_size
,
gather_output
=
False
,
init_method
=
init_method
)
else
:
assert
attention_type
==
AttnType
.
cross_attn
self
.
query
=
mpu
.
ColumnParallelLinear
(
args
.
hidden_size
,
projection_size
,
gather_output
=
False
,
init_method
=
init_method
)
self
.
key_value
=
mpu
.
ColumnParallelLinear
(
args
.
hidden_size
,
2
*
projection_size
,
gather_output
=
False
,
gather_output
=
False
,
init_method
=
init_method
)
init_method
=
init_method
)
...
@@ -153,8 +173,8 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -153,8 +173,8 @@ class ParallelSelfAttention(MegatronModule):
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
fp16
,
self
.
fp16
,
args
.
scaled_upper_triang_masked_softmax_fusion
,
self
.
attn_mask_type
,
args
.
scaled_
masked_softmax_fusion
,
args
.
masked_softmax_fusion
,
self
.
attention_mask_func
,
self
.
attention_mask_func
,
self
.
attention_softmax_in_fp32
,
self
.
attention_softmax_in_fp32
,
coeff
)
coeff
)
...
@@ -166,14 +186,14 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -166,14 +186,14 @@ class ParallelSelfAttention(MegatronModule):
# Output.
# Output.
self
.
dense
=
mpu
.
RowParallelLinear
(
self
.
dense
=
mpu
.
RowParallelLinear
(
args
.
hidde
n_size
,
projectio
n_size
,
args
.
hidden_size
,
args
.
hidden_size
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
,
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
skip_bias_add
=
True
)
def
_transpose_last_dim
(
self
,
mixed_layer
,
num_splits
,
num_splits_first
):
def
_transpose_last_dim
(
self
,
mixed_layer
,
num_splits
,
num_splits_first
):
input_shape
=
mixed_layer
.
size
()
;
input_shape
=
mixed_layer
.
size
()
if
num_splits_first
:
if
num_splits_first
:
"""[s, b, num_splits * np * hn]
"""[s, b, num_splits * np * hn]
-->(view) [s, b, num_splits, np, hn]
-->(view) [s, b, num_splits, np, hn]
...
@@ -203,13 +223,14 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -203,13 +223,14 @@ class ParallelSelfAttention(MegatronModule):
return
mixed_layer
return
mixed_layer
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
get_key_value
=
False
,
encoder_output
=
None
):
# hidden_states: [sq, b, h]
# hidden_states: [sq, b, h]
# =====================
# =====================
# Query, Key, and Value
# Query, Key, and Value
# =====================
# =====================
if
self
.
attention_type
==
AttnType
.
self_attn
:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
...
@@ -232,6 +253,36 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -232,6 +253,36 @@ class ParallelSelfAttention(MegatronModule):
(
query_layer
,
(
query_layer
,
key_layer
,
key_layer
,
value_layer
)
=
mpu
.
split_tensor_along_last_dim
(
mixed_x_layer
,
3
)
value_layer
)
=
mpu
.
split_tensor_along_last_dim
(
mixed_x_layer
,
3
)
else
:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer
,
_
=
self
.
key_value
(
encoder_output
)
checkpoint_version
=
get_checkpoint_version
()
if
checkpoint_version
is
not
None
:
if
checkpoint_version
==
0
:
# [s, b, (2 * np * hn)] --> [s, b, (np * 2 * hn)]
mixed_kv_layer
=
self
.
_transpose_last_dim
(
mixed_kv_layer
,
2
,
True
)
elif
checkpoint_version
==
1.0
:
# [s, b, (np * hn * 2)] --> [s, b, (np * 2 * hn)]
mixed_kv_layer
=
self
.
_transpose_last_dim
(
mixed_kv_layer
,
2
,
False
)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape
=
mixed_kv_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
2
*
self
.
hidden_size_per_attention_head
)
mixed_kv_layer
=
mixed_kv_layer
.
view
(
*
new_tensor_shape
)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(
key_layer
,
value_layer
)
=
mpu
.
split_tensor_along_last_dim
(
mixed_kv_layer
,
2
)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer
,
_
=
self
.
query
(
hidden_states
)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape
=
query_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
)
query_layer
=
query_layer
.
view
(
*
new_tensor_shape
)
# ==================================
# ==================================
# Adjust key and value for inference
# Adjust key and value for inference
...
@@ -246,7 +297,6 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -246,7 +297,6 @@ class ParallelSelfAttention(MegatronModule):
if
get_key_value
:
if
get_key_value
:
present
=
(
key_layer
,
value_layer
)
present
=
(
key_layer
,
value_layer
)
# ===================================
# ===================================
# Raw attention scores. [b, np, s, s]
# Raw attention scores. [b, np, s, s]
# ===================================
# ===================================
...
@@ -260,6 +310,7 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -260,6 +310,7 @@ class ParallelSelfAttention(MegatronModule):
# [sq, b, np, hn] -> [sq, b * np, hn]
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer
=
query_layer
.
view
(
output_size
[
2
],
query_layer
=
query_layer
.
view
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer
=
key_layer
.
view
(
output_size
[
3
],
key_layer
=
key_layer
.
view
(
output_size
[
3
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
output_size
[
0
]
*
output_size
[
1
],
-
1
)
...
@@ -272,15 +323,15 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -272,15 +323,15 @@ class ParallelSelfAttention(MegatronModule):
device
=
torch
.
cuda
.
current_device
())
device
=
torch
.
cuda
.
current_device
())
# Raw attention scores. [b * np, sq, sk]
# Raw attention scores. [b * np, sq, sk]
matmul_result
=
torch
.
baddbmm
(
matmul_result
,
matmul_result
=
torch
.
baddbmm
(
matmul_result
,
query_layer
.
transpose
(
0
,
1
),
# [b * np, sq, hn]
query_layer
.
transpose
(
0
,
1
),
# [b * np, sq, hn]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
#[b * np, hn, sk]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
#
[b * np, hn, sk]
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
))
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
))
# change view to [b, np, sq, sk]
# change view to [b, np, sq, sk]
attention_scores
=
matmul_result
.
view
(
*
output_size
)
attention_scores
=
matmul_result
.
view
(
*
output_size
)
# ==================================================
# ==================================================
# Update attention mask for inference. [b, np, sq, sk]
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================
# ==================================================
...
@@ -298,7 +349,6 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -298,7 +349,6 @@ class ParallelSelfAttention(MegatronModule):
:
attention_scores
.
size
(
3
),
:
attention_scores
.
size
(
3
),
:
attention_scores
.
size
(
3
)]
:
attention_scores
.
size
(
3
)]
# ===========================
# ===========================
# Attention probs and dropout
# Attention probs and dropout
# ===========================
# ===========================
...
@@ -312,7 +362,6 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -312,7 +362,6 @@ class ParallelSelfAttention(MegatronModule):
with
mpu
.
get_cuda_rng_tracker
().
fork
():
with
mpu
.
get_cuda_rng_tracker
().
fork
():
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
# =========================
# =========================
# Context layer. [sq, b, hp]
# Context layer. [sq, b, hp]
# =========================
# =========================
...
@@ -335,7 +384,7 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -335,7 +384,7 @@ class ParallelSelfAttention(MegatronModule):
output_size
[
2
],
-
1
)
output_size
[
2
],
-
1
)
# matmul: [b * np, sq, hn]
# matmul: [b * np, sq, hn]
context_layer
=
torch
.
bmm
(
attention_probs
,
value_layer
.
transpose
(
0
,
1
))
context_layer
=
torch
.
bmm
(
attention_probs
,
value_layer
.
transpose
(
0
,
1
))
# change view [b, np, sq, hn]
# change view [b, np, sq, hn]
context_layer
=
context_layer
.
view
(
*
output_size
)
context_layer
=
context_layer
.
view
(
*
output_size
)
...
@@ -348,7 +397,6 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -348,7 +397,6 @@ class ParallelSelfAttention(MegatronModule):
(
self
.
hidden_size_per_partition
,)
(
self
.
hidden_size_per_partition
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
# =================
# =================
# Output. [sq, b, h]
# Output. [sq, b, h]
# =================
# =================
...
@@ -389,16 +437,19 @@ def bias_dropout_add_fused_inference(x, bias, residual, prob) :
...
@@ -389,16 +437,19 @@ def bias_dropout_add_fused_inference(x, bias, residual, prob) :
class
ParallelTransformerLayer
(
MegatronModule
):
class
ParallelTransformerLayer
(
MegatronModule
):
"""A single transformer layer.
"""A single transformer layer.
Transform
or
e layer takes input with size [b, s, h] and returns an
Transforme
r
layer takes input with size [b, s, h] and returns an
output of the same size.
output of the same size.
"""
"""
def
__init__
(
self
,
attention_mask_func
,
init_method
,
def
__init__
(
self
,
attention_mask_func
,
init_method
,
output_layer_init_method
,
layer_number
):
output_layer_init_method
,
layer_number
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
):
args
=
get_args
()
args
=
get_args
()
super
(
ParallelTransformerLayer
,
self
).
__init__
()
super
(
ParallelTransformerLayer
,
self
).
__init__
()
self
.
layer_number
=
layer_number
self
.
layer_number
=
layer_number
self
.
layer_type
=
layer_type
self
.
apply_residual_connection_post_layernorm
\
self
.
apply_residual_connection_post_layernorm
\
=
args
.
apply_residual_connection_post_layernorm
=
args
.
apply_residual_connection_post_layernorm
...
@@ -410,30 +461,47 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -410,30 +461,47 @@ class ParallelTransformerLayer(MegatronModule):
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
)
# Self attention.
# Self attention.
self
.
attention
=
ParallelSelfAttention
(
attention_mask_func
,
init_method
,
self
.
self_attention
=
ParallelAttention
(
attention_mask_func
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
layer_number
)
layer_number
,
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
self_attn_mask_type
)
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
# Layernorm on the
input data.
# Layernorm on the
attention output
self
.
post_attention_layernorm
=
LayerNorm
(
self
.
post_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
)
if
self
.
layer_type
==
LayerType
.
decoder
:
self
.
inter_attention
=
ParallelAttention
(
attention_mask_func
,
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
cross_attn
)
# Layernorm on the attention output.
self
.
post_inter_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
# MLP
# MLP
self
.
mlp
=
ParallelMLP
(
init_method
,
self
.
mlp
=
ParallelMLP
(
init_method
,
output_layer_init_method
)
output_layer_init_method
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
get_key_value
=
False
):
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
layer_past
=
None
,
get_key_value
=
False
):
# hidden_states: [b, s, h]
# hidden_states: [b, s, h]
# Layer norm at the begining of the transformer layer.
# Layer norm at the begin
n
ing of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
# Self attention.
attention_output
,
attention_bias
=
\
attention_output
,
attention_bias
=
\
self
.
attention
(
layernorm_output
,
self
.
self_
attention
(
layernorm_output
,
attention_mask
,
attention_mask
,
layer_past
=
layer_past
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
get_key_value
=
get_key_value
)
...
@@ -459,7 +527,7 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -459,7 +527,7 @@ class ParallelTransformerLayer(MegatronModule):
else
:
else
:
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
#re-enable torch grad to enable fused optimization.
#
re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_output
,
...
@@ -470,6 +538,28 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -470,6 +538,28 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention.
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
if
self
.
layer_type
==
LayerType
.
decoder
:
attention_output
,
attention_bias
=
\
self
.
inter_attention
(
layernorm_output
,
enc_dec_attn_mask
,
encoder_output
=
encoder_output
)
# residual connection
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
# Layer norm post the decoder attention
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
# MLP.
# MLP.
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
...
@@ -479,7 +569,7 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -479,7 +569,7 @@ class ParallelTransformerLayer(MegatronModule):
else
:
else
:
residual
=
layernorm_input
residual
=
layernorm_input
#re-enable torch grad to enable fused optimization.
#
re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
with
torch
.
enable_grad
():
output
=
bias_dropout_add_func
(
output
=
bias_dropout_add_func
(
mlp_output
,
mlp_output
,
...
@@ -497,7 +587,9 @@ class ParallelTransformer(MegatronModule):
...
@@ -497,7 +587,9 @@ class ParallelTransformer(MegatronModule):
"""Transformer class."""
"""Transformer class."""
def
__init__
(
self
,
attention_mask_func
,
def
__init__
(
self
,
attention_mask_func
,
init_method
,
output_layer_init_method
):
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
):
super
(
ParallelTransformer
,
self
).
__init__
()
super
(
ParallelTransformer
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
...
@@ -516,7 +608,9 @@ class ParallelTransformer(MegatronModule):
...
@@ -516,7 +608,9 @@ class ParallelTransformer(MegatronModule):
def
build_layer
(
layer_number
):
def
build_layer
(
layer_number
):
return
ParallelTransformerLayer
(
return
ParallelTransformerLayer
(
attention_mask_func
,
init_method
,
attention_mask_func
,
init_method
,
output_layer_init_method
,
layer_number
)
output_layer_init_method
,
layer_number
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
)
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
self
.
layers
=
torch
.
nn
.
ModuleList
(
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
...
@@ -531,14 +625,18 @@ class ParallelTransformer(MegatronModule):
...
@@ -531,14 +625,18 @@ class ParallelTransformer(MegatronModule):
def
_get_layer
(
self
,
layer_number
):
def
_get_layer
(
self
,
layer_number
):
return
self
.
layers
[
layer_number
]
return
self
.
layers
[
layer_number
]
def
_checkpointed_forward
(
self
,
hidden_states
,
attention_mask
):
def
_checkpointed_forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
):
"""Forward method with activation checkpointing."""
"""Forward method with activation checkpointing."""
def
custom
(
start
,
end
):
def
custom
(
start
,
end
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
x_
=
inputs
[
0
]
x_
=
inputs
[
0
]
attention_mask
=
inputs
[
1
]
encoder_output
=
inputs
[
2
]
enc_dec_attn_mask
=
inputs
[
3
]
for
index
in
range
(
start
,
end
):
for
index
in
range
(
start
,
end
):
layer
=
self
.
_get_layer
(
index
)
layer
=
self
.
_get_layer
(
index
)
x_
=
layer
(
x_
,
inputs
[
1
]
)
x_
=
layer
(
x_
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
return
x_
return
x_
return
custom_forward
return
custom_forward
...
@@ -548,13 +646,13 @@ class ParallelTransformer(MegatronModule):
...
@@ -548,13 +646,13 @@ class ParallelTransformer(MegatronModule):
while
l
<
self
.
num_layers
:
while
l
<
self
.
num_layers
:
hidden_states
=
mpu
.
checkpoint
(
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
self
.
checkpoint_num_layers
),
custom
(
l
,
l
+
self
.
checkpoint_num_layers
),
hidden_states
,
attention_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
l
+=
self
.
checkpoint_num_layers
l
+=
self
.
checkpoint_num_layers
return
hidden_states
return
hidden_states
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
get_key_value
=
False
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
):
# Checks.
# Checks.
if
layer_past
is
not
None
:
if
layer_past
is
not
None
:
...
@@ -577,7 +675,9 @@ class ParallelTransformer(MegatronModule):
...
@@ -577,7 +675,9 @@ class ParallelTransformer(MegatronModule):
if
self
.
checkpoint_activations
:
if
self
.
checkpoint_activations
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
attention_mask
)
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
else
:
if
get_key_value
:
if
get_key_value
:
presents
=
[]
presents
=
[]
...
@@ -588,6 +688,8 @@ class ParallelTransformer(MegatronModule):
...
@@ -588,6 +688,8 @@ class ParallelTransformer(MegatronModule):
past
=
layer_past
[
index
]
past
=
layer_past
[
index
]
hidden_states
=
layer
(
hidden_states
,
hidden_states
=
layer
(
hidden_states
,
attention_mask
,
attention_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
layer_past
=
past
,
layer_past
=
past
,
get_key_value
=
get_key_value
)
get_key_value
=
get_key_value
)
if
get_key_value
:
if
get_key_value
:
...
...
pretrain_gpt2.py
View file @
4b506832
...
@@ -141,5 +141,4 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
...
@@ -141,5 +141,4 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
})
'scaled_upper_triang_masked_softmax_fusion'
:
True
})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment