Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
c7fef593
Commit
c7fef593
authored
Mar 24, 2022
by
Sangkug Lym
Browse files
jit functions warmpups to (1) match fprop and recompute results and (2) remove grad_enable
parent
15f6bb1b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
57 additions
and
21 deletions
+57
-21
megatron/initialize.py
megatron/initialize.py
+40
-0
megatron/model/transformer.py
megatron/model/transformer.py
+15
-21
megatron/training.py
megatron/training.py
+2
-0
No files found.
megatron/initialize.py
View file @
c7fef593
...
...
@@ -31,6 +31,8 @@ from megatron import mpu
from
megatron.global_vars
import
set_global_variables
from
megatron.mpu
import
(
set_tensor_model_parallel_rank
,
set_tensor_model_parallel_world_size
)
from
megatron.model.transformer
import
bias_dropout_add_fused_train
from
megatron.model.fused_bias_gelu
import
bias_gelu
def
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{},
...
...
@@ -251,3 +253,41 @@ def _set_jit_fusion_options():
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
def
warmup_jit_function
():
""" Compilie JIT functions before the main training steps """
args
=
get_args
()
if
args
.
bf16
:
p
=
torch
.
bfloat16
elif
args
.
fp16
:
p
=
torch
.
float16
else
:
p
=
torch
.
float32
# Warmup fused bias+gelu
b
=
torch
.
rand
(
int
(
args
.
hidden_size
*
4
/
args
.
tensor_model_parallel_size
),
dtype
=
p
,
device
=
'cuda'
)
x
=
torch
.
rand
((
args
.
seq_length
,
args
.
micro_batch_size
,
int
(
args
.
hidden_size
*
4
/
args
.
tensor_model_parallel_size
)),
dtype
=
p
,
device
=
'cuda'
)
# Warmup JIT fusions with the input grad_enable state at both forward
# prop and recomputation
for
b_grad
,
x_grad
in
zip
([
True
,
True
],
[
False
,
True
]):
b
.
requires_grad
,
x
.
requires_grad
=
b_grad
,
x_grad
for
_
in
range
(
5
):
y
=
bias_gelu
(
b
,
x
)
del
b
,
x
,
y
# Warmup fused bias+dropout+add
input_size
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
x
=
torch
.
rand
(
input_size
,
dtype
=
p
,
device
=
'cuda'
)
r
=
torch
.
rand
(
input_size
,
dtype
=
p
,
device
=
'cuda'
)
b
=
torch
.
rand
((
args
.
hidden_size
),
dtype
=
p
,
device
=
'cuda'
).
expand_as
(
r
)
# Warmup JIT fusions with the input grad_enable state at both forward
# prop and recomputation
for
x_grad
,
b_grad
,
r_grad
in
zip
([
False
,
True
],
[
True
,
True
],
[
True
,
True
]):
x
.
requires_grad
,
b
.
requires_grad
,
r
.
requires_grad
=
x_grad
,
b_grad
,
r_grad
for
_
in
range
(
5
):
y
=
bias_dropout_add_fused_train
(
x
,
b
,
r
,
0.1
)
del
b
,
x
,
r
,
y
torch
.
cuda
.
empty_cache
()
megatron/model/transformer.py
View file @
c7fef593
...
...
@@ -564,8 +564,6 @@ class ParallelTransformerLayer(MegatronModule):
else
:
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
...
...
@@ -591,8 +589,6 @@ class ParallelTransformerLayer(MegatronModule):
else
:
residual
=
layernorm_input
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
...
...
@@ -612,8 +608,6 @@ class ParallelTransformerLayer(MegatronModule):
residual
=
layernorm_input
if
self
.
drop_path
is
None
:
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
output
=
bias_dropout_add_func
(
mlp_output
,
mlp_bias
.
expand_as
(
residual
),
...
...
megatron/training.py
View file @
c7fef593
...
...
@@ -42,6 +42,7 @@ from megatron.model import ModelType
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.initialize
import
warmup_jit_function
from
megatron.optimizer_param_scheduler
import
OptimizerParamScheduler
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.utils
import
check_adlr_autoresume_termination
...
...
@@ -99,6 +100,7 @@ def pretrain(train_valid_test_dataset_provider,
# Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron
(
extra_args_provider
=
extra_args_provider
,
args_defaults
=
args_defaults
)
warmup_jit_function
()
# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment