Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
7f9a48ba
Commit
7f9a48ba
authored
May 16, 2022
by
Vijay Korthikanti
Browse files
Merge branch 'main' into sequence_parallel
parents
00129014
d2394294
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
75 additions
and
14 deletions
+75
-14
megatron/initialize.py
megatron/initialize.py
+47
-4
megatron/model/fused_layer_norm.py
megatron/model/fused_layer_norm.py
+12
-1
megatron/model/transformer.py
megatron/model/transformer.py
+12
-8
megatron/mpu/layers.py
megatron/mpu/layers.py
+1
-1
megatron/training.py
megatron/training.py
+3
-0
No files found.
megatron/initialize.py
View file @
7f9a48ba
...
@@ -31,6 +31,8 @@ from megatron import mpu
...
@@ -31,6 +31,8 @@ from megatron import mpu
from
megatron.global_vars
import
set_global_variables
from
megatron.global_vars
import
set_global_variables
from
megatron.mpu
import
(
set_tensor_model_parallel_rank
,
from
megatron.mpu
import
(
set_tensor_model_parallel_rank
,
set_tensor_model_parallel_world_size
)
set_tensor_model_parallel_world_size
)
from
megatron.model.transformer
import
bias_dropout_add_fused_train
from
megatron.model.fused_bias_gelu
import
bias_gelu
def
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{},
def
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{},
...
@@ -64,9 +66,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
...
@@ -64,9 +66,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
print
(
'> setting random seeds to {} ...'
.
format
(
args
.
seed
))
print
(
'> setting random seeds to {} ...'
.
format
(
args
.
seed
))
_set_random_seed
(
args
.
seed
,
args
.
data_parallel_random_init
)
_set_random_seed
(
args
.
seed
,
args
.
data_parallel_random_init
)
# Set pytorch JIT layer fusion options.
_set_jit_fusion_options
()
args
=
get_args
()
args
=
get_args
()
if
args
.
lazy_mpu_init
:
if
args
.
lazy_mpu_init
:
args
.
use_cpu_initialization
=
True
args
.
use_cpu_initialization
=
True
...
@@ -230,7 +229,7 @@ def write_args_to_tensorboard():
...
@@ -230,7 +229,7 @@ def write_args_to_tensorboard():
global_step
=
args
.
iteration
)
global_step
=
args
.
iteration
)
def
_
set_jit_fusion_options
():
def
set_jit_fusion_options
():
"""Set PyTorch JIT layer fusion options."""
"""Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels
# flags required to enable jit fusion kernels
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
...
@@ -251,3 +250,47 @@ def _set_jit_fusion_options():
...
@@ -251,3 +250,47 @@ def _set_jit_fusion_options():
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
_warmup_jit_function
()
def
_warmup_jit_function
():
""" Compilie JIT functions before the main training steps """
args
=
get_args
()
if
args
.
bf16
:
dtype
=
torch
.
bfloat16
elif
args
.
fp16
:
dtype
=
torch
.
float16
else
:
dtype
=
torch
.
float32
# Warmup fused bias+gelu
bias
=
torch
.
rand
(
args
.
ffn_hidden_size
//
args
.
tensor_model_parallel_size
,
dtype
=
dtype
,
device
=
'cuda'
)
input
=
torch
.
rand
((
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
ffn_hidden_size
//
args
.
tensor_model_parallel_size
),
dtype
=
dtype
,
device
=
'cuda'
)
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for
bias_grad
,
input_grad
in
zip
([
True
,
True
],
[
False
,
True
]):
bias
.
requires_grad
,
input
.
requires_grad
=
bias_grad
,
input_grad
for
_
in
range
(
5
):
output
=
bias_gelu
(
bias
,
input
)
del
bias
,
input
,
output
# Warmup fused bias+dropout+add
input
=
torch
.
rand
((
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
),
dtype
=
dtype
,
device
=
'cuda'
)
residual
=
torch
.
rand
((
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
),
dtype
=
dtype
,
device
=
'cuda'
)
bias
=
torch
.
rand
((
args
.
hidden_size
),
dtype
=
dtype
,
device
=
'cuda'
).
expand_as
(
residual
)
dropout_rate
=
0.1
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for
input_grad
,
bias_grad
,
residual_grad
in
zip
([
False
,
True
],
[
True
,
True
],
[
True
,
True
]):
input
.
requires_grad
=
input_grad
bias
.
requires_grad
=
bias_grad
residual
.
requires_grad
=
residual_grad
for
_
in
range
(
5
):
output
=
bias_dropout_add_fused_train
(
input
,
bias
,
residual
,
dropout_rate
)
del
bias
,
input
,
residual
,
output
torch
.
cuda
.
empty_cache
()
megatron/model/fused_layer_norm.py
View file @
7f9a48ba
...
@@ -23,6 +23,8 @@ from torch.nn.parameter import Parameter
...
@@ -23,6 +23,8 @@ from torch.nn.parameter import Parameter
from
torch.nn
import
init
from
torch.nn
import
init
import
importlib
import
importlib
from
megatron.mpu
import
make_viewless_tensor
try
:
try
:
from
apex.contrib.layer_norm.layer_norm
import
FastLayerNormFN
from
apex.contrib.layer_norm.layer_norm
import
FastLayerNormFN
HAVE_PERSIST_LAYER_NORM
=
True
HAVE_PERSIST_LAYER_NORM
=
True
...
@@ -113,6 +115,15 @@ class MixedFusedLayerNorm(torch.nn.Module):
...
@@ -113,6 +115,15 @@ class MixedFusedLayerNorm(torch.nn.Module):
return
FusedLayerNormAffineFunction
.
apply
(
return
FusedLayerNormAffineFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
else
:
else
:
return
FastLayerNormFN
.
apply
(
output
=
FastLayerNormFN
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
input
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# a populated '_base' field). This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
output
=
make_viewless_tensor
(
inp
=
output
,
requires_grad
=
input
.
requires_grad
,
keep_graph
=
True
)
return
output
megatron/model/transformer.py
View file @
7f9a48ba
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
"""Transformer."""
"""Transformer."""
import
math
import
math
import
context
lib
from
contextlib
import
null
context
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -593,6 +593,13 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -593,6 +593,13 @@ class ParallelTransformerLayer(MegatronModule):
else
:
else
:
self
.
mlp
=
ParallelMLP
(
init_method
,
output_layer_init_method
)
self
.
mlp
=
ParallelMLP
(
init_method
,
output_layer_init_method
)
# Set bias+dropout+add fusion grad_enable execution handler.
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
use_nvfuser
=
TORCH_MAJOR
>
1
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>=
10
)
self
.
bias_dropout_add_exec_handler
=
\
nullcontext
if
use_nvfuser
else
torch
.
enable_grad
def
forward
(
self
,
hidden_states
,
attention_mask
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
inference_params
=
None
):
...
@@ -626,8 +633,7 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -626,8 +633,7 @@ class ParallelTransformerLayer(MegatronModule):
else
:
else
:
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
# re-enable torch grad to enable fused optimization.
with
self
.
bias_dropout_add_exec_handler
():
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_output
,
attention_bias
.
expand_as
(
residual
),
attention_bias
.
expand_as
(
residual
),
...
@@ -653,8 +659,7 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -653,8 +659,7 @@ class ParallelTransformerLayer(MegatronModule):
else
:
else
:
residual
=
layernorm_input
residual
=
layernorm_input
# re-enable torch grad to enable fused optimization.
with
self
.
bias_dropout_add_exec_handler
():
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_output
,
attention_bias
.
expand_as
(
residual
),
attention_bias
.
expand_as
(
residual
),
...
@@ -674,8 +679,7 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -674,8 +679,7 @@ class ParallelTransformerLayer(MegatronModule):
residual
=
layernorm_input
residual
=
layernorm_input
if
self
.
drop_path
is
None
:
if
self
.
drop_path
is
None
:
# re-enable torch grad to enable fused optimization.
with
self
.
bias_dropout_add_exec_handler
():
with
torch
.
enable_grad
():
output
=
bias_dropout_add_func
(
output
=
bias_dropout_add_func
(
mlp_output
,
mlp_output
,
mlp_bias
.
expand_as
(
residual
),
mlp_bias
.
expand_as
(
residual
),
...
@@ -909,7 +913,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -909,7 +913,7 @@ class ParallelTransformer(MegatronModule):
if
self
.
sequence_parallel
:
if
self
.
sequence_parallel
:
rng_context
=
mpu
.
get_cuda_rng_tracker
().
fork
()
rng_context
=
mpu
.
get_cuda_rng_tracker
().
fork
()
else
:
else
:
rng_context
=
contextlib
.
nullcontext
rng_context
=
nullcontext
with
rng_context
:
with
rng_context
:
# Forward pass.
# Forward pass.
...
...
megatron/mpu/layers.py
View file @
7f9a48ba
...
@@ -241,7 +241,6 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -241,7 +241,6 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
import
fused_dense_cuda
input
,
weight
=
ctx
.
saved_tensors
input
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
use_bias
=
ctx
.
use_bias
...
@@ -296,6 +295,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -296,6 +295,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
if
ctx
.
gradient_accumulation_fusion
:
if
ctx
.
gradient_accumulation_fusion
:
import
fused_dense_cuda
fused_dense_cuda
.
wgrad_gemm_accum_fp32
(
total_input
,
grad_output
,
weight
.
main_grad
)
fused_dense_cuda
.
wgrad_gemm_accum_fp32
(
total_input
,
grad_output
,
weight
.
main_grad
)
grad_weight
=
None
grad_weight
=
None
else
:
else
:
...
...
megatron/training.py
View file @
7f9a48ba
...
@@ -43,6 +43,7 @@ from megatron.model import ModelType
...
@@ -43,6 +43,7 @@ from megatron.model import ModelType
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.initialize
import
set_jit_fusion_options
from
megatron.optimizer_param_scheduler
import
OptimizerParamScheduler
from
megatron.optimizer_param_scheduler
import
OptimizerParamScheduler
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
check_adlr_autoresume_termination
...
@@ -100,6 +101,8 @@ def pretrain(train_valid_test_dataset_provider,
...
@@ -100,6 +101,8 @@ def pretrain(train_valid_test_dataset_provider,
# Initalize and get arguments, timers, and Tensorboard writer.
# Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron
(
extra_args_provider
=
extra_args_provider
,
initialize_megatron
(
extra_args_provider
=
extra_args_provider
,
args_defaults
=
args_defaults
)
args_defaults
=
args_defaults
)
# Set pytorch JIT layer fusion options and warmup JIT functions.
set_jit_fusion_options
()
# Adjust the startup time so it reflects the largest value.
# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
# This will be closer to what scheduler will see (outside of
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment