Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
4b506832
Commit
4b506832
authored
Jan 08, 2021
by
Vijay Korthikanti
Browse files
decoder support in transformers
parent
f5eac3d1
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
520 additions
and
261 deletions
+520
-261
megatron/arguments.py
megatron/arguments.py
+31
-17
megatron/fused_kernels/scaled_masked_softmax.h
megatron/fused_kernels/scaled_masked_softmax.h
+43
-47
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
+13
-11
megatron/model/enums.py
megatron/model/enums.py
+28
-0
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+31
-21
megatron/model/gpt2_model.py
megatron/model/gpt2_model.py
+2
-0
megatron/model/language_model.py
megatron/model/language_model.py
+172
-66
megatron/model/transformer.py
megatron/model/transformer.py
+199
-97
pretrain_gpt2.py
pretrain_gpt2.py
+1
-2
No files found.
megatron/arguments.py
View file @
4b506832
...
@@ -164,6 +164,20 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -164,6 +164,20 @@ def parse_args(extra_args_provider=None, defaults={},
_check_arg_is_not_none
(
args
,
req_arg
)
_check_arg_is_not_none
(
args
,
req_arg
)
# Checks.
# Checks.
if
args
.
ffn_hidden_size
is
None
:
args
.
ffn_hidden_size
=
4
*
args
.
hidden_size
if
args
.
kv_channels
is
None
:
assert
args
.
hidden_size
%
args
.
num_attention_heads
==
0
args
.
kv_channels
=
args
.
hidden_size
//
args
.
num_attention_heads
if
args
.
seq_length
is
not
None
:
assert
args
.
encoder_seq_length
is
None
args
.
encoder_seq_length
=
args
.
seq_length
else
:
assert
args
.
encoder_seq_length
is
not
None
args
.
seq_length
=
args
.
encoder_seq_length
assert
args
.
hidden_size
%
args
.
num_attention_heads
==
0
assert
args
.
hidden_size
%
args
.
num_attention_heads
==
0
if
args
.
seq_length
is
not
None
:
if
args
.
seq_length
is
not
None
:
assert
args
.
max_position_embeddings
>=
args
.
seq_length
assert
args
.
max_position_embeddings
>=
args
.
seq_length
...
@@ -182,16 +196,11 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -182,16 +196,11 @@ def parse_args(extra_args_provider=None, defaults={},
assert
args
.
checkpoint_activations
,
\
assert
args
.
checkpoint_activations
,
\
'for distribute-checkpointed-activations to work you '
\
'for distribute-checkpointed-activations to work you '
\
'need to enable checkpoint-activations'
'need to enable checkpoint-activations'
if
args
.
scaled_masked_softmax_fusion
:
# Load scaled_masked_softmax_fusion_kernels
if
args
.
scaled_upper_triang_masked_softmax_fusion
:
if
args
.
masked_softmax_fusion
:
fused_kernels
.
load_scaled_upper_triang_masked_softmax_fusion_kernel
()
fused_kernels
.
load_scaled_upper_triang_masked_softmax_fusion_kernel
()
else
:
fused_kernels
.
load_scaled_masked_softmax_fusion_kernel
()
fused_kernels
.
load_scaled_masked_softmax_fusion_kernel
()
else
:
# This argument will eventually go away, for now make sure it is off
# if scaled_masked_softmax_fusion is off.
args
.
scaled_upper_triang_masked_softmax_fusion
=
False
# Load mixed precision fused layer norm.
# Load mixed precision fused layer norm.
if
args
.
fp32_residual_connection
:
if
args
.
fp32_residual_connection
:
...
@@ -227,8 +236,14 @@ def _add_network_size_args(parser):
...
@@ -227,8 +236,14 @@ def _add_network_size_args(parser):
help
=
'Number of transformer layers.'
)
help
=
'Number of transformer layers.'
)
group
.
add_argument
(
'--hidden-size'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--hidden-size'
,
type
=
int
,
default
=
None
,
help
=
'Tansformer hidden size.'
)
help
=
'Tansformer hidden size.'
)
group
.
add_argument
(
'--ffn-hidden-size'
,
type
=
int
,
default
=
None
,
help
=
'Transformer Feed-Forward Network hidden size. This is set to 4*hidden-size if not '
'provided'
)
group
.
add_argument
(
'--num-attention-heads'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--num-attention-heads'
,
type
=
int
,
default
=
None
,
help
=
'Number of transformer attention heads.'
)
help
=
'Number of transformer attention heads.'
)
group
.
add_argument
(
'--kv-channels'
,
type
=
int
,
default
=
None
,
help
=
'Projection weights dimension in multi-head attention. '
'This is set to args.hidden_size // args.num_attention_heads if not provided.'
)
group
.
add_argument
(
'--max-position-embeddings'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--max-position-embeddings'
,
type
=
int
,
default
=
None
,
help
=
'Maximum number of position embeddings to use. '
help
=
'Maximum number of position embeddings to use. '
'This is the size of position embedding.'
)
'This is the size of position embedding.'
)
...
@@ -330,16 +345,11 @@ def _add_training_args(parser):
...
@@ -330,16 +345,11 @@ def _add_training_args(parser):
help
=
'Exit the program after this many minutes.'
)
help
=
'Exit the program after this many minutes.'
)
group
.
add_argument
(
'--tensorboard-dir'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--tensorboard-dir'
,
type
=
str
,
default
=
None
,
help
=
'Write TensorBoard logs to this directory.'
)
help
=
'Write TensorBoard logs to this directory.'
)
group
.
add_argument
(
'--no-
scaled-
masked-softmax-fusion'
,
group
.
add_argument
(
'--no-masked-softmax-fusion'
,
action
=
'store_false'
,
action
=
'store_false'
,
help
=
'Disable fusion of query_key_value scaling, '
help
=
'Disable fusion of query_key_value scaling, '
'masking, and softmax.'
,
'masking, and softmax.'
,
dest
=
'scaled_masked_softmax_fusion'
)
dest
=
'masked_softmax_fusion'
)
group
.
add_argument
(
'--scaled-upper-triang-masked-softmax-fusion'
,
type
=
bool
,
help
=
'Use upper triangular version of fused '
'scale, mask, softmax fusion kernel (default for GPT). '
'- DEPRECATED'
)
group
.
add_argument
(
'--no-bias-gelu-fusion'
,
action
=
'store_false'
,
group
.
add_argument
(
'--no-bias-gelu-fusion'
,
action
=
'store_false'
,
help
=
'Disable bias and gelu fusion.'
,
help
=
'Disable bias and gelu fusion.'
,
dest
=
'bias_gelu_fusion'
)
dest
=
'bias_gelu_fusion'
)
...
@@ -530,6 +540,10 @@ def _add_data_args(parser):
...
@@ -530,6 +540,10 @@ def _add_data_args(parser):
help
=
'Path to the BPE merge file.'
)
help
=
'Path to the BPE merge file.'
)
group
.
add_argument
(
'--seq-length'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--seq-length'
,
type
=
int
,
default
=
None
,
help
=
"Maximum sequence length to process."
)
help
=
"Maximum sequence length to process."
)
group
.
add_argument
(
'--encoder-seq-length'
,
type
=
int
,
default
=
None
,
help
=
"Maximum encoder sequence length to process."
)
group
.
add_argument
(
'--decoder-seq-length'
,
type
=
int
,
default
=
None
,
help
=
"Maximum decoder sequence length to process."
)
group
.
add_argument
(
'--mask-prob'
,
type
=
float
,
default
=
0.15
,
group
.
add_argument
(
'--mask-prob'
,
type
=
float
,
default
=
0.15
,
help
=
'Probability of replacing a token with mask.'
)
help
=
'Probability of replacing a token with mask.'
)
group
.
add_argument
(
'--short-seq-prob'
,
type
=
float
,
default
=
0.1
,
group
.
add_argument
(
'--short-seq-prob'
,
type
=
float
,
default
=
0.1
,
...
...
megatron/fused_kernels/scaled_masked_softmax.h
View file @
4b506832
...
@@ -81,7 +81,6 @@ __global__ void scaled_masked_softmax_warp_forward(
...
@@ -81,7 +81,6 @@ __global__ void scaled_masked_softmax_warp_forward(
const
uint8_t
*
mask
,
const
uint8_t
*
mask
,
const
acc_t
scale
,
const
acc_t
scale
,
int
micro_batch_size
,
int
micro_batch_size
,
int
stride
,
int
element_count
,
int
element_count
,
int
pad_batches
)
int
pad_batches
)
{
{
...
@@ -111,9 +110,9 @@ __global__ void scaled_masked_softmax_warp_forward(
...
@@ -111,9 +110,9 @@ __global__ void scaled_masked_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
stride
+
local_idx
;
src
+=
first_batch
*
element_count
+
local_idx
;
dst
+=
first_batch
*
stride
+
local_idx
;
dst
+=
first_batch
*
element_count
+
local_idx
;
mask
+=
pad_first_batch
*
stride
+
local_idx
;
mask
+=
pad_first_batch
*
element_count
+
local_idx
;
// load data from global memory
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
...
@@ -185,7 +184,6 @@ __global__ void scaled_masked_softmax_warp_backward(
...
@@ -185,7 +184,6 @@ __global__ void scaled_masked_softmax_warp_backward(
const
input_t
*
output
,
const
input_t
*
output
,
acc_t
scale
,
acc_t
scale
,
int
micro_batch_size
,
int
micro_batch_size
,
int
stride
,
int
element_count
)
int
element_count
)
{
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
...
@@ -209,7 +207,7 @@ __global__ void scaled_masked_softmax_warp_backward(
...
@@ -209,7 +207,7 @@ __global__ void scaled_masked_softmax_warp_backward(
int
local_idx
=
threadIdx
.
x
;
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
stride
+
local_idx
;
int
thread_offset
=
first_batch
*
element_count
+
local_idx
;
grad
+=
thread_offset
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
gradInput
+=
thread_offset
;
...
@@ -277,20 +275,19 @@ void dispatch_scaled_masked_softmax_forward(
...
@@ -277,20 +275,19 @@ void dispatch_scaled_masked_softmax_forward(
const
input_t
*
src
,
const
input_t
*
src
,
const
uint8_t
*
mask
,
const
uint8_t
*
mask
,
const
input_t
scale
,
const
input_t
scale
,
int
softmax_elements
,
int
query_seq_len
,
int
softmax_elements_stride
,
int
key_seq_len
,
int
batches
,
int
batches
,
int
attn_heads
,
int
attn_heads
,
int
pad_batches
)
int
pad_batches
)
{
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
2048
);
if
(
softmax_elements
==
0
)
{
if
(
key_seq_len
==
0
)
{
return
;
return
;
}
else
{
}
else
{
int
log2_elements
=
log2_ceil
(
softmax_elements
);
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
seq_len
=
softmax_elements
;
int
batch_count
=
batches
*
attn_heads
*
query_seq_len
;
int
batch_count
=
batches
*
attn_heads
*
seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
...
@@ -302,59 +299,59 @@ void dispatch_scaled_masked_softmax_forward(
...
@@ -302,59 +299,59 @@ void dispatch_scaled_masked_softmax_forward(
constexpr
int
threads_per_block
=
128
;
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
seq_len
%
batches_per_block
==
0
);
TORCH_INTERNAL_ASSERT
(
query_
seq_len
%
batches_per_block
==
0
);
dim3
blocks
(
seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
blocks
(
query_
seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
switch
(
log2_elements
)
{
case
0
:
// 1
case
0
:
// 1
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
0
>
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
break
;
case
1
:
// 2
case
1
:
// 2
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
>
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
break
;
case
2
:
// 4
case
2
:
// 4
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
>
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
break
;
case
3
:
// 8
case
3
:
// 8
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
3
>
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
break
;
case
4
:
// 16
case
4
:
// 16
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
4
>
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
break
;
case
5
:
// 32
case
5
:
// 32
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
5
>
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
break
;
case
6
:
// 64
case
6
:
// 64
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
6
>
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
break
;
case
7
:
// 128
case
7
:
// 128
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
7
>
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
break
;
case
8
:
// 256
case
8
:
// 256
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
8
>
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
break
;
case
9
:
// 512
case
9
:
// 512
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
9
>
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
break
;
case
10
:
// 1024
case
10
:
// 1024
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
10
>
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
break
;
case
11
:
// 2048
case
11
:
// 2048
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
11
>
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
break
;
default:
default:
break
;
break
;
...
@@ -368,19 +365,18 @@ void dispatch_scaled_masked_softmax_backward(
...
@@ -368,19 +365,18 @@ void dispatch_scaled_masked_softmax_backward(
input_t
*
grad
,
input_t
*
grad
,
const
input_t
*
output
,
const
input_t
*
output
,
const
acc_t
scale
,
const
acc_t
scale
,
int
softmax_elements
,
int
query_seq_len
,
int
softmax_elements_stride
,
int
key_seq_len
,
int
batches
,
int
batches
,
int
attn_heads
)
int
attn_heads
)
{
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
2048
);
if
(
softmax_elements
==
0
)
{
if
(
key_seq_len
==
0
)
{
return
;
return
;
}
else
{
}
else
{
int
log2_elements
=
log2_ceil
(
softmax_elements
);
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
seq_len
=
softmax_elements
;
int
batch_count
=
batches
*
attn_heads
*
query_seq_len
;
int
batch_count
=
batches
*
attn_heads
*
seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
...
@@ -399,51 +395,51 @@ void dispatch_scaled_masked_softmax_backward(
...
@@ -399,51 +395,51 @@ void dispatch_scaled_masked_softmax_backward(
switch
(
log2_elements
)
{
switch
(
log2_elements
)
{
case
0
:
// 1
case
0
:
// 1
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
0
>
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
break
;
case
1
:
// 2
case
1
:
// 2
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
>
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
break
;
case
2
:
// 4
case
2
:
// 4
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
>
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
break
;
case
3
:
// 8
case
3
:
// 8
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
3
>
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
break
;
case
4
:
// 16
case
4
:
// 16
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
4
>
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
break
;
case
5
:
// 32
case
5
:
// 32
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
5
>
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
break
;
case
6
:
// 64
case
6
:
// 64
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
6
>
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
break
;
case
7
:
// 128
case
7
:
// 128
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
7
>
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
break
;
case
8
:
// 256
case
8
:
// 256
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
8
>
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
break
;
case
9
:
// 512
case
9
:
// 512
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
9
>
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
break
;
case
10
:
// 1024
case
10
:
// 1024
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
10
>
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
break
;
case
11
:
// 2048
case
11
:
// 2048
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
11
>
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
break
;
default:
default:
break
;
break
;
...
...
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
View file @
4b506832
...
@@ -37,17 +37,19 @@ torch::Tensor fwd_cuda(
...
@@ -37,17 +37,19 @@ torch::Tensor fwd_cuda(
const
int
batches
=
input
.
size
(
0
);
const
int
batches
=
input
.
size
(
0
);
const
int
pad_batches
=
mask
.
size
(
0
);
const
int
pad_batches
=
mask
.
size
(
0
);
const
int
attn_heads
=
input
.
size
(
1
);
const
int
attn_heads
=
input
.
size
(
1
);
const
int
seq_len
=
input
.
size
(
2
);
const
int
query_seq_len
=
input
.
size
(
2
);
TORCH_INTERNAL_ASSERT
(
seq_len
<=
2048
);
const
int
key_seq_len
=
input
.
size
(
3
);
TORCH_INTERNAL_ASSERT
(
key_seq_len
<=
2048
);
TORCH_INTERNAL_ASSERT
(
query_seq_len
>
1
);
TORCH_INTERNAL_ASSERT
(
pad_batches
==
1
||
pad_batches
==
batches
);
TORCH_INTERNAL_ASSERT
(
pad_batches
==
1
||
pad_batches
==
batches
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
1
)
==
1
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
1
)
==
1
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
2
)
==
seq_len
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
2
)
==
query_
seq_len
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
3
)
==
seq_len
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
3
)
==
key_
seq_len
);
// Output
// Output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
torch
::
Tensor
softmax_results
=
torch
::
Tensor
softmax_results
=
torch
::
empty
({
batches
,
attn_heads
,
seq_len
,
seq_len
},
act_options
);
torch
::
empty
({
batches
,
attn_heads
,
query_
seq_len
,
key_
seq_len
},
act_options
);
// Softmax Intermediate Result Ptr
// Softmax Intermediate Result Ptr
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
...
@@ -59,8 +61,8 @@ torch::Tensor fwd_cuda(
...
@@ -59,8 +61,8 @@ torch::Tensor fwd_cuda(
reinterpret_cast
<
const
half
*>
(
input_ptr
),
reinterpret_cast
<
const
half
*>
(
input_ptr
),
reinterpret_cast
<
const
uint8_t
*>
(
mask_ptr
),
reinterpret_cast
<
const
uint8_t
*>
(
mask_ptr
),
scale_factor
,
scale_factor
,
seq_len
,
query_
seq_len
,
seq_len
,
key_
seq_len
,
batches
,
batches
,
attn_heads
,
attn_heads
,
pad_batches
);
pad_batches
);
...
@@ -78,8 +80,8 @@ torch::Tensor bwd_cuda(
...
@@ -78,8 +80,8 @@ torch::Tensor bwd_cuda(
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
output_grads
.
size
(
0
);
const
int
batches
=
output_grads
.
size
(
0
);
const
int
attn_heads
=
output_grads
.
size
(
1
);
const
int
attn_heads
=
output_grads
.
size
(
1
);
const
int
seq_len
=
output_grads
.
size
(
2
);
const
int
query_
seq_len
=
output_grads
.
size
(
2
);
TORCH_INTERNAL_ASSERT
(
output_grads
.
size
(
2
)
=
=
output_grads
.
size
(
3
)
)
;
const
int
key_seq_len
=
output_grads
.
size
(
3
);
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
...
@@ -89,8 +91,8 @@ torch::Tensor bwd_cuda(
...
@@ -89,8 +91,8 @@ torch::Tensor bwd_cuda(
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
scale_factor
,
seq_len
,
query_
seq_len
,
seq_len
,
key_
seq_len
,
batches
,
batches
,
attn_heads
);
attn_heads
);
...
...
megatron/model/enums.py
0 → 100644
View file @
4b506832
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
enum
class
LayerType
(
enum
.
Enum
):
encoder
=
1
decoder
=
2
class
AttnType
(
enum
.
Enum
):
self_attn
=
1
cross_attn
=
2
class
AttnMaskType
(
enum
.
Enum
):
padding
=
1
causal
=
2
megatron/model/fused_softmax.py
View file @
4b506832
...
@@ -14,11 +14,13 @@
...
@@ -14,11 +14,13 @@
# limitations under the License.
# limitations under the License.
import
torch
import
torch
from
megatron.model.enums
import
AttnMaskType
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
)
:
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
):
"""
"""
Fused operation which performs following three operations in sequence
Fused operation which performs following three operations in sequence
1. Scale the tensor.
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
3. Perform softmax.
"""
"""
...
@@ -38,15 +40,16 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function) :
...
@@ -38,15 +40,16 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function) :
softmax_results
,
scale_t
=
ctx
.
saved_tensors
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
\
input_grads
=
\
scaled_upper_triang_masked_softmax_cuda
.
backward
(
output_grads
,
scaled_upper_triang_masked_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
softmax_results
,
scale_t
[
0
])
scale_t
[
0
])
return
input_grads
,
None
return
input_grads
,
None
class
ScaledMaskedSoftmax
(
torch
.
autograd
.
Function
)
:
class
ScaledMaskedSoftmax
(
torch
.
autograd
.
Function
):
"""
"""
Fused operation which performs following three operations in sequence
Fused operation which performs following three operations in sequence
1. Scale the tensor.
1. Scale the tensor.
2. Apply the mask.
2. Apply the mask.
3. Perform softmax.
3. Perform softmax.
"""
"""
...
@@ -71,24 +74,25 @@ class ScaledMaskedSoftmax(torch.autograd.Function) :
...
@@ -71,24 +74,25 @@ class ScaledMaskedSoftmax(torch.autograd.Function) :
scale_t
[
0
])
scale_t
[
0
])
return
input_grads
,
None
,
None
return
input_grads
,
None
,
None
class
FusedScaleMaskSoftmax
(
torch
.
nn
.
Module
):
class
FusedScaleMaskSoftmax
(
torch
.
nn
.
Module
):
"""
"""
fused operation: scaling + mask + softmax
fused operation: scaling + mask + softmax
Arguments:
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_fp16: flag to indicate if input in fp16 data format.
upper_triang_mask: if true, apply upper triangular masking.
attn_mask_type: attention mask type (pad or causal)
(used in gpt family networks)
mask_func: mask function to be applied.
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
scale: scaling factor used in input tensor scaling.
"""
"""
def
__init__
(
self
,
input_in_fp16
,
upper_triang_mask_fusion
,
def
__init__
(
self
,
input_in_fp16
,
attn_mask_type
,
general_mask_fusion
,
mask_func
,
softmax_in_fp32
,
scale
):
scaled_masked_softmax_fusion
,
mask_func
,
softmax_in_fp32
,
scale
):
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_fp16
=
input_in_fp16
self
.
upper_triang_mask_fusion
=
upper_triang_mask_fusion
self
.
attn_mask_type
=
attn_mask_type
self
.
general_mask_fusion
=
general_mask
_fusion
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax
_fusion
self
.
mask_func
=
mask_func
self
.
mask_func
=
mask_func
self
.
softmax_in_fp32
=
softmax_in_fp32
self
.
softmax_in_fp32
=
softmax_in_fp32
self
.
scale
=
scale
self
.
scale
=
scale
...
@@ -97,20 +101,26 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -97,20 +101,26 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
'softmax should be in fp32 when scaled'
'softmax should be in fp32 when scaled'
def
forward
(
self
,
input
,
mask
):
def
forward
(
self
,
input
,
mask
):
# [b, np, s, s]
# [b, np, s
q
, s
k
]
data_size
=
input
.
size
()
data_size
=
input
.
size
()
assert
input
.
dim
()
==
4
query_seq_len
=
data_size
[
-
2
]
key_seq_len
=
data_size
[
-
1
]
assert
input
.
dim
()
==
4
# invoke custom kernel
# invoke custom kernel
if
self
.
input_in_fp16
and
data_size
[
-
1
]
<=
2048
and
\
if
self
.
input_in_fp16
and
key_seq_len
<=
2048
and
\
(
self
.
upper_triang_mask_fusion
or
self
.
general_mask_fusion
)
and
\
query_seq_len
%
4
==
0
and
self
.
scaled_masked_softmax_fusion
:
input
.
size
()[
2
]
==
input
.
size
()[
3
]:
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
upper_triang_mask_fusion
:
input
=
input
.
view
(
-
1
,
data_size
[
2
],
data_size
[
3
])
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
assert
query_seq_len
==
key_seq_len
,
\
"causal mask is only for self attention"
input
=
input
.
view
(
-
1
,
query_seq_len
,
key_seq_len
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
probs
=
probs
.
view
(
*
data_size
)
probs
=
probs
.
view
(
*
data_size
)
else
:
else
:
assert
self
.
attn_mask_type
==
AttnMaskType
.
padding
probs
=
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
probs
=
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
else
:
else
:
if
self
.
input_in_fp16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_fp16
and
self
.
softmax_in_fp32
:
...
...
megatron/model/gpt2_model.py
View file @
4b506832
...
@@ -21,6 +21,7 @@ from megatron import get_args
...
@@ -21,6 +21,7 @@ from megatron import get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
.module
import
MegatronModule
from
.module
import
MegatronModule
from
.enums
import
AttnMaskType
from
.language_model
import
parallel_lm_logits
from
.language_model
import
parallel_lm_logits
from
.language_model
import
get_language_model
from
.language_model
import
get_language_model
from
.utils
import
init_method_normal
from
.utils
import
init_method_normal
...
@@ -75,6 +76,7 @@ class GPT2ModelBase(MegatronModule):
...
@@ -75,6 +76,7 @@ class GPT2ModelBase(MegatronModule):
attention_mask_func
=
gpt2_attention_mask_func
,
attention_mask_func
=
gpt2_attention_mask_func
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
False
,
add_pooler
=
False
,
self_attn_mask_type
=
AttnMaskType
.
causal
,
init_method
=
init_method_normal
(
args
.
init_method_std
),
init_method
=
init_method_normal
(
args
.
init_method_std
),
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
args
.
num_layers
))
...
...
megatron/model/language_model.py
View file @
4b506832
...
@@ -21,6 +21,7 @@ import torch.nn.functional as F
...
@@ -21,6 +21,7 @@ import torch.nn.functional as F
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
.module
import
MegatronModule
from
.module
import
MegatronModule
from
megatron.model.enums
import
LayerType
,
AttnMaskType
from
megatron.model.transformer
import
ParallelTransformer
from
megatron.model.transformer
import
ParallelTransformer
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
,
scaled_init_method_normal
from
megatron.model.utils
import
init_method_normal
,
scaled_init_method_normal
...
@@ -43,7 +44,9 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
...
@@ -43,7 +44,9 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def
get_language_model
(
attention_mask_func
,
num_tokentypes
,
add_pooler
,
def
get_language_model
(
attention_mask_func
,
num_tokentypes
,
add_pooler
,
init_method
=
None
,
scaled_init_method
=
None
):
add_decoder
=
False
,
init_method
=
None
,
scaled_init_method
=
None
,
self_attn_mask_type
=
AttnMaskType
.
padding
):
"""Build language model and return along with the key to save."""
"""Build language model and return along with the key to save."""
args
=
get_args
()
args
=
get_args
()
...
@@ -51,7 +54,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
...
@@ -51,7 +54,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method
=
init_method_normal
(
args
.
init_method_std
)
init_method
=
init_method_normal
(
args
.
init_method_std
)
if
scaled_init_method
is
None
:
if
scaled_init_method
is
None
:
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
# Language model.
# Language model.
args
=
[
attention_mask_func
,
init_method
,
scaled_init_method
]
args
=
[
attention_mask_func
,
init_method
,
scaled_init_method
]
...
@@ -60,6 +64,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
...
@@ -60,6 +64,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
is_pipeline_last_stage
():
cls
=
TransformerLanguageModel
cls
=
TransformerLanguageModel
kwargs
[
'num_tokentypes'
]
=
num_tokentypes
kwargs
[
'num_tokentypes'
]
=
num_tokentypes
kwargs
[
'self_attn_mask_type'
]
=
self_attn_mask_type
kwargs
[
'add_decoder'
]
=
add_decoder
kwargs
[
'add_pooler'
]
=
add_pooler
kwargs
[
'add_pooler'
]
=
add_pooler
elif
mpu
.
is_pipeline_first_stage
()
and
not
mpu
.
is_pipeline_last_stage
():
elif
mpu
.
is_pipeline_first_stage
()
and
not
mpu
.
is_pipeline_last_stage
():
cls
=
TransformerLanguageModelFirstStage
cls
=
TransformerLanguageModelFirstStage
...
@@ -186,8 +192,6 @@ class Embedding(MegatronModule):
...
@@ -186,8 +192,6 @@ class Embedding(MegatronModule):
if
tokentype_ids
is
not
None
:
if
tokentype_ids
is
not
None
:
assert
self
.
tokentype_embeddings
is
not
None
assert
self
.
tokentype_embeddings
is
not
None
embeddings
=
embeddings
+
self
.
tokentype_embeddings
(
tokentype_ids
)
embeddings
=
embeddings
+
self
.
tokentype_embeddings
(
tokentype_ids
)
else
:
assert
self
.
tokentype_embeddings
is
None
# Dropout.
# Dropout.
embeddings
=
self
.
embedding_dropout
(
embeddings
)
embeddings
=
self
.
embedding_dropout
(
embeddings
)
...
@@ -281,6 +285,8 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -281,6 +285,8 @@ class TransformerLanguageModelBase(MegatronModule):
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
num_tokentypes
=
0
,
num_tokentypes
=
0
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
add_decoder
=
False
,
add_pooler
=
False
):
add_pooler
=
False
):
super
(
TransformerLanguageModelBase
,
self
).
__init__
()
super
(
TransformerLanguageModelBase
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
...
@@ -288,6 +294,8 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -288,6 +294,8 @@ class TransformerLanguageModelBase(MegatronModule):
self
.
hidden_size
=
args
.
hidden_size
self
.
hidden_size
=
args
.
hidden_size
self
.
num_tokentypes
=
num_tokentypes
self
.
num_tokentypes
=
num_tokentypes
self
.
init_method
=
init_method
self
.
init_method
=
init_method
self
.
self_attn_mask_type
=
self_attn_mask_type
self
.
add_decoder
=
add_decoder
self
.
add_pooler
=
add_pooler
self
.
add_pooler
=
add_pooler
# Embeddings.
# Embeddings.
...
@@ -301,41 +309,87 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -301,41 +309,87 @@ class TransformerLanguageModelBase(MegatronModule):
self
.
_embedding_key
=
'embedding'
self
.
_embedding_key
=
'embedding'
# Transformer.
# Transformer.
self
.
transformer
=
ParallelTransformer
(
self
.
encoder
=
ParallelTransformer
(
attention_mask_func
,
self
.
init_method
,
attention_mask_func
,
output_layer_init_method
)
self
.
init_method
,
self
.
_transformer_key
=
'transformer'
output_layer_init_method
,
self_attn_mask_type
=
self_attn_mask_type
)
# Pooler.
self
.
_encoder_key
=
'encoder'
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_pooler
:
self
.
pooler
=
Pooler
(
self
.
hidden_size
,
self
.
init_method
)
# assuming pooler and decoder are in the last stage
self
.
_pooler_key
=
'pooler'
# of the pipeline(to be revised)
if
mpu
.
is_pipeline_last_stage
():
def
forward
(
self
,
language_model_input
,
attention_mask
,
# decoder
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
if
self
.
add_decoder
:
pooling_sequence_index
=
0
):
self
.
decoder
=
ParallelTransformer
(
attention_mask_func
,
self
.
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
decoder
,
self_attn_mask_type
=
AttnMaskType
.
causal
)
self
.
_decoder_key
=
'decoder'
# Pooler.
if
self
.
add_pooler
:
self
.
pooler
=
Pooler
(
self
.
hidden_size
,
self
.
init_method
)
self
.
_pooler_key
=
'pooler'
def
forward
(
self
,
enc_language_model_input
,
enc_attention_mask
,
dec_language_model_input
=
None
,
dec_attn_mask
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
# Embeddings.
# Embeddings.
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
(
input_ids
,
position_ids
)
=
language_model_input
(
input_ids
,
position_ids
)
=
enc_
language_model_input
embedding_output
=
self
.
embedding
(
input_ids
,
position_ids
,
embedding_output
=
self
.
embedding
(
input_ids
,
position_ids
,
tokentype_ids
=
tokentype_ids
)
tokentype_ids
=
tokentype_ids
)
transform
er_input
=
embedding_output
encod
er_input
=
embedding_output
else
:
else
:
transformer_input
=
language_model_input
encoder_input
=
enc_language_model_input
# Transformer.
# encoder.
transformer_output
=
self
.
transformer
(
transformer_input
,
if
enc_hidden_states
is
None
:
attention_mask
,
encoder_output
=
self
.
encoder
(
encoder_input
,
layer_past
=
layer_past
,
enc_attention_mask
,
get_key_value
=
get_key_value
)
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_pooler
:
else
:
pooled_output
=
self
.
pooler
(
transformer_output
,
encoder_output
=
enc_hidden_states
.
to
(
encoder_input
.
dtype
)
pooling_sequence_index
)
return
transformer_output
,
pooled_output
if
mpu
.
is_pipeline_last_stage
():
if
self
.
add_pooler
:
pooled_output
=
self
.
pooler
(
encoder_output
,
pooling_sequence_index
)
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if
not
self
.
add_decoder
or
output_enc_hidden
:
if
self
.
add_pooler
:
return
encoder_output
,
pooled_output
else
:
return
encoder_output
# Decoder Embedding
(
dec_input_ids
,
dec_position_ids
)
=
dec_language_model_input
dec_embedding_output
=
self
.
embedding
(
dec_input_ids
,
dec_position_ids
)
# decoder
decoder_output
=
self
.
decoder
(
dec_embedding_output
,
dec_attn_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
)
if
self
.
add_pooler
:
return
decoder_output
,
encoder_output
,
pooled_output
else
:
return
decoder_output
,
encoder_output
return
transform
er_output
return
encod
er_output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
...
@@ -346,13 +400,18 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -346,13 +400,18 @@ class TransformerLanguageModelBase(MegatronModule):
state_dict_
[
self
.
_embedding_key
]
\
state_dict_
[
self
.
_embedding_key
]
\
=
self
.
embedding
.
state_dict_for_save_checkpoint
(
=
self
.
embedding
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
destination
,
prefix
,
keep_vars
)
state_dict_
[
self
.
_
transform
er_key
]
\
state_dict_
[
self
.
_
encod
er_key
]
\
=
self
.
transform
er
.
state_dict_for_save_checkpoint
(
=
self
.
encod
er
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
destination
,
prefix
,
keep_vars
)
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_pooler
:
if
mpu
.
is_pipeline_last_stage
():
state_dict_
[
self
.
_pooler_key
]
\
if
self
.
add_pooler
:
=
self
.
pooler
.
state_dict_for_save_checkpoint
(
state_dict_
[
self
.
_pooler_key
]
\
destination
,
prefix
,
keep_vars
)
=
self
.
pooler
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
add_decoder
:
state_dict_
[
self
.
_decoder_key
]
\
=
self
.
decoder
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
return
state_dict_
...
@@ -371,23 +430,44 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -371,23 +430,44 @@ class TransformerLanguageModelBase(MegatronModule):
state_dict_
[
key
]
=
state_dict
[
key
]
state_dict_
[
key
]
=
state_dict
[
key
]
self
.
embedding
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
self
.
embedding
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
# Transformer.
# Encoder.
if
self
.
_transformer_key
in
state_dict
:
if
self
.
_encoder_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_transformer_key
]
state_dict_
=
state_dict
[
self
.
_encoder_key
]
# for backward compatibility.
elif
'transformer'
in
state_dict
:
state_dict_
=
state_dict
[
'transformer'
]
else
:
else
:
# for backward compatibility.
# for backward compatibility.
state_dict_
=
{}
state_dict_
=
{}
for
key
in
state_dict
.
keys
():
for
key
in
state_dict
.
keys
():
if
'transformer.'
in
key
:
if
'encoder.'
in
key
:
state_dict_
[
key
.
split
(
'transformer.'
)[
1
]]
=
state_dict
[
key
]
state_dict_
[
key
.
split
(
'encoder.'
)[
1
]]
=
state_dict
[
key
]
self
.
transformer
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
# for backward compatibility.
# Pooler.
state_dict_self_attention
=
{}
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_pooler
:
for
key
in
state_dict_
.
keys
():
assert
'pooler'
in
state_dict
,
\
if
'.attention.'
in
key
:
'could not find data for pooler in the checkpoint'
state_dict_self_attention
[
key
.
replace
(
".attention."
,
self
.
pooler
.
load_state_dict
(
state_dict
[
self
.
_pooler_key
],
".self_attention."
)]
=
state_dict_
[
key
]
strict
=
strict
)
else
:
state_dict_self_attention
[
key
]
=
state_dict_
[
key
]
state_dict_
=
state_dict_self_attention
self
.
encoder
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
if
mpu
.
is_pipeline_last_stage
():
# pooler
if
self
.
add_pooler
:
assert
'pooler'
in
state_dict
,
\
'could not find data for pooler in the checkpoint'
self
.
pooler
.
load_state_dict
(
state_dict
[
self
.
_pooler_key
],
strict
=
strict
)
# decoder
if
self
.
add_decoder
:
assert
'decoder'
in
state_dict
,
\
'could not find data for pooler in the checkpoint'
self
.
decoder
.
load_state_dict
(
state_dict
[
self
.
_decoder_key
],
strict
=
strict
)
class
TransformerLanguageModel
(
TransformerLanguageModelBase
):
class
TransformerLanguageModel
(
TransformerLanguageModelBase
):
...
@@ -400,24 +480,35 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
...
@@ -400,24 +480,35 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
num_tokentypes
=
0
,
num_tokentypes
=
0
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
add_decoder
=
False
,
add_pooler
=
False
):
add_pooler
=
False
):
super
(
TransformerLanguageModel
,
self
).
__init__
(
super
(
TransformerLanguageModel
,
self
).
__init__
(
attention_mask_func
,
attention_mask_func
,
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
self_attn_mask_type
=
self_attn_mask_type
,
add_decoder
=
add_decoder
,
add_pooler
=
add_pooler
)
add_pooler
=
add_pooler
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
pooling_sequence_index
=
0
):
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
return
super
(
TransformerLanguageModel
,
self
).
forward
(
return
super
(
TransformerLanguageModel
,
self
).
forward
(
(
input_ids
,
position_ids
),
(
enc_input_ids
,
enc_position_ids
),
attention_mask
,
enc_attention_mask
,
dec_language_model_input
=
(
dec_input_ids
,
dec_position_ids
),
dec_attn_mask
=
dec_attn_mask
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
tokentype_ids
=
tokentype_ids
,
tokentype_ids
=
tokentype_ids
,
layer_past
=
layer_past
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
get_key_value
=
get_key_value
,
pooling_sequence_index
=
pooling_sequence_index
pooling_sequence_index
=
pooling_sequence_index
,
enc_hidden_states
=
enc_hidden_states
,
output_enc_hidden
=
output_enc_hidden
)
)
...
@@ -430,12 +521,14 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
...
@@ -430,12 +521,14 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
attention_mask_func
,
attention_mask_func
,
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
num_tokentypes
=
0
):
num_tokentypes
=
0
,
self_attn_mask_type
=
AttnMaskType
.
padding
):
super
(
TransformerLanguageModelFirstStage
,
self
).
__init__
(
super
(
TransformerLanguageModelFirstStage
,
self
).
__init__
(
attention_mask_func
,
attention_mask_func
,
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
num_tokentypes
=
num_tokentypes
)
num_tokentypes
=
num_tokentypes
,
self_attn_mask_type
=
self_attn_mask_type
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
):
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
):
...
@@ -456,11 +549,13 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
...
@@ -456,11 +549,13 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
def
__init__
(
self
,
def
__init__
(
self
,
attention_mask_func
,
attention_mask_func
,
init_method
,
init_method
,
output_layer_init_method
):
output_layer_init_method
,
self_attn_mask_type
=
AttnMaskType
.
padding
):
super
(
TransformerLanguageModelIntermediateStage
,
self
).
__init__
(
super
(
TransformerLanguageModelIntermediateStage
,
self
).
__init__
(
attention_mask_func
,
attention_mask_func
,
init_method
,
init_method
,
output_layer_init_method
)
output_layer_init_method
,
self_attn_mask_type
=
self_attn_mask_type
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
layer_past
=
None
,
get_key_value
=
False
):
...
@@ -481,20 +576,31 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
...
@@ -481,20 +576,31 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
attention_mask_func
,
attention_mask_func
,
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
add_decoder
=
False
,
add_pooler
=
False
):
add_pooler
=
False
):
super
(
TransformerLanguageModelLastStage
,
self
).
__init__
(
super
(
TransformerLanguageModelLastStage
,
self
).
__init__
(
attention_mask_func
,
attention_mask_func
,
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
add_decoder
=
add_decoder
,
add_pooler
=
add_pooler
)
add_pooler
=
add_pooler
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
def
forward
(
self
,
hidden_states
,
enc_attention_mask
,
layer_past
=
None
,
get_key_value
=
False
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
pooling_sequence_index
=
0
):
enc_dec_attn_mask
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
return
super
(
TransformerLanguageModelLastStage
,
self
).
forward
(
return
super
(
TransformerLanguageModelLastStage
,
self
).
forward
(
hidden_states
,
hidden_states
,
attention_mask
,
enc_attention_mask
,
dec_language_input
=
(
dec_input_ids
,
dec_position_ids
),
dec_attn_mask
=
dec_attn_mask
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
layer_past
=
layer_past
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
get_key_value
=
get_key_value
,
pooling_sequence_index
=
pooling_sequence_index
pooling_sequence_index
=
pooling_sequence_index
,
enc_hidden_states
=
enc_hidden_states
,
ouput_enc_hidden
=
output_enc_hidden
)
)
megatron/model/transformer.py
View file @
4b506832
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
# limitations under the License.
# limitations under the License.
"""Transformer."""
"""Transformer."""
import
enum
import
math
import
math
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -23,6 +23,7 @@ from megatron import get_args
...
@@ -23,6 +23,7 @@ from megatron import get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
.module
import
MegatronModule
from
.module
import
MegatronModule
from
megatron.checkpointing
import
get_checkpoint_version
from
megatron.checkpointing
import
get_checkpoint_version
from
megatron.model.enums
import
AttnMaskType
,
LayerType
,
AttnType
from
megatron.model
import
import_layernorm
from
megatron.model
import
import_layernorm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
...
@@ -71,7 +72,7 @@ class ParallelMLP(MegatronModule):
...
@@ -71,7 +72,7 @@ class ParallelMLP(MegatronModule):
# Project to 4h.
# Project to 4h.
self
.
dense_h_to_4h
=
mpu
.
ColumnParallelLinear
(
self
.
dense_h_to_4h
=
mpu
.
ColumnParallelLinear
(
args
.
hidden_size
,
args
.
hidden_size
,
4
*
args
.
hidden_size
,
args
.
ffn_
hidden_size
,
gather_output
=
False
,
gather_output
=
False
,
init_method
=
init_method
,
init_method
=
init_method
,
skip_bias_add
=
True
)
skip_bias_add
=
True
)
...
@@ -85,12 +86,11 @@ class ParallelMLP(MegatronModule):
...
@@ -85,12 +86,11 @@ class ParallelMLP(MegatronModule):
# Project back to h.
# Project back to h.
self
.
dense_4h_to_h
=
mpu
.
RowParallelLinear
(
self
.
dense_4h_to_h
=
mpu
.
RowParallelLinear
(
4
*
args
.
hidden_size
,
args
.
ffn_
hidden_size
,
args
.
hidden_size
,
args
.
hidden_size
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
,
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
skip_bias_add
=
True
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
...
@@ -109,7 +109,7 @@ class ParallelMLP(MegatronModule):
...
@@ -109,7 +109,7 @@ class ParallelMLP(MegatronModule):
return
output
,
output_bias
return
output
,
output_bias
class
Parallel
Self
Attention
(
MegatronModule
):
class
ParallelAttention
(
MegatronModule
):
"""Parallel self-attention layer abstract class.
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
Self-attention layer takes input with size [b, s, h]
...
@@ -117,8 +117,10 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -117,8 +117,10 @@ class ParallelSelfAttention(MegatronModule):
"""
"""
def
__init__
(
self
,
attention_mask_func
,
init_method
,
def
__init__
(
self
,
attention_mask_func
,
init_method
,
output_layer_init_method
,
layer_number
):
output_layer_init_method
,
layer_number
,
super
(
ParallelSelfAttention
,
self
).
__init__
()
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
AttnMaskType
.
padding
):
super
(
ParallelAttention
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
self
.
fp16
=
args
.
fp16
self
.
fp16
=
args
.
fp16
...
@@ -128,22 +130,40 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -128,22 +130,40 @@ class ParallelSelfAttention(MegatronModule):
if
self
.
apply_query_key_layer_scaling
:
if
self
.
apply_query_key_layer_scaling
:
self
.
attention_softmax_in_fp32
=
True
self
.
attention_softmax_in_fp32
=
True
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
attention_type
=
attention_type
self
.
attn_mask_type
=
attn_mask_type
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
# Per attention head and per partition values.
# Per attention head and per partition values.
world_size
=
mpu
.
get_tensor_model_parallel_world_size
()
world_size
=
mpu
.
get_tensor_model_parallel_world_size
()
self
.
hidden_size_per_partition
=
mpu
.
divide
(
args
.
hidde
n_size
,
self
.
hidden_size_per_partition
=
mpu
.
divide
(
projectio
n_size
,
world_size
)
world_size
)
self
.
hidden_size_per_attention_head
=
mpu
.
divide
(
self
.
hidden_size_per_attention_head
=
mpu
.
divide
(
args
.
hidde
n_size
,
args
.
num_attention_heads
)
projectio
n_size
,
args
.
num_attention_heads
)
self
.
num_attention_heads_per_partition
=
mpu
.
divide
(
self
.
num_attention_heads_per_partition
=
mpu
.
divide
(
args
.
num_attention_heads
,
world_size
)
args
.
num_attention_heads
,
world_size
)
# Strided linear layer.
# Strided linear layer.
self
.
query_key_value
=
mpu
.
ColumnParallelLinear
(
if
attention_type
==
AttnType
.
self_attn
:
args
.
hidden_size
,
self
.
query_key_value
=
mpu
.
ColumnParallelLinear
(
3
*
args
.
hidden_size
,
args
.
hidden_size
,
gather_output
=
False
,
3
*
projection_size
,
init_method
=
init_method
)
gather_output
=
False
,
init_method
=
init_method
)
else
:
assert
attention_type
==
AttnType
.
cross_attn
self
.
query
=
mpu
.
ColumnParallelLinear
(
args
.
hidden_size
,
projection_size
,
gather_output
=
False
,
init_method
=
init_method
)
self
.
key_value
=
mpu
.
ColumnParallelLinear
(
args
.
hidden_size
,
2
*
projection_size
,
gather_output
=
False
,
init_method
=
init_method
)
coeff
=
None
coeff
=
None
self
.
norm_factor
=
math
.
sqrt
(
self
.
hidden_size_per_attention_head
)
self
.
norm_factor
=
math
.
sqrt
(
self
.
hidden_size_per_attention_head
)
...
@@ -153,8 +173,8 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -153,8 +173,8 @@ class ParallelSelfAttention(MegatronModule):
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
fp16
,
self
.
fp16
,
args
.
scaled_upper_triang_masked_softmax_fusion
,
self
.
attn_mask_type
,
args
.
scaled_
masked_softmax_fusion
,
args
.
masked_softmax_fusion
,
self
.
attention_mask_func
,
self
.
attention_mask_func
,
self
.
attention_softmax_in_fp32
,
self
.
attention_softmax_in_fp32
,
coeff
)
coeff
)
...
@@ -166,18 +186,18 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -166,18 +186,18 @@ class ParallelSelfAttention(MegatronModule):
# Output.
# Output.
self
.
dense
=
mpu
.
RowParallelLinear
(
self
.
dense
=
mpu
.
RowParallelLinear
(
args
.
hidde
n_size
,
projectio
n_size
,
args
.
hidden_size
,
args
.
hidden_size
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
,
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
skip_bias_add
=
True
)
def
_transpose_last_dim
(
self
,
mixed_layer
,
num_splits
,
num_splits_first
):
def
_transpose_last_dim
(
self
,
mixed_layer
,
num_splits
,
num_splits_first
):
input_shape
=
mixed_layer
.
size
()
;
input_shape
=
mixed_layer
.
size
()
if
num_splits_first
:
if
num_splits_first
:
"""[s, b, num_splits * np * hn]
"""[s, b, num_splits * np * hn]
-->(view) [s, b, num_splits, np, hn]
-->(view) [s, b, num_splits, np, hn]
-->(tranpose) [s, b, np, num_splits, hn]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
-->(view) [s, b, np * num_splits * hn] """
intermediate_shape
=
input_shape
[:
-
1
]
+
\
intermediate_shape
=
input_shape
[:
-
1
]
+
\
...
@@ -188,8 +208,8 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -188,8 +208,8 @@ class ParallelSelfAttention(MegatronModule):
mixed_layer
=
mixed_layer
.
transpose
(
-
2
,
-
3
).
contiguous
()
mixed_layer
=
mixed_layer
.
transpose
(
-
2
,
-
3
).
contiguous
()
else
:
else
:
"""[s, b, np * hn * num_splits]
"""[s, b, np * hn * num_splits]
-->(view) [s, b, np, hn, num_splits]
-->(view) [s, b, np, hn, num_splits]
-->(tranpose) [s, b, np, num_splits, hn]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
-->(view) [s, b, np * num_splits * hn] """
intermediate_shape
=
input_shape
[:
-
1
]
+
\
intermediate_shape
=
input_shape
[:
-
1
]
+
\
...
@@ -199,39 +219,70 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -199,39 +219,70 @@ class ParallelSelfAttention(MegatronModule):
mixed_layer
=
mixed_layer
.
view
(
*
intermediate_shape
)
mixed_layer
=
mixed_layer
.
view
(
*
intermediate_shape
)
mixed_layer
=
mixed_layer
.
transpose
(
-
1
,
-
2
).
contiguous
()
mixed_layer
=
mixed_layer
.
transpose
(
-
1
,
-
2
).
contiguous
()
mixed_layer
=
mixed_layer
.
view
(
*
input_shape
)
mixed_layer
=
mixed_layer
.
view
(
*
input_shape
)
return
mixed_layer
return
mixed_layer
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
get_key_value
=
False
,
encoder_output
=
None
):
# hidden_states: [sq, b, h]
# hidden_states: [sq, b, h]
# =====================
# =====================
# Query, Key, and Value
# Query, Key, and Value
# =====================
# =====================
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
if
self
.
attention_type
==
AttnType
.
self_attn
:
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
checkpoint_version
=
get_checkpoint_version
()
if
checkpoint_version
is
not
None
:
checkpoint_version
=
get_checkpoint_version
()
if
checkpoint_version
==
0
:
if
checkpoint_version
is
not
None
:
# [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
if
checkpoint_version
==
0
:
mixed_x_layer
=
self
.
_transpose_last_dim
(
mixed_x_layer
,
3
,
True
)
# [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
elif
checkpoint_version
==
1.0
:
mixed_x_layer
=
self
.
_transpose_last_dim
(
mixed_x_layer
,
3
,
True
)
# [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
elif
checkpoint_version
==
1.0
:
mixed_x_layer
=
self
.
_transpose_last_dim
(
mixed_x_layer
,
3
,
False
)
# [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
mixed_x_layer
=
self
.
_transpose_last_dim
(
mixed_x_layer
,
3
,
False
)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
\
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
(
self
.
num_attention_heads_per_partition
,
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
\
3
*
self
.
hidden_size_per_attention_head
)
(
self
.
num_attention_heads_per_partition
,
mixed_x_layer
=
mixed_x_layer
.
view
(
*
new_tensor_shape
)
3
*
self
.
hidden_size_per_attention_head
)
mixed_x_layer
=
mixed_x_layer
.
view
(
*
new_tensor_shape
)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(
query_layer
,
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
key_layer
,
(
query_layer
,
value_layer
)
=
mpu
.
split_tensor_along_last_dim
(
mixed_x_layer
,
3
)
key_layer
,
value_layer
)
=
mpu
.
split_tensor_along_last_dim
(
mixed_x_layer
,
3
)
else
:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer
,
_
=
self
.
key_value
(
encoder_output
)
checkpoint_version
=
get_checkpoint_version
()
if
checkpoint_version
is
not
None
:
if
checkpoint_version
==
0
:
# [s, b, (2 * np * hn)] --> [s, b, (np * 2 * hn)]
mixed_kv_layer
=
self
.
_transpose_last_dim
(
mixed_kv_layer
,
2
,
True
)
elif
checkpoint_version
==
1.0
:
# [s, b, (np * hn * 2)] --> [s, b, (np * 2 * hn)]
mixed_kv_layer
=
self
.
_transpose_last_dim
(
mixed_kv_layer
,
2
,
False
)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape
=
mixed_kv_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
2
*
self
.
hidden_size_per_attention_head
)
mixed_kv_layer
=
mixed_kv_layer
.
view
(
*
new_tensor_shape
)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(
key_layer
,
value_layer
)
=
mpu
.
split_tensor_along_last_dim
(
mixed_kv_layer
,
2
)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer
,
_
=
self
.
query
(
hidden_states
)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape
=
query_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
)
query_layer
=
query_layer
.
view
(
*
new_tensor_shape
)
# ==================================
# ==================================
# Adjust key and value for inference
# Adjust key and value for inference
...
@@ -246,41 +297,41 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -246,41 +297,41 @@ class ParallelSelfAttention(MegatronModule):
if
get_key_value
:
if
get_key_value
:
present
=
(
key_layer
,
value_layer
)
present
=
(
key_layer
,
value_layer
)
# ===================================
# ===================================
# Raw attention scores. [b, np, s, s]
# Raw attention scores. [b, np, s, s]
# ===================================
# ===================================
# [b, np, sq, sk]
# [b, np, sq, sk]
output_size
=
(
query_layer
.
size
(
1
),
output_size
=
(
query_layer
.
size
(
1
),
query_layer
.
size
(
2
),
query_layer
.
size
(
2
),
query_layer
.
size
(
0
),
query_layer
.
size
(
0
),
key_layer
.
size
(
0
))
key_layer
.
size
(
0
))
# [sq, b, np, hn] -> [sq, b * np, hn]
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer
=
query_layer
.
view
(
output_size
[
2
],
query_layer
=
query_layer
.
view
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer
=
key_layer
.
view
(
output_size
[
3
],
key_layer
=
key_layer
.
view
(
output_size
[
3
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# preallocting result tensor: [b * np, sq, sk]
# preallocting result tensor: [b * np, sq, sk]
matmul_result
=
torch
.
empty
(
matmul_result
=
torch
.
empty
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
2
],
output_size
[
3
],
output_size
[
3
],
dtype
=
query_layer
.
dtype
,
dtype
=
query_layer
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
device
=
torch
.
cuda
.
current_device
())
# Raw attention scores. [b * np, sq, sk]
# Raw attention scores. [b * np, sq, sk]
matmul_result
=
torch
.
baddbmm
(
matmul_result
,
matmul_result
=
torch
.
baddbmm
(
matmul_result
,
query_layer
.
transpose
(
0
,
1
),
# [b * np, sq, hn]
query_layer
.
transpose
(
0
,
1
),
# [b * np, sq, hn]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
#[b * np, hn, sk]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
#
[b * np, hn, sk]
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
))
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
))
# change view to [b, np, sq, sk]
# change view to [b, np, sq, sk]
attention_scores
=
matmul_result
.
view
(
*
output_size
)
attention_scores
=
matmul_result
.
view
(
*
output_size
)
# ==================================================
# ==================================================
# Update attention mask for inference. [b, np, sq, sk]
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================
# ==================================================
...
@@ -298,7 +349,6 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -298,7 +349,6 @@ class ParallelSelfAttention(MegatronModule):
:
attention_scores
.
size
(
3
),
:
attention_scores
.
size
(
3
),
:
attention_scores
.
size
(
3
)]
:
attention_scores
.
size
(
3
)]
# ===========================
# ===========================
# Attention probs and dropout
# Attention probs and dropout
# ===========================
# ===========================
...
@@ -312,7 +362,6 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -312,7 +362,6 @@ class ParallelSelfAttention(MegatronModule):
with
mpu
.
get_cuda_rng_tracker
().
fork
():
with
mpu
.
get_cuda_rng_tracker
().
fork
():
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
# =========================
# =========================
# Context layer. [sq, b, hp]
# Context layer. [sq, b, hp]
# =========================
# =========================
...
@@ -321,21 +370,21 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -321,21 +370,21 @@ class ParallelSelfAttention(MegatronModule):
# [sk, b, np, hn] --> [b, np, sq, hn]
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size
=
(
value_layer
.
size
(
1
),
output_size
=
(
value_layer
.
size
(
1
),
value_layer
.
size
(
2
),
value_layer
.
size
(
2
),
query_layer
.
size
(
0
),
query_layer
.
size
(
0
),
value_layer
.
size
(
3
))
value_layer
.
size
(
3
))
# change view [sk, b * np, hn]
# change view [sk, b * np, hn]
value_layer
=
value_layer
.
view
(
value_layer
.
size
(
0
),
value_layer
=
value_layer
.
view
(
value_layer
.
size
(
0
),
output_size
[
0
]
*
output_size
[
1
],
-
1
)
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# change view [b * np, sq, sk]
# change view [b * np, sq, sk]
attention_probs
=
attention_probs
.
view
(
output_size
[
0
]
*
output_size
[
1
],
attention_probs
=
attention_probs
.
view
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
-
1
)
output_size
[
2
],
-
1
)
# matmul: [b * np, sq, hn]
# matmul: [b * np, sq, hn]
context_layer
=
torch
.
bmm
(
attention_probs
,
value_layer
.
transpose
(
0
,
1
))
context_layer
=
torch
.
bmm
(
attention_probs
,
value_layer
.
transpose
(
0
,
1
))
# change view [b, np, sq, hn]
# change view [b, np, sq, hn]
context_layer
=
context_layer
.
view
(
*
output_size
)
context_layer
=
context_layer
.
view
(
*
output_size
)
...
@@ -348,7 +397,6 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -348,7 +397,6 @@ class ParallelSelfAttention(MegatronModule):
(
self
.
hidden_size_per_partition
,)
(
self
.
hidden_size_per_partition
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
# =================
# =================
# Output. [sq, b, h]
# Output. [sq, b, h]
# =================
# =================
...
@@ -389,16 +437,19 @@ def bias_dropout_add_fused_inference(x, bias, residual, prob) :
...
@@ -389,16 +437,19 @@ def bias_dropout_add_fused_inference(x, bias, residual, prob) :
class
ParallelTransformerLayer
(
MegatronModule
):
class
ParallelTransformerLayer
(
MegatronModule
):
"""A single transformer layer.
"""A single transformer layer.
Transform
or
e layer takes input with size [b, s, h] and returns an
Transforme
r
layer takes input with size [b, s, h] and returns an
output of the same size.
output of the same size.
"""
"""
def
__init__
(
self
,
attention_mask_func
,
init_method
,
def
__init__
(
self
,
attention_mask_func
,
init_method
,
output_layer_init_method
,
layer_number
):
output_layer_init_method
,
layer_number
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
):
args
=
get_args
()
args
=
get_args
()
super
(
ParallelTransformerLayer
,
self
).
__init__
()
super
(
ParallelTransformerLayer
,
self
).
__init__
()
self
.
layer_number
=
layer_number
self
.
layer_number
=
layer_number
self
.
layer_type
=
layer_type
self
.
apply_residual_connection_post_layernorm
\
self
.
apply_residual_connection_post_layernorm
\
=
args
.
apply_residual_connection_post_layernorm
=
args
.
apply_residual_connection_post_layernorm
...
@@ -410,45 +461,62 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -410,45 +461,62 @@ class ParallelTransformerLayer(MegatronModule):
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
)
# Self attention.
# Self attention.
self
.
attention
=
ParallelSelfAttention
(
attention_mask_func
,
init_method
,
self
.
self_attention
=
ParallelAttention
(
output_layer_init_method
,
attention_mask_func
,
layer_number
)
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
self_attn_mask_type
)
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
# Layernorm on the
input data.
# Layernorm on the
attention output
self
.
post_attention_layernorm
=
LayerNorm
(
self
.
post_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
)
if
self
.
layer_type
==
LayerType
.
decoder
:
self
.
inter_attention
=
ParallelAttention
(
attention_mask_func
,
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
cross_attn
)
# Layernorm on the attention output.
self
.
post_inter_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
# MLP
# MLP
self
.
mlp
=
ParallelMLP
(
init_method
,
self
.
mlp
=
ParallelMLP
(
init_method
,
output_layer_init_method
)
output_layer_init_method
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
get_key_value
=
False
):
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
layer_past
=
None
,
get_key_value
=
False
):
# hidden_states: [b, s, h]
# hidden_states: [b, s, h]
# Layer norm at the begining of the transformer layer.
# Layer norm at the begin
n
ing of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
# Self attention.
attention_output
,
attention_bias
=
\
attention_output
,
attention_bias
=
\
self
.
attention
(
layernorm_output
,
self
.
self_
attention
(
layernorm_output
,
attention_mask
,
attention_mask
,
layer_past
=
layer_past
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
get_key_value
=
get_key_value
)
if
get_key_value
:
if
get_key_value
:
attention_output
,
presents
=
attention_output
attention_output
,
presents
=
attention_output
# Residual connection.
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
residual
=
layernorm_output
else
:
else
:
residual
=
hidden_states
residual
=
hidden_states
# jit scripting for a nn.module (with dropout) is not
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
# dropout semantics during training and inference phases.
if
self
.
bias_dropout_fusion
:
if
self
.
bias_dropout_fusion
:
...
@@ -459,7 +527,7 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -459,7 +527,7 @@ class ParallelTransformerLayer(MegatronModule):
else
:
else
:
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
#re-enable torch grad to enable fused optimization.
#
re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_output
,
...
@@ -470,16 +538,38 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -470,16 +538,38 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention.
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
if
self
.
layer_type
==
LayerType
.
decoder
:
attention_output
,
attention_bias
=
\
self
.
inter_attention
(
layernorm_output
,
enc_dec_attn_mask
,
encoder_output
=
encoder_output
)
# residual connection
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
# Layer norm post the decoder attention
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
# MLP.
# MLP.
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
# Second residual connection.
# Second residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
residual
=
layernorm_output
else
:
else
:
residual
=
layernorm_input
residual
=
layernorm_input
#re-enable torch grad to enable fused optimization.
#
re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
with
torch
.
enable_grad
():
output
=
bias_dropout_add_func
(
output
=
bias_dropout_add_func
(
mlp_output
,
mlp_output
,
...
@@ -497,7 +587,9 @@ class ParallelTransformer(MegatronModule):
...
@@ -497,7 +587,9 @@ class ParallelTransformer(MegatronModule):
"""Transformer class."""
"""Transformer class."""
def
__init__
(
self
,
attention_mask_func
,
def
__init__
(
self
,
attention_mask_func
,
init_method
,
output_layer_init_method
):
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
):
super
(
ParallelTransformer
,
self
).
__init__
()
super
(
ParallelTransformer
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
...
@@ -516,7 +608,9 @@ class ParallelTransformer(MegatronModule):
...
@@ -516,7 +608,9 @@ class ParallelTransformer(MegatronModule):
def
build_layer
(
layer_number
):
def
build_layer
(
layer_number
):
return
ParallelTransformerLayer
(
return
ParallelTransformerLayer
(
attention_mask_func
,
init_method
,
attention_mask_func
,
init_method
,
output_layer_init_method
,
layer_number
)
output_layer_init_method
,
layer_number
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
)
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
self
.
layers
=
torch
.
nn
.
ModuleList
(
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
...
@@ -531,14 +625,18 @@ class ParallelTransformer(MegatronModule):
...
@@ -531,14 +625,18 @@ class ParallelTransformer(MegatronModule):
def
_get_layer
(
self
,
layer_number
):
def
_get_layer
(
self
,
layer_number
):
return
self
.
layers
[
layer_number
]
return
self
.
layers
[
layer_number
]
def
_checkpointed_forward
(
self
,
hidden_states
,
attention_mask
):
def
_checkpointed_forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
):
"""Forward method with activation checkpointing."""
"""Forward method with activation checkpointing."""
def
custom
(
start
,
end
):
def
custom
(
start
,
end
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
x_
=
inputs
[
0
]
x_
=
inputs
[
0
]
attention_mask
=
inputs
[
1
]
encoder_output
=
inputs
[
2
]
enc_dec_attn_mask
=
inputs
[
3
]
for
index
in
range
(
start
,
end
):
for
index
in
range
(
start
,
end
):
layer
=
self
.
_get_layer
(
index
)
layer
=
self
.
_get_layer
(
index
)
x_
=
layer
(
x_
,
inputs
[
1
]
)
x_
=
layer
(
x_
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
return
x_
return
x_
return
custom_forward
return
custom_forward
...
@@ -548,13 +646,13 @@ class ParallelTransformer(MegatronModule):
...
@@ -548,13 +646,13 @@ class ParallelTransformer(MegatronModule):
while
l
<
self
.
num_layers
:
while
l
<
self
.
num_layers
:
hidden_states
=
mpu
.
checkpoint
(
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
self
.
checkpoint_num_layers
),
custom
(
l
,
l
+
self
.
checkpoint_num_layers
),
hidden_states
,
attention_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
l
+=
self
.
checkpoint_num_layers
l
+=
self
.
checkpoint_num_layers
return
hidden_states
return
hidden_states
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
get_key_value
=
False
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
):
# Checks.
# Checks.
if
layer_past
is
not
None
:
if
layer_past
is
not
None
:
...
@@ -577,7 +675,9 @@ class ParallelTransformer(MegatronModule):
...
@@ -577,7 +675,9 @@ class ParallelTransformer(MegatronModule):
if
self
.
checkpoint_activations
:
if
self
.
checkpoint_activations
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
attention_mask
)
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
else
:
if
get_key_value
:
if
get_key_value
:
presents
=
[]
presents
=
[]
...
@@ -588,12 +688,14 @@ class ParallelTransformer(MegatronModule):
...
@@ -588,12 +688,14 @@ class ParallelTransformer(MegatronModule):
past
=
layer_past
[
index
]
past
=
layer_past
[
index
]
hidden_states
=
layer
(
hidden_states
,
hidden_states
=
layer
(
hidden_states
,
attention_mask
,
attention_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
layer_past
=
past
,
layer_past
=
past
,
get_key_value
=
get_key_value
)
get_key_value
=
get_key_value
)
if
get_key_value
:
if
get_key_value
:
hidden_states
,
present
=
hidden_states
hidden_states
,
present
=
hidden_states
presents
.
append
(
present
)
presents
.
append
(
present
)
# Final layer norm.
# Final layer norm.
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
# Reverting data format change [s b h] --> [b s h].
# Reverting data format change [s b h] --> [b s h].
...
...
pretrain_gpt2.py
View file @
4b506832
...
@@ -141,5 +141,4 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
...
@@ -141,5 +141,4 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
})
'scaled_upper_triang_masked_softmax_fusion'
:
True
})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment