Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
cb00a196
Commit
cb00a196
authored
Sep 29, 2021
by
Jared Casper
Browse files
Merge branch 'main' into t5_pipeline_parallelism
parents
38a774e9
5ab64637
Changes
44
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1025 additions
and
637 deletions
+1025
-637
megatron/arguments.py
megatron/arguments.py
+44
-14
megatron/fused_kernels/scaled_masked_softmax.cpp
megatron/fused_kernels/scaled_masked_softmax.cpp
+21
-1
megatron/fused_kernels/scaled_masked_softmax.h
megatron/fused_kernels/scaled_masked_softmax.h
+14
-1
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
+5
-0
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
+2
-0
megatron/fused_kernels/tests/__init__.py
megatron/fused_kernels/tests/__init__.py
+0
-0
megatron/fused_kernels/tests/test_fused_kernels.py
megatron/fused_kernels/tests/test_fused_kernels.py
+300
-0
megatron/initialize.py
megatron/initialize.py
+29
-14
megatron/model/fused_bias_gelu.py
megatron/model/fused_bias_gelu.py
+0
-4
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+73
-45
megatron/model/gpt_model.py
megatron/model/gpt_model.py
+6
-15
megatron/model/language_model.py
megatron/model/language_model.py
+16
-12
megatron/model/transformer.py
megatron/model/transformer.py
+135
-87
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+0
-2
megatron/mpu/layers.py
megatron/mpu/layers.py
+51
-5
megatron/mpu/random.py
megatron/mpu/random.py
+23
-46
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+2
-2
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+10
-10
megatron/text_generation_server.py
megatron/text_generation_server.py
+96
-0
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+198
-379
No files found.
megatron/arguments.py
View file @
cb00a196
...
@@ -97,6 +97,13 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -97,6 +97,13 @@ def parse_args(extra_args_provider=None, defaults={},
assert
args
.
model_parallel_size
is
None
,
'--model-parallel-size is no '
\
assert
args
.
model_parallel_size
is
None
,
'--model-parallel-size is no '
\
'longer valid, use --tensor-model-parallel-size instead'
'longer valid, use --tensor-model-parallel-size instead'
del
args
.
model_parallel_size
del
args
.
model_parallel_size
if
args
.
checkpoint_activations
:
args
.
activations_checkpoint_method
=
'uniform'
if
args
.
rank
==
0
:
print
(
'--checkpoint-activations is no longer valid, '
'use --activation-checkpoint-method instead. '
'Defaulting to activation-checkpoint-method=uniform.'
)
del
args
.
checkpoint_activations
# Set input defaults.
# Set input defaults.
for
key
in
defaults
:
for
key
in
defaults
:
...
@@ -154,16 +161,15 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -154,16 +161,15 @@ def parse_args(extra_args_provider=None, defaults={},
print
(
'using {} for parameters ...'
.
format
(
args
.
params_dtype
),
print
(
'using {} for parameters ...'
.
format
(
args
.
params_dtype
),
flush
=
True
)
flush
=
True
)
# If we do accumulation and all-reduces in fp32, we need to have
# If we do accumulation and all-reduces in fp32, we need to have
local DDP
#
local DDP
and we should
set th
e use-contiguous-buffers-in-
ddp
.
# and we should
make sur
e use-contiguous-buffers-in-
local-ddp is not off
.
if
args
.
accumulate_allreduce_grads_in_fp32
:
if
args
.
accumulate_allreduce_grads_in_fp32
:
assert
args
.
DDP_impl
==
'local'
assert
args
.
DDP_impl
==
'local'
args
.
use_contiguous_buffers_in_
ddp
=
True
assert
args
.
use_contiguous_buffers_in_
local_ddp
# If we use a contiguous buffer to hold main grads, we need to have
# For torch DDP, we do not use contiguous buffer
# local DDP.
if
args
.
DDP_impl
==
'torch'
:
if
args
.
use_contiguous_buffers_in_ddp
:
args
.
use_contiguous_buffers_in_local_ddp
=
False
assert
args
.
DDP_impl
==
'local'
if
args
.
dataloader_type
is
None
:
if
args
.
dataloader_type
is
None
:
args
.
dataloader_type
=
'single'
args
.
dataloader_type
=
'single'
...
@@ -240,9 +246,15 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -240,9 +246,15 @@ def parse_args(extra_args_provider=None, defaults={},
'residual connection in fp32 only supported when using fp16 or bf16.'
'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing.
# Activation checkpointing.
if
args
.
distribute_checkpointed_activations
:
if
args
.
distribute_checkpointed_activations
:
assert
args
.
checkpoint_activations
,
\
assert
args
.
tensor_model_parallel_size
>
1
,
'can distribute '
\
'checkpointed activations only across tensor model '
\
'parallel groups'
assert
args
.
activations_checkpoint_method
is
not
None
,
\
'for distribute-checkpointed-activations to work you '
\
'for distribute-checkpointed-activations to work you '
\
'need to enable checkpoint-activations'
'need to use a activation-checkpoint method '
assert
args
.
num_layers_per_virtual_pipeline_stage
is
None
,
\
'currently distrobuted checkpoint activations only supported for '
\
'nointerleaved pipeline parallelism'
_print_args
(
args
)
_print_args
(
args
)
return
args
return
args
...
@@ -408,8 +420,20 @@ def _add_training_args(parser):
...
@@ -408,8 +420,20 @@ def _add_training_args(parser):
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'If set, distribute checkpointed activations '
help
=
'If set, distribute checkpointed activations '
'across model parallel group.'
)
'across model parallel group.'
)
group
.
add_argument
(
'--checkpoint-num-layers'
,
type
=
int
,
default
=
1
,
group
.
add_argument
(
'--activations-checkpoint-method'
,
type
=
str
,
default
=
None
,
help
=
'chunk size (number of layers) for checkpointing.'
)
choices
=
[
'uniform'
,
'block'
],
help
=
'1) uniform: uniformly divide the total number of '
'Transformer layers and checkpoint the input activation of '
'each divided chunk, '
'2) checkpoint the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the '
'rest without any checkpointing'
'default) do not apply activations checkpoint to any layers'
)
group
.
add_argument
(
'--activations-checkpoint-num-layers'
,
type
=
int
,
default
=
1
,
help
=
'1) uniform: the number of Transformer layers in each '
'uniformly divided checkpoint unit, '
'2) block: the number of individual Transformer layers '
'to checkpoint within each pipeline stage.'
)
group
.
add_argument
(
'--train-iters'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--train-iters'
,
type
=
int
,
default
=
None
,
help
=
'Total number of iterations to train over all '
help
=
'Total number of iterations to train over all '
'training runs. Note that either train-iters or '
'training runs. Note that either train-iters or '
...
@@ -444,6 +468,11 @@ def _add_training_args(parser):
...
@@ -444,6 +468,11 @@ def _add_training_args(parser):
group
.
add_argument
(
'--dataloader-type'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--dataloader-type'
,
type
=
str
,
default
=
None
,
choices
=
[
'single'
,
'cyclic'
],
choices
=
[
'single'
,
'cyclic'
],
help
=
'Single pass vs multiple pass data loader'
)
help
=
'Single pass vs multiple pass data loader'
)
group
.
add_argument
(
'--no-async-tensor-model-parallel-allreduce'
,
action
=
'store_true'
,
help
=
'Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.'
)
return
parser
return
parser
...
@@ -593,9 +622,10 @@ def _add_distributed_args(parser):
...
@@ -593,9 +622,10 @@ def _add_distributed_args(parser):
choices
=
[
'local'
,
'torch'
],
choices
=
[
'local'
,
'torch'
],
help
=
'which DistributedDataParallel implementation '
help
=
'which DistributedDataParallel implementation '
'to use.'
)
'to use.'
)
group
.
add_argument
(
'--use-contiguous-buffers-in-ddp'
,
action
=
'store_true'
,
group
.
add_argument
(
'--no-contiguous-buffers-in-local-ddp'
,
help
=
'If set, use contiguous buffer in DDP. Note that '
action
=
'store_false'
,
help
=
'If set, dont use '
'this option only works woth local DDP.'
)
'contiguous buffer in local DDP.'
,
dest
=
'use_contiguous_buffers_in_local_ddp'
)
group
.
add_argument
(
'--no-scatter-gather-tensors-in-pipeline'
,
action
=
'store_false'
,
group
.
add_argument
(
'--no-scatter-gather-tensors-in-pipeline'
,
action
=
'store_false'
,
help
=
'Use scatter/gather to optimize communication of tensors in pipeline'
,
help
=
'Use scatter/gather to optimize communication of tensors in pipeline'
,
dest
=
'scatter_gather_tensors_in_pipeline'
)
dest
=
'scatter_gather_tensors_in_pipeline'
)
...
...
megatron/fused_kernels/scaled_masked_softmax.cpp
View file @
cb00a196
...
@@ -32,6 +32,12 @@ torch::Tensor bwd_cuda(
...
@@ -32,6 +32,12 @@ torch::Tensor bwd_cuda(
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
);
float
scale_factor
);
int
get_batch_per_block_cuda
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
);
torch
::
Tensor
fwd
(
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
torch
::
Tensor
const
&
mask
,
...
@@ -63,6 +69,14 @@ torch::Tensor bwd(
...
@@ -63,6 +69,14 @@ torch::Tensor bwd(
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
}
int
get_batch_per_block
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
)
{
return
get_batch_per_block_cuda
(
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
}
}
// end namespace scaled_masked_softmax
}
// end namespace scaled_masked_softmax
}
// end namespace fused_softmax
}
// end namespace fused_softmax
}
// end namespace multihead_attn
}
// end namespace multihead_attn
...
@@ -71,7 +85,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -71,7 +85,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"forward"
,
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
fwd
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
fwd
,
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
m
.
def
(
"backward"
,
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
bwd
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
bwd
,
"Self Multihead Attention scaled, time masked softmax -- Backward."
);
"Self Multihead Attention scaled, time masked softmax -- Backward."
);
m
.
def
(
"get_batch_per_block"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
get_batch_per_block
,
"Return Batch per block size."
);
}
}
megatron/fused_kernels/scaled_masked_softmax.h
View file @
cb00a196
...
@@ -310,9 +310,22 @@ __global__ void scaled_masked_softmax_warp_backward(
...
@@ -310,9 +310,22 @@ __global__ void scaled_masked_softmax_warp_backward(
}
}
}
}
}
}
}
// end of anonymous namespace
}
// end of anonymous namespace
int
get_batch_per_block
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
){
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
return
batches_per_block
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_masked_softmax_forward
(
void
dispatch_scaled_masked_softmax_forward
(
output_t
*
dst
,
output_t
*
dst
,
...
...
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
View file @
cb00a196
...
@@ -28,6 +28,11 @@ namespace multihead_attn {
...
@@ -28,6 +28,11 @@ namespace multihead_attn {
namespace
fused_softmax
{
namespace
fused_softmax
{
namespace
scaled_masked_softmax
{
namespace
scaled_masked_softmax
{
int
get_batch_per_block_cuda
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
){
return
get_batch_per_block
(
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
}
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
torch
::
Tensor
const
&
mask
,
...
...
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
View file @
cb00a196
...
@@ -361,6 +361,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
...
@@ -361,6 +361,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
...
@@ -451,6 +452,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
...
@@ -451,6 +452,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
...
...
megatron/fused_kernels/tests/__init__.py
0 → 100644
View file @
cb00a196
megatron/fused_kernels/tests/test_fused_kernels.py
0 → 100644
View file @
cb00a196
import
math
import
torch
from
torch.nn
import
LayerNorm
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.fused_layer_norm
import
MixedFusedLayerNorm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.utils
import
attention_mask_func
def
test_load_fused_kernels
():
try
:
import
fused_mix_prec_layer_norm_cuda
import
scaled_masked_softmax_cuda
import
scaled_upper_triang_masked_softmax_cuda
import
torch
print
(
"[Success] load_fused_kernels"
)
except
ImportError
as
e
:
print
(
"[Fail] load_fused_kernels"
)
raise
e
def
test_fused_softmax
():
bert
=
BertModel
.
from_pretrained
(
"bert-base-cased"
).
cuda
().
half
()
tokenizer
=
BertTokenizer
.
from_pretrained
(
"bert-base-cased"
)
test_text
=
(
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi"
# 32
)
tokens
=
tokenizer
(
[
test_text
]
*
4
,
return_tensors
=
"pt"
,
)
embedding_output
=
bert
.
embeddings
(
input_ids
=
tokens
[
"input_ids"
].
cuda
(),
position_ids
=
None
,
token_type_ids
=
tokens
[
"token_type_ids"
].
cuda
(),
inputs_embeds
=
None
,
past_key_values_length
=
0
,
)
# (bsz, 1, 1, seq_len)
mask
=
bert
.
get_extended_attention_mask
(
attention_mask
=
tokens
[
"attention_mask"
].
cuda
(),
input_shape
=
tokens
[
"input_ids"
].
shape
,
device
=
bert
.
device
,
)
# (bsz, 1, seq_len, seq_len)
mask
=
mask
.
repeat
(
1
,
1
,
mask
.
size
()[
-
1
],
1
)
attention
=
bert
.
encoder
.
layer
[
0
].
attention
.
self
key_layer
=
attention
.
transpose_for_scores
(
attention
.
key
(
embedding_output
))
query_layer
=
attention
.
transpose_for_scores
(
attention
.
query
(
embedding_output
))
attention_scores
=
torch
.
matmul
(
query_layer
,
key_layer
.
transpose
(
-
1
,
-
2
))
attention_scores
/=
math
.
sqrt
(
key_layer
.
size
()[
-
1
])
fused_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
padding
,
scaled_masked_softmax_fusion
=
True
,
)
.
cuda
()
.
half
()
)
fused_softmax_output
=
fused_softmax
(
attention_scores
,
(
mask
!=
0
),
)
torch_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
padding
,
scaled_masked_softmax_fusion
=
False
,
)
.
cuda
()
.
half
()
)
torch_softmax_output
=
torch_softmax
(
attention_scores
,
(
mask
!=
0
),
)
test_result
=
(
fused_softmax_output
-
torch_softmax_output
).
abs
()
while
test_result
.
dim
()
!=
1
:
test_result
=
test_result
.
mean
(
dim
=-
1
)
diff
=
test_result
.
mean
(
dim
=-
1
)
if
diff
<=
1e-3
:
print
(
f
"
\n
[Success] test_fused_softmax"
f
"
\n
> mean_difference=
{
diff
}
"
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
else
:
print
(
f
"
\n
[Fail] test_fused_softmax"
f
"
\n
> mean_difference=
{
diff
}
, "
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
, "
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
def
test_fused_upper_triangle_mask_softmax
():
gpt
=
GPT2Model
.
from_pretrained
(
"gpt2"
).
cuda
().
half
()
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
test_text
=
(
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi"
# 24
)
tokens
=
tokenizer
(
[
test_text
]
*
4
,
return_tensors
=
"pt"
,
)
attention_mask
=
tokens
[
"attention_mask"
].
cuda
()
attention_mask
=
attention_mask
.
view
(
attention_mask
.
size
(
0
),
-
1
)
attention_mask
=
attention_mask
[:,
None
,
None
,
:]
attention_mask
=
(
1.0
-
attention_mask
)
*
-
10000.0
attention_mask
=
attention_mask
.
repeat
(
1
,
1
,
attention_mask
.
size
()[
-
1
],
1
)
attn
=
gpt
.
h
[
0
]
hidden_states
=
gpt
.
wte
(
tokens
[
"input_ids"
].
cuda
())
q
,
k
,
v
=
attn
.
attn
.
c_attn
(
hidden_states
).
split
(
768
,
dim
=-
1
)
q
=
attn
.
attn
.
_split_heads
(
q
,
attn
.
attn
.
num_heads
,
attn
.
attn
.
head_dim
)
k
=
attn
.
attn
.
_split_heads
(
k
,
attn
.
attn
.
num_heads
,
attn
.
attn
.
head_dim
)
attn_weights
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
sq
,
sk
=
q
.
size
(
-
2
),
k
.
size
(
-
2
)
causal_mask
=
attn
.
attn
.
bias
[:,
:,
sk
-
sq
:
sk
,
:
sk
].
bool
()
total_mask
=
~
(
causal_mask
&
(
attention_mask
==
0
))
"""
tensor([[[[False, True, True, ..., True, True, True],
[False, False, True, ..., True, True, True],
[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., False, True, True],
[False, False, False, ..., False, False, True],
[False, False, False, ..., False, False, False]]]
"""
fused_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
causal
,
scaled_masked_softmax_fusion
=
True
,
)
.
cuda
()
.
half
()
)
fused_softmax_output
=
fused_softmax
(
attn_weights
,
total_mask
,
)
torch_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
causal
,
scaled_masked_softmax_fusion
=
False
,
)
.
cuda
()
.
half
()
)
torch_softmax_output
=
torch_softmax
(
attn_weights
,
total_mask
,
)
test_result
=
(
fused_softmax_output
-
torch_softmax_output
).
abs
()
while
test_result
.
dim
()
!=
1
:
test_result
=
test_result
.
mean
(
dim
=-
1
)
diff
=
test_result
.
mean
(
dim
=-
1
)
if
diff
<=
1e-3
:
print
(
f
"
\n
[Success] test_fused_upper_triangle_mask_softmax"
f
"
\n
> mean_difference=
{
diff
}
"
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
else
:
print
(
f
"
\n
[Fail] test_fused_upper_triangle_mask_softmax"
f
"
\n
> mean_difference=
{
diff
}
, "
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
, "
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
def
test_layer_norm
():
bert
=
BertModel
.
from_pretrained
(
"bert-base-cased"
).
cuda
().
half
()
tokenizer
=
BertTokenizer
.
from_pretrained
(
"bert-base-cased"
)
test_text
=
(
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi"
# 32
)
tokens
=
tokenizer
(
[
test_text
]
*
4
,
return_tensors
=
"pt"
,
)
# [bsz, seq_len, d_model]
embedding_output
=
(
bert
.
embeddings
(
input_ids
=
tokens
[
"input_ids"
].
cuda
(),
position_ids
=
None
,
token_type_ids
=
tokens
[
"token_type_ids"
].
cuda
(),
inputs_embeds
=
None
,
past_key_values_length
=
0
,
)
.
cuda
()
.
half
()
)
fused_layernorm_layer
=
(
MixedFusedLayerNorm
(
normalized_shape
=
embedding_output
.
size
(
-
1
)).
cuda
().
half
()
)
torch_layernorm_layer
=
(
LayerNorm
(
normalized_shape
=
embedding_output
.
size
(
-
1
)).
cuda
().
half
()
)
fused_output
=
fused_layernorm_layer
(
embedding_output
)
torch_output
=
torch_layernorm_layer
(
embedding_output
)
test_result
=
(
fused_output
-
torch_output
).
abs
()
while
test_result
.
dim
()
!=
1
:
test_result
=
test_result
.
mean
(
dim
=-
1
)
diff
=
test_result
.
mean
(
dim
=-
1
)
if
diff
<=
1e-3
:
print
(
f
"
\n
[Success] test_layer_norm"
f
"
\n
> mean_difference=
{
diff
}
"
f
"
\n
> fused_values=
{
fused_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
"
f
"
\n
> torch_values=
{
torch_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
else
:
print
(
f
"
\n
[Fail] test_layer_norm"
f
"
\n
> mean_difference=
{
diff
}
, "
f
"
\n
> fused_values=
{
fused_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
, "
f
"
\n
> torch_values=
{
torch_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
if
__name__
==
"__main__"
:
try
:
from
transformers
import
BertTokenizer
,
GPT2Tokenizer
from
transformers.models.bert.modeling_bert
import
BertModel
from
transformers.models.gpt2.modeling_gpt2
import
GPT2Model
import
transformers
transformers
.
logging
.
set_verbosity
(
transformers
.
logging
.
FATAL
,
)
except
:
print
(
"
\n
[Fail] Please install `transformers` package to test fused kernels
\n
"
)
exit
(
-
1
)
test_load_fused_kernels
()
test_fused_softmax
()
test_fused_upper_triangle_mask_softmax
()
test_layer_norm
()
megatron/initialize.py
View file @
cb00a196
...
@@ -21,6 +21,7 @@ import time
...
@@ -21,6 +21,7 @@ import time
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
datetime
import
timedelta
from
megatron
import
fused_kernels
from
megatron
import
fused_kernels
from
megatron
import
get_adlr_autoresume
from
megatron
import
get_adlr_autoresume
...
@@ -63,6 +64,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
...
@@ -63,6 +64,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
print
(
'> setting random seeds to {} ...'
.
format
(
args
.
seed
))
print
(
'> setting random seeds to {} ...'
.
format
(
args
.
seed
))
_set_random_seed
(
args
.
seed
)
_set_random_seed
(
args
.
seed
)
# Set pytorch JIT layer fusion options.
_set_jit_fusion_options
()
args
=
get_args
()
args
=
get_args
()
if
args
.
lazy_mpu_init
:
if
args
.
lazy_mpu_init
:
args
.
use_cpu_initialization
=
True
args
.
use_cpu_initialization
=
True
...
@@ -77,9 +81,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
...
@@ -77,9 +81,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Megatron's MPU is the master. Complete initialization right away.
# Megatron's MPU is the master. Complete initialization right away.
finish_mpu_init
()
finish_mpu_init
()
# Initialize memory buffers.
_initialize_mem_buffs
()
# Autoresume.
# Autoresume.
_init_autoresume
()
_init_autoresume
()
...
@@ -175,11 +176,11 @@ def _initialize_distributed():
...
@@ -175,11 +176,11 @@ def _initialize_distributed():
else
:
else
:
args
.
local_rank
=
device
args
.
local_rank
=
device
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
# Call the init process
# Call the init process
torch
.
distributed
.
init_process_group
(
torch
.
distributed
.
init_process_group
(
backend
=
args
.
distributed_backend
,
backend
=
args
.
distributed_backend
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
)
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
timeout
=
timedelta
(
days
=
7
))
# Set the tensor model-parallel, pipeline model-parallel, and
# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
# data-parallel communicators.
...
@@ -226,10 +227,24 @@ def write_args_to_tensorboard():
...
@@ -226,10 +227,24 @@ def write_args_to_tensorboard():
global_step
=
args
.
iteration
)
global_step
=
args
.
iteration
)
def
_initialize_mem_buffs
():
def
_set_jit_fusion_options
():
"""Initialize manually allocated static memory."""
"""Set PyTorch JIT layer fusion options."""
args
=
get_args
()
# flags required to enable jit fusion kernels
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
if
(
TORCH_MAJOR
>
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>=
10
):
# nvfuser
torch
.
_C
.
_jit_set_profiling_executor
(
True
)
torch
.
_C
.
_jit_set_profiling_mode
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
False
)
torch
.
_C
.
_jit_set_texpr_fuser_enabled
(
False
)
torch
.
_C
.
_jit_set_nvfuser_enabled
(
True
)
torch
.
_C
.
_debug_set_autodiff_subgraph_inlining
(
False
)
else
:
# legacy pytorch fuser
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
torch
.
_C
.
_jit_set_profiling_executor
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
# Initialize memory for checkpointed activations.
if
args
.
distribute_checkpointed_activations
:
mpu
.
init_checkpointed_activations_memory_buffer
()
megatron/model/fused_bias_gelu.py
View file @
cb00a196
...
@@ -15,10 +15,6 @@
...
@@ -15,10 +15,6 @@
import
torch
import
torch
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
torch
.
_C
.
_jit_set_profiling_executor
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
###### BIAS GELU FUSION/ NO AUTOGRAD ################
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2*pi)-> 0.3989423
...
...
megatron/model/fused_softmax.py
View file @
cb00a196
...
@@ -13,7 +13,9 @@
...
@@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
torch
import
torch
import
torch.nn
as
nn
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.enums
import
AttnMaskType
...
@@ -30,10 +32,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
...
@@ -30,10 +32,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import
scaled_upper_triang_masked_softmax_cuda
import
scaled_upper_triang_masked_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_upper_triang_masked_softmax_cuda
.
forward
(
softmax_results
=
scaled_upper_triang_masked_softmax_cuda
.
forward
(
inputs
,
scale_t
[
0
]
inputs
,
scale_t
[
0
]
)
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
return
softmax_results
...
@@ -42,10 +44,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
...
@@ -42,10 +44,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import
scaled_upper_triang_masked_softmax_cuda
import
scaled_upper_triang_masked_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_upper_triang_masked_softmax_cuda
.
backward
(
input_grads
=
scaled_upper_triang_masked_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
]
output_grads
,
softmax_results
,
scale_t
[
0
]
)
)
return
input_grads
,
None
return
input_grads
,
None
...
@@ -63,9 +65,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
...
@@ -63,9 +65,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
scale_t
=
torch
.
tensor
([
scale
])
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_masked_softmax_cuda
.
forward
(
softmax_results
=
scaled_masked_softmax_cuda
.
forward
(
inputs
,
mask
,
scale_t
[
0
])
inputs
,
mask
,
scale_t
[
0
]
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
return
softmax_results
...
@@ -81,16 +81,18 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
...
@@ -81,16 +81,18 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
return
input_grads
,
None
,
None
return
input_grads
,
None
,
None
class
FusedScaleMaskSoftmax
(
torch
.
nn
.
Module
):
class
FusedScaleMaskSoftmax
(
nn
.
Module
):
"""
"""
fused operation: scaling + mask + softmax
fused operation: scaling + mask + softmax
Arguments:
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal)
attn_mask_type: attention mask type (pad or causal)
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied.
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
scale: scaling factor used in input tensor scaling.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -106,8 +108,9 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -106,8 +108,9 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_bf16
=
input_in_bf16
self
.
input_in_bf16
=
input_in_bf16
assert
not
(
self
.
input_in_fp16
and
self
.
input_in_bf16
),
\
assert
not
(
'both fp16 and bf16 flags cannot be active at the same time.'
self
.
input_in_fp16
and
self
.
input_in_bf16
),
"both fp16 and bf16 flags cannot be active at the same time."
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
attn_mask_type
=
attn_mask_type
self
.
attn_mask_type
=
attn_mask_type
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
...
@@ -118,47 +121,72 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -118,47 +121,72 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
assert
(
assert
(
self
.
scale
is
None
or
softmax_in_fp32
self
.
scale
is
None
or
softmax_in_fp32
),
"softmax should be in fp32 when scaled"
),
"softmax should be in fp32 when scaled"
def
forward
(
self
,
input
,
mask
):
def
forward
(
self
,
input
,
mask
):
# [b, np, sq, sk]
# [b, np, sq, sk]
assert
input
.
dim
()
==
4
assert
input
.
dim
()
==
4
data_size
=
input
.
size
()
query_seq_len
=
data_size
[
-
2
]
if
self
.
is_kernel_available
(
mask
,
*
input
.
size
()):
key_seq_len
=
data_size
[
-
1
]
return
self
.
forward_fused_softmax
(
input
,
mask
)
attn_batch_size
=
data_size
[
0
]
*
data_size
[
1
]
# constraints on various tensor dimensions to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
key_seq_len
>
16
and
key_seq_len
<=
2048
and
\
query_seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
# invoke custom kernel
if
self
.
input_in_float16
and
mask
is
not
None
and
\
custom_kernel_constraint
and
self
.
scaled_masked_softmax_fusion
:
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
assert
query_seq_len
==
key_seq_len
,
\
"causal mask is only for self attention"
input
=
input
.
view
(
-
1
,
query_seq_len
,
key_seq_len
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
probs
=
probs
.
view
(
*
data_size
)
else
:
assert
self
.
attn_mask_type
==
AttnMaskType
.
padding
probs
=
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
else
:
else
:
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
return
self
.
forward_torch_softmax
(
input
,
mask
)
input
=
input
.
float
()
def
is_kernel_available
(
self
,
mask
,
b
,
np
,
sq
,
sk
):
attn_batches
=
b
*
np
if
(
self
.
scaled_masked_softmax_fusion
# user want to fuse
and
self
.
input_in_float16
# input must be fp16
and
mask
is
not
None
# mask tensor must not be None
and
16
<
sk
<=
2048
# sk must be 16 ~ 2048
and
sq
%
4
==
0
# sq must be divisor of 4
and
attn_batches
%
4
==
0
# np * b must be divisor of 4
):
if
0
<=
sk
<=
2048
:
batch_per_block
=
self
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
if
attn_batches
%
batch_per_block
==
0
:
return
True
else
:
if
sq
%
batch_per_block
==
0
:
return
True
return
False
if
self
.
scale
is
not
None
:
def
forward_fused_softmax
(
self
,
input
,
mask
):
input
=
input
*
self
.
scale
b
,
np
,
sq
,
sk
=
input
.
size
()
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
if
self
.
input_in_fp16
:
assert
sq
==
sk
,
"causal mask is only for self attention"
probs
=
probs
.
half
()
else
:
# input is 3D tensor (attn_batches, sq, sk)
probs
=
probs
.
bfloat16
()
input
=
input
.
view
(
-
1
,
sq
,
sk
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
return
probs
.
view
(
b
,
np
,
sq
,
sk
)
else
:
# input is 4D tensor (b, np, sq, sk)
return
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
def
forward_torch_softmax
(
self
,
input
,
mask
):
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
input
=
input
.
float
()
if
self
.
scale
is
not
None
:
input
=
input
*
self
.
scale
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_fp16
:
probs
=
probs
.
half
()
else
:
probs
=
probs
.
bfloat16
()
return
probs
return
probs
@
staticmethod
def
get_batch_per_block
(
sq
,
sk
,
b
,
np
):
import
scaled_masked_softmax_cuda
return
scaled_masked_softmax_cuda
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
megatron/model/gpt_model.py
View file @
cb00a196
...
@@ -29,23 +29,15 @@ from .utils import scaled_init_method_normal
...
@@ -29,23 +29,15 @@ from .utils import scaled_init_method_normal
def
post_language_model_processing
(
lm_output
,
labels
,
logit_weights
,
def
post_language_model_processing
(
lm_output
,
labels
,
logit_weights
,
get_key_value
,
parallel_output
,
parallel_output
,
forward_method_parallel_output
,
fp16_lm_cross_entropy
):
fp16_lm_cross_entropy
):
if
get_key_value
:
lm_output
,
presents
=
lm_output
# Output.
# Output.
if
forward_method_parallel_output
is
not
None
:
parallel_output
=
forward_method_parallel_output
output
=
parallel_lm_logits
(
output
=
parallel_lm_logits
(
lm_output
,
lm_output
,
logit_weights
,
logit_weights
,
parallel_output
)
parallel_output
)
if
get_key_value
:
output
=
[
output
,
presents
]
if
labels
is
None
:
if
labels
is
None
:
return
output
return
output
else
:
else
:
...
@@ -90,23 +82,22 @@ class GPTModel(MegatronModule):
...
@@ -90,23 +82,22 @@ class GPTModel(MegatronModule):
self
.
language_model
.
set_input_tensor
(
input_tensor
)
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
labels
=
None
,
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
labels
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
tokentype_ids
=
None
,
forward_method_parallel_output
=
None
):
set_inference_key_value_memory
=
False
,
inference_max_sequence_len
=
None
):
lm_output
=
self
.
language_model
(
lm_output
=
self
.
language_model
(
input_ids
,
input_ids
,
position_ids
,
position_ids
,
attention_mask
,
attention_mask
,
layer_past
=
layer_past
,
set_inference_key_value_memory
=
set_inference_key_value_memory
,
get_key_value
=
get_key_value
)
inference_max_sequence_len
=
inference_max_sequence_len
)
if
self
.
post_process
:
if
self
.
post_process
:
return
post_language_model_processing
(
return
post_language_model_processing
(
lm_output
,
labels
,
lm_output
,
labels
,
self
.
word_embeddings_weight
(),
self
.
word_embeddings_weight
(),
get_key_value
,
self
.
parallel_output
,
self
.
parallel_output
,
forward_method_parallel_output
,
self
.
fp16_lm_cross_entropy
)
self
.
fp16_lm_cross_entropy
)
else
:
else
:
return
lm_output
return
lm_output
...
...
megatron/model/language_model.py
View file @
cb00a196
...
@@ -379,8 +379,10 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -379,8 +379,10 @@ class TransformerLanguageModel(MegatronModule):
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_attn_mask
,
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_attn_mask
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
set_inference_key_value_memory
=
False
,
inference_max_sequence_len
=
None
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
# Encoder embedding.
# Encoder embedding.
...
@@ -393,10 +395,11 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -393,10 +395,11 @@ class TransformerLanguageModel(MegatronModule):
# Run encoder.
# Run encoder.
if
enc_hidden_states
is
None
:
if
enc_hidden_states
is
None
:
if
self
.
encoder
is
not
None
:
if
self
.
encoder
is
not
None
:
encoder_output
=
self
.
encoder
(
encoder_input
,
encoder_output
=
self
.
encoder
(
enc_attn_mask
,
encoder_input
,
layer_past
=
layer_past
,
enc_attn_mask
,
get_key_value
=
get_key_value
)
set_inference_key_value_memory
=
set_inference_key_value_memory
,
inference_max_sequence_len
=
inference_max_sequence_len
)
else
:
else
:
encoder_output
=
self
.
encoder_hidden_state
encoder_output
=
self
.
encoder_hidden_state
else
:
else
:
...
@@ -424,12 +427,13 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -424,12 +427,13 @@ class TransformerLanguageModel(MegatronModule):
decoder_input
=
None
decoder_input
=
None
# Run decoder.
# Run decoder.
decoder_output
=
self
.
decoder
(
decoder_input
,
decoder_output
=
self
.
decoder
(
dec_attn_mask
,
decoder_input
,
layer_past
=
layer_past
,
dec_attn_mask
,
get_key_value
=
get_key_value
,
encoder_output
=
encoder_output
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
enc_dec_attn_mask
=
enc_dec_attn_mask
)
set_inference_key_value_memory
=
set_inference_key_value_memory
,
inference_max_sequence_len
=
inference_max_sequence_len
)
if
self
.
add_pooler
and
self
.
post_process
:
if
self
.
add_pooler
and
self
.
post_process
:
return
decoder_output
,
encoder_output
,
pooled_output
return
decoder_output
,
encoder_output
,
pooled_output
...
...
megatron/model/transformer.py
View file @
cb00a196
...
@@ -27,11 +27,6 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
...
@@ -27,11 +27,6 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
# flags required to enable jit fusion kernels
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
torch
.
_C
.
_jit_set_profiling_executor
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
""" We use the following notation throughout this file:
""" We use the following notation throughout this file:
h: hidden size
h: hidden size
...
@@ -123,6 +118,7 @@ class ParallelAttention(MegatronModule):
...
@@ -123,6 +118,7 @@ class ParallelAttention(MegatronModule):
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
attention_type
=
attention_type
self
.
attention_type
=
attention_type
self
.
attn_mask_type
=
attn_mask_type
self
.
attn_mask_type
=
attn_mask_type
self
.
params_dtype
=
args
.
params_dtype
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
...
@@ -183,10 +179,53 @@ class ParallelAttention(MegatronModule):
...
@@ -183,10 +179,53 @@ class ParallelAttention(MegatronModule):
init_method
=
output_layer_init_method
,
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
skip_bias_add
=
True
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
# Inference key-value memory
get_key_value
=
False
,
encoder_output
=
None
):
self
.
inference_key_memory
=
None
self
.
inference_value_memory
=
None
self
.
inference_current_sequence_len
=
0
def
_allocate_memory
(
self
,
inference_max_sequence_len
,
batch_size
):
return
torch
.
empty
(
inference_max_sequence_len
,
batch_size
,
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
,
dtype
=
self
.
params_dtype
,
device
=
torch
.
cuda
.
current_device
())
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
set_inference_key_value_memory
=
False
,
inference_max_sequence_len
=
None
):
# hidden_states: [sq, b, h]
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if
set_inference_key_value_memory
:
assert
inference_max_sequence_len
and
inference_max_sequence_len
>
0
self
.
inference_key_memory
=
self
.
_allocate_memory
(
inference_max_sequence_len
,
hidden_states
.
size
(
1
))
self
.
inference_value_memory
=
self
.
_allocate_memory
(
inference_max_sequence_len
,
hidden_states
.
size
(
1
))
self
.
inference_current_sequence_len
=
0
# Some consistency check.
if
inference_max_sequence_len
:
assert
self
.
inference_current_sequence_len
<
\
self
.
inference_key_memory
.
size
(
0
)
assert
inference_max_sequence_len
==
\
self
.
inference_key_memory
.
size
(
0
)
# This is added for safety. In case inference_max_sequence_len
# is not provided, make sure there is no potential memory left
# from previous inference.
if
not
inference_max_sequence_len
:
self
.
inference_key_memory
=
None
self
.
inference_value_memory
=
None
# =====================
# =====================
# Query, Key, and Value
# Query, Key, and Value
# =====================
# =====================
...
@@ -227,18 +266,24 @@ class ParallelAttention(MegatronModule):
...
@@ -227,18 +266,24 @@ class ParallelAttention(MegatronModule):
self
.
hidden_size_per_attention_head
)
self
.
hidden_size_per_attention_head
)
query_layer
=
query_layer
.
view
(
*
new_tensor_shape
)
query_layer
=
query_layer
.
view
(
*
new_tensor_shape
)
# ==================================
# Adjust key and value for inference
# ==================================
if
layer_past
is
not
None
:
# ===================================================
past_key
,
past_value
=
layer_past
# Adjust key, value, and attention mask for inference
key_layer
=
torch
.
cat
((
past_key
.
type_as
(
key_layer
),
# ===================================================
key_layer
),
dim
=
0
)
value_layer
=
torch
.
cat
((
past_value
.
type_as
(
value_layer
),
if
inference_max_sequence_len
:
value_layer
),
dim
=
0
)
# Adjust the range variables.
if
get_key_value
:
start
=
self
.
inference_current_sequence_len
present
=
(
key_layer
,
value_layer
)
self
.
inference_current_sequence_len
+=
key_layer
.
size
(
0
)
end
=
self
.
inference_current_sequence_len
# Copy key and values.
self
.
inference_key_memory
[
start
:
end
,
...]
=
key_layer
self
.
inference_value_memory
[
start
:
end
,
...]
=
value_layer
key_layer
=
self
.
inference_key_memory
[:
end
,
...]
value_layer
=
self
.
inference_value_memory
[:
end
,
...]
# Adjust attention mask
attention_mask
=
attention_mask
[...,
start
:
end
,
:
end
]
# ===================================
# ===================================
# Raw attention scores. [b, np, s, s]
# Raw attention scores. [b, np, s, s]
...
@@ -275,22 +320,6 @@ class ParallelAttention(MegatronModule):
...
@@ -275,22 +320,6 @@ class ParallelAttention(MegatronModule):
# change view to [b, np, sq, sk]
# change view to [b, np, sq, sk]
attention_scores
=
matmul_result
.
view
(
*
output_size
)
attention_scores
=
matmul_result
.
view
(
*
output_size
)
# ==================================================
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================
if
get_key_value
:
with
torch
.
no_grad
():
if
layer_past
is
not
None
:
attention_mask
=
attention_mask
[
...,
attention_scores
.
size
(
3
)
-
1
,
:
attention_scores
.
size
(
3
)].
unsqueeze
(
2
)
else
:
attention_mask
=
attention_mask
[
...,
:
attention_scores
.
size
(
3
),
:
attention_scores
.
size
(
3
)]
# ===========================
# ===========================
# Attention probs and dropout
# Attention probs and dropout
...
@@ -346,9 +375,6 @@ class ParallelAttention(MegatronModule):
...
@@ -346,9 +375,6 @@ class ParallelAttention(MegatronModule):
output
,
bias
=
self
.
dense
(
context_layer
)
output
,
bias
=
self
.
dense
(
context_layer
)
if
get_key_value
:
output
=
[
output
,
present
]
return
output
,
bias
return
output
,
bias
...
@@ -435,21 +461,21 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -435,21 +461,21 @@ class ParallelTransformerLayer(MegatronModule):
output_layer_init_method
)
output_layer_init_method
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
encoder_output
=
None
,
layer_past
=
None
,
get_key_value
=
False
):
enc_dec_attn_mask
=
None
,
set_inference_key_value_memory
=
False
,
inference_max_sequence_len
=
None
):
# hidden_states: [b, s, h]
# hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer.
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
# Self attention.
attention_output
,
attention_bias
=
\
attention_output
,
attention_bias
=
\
self
.
self_attention
(
layernorm_output
,
self
.
self_attention
(
attention_mask
,
layernorm_output
,
layer_past
=
layer_past
,
attention_mask
,
get_key_value
=
get_key_value
)
set_inference_key_value_memory
=
set_inference_key_value_memory
,
inference_max_sequence_len
=
inference_max_sequence_len
)
if
get_key_value
:
attention_output
,
presents
=
attention_output
# Residual connection.
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
if
self
.
apply_residual_connection_post_layernorm
:
...
@@ -519,9 +545,6 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -519,9 +545,6 @@ class ParallelTransformerLayer(MegatronModule):
residual
,
residual
,
self
.
hidden_dropout
)
self
.
hidden_dropout
)
if
get_key_value
:
output
=
[
output
,
presents
]
return
output
return
output
...
@@ -542,8 +565,9 @@ class ParallelTransformer(MegatronModule):
...
@@ -542,8 +565,9 @@ class ParallelTransformer(MegatronModule):
self
.
input_tensor
=
None
self
.
input_tensor
=
None
# Store activation checkpoiting flag.
# Store activation checkpoiting flag.
self
.
checkpoint_activations
=
args
.
checkpoint_activations
self
.
activations_checkpoint_method
=
args
.
activations_checkpoint_method
self
.
checkpoint_num_layers
=
args
.
checkpoint_num_layers
self
.
activations_checkpoint_num_layers
=
args
.
activations_checkpoint_num_layers
self
.
distribute_checkpointed_activations
=
args
.
distribute_checkpointed_activations
# Number of layers.
# Number of layers.
self
.
num_layers
=
mpu
.
get_num_layers
(
self
.
num_layers
=
mpu
.
get_num_layers
(
...
@@ -606,14 +630,49 @@ class ParallelTransformer(MegatronModule):
...
@@ -606,14 +630,49 @@ class ParallelTransformer(MegatronModule):
return
x_
return
x_
return
custom_forward
return
custom_forward
# Make sure memory is freed.
def
distribute_checkpointed_activations_helper
(
layer_number
):
mpu
.
reset_checkpointed_activations_memory_buffer
()
"""Distribute checkpointed activations across the tensor model
l
=
0
Parallel ranks if the `distribute-checkpointed-activations
while
l
<
self
.
num_layers
:
is on and either of the following conditions is met:
hidden_states
=
mpu
.
checkpoint
(
- it is not the first layer in the in the pipeline stage.
custom
(
l
,
l
+
self
.
checkpoint_num_layers
),
The first layer is used in the pipeline parallelism
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
and changing its shape throws error in the backward pass.
l
+=
self
.
checkpoint_num_layers
- we are at the first pipline stage so the input tensor is
not used in pipeline parallelism. Note that no pipeline
parallelism is a special case of this.
"""
not_first_layer_in_pipeline_stage
=
(
layer_number
>
0
)
is_first_pipeline_stage
=
(
mpu
.
get_pipeline_model_parallel_rank
()
==
0
)
return
self
.
distribute_checkpointed_activations
and
\
(
not_first_layer_in_pipeline_stage
or
is_first_pipeline_stage
)
if
self
.
activations_checkpoint_method
==
'uniform'
:
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l
=
0
while
l
<
self
.
num_layers
:
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
self
.
activations_checkpoint_num_layers
),
distribute_checkpointed_activations_helper
(
l
),
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
l
+=
self
.
activations_checkpoint_num_layers
elif
self
.
activations_checkpoint_method
==
'block'
:
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for
l
in
range
(
self
.
num_layers
):
if
l
<
self
.
activations_checkpoint_num_layers
:
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
1
),
distribute_checkpointed_activations_helper
(
l
),
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
raise
ValueError
(
"Invalid activation checkpoint method."
)
return
hidden_states
return
hidden_states
...
@@ -627,18 +686,16 @@ class ParallelTransformer(MegatronModule):
...
@@ -627,18 +686,16 @@ class ParallelTransformer(MegatronModule):
forward_step_func"""
forward_step_func"""
self
.
input_tensor
=
input_tensor
self
.
input_tensor
=
input_tensor
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
get_key_value
=
False
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
):
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
set_inference_key_value_memory
=
False
,
inference_max_sequence_len
=
None
):
# Checks.
# Checks.
if
layer_past
is
not
None
:
if
inference_max_sequence_len
:
assert
get_key_value
,
\
assert
self
.
activations_checkpoint_method
is
None
,
\
'for not None values in layer_past, '
\
'inference does not work with activation checkpointing'
'expected get_key_value to be set'
if
get_key_value
:
assert
not
self
.
checkpoint_activations
,
\
'get_key_value does not work with '
\
'activation checkpointing'
if
self
.
pre_process
:
if
self
.
pre_process
:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
...
@@ -655,28 +712,21 @@ class ParallelTransformer(MegatronModule):
...
@@ -655,28 +712,21 @@ class ParallelTransformer(MegatronModule):
if
encoder_output
is
not
None
:
if
encoder_output
is
not
None
:
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
checkpoint_activati
on
s
:
if
self
.
activations_checkpoint_method
is
not
N
on
e
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
attention_mask
,
attention_mask
,
encoder_output
,
encoder_output
,
enc_dec_attn_mask
)
enc_dec_attn_mask
)
else
:
else
:
if
get_key_value
:
presents
=
[]
for
index
in
range
(
self
.
num_layers
):
for
index
in
range
(
self
.
num_layers
):
layer
=
self
.
_get_layer
(
index
)
layer
=
self
.
_get_layer
(
index
)
past
=
None
hidden_states
=
layer
(
if
layer_past
is
not
None
:
hidden_states
,
past
=
layer_past
[
index
]
attention_mask
,
hidden_states
=
layer
(
hidden_states
,
encoder_output
=
encoder_output
,
attention_mask
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
encoder_output
=
encoder_output
,
set_inference_key_value_memory
=
set_inference_key_value_memory
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_max_sequence_len
=
inference_max_sequence_len
)
layer_past
=
past
,
get_key_value
=
get_key_value
)
if
get_key_value
:
hidden_states
,
present
=
hidden_states
presents
.
append
(
present
)
# Final layer norm.
# Final layer norm.
if
self
.
post_process
:
if
self
.
post_process
:
...
@@ -685,7 +735,5 @@ class ParallelTransformer(MegatronModule):
...
@@ -685,7 +735,5 @@ class ParallelTransformer(MegatronModule):
output
=
self
.
final_layernorm
(
hidden_states
)
output
=
self
.
final_layernorm
(
hidden_states
)
else
:
else
:
output
=
hidden_states
output
=
hidden_states
if
get_key_value
:
output
=
[
output
,
presents
]
return
output
return
output
megatron/mpu/__init__.py
View file @
cb00a196
...
@@ -60,9 +60,7 @@ from .mappings import scatter_to_tensor_model_parallel_region
...
@@ -60,9 +60,7 @@ from .mappings import scatter_to_tensor_model_parallel_region
from
.random
import
checkpoint
from
.random
import
checkpoint
from
.random
import
get_cuda_rng_tracker
from
.random
import
get_cuda_rng_tracker
from
.random
import
init_checkpointed_activations_memory_buffer
from
.random
import
model_parallel_cuda_manual_seed
from
.random
import
model_parallel_cuda_manual_seed
from
.random
import
reset_checkpointed_activations_memory_buffer
from
.random
import
gather_split_1d_tensor
from
.random
import
gather_split_1d_tensor
from
.random
import
split_tensor_into_1d_equal_chunks
from
.random
import
split_tensor_into_1d_equal_chunks
...
...
megatron/mpu/layers.py
View file @
cb00a196
...
@@ -27,6 +27,7 @@ from torch.nn.parameter import Parameter
...
@@ -27,6 +27,7 @@ from torch.nn.parameter import Parameter
from
.initialize
import
get_tensor_model_parallel_rank
from
.initialize
import
get_tensor_model_parallel_rank
from
.initialize
import
get_tensor_model_parallel_world_size
from
.initialize
import
get_tensor_model_parallel_world_size
from
.initialize
import
get_tensor_model_parallel_group
from
.mappings
import
copy_to_tensor_model_parallel_region
from
.mappings
import
copy_to_tensor_model_parallel_region
from
.mappings
import
gather_from_tensor_model_parallel_region
from
.mappings
import
gather_from_tensor_model_parallel_region
from
.mappings
import
reduce_from_tensor_model_parallel_region
from
.mappings
import
reduce_from_tensor_model_parallel_region
...
@@ -198,6 +199,37 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -198,6 +199,37 @@ class VocabParallelEmbedding(torch.nn.Module):
return
output
return
output
class
ColumnParallelLinearWithAsyncAllreduce
(
torch
.
autograd
.
Function
):
"""
Column-parallel linear layer execution with asynchronous all-reduce
execution in backprop.
"""
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
):
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
use_bias
=
bias
is
not
None
output
=
torch
.
matmul
(
input
,
weight
.
t
())
if
bias
is
not
None
:
output
=
output
+
bias
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
grad_input
=
grad_output
.
matmul
(
weight
)
# Asyncronous all-reduce
handle
=
torch
.
distributed
.
all_reduce
(
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
grad_weight
=
grad_output
.
t
().
matmul
(
input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
"""Linear layer with column parallelism.
"""Linear layer with column parallelism.
...
@@ -272,16 +304,30 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -272,16 +304,30 @@ class ColumnParallelLinear(torch.nn.Module):
self
.
bias
.
zero_
()
self
.
bias
.
zero_
()
else
:
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
register_parameter
(
'bias'
,
None
)
self
.
async_tensor_model_parallel_allreduce
=
(
not
args
.
no_async_tensor_model_parallel_allreduce
and
world_size
>
1
)
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
# Set up backprop all-reduce.
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
bias
)
if
self
.
async_tensor_model_parallel_allreduce
:
input_shape
=
input_
.
shape
input_
=
input_
.
view
(
input_shape
[
0
]
*
input_shape
[
1
],
input_shape
[
2
])
# Maxtrix multiply with asynchronouse all-reduce execution
output_parallel
=
ColumnParallelLinearWithAsyncAllreduce
.
apply
(
input_
,
self
.
weight
,
bias
)
output_parallel
=
output_parallel
.
view
(
input_shape
[
0
],
input_shape
[
1
],
output_parallel
.
shape
[
1
])
else
:
# Set up backprop all-reduce.
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
bias
)
if
self
.
gather_output
:
if
self
.
gather_output
:
# All-gather across the partitions.
# All-gather across the partitions.
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
...
...
megatron/mpu/random.py
View file @
cb00a196
...
@@ -37,37 +37,6 @@ from .initialize import get_tensor_model_parallel_world_size
...
@@ -37,37 +37,6 @@ from .initialize import get_tensor_model_parallel_world_size
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
'model-parallel-rng'
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
'model-parallel-rng'
# Whether apply model parallelsim to checkpointed hidden states.
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
=
None
def
init_checkpointed_activations_memory_buffer
():
"""Initializ the memory buffer for the checkpointed activations."""
args
=
get_args
()
per_layer
=
args
.
micro_batch_size
*
args
.
max_position_embeddings
*
\
args
.
hidden_size
//
args
.
tensor_model_parallel_size
assert
args
.
num_layers
%
args
.
checkpoint_num_layers
==
0
,
\
'number of layers is not divisible by checkpoint-num-layers'
num_checkpointer_layers
=
args
.
num_layers
//
args
.
checkpoint_num_layers
numel
=
per_layer
*
num_checkpointer_layers
dtype
=
torch
.
half
if
not
args
.
fp16
:
dtype
=
torch
.
float
global
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
assert
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
None
,
\
'checkpointed activations memory buffer is already allocated.'
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
=
allocate_mem_buff
(
'checkpointed activations'
,
numel
,
dtype
,
track_usage
=
False
)
def
reset_checkpointed_activations_memory_buffer
():
"""Reset the memory used for checkpointing."""
if
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
not
None
:
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
.
reset
()
def
_set_cuda_rng_state
(
new_state
,
device
=-
1
):
def
_set_cuda_rng_state
(
new_state
,
device
=-
1
):
"""Sets the random number generator state of the current GPU.
"""Sets the random number generator state of the current GPU.
...
@@ -101,14 +70,21 @@ def _set_cuda_rng_state(new_state, device=-1):
...
@@ -101,14 +70,21 @@ def _set_cuda_rng_state(new_state, device=-1):
_lazy_call
(
cb
)
_lazy_call
(
cb
)
def
split_tensor_into_1d_equal_chunks
(
tensor
):
def
split_tensor_into_1d_equal_chunks
(
tensor
,
new_buffer
=
False
):
"""Break a tensor into equal 1D chunks."""
"""Break a tensor into equal 1D chunks."""
data
=
tensor
.
view
(
-
1
)
partition_size
=
torch
.
numel
(
tensor
)
//
\
partition_size
=
torch
.
numel
(
data
)
//
get_tensor_model_parallel_world_size
()
get_tensor_model_parallel_world_size
()
start_index
=
partition_size
*
get_tensor_model_parallel_rank
()
start_index
=
partition_size
*
get_tensor_model_parallel_rank
()
end_index
=
start_index
+
partition_size
end_index
=
start_index
+
partition_size
return
data
[
start_index
:
end_index
]
if
new_buffer
:
data
=
torch
.
empty
(
partition_size
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
data
.
copy_
(
tensor
.
view
(
-
1
)[
start_index
:
end_index
])
else
:
data
=
tensor
.
view
(
-
1
)[
start_index
:
end_index
]
return
data
def
gather_split_1d_tensor
(
tensor
):
def
gather_split_1d_tensor
(
tensor
):
"""Opposite of above function, gather values from model parallel ranks."""
"""Opposite of above function, gather values from model parallel ranks."""
...
@@ -250,8 +226,10 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -250,8 +226,10 @@ class CheckpointFunction(torch.autograd.Function):
tracked/set/reset.
tracked/set/reset.
"""
"""
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
run_function
,
*
args
):
def
forward
(
ctx
,
run_function
,
distribute_checkpointed_activations
,
*
args
):
ctx
.
run_function
=
run_function
ctx
.
run_function
=
run_function
ctx
.
distribute_checkpointed_activations
\
=
distribute_checkpointed_activations
# Copy the rng states.
# Copy the rng states.
ctx
.
fwd_cpu_rng_state
=
torch
.
get_rng_state
()
ctx
.
fwd_cpu_rng_state
=
torch
.
get_rng_state
()
...
@@ -263,16 +241,14 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -263,16 +241,14 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
# the chunk corresponding to the current rank.
if
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
not
N
on
e
:
if
distribute_checkpointed_activati
on
s
:
ctx
.
input_0_shape
=
args
[
0
].
data
.
shape
ctx
.
input_0_shape
=
args
[
0
].
data
.
shape
args
[
0
].
data
=
split_tensor_into_1d_equal_chunks
(
args
[
0
].
data
)
args
[
0
].
data
=
split_tensor_into_1d_equal_chunks
(
args
[
0
].
data
,
args
[
0
].
data
=
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
.
add
(
new_buffer
=
True
)
args
[
0
].
data
)
# Store everything.
# Store everything.
ctx
.
save_for_backward
(
*
args
)
ctx
.
save_for_backward
(
*
args
)
return
outputs
return
outputs
@
staticmethod
@
staticmethod
...
@@ -281,7 +257,7 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -281,7 +257,7 @@ class CheckpointFunction(torch.autograd.Function):
raise
RuntimeError
(
"Checkpointing is not compatible with .grad(), "
raise
RuntimeError
(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
"please use .backward() if possible"
)
inputs
=
ctx
.
saved_tensors
inputs
=
ctx
.
saved_tensors
if
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
not
N
on
e
:
if
ctx
.
distribute_checkpointed_activati
on
s
:
inputs
[
0
].
data
=
gather_split_1d_tensor
(
inputs
[
0
].
data
)
inputs
[
0
].
data
=
gather_split_1d_tensor
(
inputs
[
0
].
data
)
inputs
[
0
].
data
=
inputs
[
0
].
data
.
view
(
ctx
.
input_0_shape
)
inputs
[
0
].
data
=
inputs
[
0
].
data
.
view
(
ctx
.
input_0_shape
)
...
@@ -310,10 +286,11 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -310,10 +286,11 @@ class CheckpointFunction(torch.autograd.Function):
torch
.
autograd
.
backward
(
outputs
,
args
)
torch
.
autograd
.
backward
(
outputs
,
args
)
grads
=
tuple
(
inp
.
grad
if
isinstance
(
inp
,
torch
.
Tensor
)
else
inp
grads
=
tuple
(
inp
.
grad
if
isinstance
(
inp
,
torch
.
Tensor
)
else
inp
for
inp
in
detached_inputs
)
for
inp
in
detached_inputs
)
return
(
None
,)
+
grads
return
(
None
,
None
)
+
grads
def
checkpoint
(
function
,
*
args
):
def
checkpoint
(
function
,
distribute_checkpointed_activations
,
*
args
):
"""Checkpoint a model or part of the model.
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint."""
This has been directly copied from torch.utils.checkpoint."""
return
CheckpointFunction
.
apply
(
function
,
*
args
)
return
CheckpointFunction
.
apply
(
function
,
distribute_checkpointed_activations
,
*
args
)
megatron/optimizer/__init__.py
View file @
cb00a196
...
@@ -100,7 +100,7 @@ def get_megatron_optimizer(model):
...
@@ -100,7 +100,7 @@ def get_megatron_optimizer(model):
args
.
clip_grad
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
,
params_have_main_grad
,
args
.
use_contiguous_buffers_in_ddp
,
args
.
use_contiguous_buffers_in_
local_
ddp
,
args
.
bf16
,
args
.
bf16
,
grad_scaler
)
grad_scaler
)
...
@@ -108,4 +108,4 @@ def get_megatron_optimizer(model):
...
@@ -108,4 +108,4 @@ def get_megatron_optimizer(model):
return
FP32Optimizer
(
optimizer
,
args
.
clip_grad
,
return
FP32Optimizer
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
,
params_have_main_grad
,
args
.
use_contiguous_buffers_in_ddp
)
args
.
use_contiguous_buffers_in_
local_
ddp
)
megatron/optimizer/optimizer.py
View file @
cb00a196
...
@@ -69,7 +69,7 @@ class MegatronOptimizer(ABC):
...
@@ -69,7 +69,7 @@ class MegatronOptimizer(ABC):
def
__init__
(
self
,
optimizer
,
clip_grad
,
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_ddp
):
use_contiguous_buffers_in_
local_
ddp
):
"""Input optimizer is the base optimizer for example Adam."""
"""Input optimizer is the base optimizer for example Adam."""
self
.
optimizer
=
optimizer
self
.
optimizer
=
optimizer
...
@@ -78,9 +78,9 @@ class MegatronOptimizer(ABC):
...
@@ -78,9 +78,9 @@ class MegatronOptimizer(ABC):
self
.
clip_grad
=
clip_grad
self
.
clip_grad
=
clip_grad
self
.
log_num_zeros_in_grad
=
log_num_zeros_in_grad
self
.
log_num_zeros_in_grad
=
log_num_zeros_in_grad
self
.
params_have_main_grad
=
params_have_main_grad
self
.
params_have_main_grad
=
params_have_main_grad
self
.
use_contiguous_buffers_in_ddp
=
use_contiguous_buffers_in_ddp
self
.
use_contiguous_buffers_in_
local_
ddp
=
use_contiguous_buffers_in_
local_
ddp
if
self
.
use_contiguous_buffers_in_ddp
:
if
self
.
use_contiguous_buffers_in_
local_
ddp
:
assert
self
.
params_have_main_grad
,
\
assert
self
.
params_have_main_grad
,
\
"use of contiguous buffer requires that params have main grad"
"use of contiguous buffer requires that params have main grad"
...
@@ -193,12 +193,12 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
...
@@ -193,12 +193,12 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
"""
"""
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_ddp
,
params_have_main_grad
,
use_contiguous_buffers_in_
local_
ddp
,
bf16
,
grad_scaler
):
bf16
,
grad_scaler
):
super
(
Float16OptimizerWithFloat16Params
,
self
).
__init__
(
super
(
Float16OptimizerWithFloat16Params
,
self
).
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_ddp
)
params_have_main_grad
,
use_contiguous_buffers_in_
local_
ddp
)
self
.
bf16
=
bf16
self
.
bf16
=
bf16
self
.
grad_scaler
=
grad_scaler
self
.
grad_scaler
=
grad_scaler
...
@@ -323,7 +323,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
...
@@ -323,7 +323,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# persist and therefore should not be deallocated.)
# persist and therefore should not be deallocated.)
model_param
.
grad
=
None
model_param
.
grad
=
None
if
self
.
params_have_main_grad
and
\
if
self
.
params_have_main_grad
and
\
not
self
.
use_contiguous_buffers_in_ddp
:
not
self
.
use_contiguous_buffers_in_
local_
ddp
:
model_param
.
main_grad
=
None
model_param
.
main_grad
=
None
# For fp32 grads, we need to reset the grads to main grad.
# For fp32 grads, we need to reset the grads to main grad.
...
@@ -335,7 +335,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
...
@@ -335,7 +335,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# Safe to de-reference model's main_grad after copying.
# Safe to de-reference model's main_grad after copying.
# (If using contiguous buffers, main_grad's memory should
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
# persist and therefore should not be deallocated.)
if
not
self
.
use_contiguous_buffers_in_ddp
:
if
not
self
.
use_contiguous_buffers_in_
local_
ddp
:
model_param
.
main_grad
=
None
model_param
.
main_grad
=
None
def
_unscale_main_grads_and_check_for_nan
(
self
):
def
_unscale_main_grads_and_check_for_nan
(
self
):
...
@@ -491,11 +491,11 @@ class FP32Optimizer(MegatronOptimizer):
...
@@ -491,11 +491,11 @@ class FP32Optimizer(MegatronOptimizer):
def
__init__
(
self
,
optimizer
,
clip_grad
,
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_ddp
):
use_contiguous_buffers_in_
local_
ddp
):
super
(
FP32Optimizer
,
self
).
__init__
(
super
(
FP32Optimizer
,
self
).
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_ddp
)
params_have_main_grad
,
use_contiguous_buffers_in_
local_
ddp
)
self
.
_scale
=
torch
.
cuda
.
FloatTensor
([
1.0
])
self
.
_scale
=
torch
.
cuda
.
FloatTensor
([
1.0
])
...
@@ -525,7 +525,7 @@ class FP32Optimizer(MegatronOptimizer):
...
@@ -525,7 +525,7 @@ class FP32Optimizer(MegatronOptimizer):
# Safe to de-reference model's main_grad after copying.
# Safe to de-reference model's main_grad after copying.
# (If using contiguous buffers, main_grad's memory should
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
# persist and therefore should not be deallocated.)
if
not
self
.
use_contiguous_buffers_in_ddp
:
if
not
self
.
use_contiguous_buffers_in_
local_
ddp
:
param
.
main_grad
=
None
param
.
main_grad
=
None
# Clip gradients.
# Clip gradients.
...
...
megatron/text_generation_server.py
0 → 100644
View file @
cb00a196
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
datetime
import
torch
import
json
import
threading
from
flask
import
Flask
,
request
,
jsonify
,
current_app
from
flask_restful
import
Resource
,
Api
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron.text_generation_utils
import
generate
GENERATE_NUM
=
0
lock
=
threading
.
Lock
()
class
MegatronGenerate
(
Resource
):
def
__init__
(
self
,
model
):
self
.
model
=
model
@
staticmethod
def
send_do_generate
():
choice
=
torch
.
cuda
.
LongTensor
([
GENERATE_NUM
])
torch
.
distributed
.
broadcast
(
choice
,
0
)
def
put
(
self
):
args
=
get_args
()
print
(
"request IP: "
+
str
(
request
.
remote_addr
))
print
(
json
.
dumps
(
request
.
get_json
()),
flush
=
True
)
print
(
"current time: "
,
datetime
.
datetime
.
now
())
sentences
=
request
.
get_json
()[
"sentences"
]
if
len
(
sentences
)
>
128
:
return
"Maximum number of sentences is 128"
,
400
tokens_to_generate
=
64
# Choosing hopefully sane default. Full sequence is slow
if
"tokens_to_generate"
in
request
.
get_json
():
tokens_to_generate
=
request
.
get_json
()[
"tokens_to_generate"
]
if
not
isinstance
(
tokens_to_generate
,
int
):
return
"tokens_to_generate must be an integer greater than 0"
if
tokens_to_generate
<
1
:
return
"tokens_to_generate must be an integer greater than 0"
all_probs
=
False
if
"all_probs"
in
request
.
get_json
():
all_probs
=
request
.
get_json
()[
"all_probs"
]
if
not
isinstance
(
all_probs
,
bool
):
return
"all_probs must be a boolean value"
temperature
=
args
.
temperature
if
"temperature"
in
request
.
get_json
():
temperature
=
request
.
get_json
()[
"temperature"
]
if
not
isinstance
(
temperature
,
float
)
or
not
\
0.0
<
temperature
<=
100.0
:
return
"temperature must be a positive float less than or equal to 100.0"
add_BOS
=
False
if
"add_BOS"
in
request
.
get_json
():
add_BOS
=
request
.
get_json
()[
"add_BOS"
]
if
not
isinstance
(
add_BOS
,
bool
):
return
"add_BOS must be a boolean value"
with
lock
:
# Need to get lock to keep multiple threads from hitting code
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
resp_sentences
,
resp_sentences_seg
,
output_logits
,
full_logits
,
tokens
=
generate
(
self
.
model
,
sentences
,
tokens_to_generate
,
all_probs
,
temperature
,
add_BOS
)
if
all_probs
:
return
jsonify
({
"sentences"
:
resp_sentences
,
"segments"
:
resp_sentences_seg
,
"logits"
:
output_logits
,
"all_logits"
:
full_logits
,
"tokens"
:
tokens
})
return
jsonify
({
"sentences"
:
resp_sentences
,
"segments"
:
resp_sentences_seg
,
"logits"
:
output_logits
})
class
MegatronServer
(
object
):
def
__init__
(
self
,
model
):
self
.
app
=
Flask
(
__name__
,
static_url_path
=
''
)
api
=
Api
(
self
.
app
)
api
.
add_resource
(
MegatronGenerate
,
'/generate'
,
resource_class_args
=
[
model
])
def
run
(
self
,
url
):
self
.
app
.
run
(
url
,
threaded
=
True
,
debug
=
False
)
megatron/text_generation_utils.py
View file @
cb00a196
This diff is collapsed.
Click to expand it.
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment