Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
4b506832
Commit
4b506832
authored
Jan 08, 2021
by
Vijay Korthikanti
Browse files
decoder support in transformers
parent
f5eac3d1
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
520 additions
and
261 deletions
+520
-261
megatron/arguments.py
megatron/arguments.py
+31
-17
megatron/fused_kernels/scaled_masked_softmax.h
megatron/fused_kernels/scaled_masked_softmax.h
+43
-47
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
+13
-11
megatron/model/enums.py
megatron/model/enums.py
+28
-0
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+31
-21
megatron/model/gpt2_model.py
megatron/model/gpt2_model.py
+2
-0
megatron/model/language_model.py
megatron/model/language_model.py
+172
-66
megatron/model/transformer.py
megatron/model/transformer.py
+199
-97
pretrain_gpt2.py
pretrain_gpt2.py
+1
-2
No files found.
megatron/arguments.py
View file @
4b506832
...
...
@@ -164,6 +164,20 @@ def parse_args(extra_args_provider=None, defaults={},
_check_arg_is_not_none
(
args
,
req_arg
)
# Checks.
if
args
.
ffn_hidden_size
is
None
:
args
.
ffn_hidden_size
=
4
*
args
.
hidden_size
if
args
.
kv_channels
is
None
:
assert
args
.
hidden_size
%
args
.
num_attention_heads
==
0
args
.
kv_channels
=
args
.
hidden_size
//
args
.
num_attention_heads
if
args
.
seq_length
is
not
None
:
assert
args
.
encoder_seq_length
is
None
args
.
encoder_seq_length
=
args
.
seq_length
else
:
assert
args
.
encoder_seq_length
is
not
None
args
.
seq_length
=
args
.
encoder_seq_length
assert
args
.
hidden_size
%
args
.
num_attention_heads
==
0
if
args
.
seq_length
is
not
None
:
assert
args
.
max_position_embeddings
>=
args
.
seq_length
...
...
@@ -182,16 +196,11 @@ def parse_args(extra_args_provider=None, defaults={},
assert
args
.
checkpoint_activations
,
\
'for distribute-checkpointed-activations to work you '
\
'need to enable checkpoint-activations'
if
args
.
scaled_masked_softmax_fusion
:
if
args
.
scaled_upper_triang_masked_softmax_fusion
:
fused_kernels
.
load_scaled_upper_triang_masked_softmax_fusion_kernel
()
else
:
fused_kernels
.
load_scaled_masked_softmax_fusion_kernel
()
else
:
# This argument will eventually go away, for now make sure it is off
# if scaled_masked_softmax_fusion is off.
args
.
scaled_upper_triang_masked_softmax_fusion
=
False
# Load scaled_masked_softmax_fusion_kernels
if
args
.
masked_softmax_fusion
:
fused_kernels
.
load_scaled_upper_triang_masked_softmax_fusion_kernel
()
fused_kernels
.
load_scaled_masked_softmax_fusion_kernel
()
# Load mixed precision fused layer norm.
if
args
.
fp32_residual_connection
:
...
...
@@ -227,8 +236,14 @@ def _add_network_size_args(parser):
help
=
'Number of transformer layers.'
)
group
.
add_argument
(
'--hidden-size'
,
type
=
int
,
default
=
None
,
help
=
'Tansformer hidden size.'
)
group
.
add_argument
(
'--ffn-hidden-size'
,
type
=
int
,
default
=
None
,
help
=
'Transformer Feed-Forward Network hidden size. This is set to 4*hidden-size if not '
'provided'
)
group
.
add_argument
(
'--num-attention-heads'
,
type
=
int
,
default
=
None
,
help
=
'Number of transformer attention heads.'
)
group
.
add_argument
(
'--kv-channels'
,
type
=
int
,
default
=
None
,
help
=
'Projection weights dimension in multi-head attention. '
'This is set to args.hidden_size // args.num_attention_heads if not provided.'
)
group
.
add_argument
(
'--max-position-embeddings'
,
type
=
int
,
default
=
None
,
help
=
'Maximum number of position embeddings to use. '
'This is the size of position embedding.'
)
...
...
@@ -330,16 +345,11 @@ def _add_training_args(parser):
help
=
'Exit the program after this many minutes.'
)
group
.
add_argument
(
'--tensorboard-dir'
,
type
=
str
,
default
=
None
,
help
=
'Write TensorBoard logs to this directory.'
)
group
.
add_argument
(
'--no-
scaled-
masked-softmax-fusion'
,
group
.
add_argument
(
'--no-masked-softmax-fusion'
,
action
=
'store_false'
,
help
=
'Disable fusion of query_key_value scaling, '
'masking, and softmax.'
,
dest
=
'scaled_masked_softmax_fusion'
)
group
.
add_argument
(
'--scaled-upper-triang-masked-softmax-fusion'
,
type
=
bool
,
help
=
'Use upper triangular version of fused '
'scale, mask, softmax fusion kernel (default for GPT). '
'- DEPRECATED'
)
dest
=
'masked_softmax_fusion'
)
group
.
add_argument
(
'--no-bias-gelu-fusion'
,
action
=
'store_false'
,
help
=
'Disable bias and gelu fusion.'
,
dest
=
'bias_gelu_fusion'
)
...
...
@@ -530,6 +540,10 @@ def _add_data_args(parser):
help
=
'Path to the BPE merge file.'
)
group
.
add_argument
(
'--seq-length'
,
type
=
int
,
default
=
None
,
help
=
"Maximum sequence length to process."
)
group
.
add_argument
(
'--encoder-seq-length'
,
type
=
int
,
default
=
None
,
help
=
"Maximum encoder sequence length to process."
)
group
.
add_argument
(
'--decoder-seq-length'
,
type
=
int
,
default
=
None
,
help
=
"Maximum decoder sequence length to process."
)
group
.
add_argument
(
'--mask-prob'
,
type
=
float
,
default
=
0.15
,
help
=
'Probability of replacing a token with mask.'
)
group
.
add_argument
(
'--short-seq-prob'
,
type
=
float
,
default
=
0.1
,
...
...
megatron/fused_kernels/scaled_masked_softmax.h
View file @
4b506832
...
...
@@ -81,7 +81,6 @@ __global__ void scaled_masked_softmax_warp_forward(
const
uint8_t
*
mask
,
const
acc_t
scale
,
int
micro_batch_size
,
int
stride
,
int
element_count
,
int
pad_batches
)
{
...
...
@@ -111,9 +110,9 @@ __global__ void scaled_masked_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
stride
+
local_idx
;
dst
+=
first_batch
*
stride
+
local_idx
;
mask
+=
pad_first_batch
*
stride
+
local_idx
;
src
+=
first_batch
*
element_count
+
local_idx
;
dst
+=
first_batch
*
element_count
+
local_idx
;
mask
+=
pad_first_batch
*
element_count
+
local_idx
;
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
...
...
@@ -185,7 +184,6 @@ __global__ void scaled_masked_softmax_warp_backward(
const
input_t
*
output
,
acc_t
scale
,
int
micro_batch_size
,
int
stride
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
...
...
@@ -209,7 +207,7 @@ __global__ void scaled_masked_softmax_warp_backward(
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
stride
+
local_idx
;
int
thread_offset
=
first_batch
*
element_count
+
local_idx
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
...
...
@@ -277,20 +275,19 @@ void dispatch_scaled_masked_softmax_forward(
const
input_t
*
src
,
const
uint8_t
*
mask
,
const
input_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
,
int
pad_batches
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
if
(
softmax_elements
==
0
)
{
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
2048
);
if
(
key_seq_len
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
softmax_elements
);
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
seq_len
=
softmax_elements
;
int
batch_count
=
batches
*
attn_heads
*
seq_len
;
int
batch_count
=
batches
*
attn_heads
*
query_seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
...
...
@@ -302,59 +299,59 @@ void dispatch_scaled_masked_softmax_forward(
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
seq_len
%
batches_per_block
==
0
);
dim3
blocks
(
seq_len
/
batches_per_block
,
attn_heads
,
batches
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
query_
seq_len
%
batches_per_block
==
0
);
dim3
blocks
(
query_
seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
1
:
// 2
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
2
:
// 4
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
3
:
// 8
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
4
:
// 16
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
5
:
// 32
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
6
:
// 64
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
7
:
// 128
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
8
:
// 256
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
9
:
// 512
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
10
:
// 1024
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
11
:
// 2048
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
default:
break
;
...
...
@@ -368,19 +365,18 @@ void dispatch_scaled_masked_softmax_backward(
input_t
*
grad
,
const
input_t
*
output
,
const
acc_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
if
(
softmax_elements
==
0
)
{
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
2048
);
if
(
key_seq_len
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
softmax_elements
);
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
seq_len
=
softmax_elements
;
int
batch_count
=
batches
*
attn_heads
*
seq_len
;
int
batch_count
=
batches
*
attn_heads
*
query_seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
...
...
@@ -399,51 +395,51 @@ void dispatch_scaled_masked_softmax_backward(
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
1
:
// 2
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
2
:
// 4
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
3
:
// 8
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
4
:
// 16
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
5
:
// 32
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
6
:
// 64
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
7
:
// 128
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
8
:
// 256
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
9
:
// 512
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
10
:
// 1024
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
11
:
// 2048
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
default:
break
;
...
...
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
View file @
4b506832
...
...
@@ -37,17 +37,19 @@ torch::Tensor fwd_cuda(
const
int
batches
=
input
.
size
(
0
);
const
int
pad_batches
=
mask
.
size
(
0
);
const
int
attn_heads
=
input
.
size
(
1
);
const
int
seq_len
=
input
.
size
(
2
);
TORCH_INTERNAL_ASSERT
(
seq_len
<=
2048
);
const
int
query_seq_len
=
input
.
size
(
2
);
const
int
key_seq_len
=
input
.
size
(
3
);
TORCH_INTERNAL_ASSERT
(
key_seq_len
<=
2048
);
TORCH_INTERNAL_ASSERT
(
query_seq_len
>
1
);
TORCH_INTERNAL_ASSERT
(
pad_batches
==
1
||
pad_batches
==
batches
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
1
)
==
1
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
2
)
==
seq_len
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
3
)
==
seq_len
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
2
)
==
query_
seq_len
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
3
)
==
key_
seq_len
);
// Output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
batches
,
attn_heads
,
seq_len
,
seq_len
},
act_options
);
torch
::
empty
({
batches
,
attn_heads
,
query_
seq_len
,
key_
seq_len
},
act_options
);
// Softmax Intermediate Result Ptr
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
...
...
@@ -59,8 +61,8 @@ torch::Tensor fwd_cuda(
reinterpret_cast
<
const
half
*>
(
input_ptr
),
reinterpret_cast
<
const
uint8_t
*>
(
mask_ptr
),
scale_factor
,
seq_len
,
seq_len
,
query_
seq_len
,
key_
seq_len
,
batches
,
attn_heads
,
pad_batches
);
...
...
@@ -78,8 +80,8 @@ torch::Tensor bwd_cuda(
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
output_grads
.
size
(
0
);
const
int
attn_heads
=
output_grads
.
size
(
1
);
const
int
seq_len
=
output_grads
.
size
(
2
);
TORCH_INTERNAL_ASSERT
(
output_grads
.
size
(
2
)
=
=
output_grads
.
size
(
3
)
)
;
const
int
query_
seq_len
=
output_grads
.
size
(
2
);
const
int
key_seq_len
=
output_grads
.
size
(
3
);
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
...
...
@@ -89,8 +91,8 @@ torch::Tensor bwd_cuda(
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
seq_len
,
seq_len
,
query_
seq_len
,
key_
seq_len
,
batches
,
attn_heads
);
...
...
megatron/model/enums.py
0 → 100644
View file @
4b506832
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
enum
class
LayerType
(
enum
.
Enum
):
encoder
=
1
decoder
=
2
class
AttnType
(
enum
.
Enum
):
self_attn
=
1
cross_attn
=
2
class
AttnMaskType
(
enum
.
Enum
):
padding
=
1
causal
=
2
megatron/model/fused_softmax.py
View file @
4b506832
...
...
@@ -14,11 +14,13 @@
# limitations under the License.
import
torch
from
megatron.model.enums
import
AttnMaskType
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
)
:
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
...
...
@@ -38,15 +40,16 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function) :
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
\
scaled_upper_triang_masked_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
])
scaled_upper_triang_masked_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
])
return
input_grads
,
None
class
ScaledMaskedSoftmax
(
torch
.
autograd
.
Function
)
:
class
ScaledMaskedSoftmax
(
torch
.
autograd
.
Function
):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
"""
...
...
@@ -71,24 +74,25 @@ class ScaledMaskedSoftmax(torch.autograd.Function) :
scale_t
[
0
])
return
input_grads
,
None
,
None
class
FusedScaleMaskSoftmax
(
torch
.
nn
.
Module
):
"""
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
upper_triang_mask: if true, apply upper triangular masking.
(used in gpt family networks)
attn_mask_type: attention mask type (pad or causal)
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def
__init__
(
self
,
input_in_fp16
,
upper_triang_mask_fusion
,
general_mask_fusion
,
mask_func
,
softmax_in_fp32
,
scale
):
def
__init__
(
self
,
input_in_fp16
,
attn_mask_type
,
scaled_masked_softmax_fusion
,
mask_func
,
softmax_in_fp32
,
scale
):
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
upper_triang_mask_fusion
=
upper_triang_mask_fusion
self
.
general_mask_fusion
=
general_mask
_fusion
self
.
attn_mask_type
=
attn_mask_type
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax
_fusion
self
.
mask_func
=
mask_func
self
.
softmax_in_fp32
=
softmax_in_fp32
self
.
scale
=
scale
...
...
@@ -97,20 +101,26 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
'softmax should be in fp32 when scaled'
def
forward
(
self
,
input
,
mask
):
# [b, np, s, s]
# [b, np, s
q
, s
k
]
data_size
=
input
.
size
()
assert
input
.
dim
()
==
4
query_seq_len
=
data_size
[
-
2
]
key_seq_len
=
data_size
[
-
1
]
assert
input
.
dim
()
==
4
# invoke custom kernel
if
self
.
input_in_fp16
and
data_size
[
-
1
]
<=
2048
and
\
(
self
.
upper_triang_mask_fusion
or
self
.
general_mask_fusion
)
and
\
input
.
size
()[
2
]
==
input
.
size
()[
3
]:
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
upper_triang_mask_fusion
:
input
=
input
.
view
(
-
1
,
data_size
[
2
],
data_size
[
3
])
if
self
.
input_in_fp16
and
key_seq_len
<=
2048
and
\
query_seq_len
%
4
==
0
and
self
.
scaled_masked_softmax_fusion
:
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
assert
query_seq_len
==
key_seq_len
,
\
"causal mask is only for self attention"
input
=
input
.
view
(
-
1
,
query_seq_len
,
key_seq_len
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
probs
=
probs
.
view
(
*
data_size
)
else
:
assert
self
.
attn_mask_type
==
AttnMaskType
.
padding
probs
=
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
else
:
if
self
.
input_in_fp16
and
self
.
softmax_in_fp32
:
...
...
megatron/model/gpt2_model.py
View file @
4b506832
...
...
@@ -21,6 +21,7 @@ from megatron import get_args
from
megatron
import
mpu
from
.module
import
MegatronModule
from
.enums
import
AttnMaskType
from
.language_model
import
parallel_lm_logits
from
.language_model
import
get_language_model
from
.utils
import
init_method_normal
...
...
@@ -75,6 +76,7 @@ class GPT2ModelBase(MegatronModule):
attention_mask_func
=
gpt2_attention_mask_func
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
False
,
self_attn_mask_type
=
AttnMaskType
.
causal
,
init_method
=
init_method_normal
(
args
.
init_method_std
),
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
...
...
megatron/model/language_model.py
View file @
4b506832
...
...
@@ -21,6 +21,7 @@ import torch.nn.functional as F
from
megatron
import
get_args
from
megatron
import
mpu
from
.module
import
MegatronModule
from
megatron.model.enums
import
LayerType
,
AttnMaskType
from
megatron.model.transformer
import
ParallelTransformer
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
,
scaled_init_method_normal
...
...
@@ -43,7 +44,9 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def
get_language_model
(
attention_mask_func
,
num_tokentypes
,
add_pooler
,
init_method
=
None
,
scaled_init_method
=
None
):
add_decoder
=
False
,
init_method
=
None
,
scaled_init_method
=
None
,
self_attn_mask_type
=
AttnMaskType
.
padding
):
"""Build language model and return along with the key to save."""
args
=
get_args
()
...
...
@@ -51,7 +54,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method
=
init_method_normal
(
args
.
init_method_std
)
if
scaled_init_method
is
None
:
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
# Language model.
args
=
[
attention_mask_func
,
init_method
,
scaled_init_method
]
...
...
@@ -60,6 +64,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
is_pipeline_last_stage
():
cls
=
TransformerLanguageModel
kwargs
[
'num_tokentypes'
]
=
num_tokentypes
kwargs
[
'self_attn_mask_type'
]
=
self_attn_mask_type
kwargs
[
'add_decoder'
]
=
add_decoder
kwargs
[
'add_pooler'
]
=
add_pooler
elif
mpu
.
is_pipeline_first_stage
()
and
not
mpu
.
is_pipeline_last_stage
():
cls
=
TransformerLanguageModelFirstStage
...
...
@@ -186,8 +192,6 @@ class Embedding(MegatronModule):
if
tokentype_ids
is
not
None
:
assert
self
.
tokentype_embeddings
is
not
None
embeddings
=
embeddings
+
self
.
tokentype_embeddings
(
tokentype_ids
)
else
:
assert
self
.
tokentype_embeddings
is
None
# Dropout.
embeddings
=
self
.
embedding_dropout
(
embeddings
)
...
...
@@ -281,6 +285,8 @@ class TransformerLanguageModelBase(MegatronModule):
init_method
,
output_layer_init_method
,
num_tokentypes
=
0
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
add_decoder
=
False
,
add_pooler
=
False
):
super
(
TransformerLanguageModelBase
,
self
).
__init__
()
args
=
get_args
()
...
...
@@ -288,6 +294,8 @@ class TransformerLanguageModelBase(MegatronModule):
self
.
hidden_size
=
args
.
hidden_size
self
.
num_tokentypes
=
num_tokentypes
self
.
init_method
=
init_method
self
.
self_attn_mask_type
=
self_attn_mask_type
self
.
add_decoder
=
add_decoder
self
.
add_pooler
=
add_pooler
# Embeddings.
...
...
@@ -301,41 +309,87 @@ class TransformerLanguageModelBase(MegatronModule):
self
.
_embedding_key
=
'embedding'
# Transformer.
self
.
transformer
=
ParallelTransformer
(
attention_mask_func
,
self
.
init_method
,
output_layer_init_method
)
self
.
_transformer_key
=
'transformer'
# Pooler.
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_pooler
:
self
.
pooler
=
Pooler
(
self
.
hidden_size
,
self
.
init_method
)
self
.
_pooler_key
=
'pooler'
def
forward
(
self
,
language_model_input
,
attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
):
self
.
encoder
=
ParallelTransformer
(
attention_mask_func
,
self
.
init_method
,
output_layer_init_method
,
self_attn_mask_type
=
self_attn_mask_type
)
self
.
_encoder_key
=
'encoder'
# assuming pooler and decoder are in the last stage
# of the pipeline(to be revised)
if
mpu
.
is_pipeline_last_stage
():
# decoder
if
self
.
add_decoder
:
self
.
decoder
=
ParallelTransformer
(
attention_mask_func
,
self
.
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
decoder
,
self_attn_mask_type
=
AttnMaskType
.
causal
)
self
.
_decoder_key
=
'decoder'
# Pooler.
if
self
.
add_pooler
:
self
.
pooler
=
Pooler
(
self
.
hidden_size
,
self
.
init_method
)
self
.
_pooler_key
=
'pooler'
def
forward
(
self
,
enc_language_model_input
,
enc_attention_mask
,
dec_language_model_input
=
None
,
dec_attn_mask
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
# Embeddings.
if
mpu
.
is_pipeline_first_stage
():
(
input_ids
,
position_ids
)
=
language_model_input
(
input_ids
,
position_ids
)
=
enc_
language_model_input
embedding_output
=
self
.
embedding
(
input_ids
,
position_ids
,
tokentype_ids
=
tokentype_ids
)
transform
er_input
=
embedding_output
encod
er_input
=
embedding_output
else
:
transformer_input
=
language_model_input
# Transformer.
transformer_output
=
self
.
transformer
(
transformer_input
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_pooler
:
pooled_output
=
self
.
pooler
(
transformer_output
,
pooling_sequence_index
)
return
transformer_output
,
pooled_output
encoder_input
=
enc_language_model_input
# encoder.
if
enc_hidden_states
is
None
:
encoder_output
=
self
.
encoder
(
encoder_input
,
enc_attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
else
:
encoder_output
=
enc_hidden_states
.
to
(
encoder_input
.
dtype
)
if
mpu
.
is_pipeline_last_stage
():
if
self
.
add_pooler
:
pooled_output
=
self
.
pooler
(
encoder_output
,
pooling_sequence_index
)
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if
not
self
.
add_decoder
or
output_enc_hidden
:
if
self
.
add_pooler
:
return
encoder_output
,
pooled_output
else
:
return
encoder_output
# Decoder Embedding
(
dec_input_ids
,
dec_position_ids
)
=
dec_language_model_input
dec_embedding_output
=
self
.
embedding
(
dec_input_ids
,
dec_position_ids
)
# decoder
decoder_output
=
self
.
decoder
(
dec_embedding_output
,
dec_attn_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
)
if
self
.
add_pooler
:
return
decoder_output
,
encoder_output
,
pooled_output
else
:
return
decoder_output
,
encoder_output
return
transform
er_output
return
encod
er_output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
...
...
@@ -346,13 +400,18 @@ class TransformerLanguageModelBase(MegatronModule):
state_dict_
[
self
.
_embedding_key
]
\
=
self
.
embedding
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
state_dict_
[
self
.
_
transform
er_key
]
\
=
self
.
transform
er
.
state_dict_for_save_checkpoint
(
state_dict_
[
self
.
_
encod
er_key
]
\
=
self
.
encod
er
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_pooler
:
state_dict_
[
self
.
_pooler_key
]
\
=
self
.
pooler
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
mpu
.
is_pipeline_last_stage
():
if
self
.
add_pooler
:
state_dict_
[
self
.
_pooler_key
]
\
=
self
.
pooler
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
add_decoder
:
state_dict_
[
self
.
_decoder_key
]
\
=
self
.
decoder
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
...
...
@@ -371,23 +430,44 @@ class TransformerLanguageModelBase(MegatronModule):
state_dict_
[
key
]
=
state_dict
[
key
]
self
.
embedding
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
# Transformer.
if
self
.
_transformer_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_transformer_key
]
# Encoder.
if
self
.
_encoder_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_encoder_key
]
# for backward compatibility.
elif
'transformer'
in
state_dict
:
state_dict_
=
state_dict
[
'transformer'
]
else
:
# for backward compatibility.
state_dict_
=
{}
for
key
in
state_dict
.
keys
():
if
'transformer.'
in
key
:
state_dict_
[
key
.
split
(
'transformer.'
)[
1
]]
=
state_dict
[
key
]
self
.
transformer
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
# Pooler.
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_pooler
:
assert
'pooler'
in
state_dict
,
\
'could not find data for pooler in the checkpoint'
self
.
pooler
.
load_state_dict
(
state_dict
[
self
.
_pooler_key
],
strict
=
strict
)
if
'encoder.'
in
key
:
state_dict_
[
key
.
split
(
'encoder.'
)[
1
]]
=
state_dict
[
key
]
# for backward compatibility.
state_dict_self_attention
=
{}
for
key
in
state_dict_
.
keys
():
if
'.attention.'
in
key
:
state_dict_self_attention
[
key
.
replace
(
".attention."
,
".self_attention."
)]
=
state_dict_
[
key
]
else
:
state_dict_self_attention
[
key
]
=
state_dict_
[
key
]
state_dict_
=
state_dict_self_attention
self
.
encoder
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
if
mpu
.
is_pipeline_last_stage
():
# pooler
if
self
.
add_pooler
:
assert
'pooler'
in
state_dict
,
\
'could not find data for pooler in the checkpoint'
self
.
pooler
.
load_state_dict
(
state_dict
[
self
.
_pooler_key
],
strict
=
strict
)
# decoder
if
self
.
add_decoder
:
assert
'decoder'
in
state_dict
,
\
'could not find data for pooler in the checkpoint'
self
.
decoder
.
load_state_dict
(
state_dict
[
self
.
_decoder_key
],
strict
=
strict
)
class
TransformerLanguageModel
(
TransformerLanguageModelBase
):
...
...
@@ -400,24 +480,35 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
init_method
,
output_layer_init_method
,
num_tokentypes
=
0
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
add_decoder
=
False
,
add_pooler
=
False
):
super
(
TransformerLanguageModel
,
self
).
__init__
(
attention_mask_func
,
init_method
,
output_layer_init_method
,
num_tokentypes
=
num_tokentypes
,
self_attn_mask_type
=
self_attn_mask_type
,
add_decoder
=
add_decoder
,
add_pooler
=
add_pooler
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
):
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_attention_mask
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
return
super
(
TransformerLanguageModel
,
self
).
forward
(
(
input_ids
,
position_ids
),
attention_mask
,
(
enc_input_ids
,
enc_position_ids
),
enc_attention_mask
,
dec_language_model_input
=
(
dec_input_ids
,
dec_position_ids
),
dec_attn_mask
=
dec_attn_mask
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
tokentype_ids
=
tokentype_ids
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
pooling_sequence_index
=
pooling_sequence_index
pooling_sequence_index
=
pooling_sequence_index
,
enc_hidden_states
=
enc_hidden_states
,
output_enc_hidden
=
output_enc_hidden
)
...
...
@@ -430,12 +521,14 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
attention_mask_func
,
init_method
,
output_layer_init_method
,
num_tokentypes
=
0
):
num_tokentypes
=
0
,
self_attn_mask_type
=
AttnMaskType
.
padding
):
super
(
TransformerLanguageModelFirstStage
,
self
).
__init__
(
attention_mask_func
,
init_method
,
output_layer_init_method
,
num_tokentypes
=
num_tokentypes
)
num_tokentypes
=
num_tokentypes
,
self_attn_mask_type
=
self_attn_mask_type
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
):
...
...
@@ -456,11 +549,13 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
def
__init__
(
self
,
attention_mask_func
,
init_method
,
output_layer_init_method
):
output_layer_init_method
,
self_attn_mask_type
=
AttnMaskType
.
padding
):
super
(
TransformerLanguageModelIntermediateStage
,
self
).
__init__
(
attention_mask_func
,
init_method
,
output_layer_init_method
)
output_layer_init_method
,
self_attn_mask_type
=
self_attn_mask_type
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
...
...
@@ -481,20 +576,31 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
attention_mask_func
,
init_method
,
output_layer_init_method
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
add_decoder
=
False
,
add_pooler
=
False
):
super
(
TransformerLanguageModelLastStage
,
self
).
__init__
(
attention_mask_func
,
init_method
,
output_layer_init_method
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
add_decoder
=
add_decoder
,
add_pooler
=
add_pooler
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
):
def
forward
(
self
,
hidden_states
,
enc_attention_mask
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
enc_dec_attn_mask
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
return
super
(
TransformerLanguageModelLastStage
,
self
).
forward
(
hidden_states
,
attention_mask
,
enc_attention_mask
,
dec_language_input
=
(
dec_input_ids
,
dec_position_ids
),
dec_attn_mask
=
dec_attn_mask
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
pooling_sequence_index
=
pooling_sequence_index
pooling_sequence_index
=
pooling_sequence_index
,
enc_hidden_states
=
enc_hidden_states
,
ouput_enc_hidden
=
output_enc_hidden
)
megatron/model/transformer.py
View file @
4b506832
This diff is collapsed.
Click to expand it.
pretrain_gpt2.py
View file @
4b506832
...
...
@@ -141,5 +141,4 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
,
'scaled_upper_triang_masked_softmax_fusion'
:
True
})
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment