Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
6b5ae488
Commit
6b5ae488
authored
Sep 01, 2021
by
rprenger
Browse files
Fixing merge conflicts with main
parents
3d718bfc
3860e995
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
72 additions
and
78 deletions
+72
-78
megatron/arguments.py
megatron/arguments.py
+7
-1
megatron/initialize.py
megatron/initialize.py
+23
-9
megatron/model/fused_bias_gelu.py
megatron/model/fused_bias_gelu.py
+0
-4
megatron/model/transformer.py
megatron/model/transformer.py
+19
-7
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+0
-2
megatron/mpu/random.py
megatron/mpu/random.py
+23
-55
No files found.
megatron/arguments.py
View file @
6b5ae488
...
@@ -240,9 +240,15 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -240,9 +240,15 @@ def parse_args(extra_args_provider=None, defaults={},
'residual connection in fp32 only supported when using fp16 or bf16.'
'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing.
# Activation checkpointing.
if
args
.
distribute_checkpointed_activations
:
if
args
.
distribute_checkpointed_activations
:
assert
args
.
tensor_model_parallel_size
>
1
,
'can distribute '
\
'checkpointed activations only across tensor model '
\
'parallel groups'
assert
args
.
activations_checkpoint_method
is
not
None
,
\
assert
args
.
activations_checkpoint_method
is
not
None
,
\
'for distribute-checkpointed-activations to work you '
\
'for distribute-checkpointed-activations to work you '
\
'need to use a valid checkpoint-activation method (
\'
uniform
\'
or
\'
block
\'
)'
'need to use a activation-checkpoint method '
assert
args
.
num_layers_per_virtual_pipeline_stage
is
None
,
\
'currently distrobuted checkpoint activations only supported for '
\
'nointerleaved pipeline parallelism'
_print_args
(
args
)
_print_args
(
args
)
return
args
return
args
...
...
megatron/initialize.py
View file @
6b5ae488
...
@@ -64,6 +64,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
...
@@ -64,6 +64,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
print
(
'> setting random seeds to {} ...'
.
format
(
args
.
seed
))
print
(
'> setting random seeds to {} ...'
.
format
(
args
.
seed
))
_set_random_seed
(
args
.
seed
)
_set_random_seed
(
args
.
seed
)
# Set pytorch JIT layer fusion options.
_set_jit_fusion_options
()
args
=
get_args
()
args
=
get_args
()
if
args
.
lazy_mpu_init
:
if
args
.
lazy_mpu_init
:
args
.
use_cpu_initialization
=
True
args
.
use_cpu_initialization
=
True
...
@@ -78,9 +81,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
...
@@ -78,9 +81,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Megatron's MPU is the master. Complete initialization right away.
# Megatron's MPU is the master. Complete initialization right away.
finish_mpu_init
()
finish_mpu_init
()
# Initialize memory buffers.
_initialize_mem_buffs
()
# Autoresume.
# Autoresume.
_init_autoresume
()
_init_autoresume
()
...
@@ -226,10 +226,24 @@ def write_args_to_tensorboard():
...
@@ -226,10 +226,24 @@ def write_args_to_tensorboard():
global_step
=
args
.
iteration
)
global_step
=
args
.
iteration
)
def
_initialize_mem_buffs
():
def
_set_jit_fusion_options
():
"""Initialize manually allocated static memory."""
"""Set PyTorch JIT layer fusion options."""
args
=
get_args
()
# flags required to enable jit fusion kernels
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
if
(
TORCH_MAJOR
>
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>=
10
):
# nvfuser
torch
.
_C
.
_jit_set_profiling_executor
(
True
)
torch
.
_C
.
_jit_set_profiling_mode
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
False
)
torch
.
_C
.
_jit_set_texpr_fuser_enabled
(
False
)
torch
.
_C
.
_jit_set_nvfuser_enabled
(
True
)
torch
.
_C
.
_debug_set_autodiff_subgraph_inlining
(
False
)
else
:
# legacy pytorch fuser
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
torch
.
_C
.
_jit_set_profiling_executor
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
# Initialize memory for checkpointed activations.
if
args
.
distribute_checkpointed_activations
:
mpu
.
init_checkpointed_activations_memory_buffer
()
megatron/model/fused_bias_gelu.py
View file @
6b5ae488
...
@@ -15,10 +15,6 @@
...
@@ -15,10 +15,6 @@
import
torch
import
torch
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
torch
.
_C
.
_jit_set_profiling_executor
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
###### BIAS GELU FUSION/ NO AUTOGRAD ################
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2*pi)-> 0.3989423
...
...
megatron/model/transformer.py
View file @
6b5ae488
...
@@ -27,11 +27,6 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
...
@@ -27,11 +27,6 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
# flags required to enable jit fusion kernels
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
torch
.
_C
.
_jit_set_profiling_executor
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
""" We use the following notation throughout this file:
""" We use the following notation throughout this file:
h: hidden size
h: hidden size
...
@@ -544,6 +539,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -544,6 +539,7 @@ class ParallelTransformer(MegatronModule):
# Store activation checkpoiting flag.
# Store activation checkpoiting flag.
self
.
activations_checkpoint_method
=
args
.
activations_checkpoint_method
self
.
activations_checkpoint_method
=
args
.
activations_checkpoint_method
self
.
activations_checkpoint_num_layers
=
args
.
activations_checkpoint_num_layers
self
.
activations_checkpoint_num_layers
=
args
.
activations_checkpoint_num_layers
self
.
distribute_checkpointed_activations
=
args
.
distribute_checkpointed_activations
# Number of layers.
# Number of layers.
assert
args
.
num_layers
%
mpu
.
get_pipeline_model_parallel_world_size
()
==
0
,
\
assert
args
.
num_layers
%
mpu
.
get_pipeline_model_parallel_world_size
()
==
0
,
\
...
@@ -607,8 +603,22 @@ class ParallelTransformer(MegatronModule):
...
@@ -607,8 +603,22 @@ class ParallelTransformer(MegatronModule):
return
x_
return
x_
return
custom_forward
return
custom_forward
# Make sure memory is freed.
def
distribute_checkpointed_activations_helper
(
layer_number
):
mpu
.
reset_checkpointed_activations_memory_buffer
()
"""Distribute checkpointed activations across the tensor model
Parallel ranks if the `distribute-checkpointed-activations
is on and either of the following conditions is met:
- it is not the first layer in the in the pipeline stage.
The first layer is used in the pipeline parallelism
and changing its shape throws error in the backward pass.
- we are at the first pipline stage so the input tensor is
not used in pipeline parallelism. Note that no pipeline
parallelism is a special case of this.
"""
not_first_layer_in_pipeline_stage
=
(
layer_number
>
0
)
is_first_pipeline_stage
=
(
mpu
.
get_pipeline_model_parallel_rank
()
==
0
)
return
self
.
distribute_checkpointed_activations
and
\
(
not_first_layer_in_pipeline_stage
or
is_first_pipeline_stage
)
if
self
.
activations_checkpoint_method
==
'uniform'
:
if
self
.
activations_checkpoint_method
==
'uniform'
:
# Uniformly divide the total number of Transformer layers and checkpoint
# Uniformly divide the total number of Transformer layers and checkpoint
...
@@ -618,6 +628,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -618,6 +628,7 @@ class ParallelTransformer(MegatronModule):
while
l
<
self
.
num_layers
:
while
l
<
self
.
num_layers
:
hidden_states
=
mpu
.
checkpoint
(
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
self
.
activations_checkpoint_num_layers
),
custom
(
l
,
l
+
self
.
activations_checkpoint_num_layers
),
distribute_checkpointed_activations_helper
(
l
),
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
l
+=
self
.
activations_checkpoint_num_layers
l
+=
self
.
activations_checkpoint_num_layers
elif
self
.
activations_checkpoint_method
==
'block'
:
elif
self
.
activations_checkpoint_method
==
'block'
:
...
@@ -628,6 +639,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -628,6 +639,7 @@ class ParallelTransformer(MegatronModule):
if
l
<
self
.
activations_checkpoint_num_layers
:
if
l
<
self
.
activations_checkpoint_num_layers
:
hidden_states
=
mpu
.
checkpoint
(
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
1
),
custom
(
l
,
l
+
1
),
distribute_checkpointed_activations_helper
(
l
),
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
else
:
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
=
custom
(
l
,
l
+
1
)(
...
...
megatron/mpu/__init__.py
View file @
6b5ae488
...
@@ -56,9 +56,7 @@ from .mappings import scatter_to_tensor_model_parallel_region
...
@@ -56,9 +56,7 @@ from .mappings import scatter_to_tensor_model_parallel_region
from
.random
import
checkpoint
from
.random
import
checkpoint
from
.random
import
get_cuda_rng_tracker
from
.random
import
get_cuda_rng_tracker
from
.random
import
init_checkpointed_activations_memory_buffer
from
.random
import
model_parallel_cuda_manual_seed
from
.random
import
model_parallel_cuda_manual_seed
from
.random
import
reset_checkpointed_activations_memory_buffer
from
.random
import
gather_split_1d_tensor
from
.random
import
gather_split_1d_tensor
from
.random
import
split_tensor_into_1d_equal_chunks
from
.random
import
split_tensor_into_1d_equal_chunks
...
...
megatron/mpu/random.py
View file @
6b5ae488
...
@@ -37,46 +37,6 @@ from .initialize import get_tensor_model_parallel_world_size
...
@@ -37,46 +37,6 @@ from .initialize import get_tensor_model_parallel_world_size
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
'model-parallel-rng'
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
'model-parallel-rng'
# Whether apply model parallelsim to checkpointed hidden states.
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
=
None
def
init_checkpointed_activations_memory_buffer
():
"""Initializ the memory buffer for the checkpointed activations."""
args
=
get_args
()
per_layer
=
args
.
micro_batch_size
*
args
.
max_position_embeddings
*
\
args
.
hidden_size
//
args
.
tensor_model_parallel_size
num_layers
=
args
.
num_layers
//
mpu
.
get_pipeline_model_parallel_world_size
()
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
num_layers
=
num_layers
//
args
.
virtual_pipeline_model_parallel_size
if
args
.
activations_checkpoint_method
==
'uniform'
:
assert
num_layers
%
args
.
activations_checkpoint_num_layers
==
0
,
\
'total number of layers is not divisible by checkpoint-chunk_size'
num_checkpointer_layers
=
args
.
num_layers
//
args
.
activations_checkpoint_num_layers
elif
args
.
activations_checkpoint_method
==
'block'
:
assert
args
.
activations_checkpoint_num_layers
<=
num_layers
,
\
'total number of layers is fewer than the number of layers to checkpoint'
num_checkpointer_layers
=
args
.
activations_checkpoint_num_layers
numel
=
per_layer
*
num_checkpointer_layers
dtype
=
torch
.
half
if
not
args
.
fp16
:
dtype
=
torch
.
float
global
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
assert
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
None
,
\
'checkpointed activations memory buffer is already allocated.'
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
=
allocate_mem_buff
(
'checkpointed activations'
,
numel
,
dtype
,
track_usage
=
False
)
def
reset_checkpointed_activations_memory_buffer
():
"""Reset the memory used for checkpointing."""
if
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
not
None
:
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
.
reset
()
def
_set_cuda_rng_state
(
new_state
,
device
=-
1
):
def
_set_cuda_rng_state
(
new_state
,
device
=-
1
):
"""Sets the random number generator state of the current GPU.
"""Sets the random number generator state of the current GPU.
...
@@ -110,14 +70,21 @@ def _set_cuda_rng_state(new_state, device=-1):
...
@@ -110,14 +70,21 @@ def _set_cuda_rng_state(new_state, device=-1):
_lazy_call
(
cb
)
_lazy_call
(
cb
)
def
split_tensor_into_1d_equal_chunks
(
tensor
):
def
split_tensor_into_1d_equal_chunks
(
tensor
,
new_buffer
=
False
):
"""Break a tensor into equal 1D chunks."""
"""Break a tensor into equal 1D chunks."""
data
=
tensor
.
view
(
-
1
)
partition_size
=
torch
.
numel
(
tensor
)
//
\
partition_size
=
torch
.
numel
(
data
)
//
get_tensor_model_parallel_world_size
()
get_tensor_model_parallel_world_size
()
start_index
=
partition_size
*
get_tensor_model_parallel_rank
()
start_index
=
partition_size
*
get_tensor_model_parallel_rank
()
end_index
=
start_index
+
partition_size
end_index
=
start_index
+
partition_size
return
data
[
start_index
:
end_index
]
if
new_buffer
:
data
=
torch
.
empty
(
partition_size
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
data
.
copy_
(
tensor
.
view
(
-
1
)[
start_index
:
end_index
])
else
:
data
=
tensor
.
view
(
-
1
)[
start_index
:
end_index
]
return
data
def
gather_split_1d_tensor
(
tensor
):
def
gather_split_1d_tensor
(
tensor
):
"""Opposite of above function, gather values from model parallel ranks."""
"""Opposite of above function, gather values from model parallel ranks."""
...
@@ -259,8 +226,10 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -259,8 +226,10 @@ class CheckpointFunction(torch.autograd.Function):
tracked/set/reset.
tracked/set/reset.
"""
"""
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
run_function
,
*
args
):
def
forward
(
ctx
,
run_function
,
distribute_checkpointed_activations
,
*
args
):
ctx
.
run_function
=
run_function
ctx
.
run_function
=
run_function
ctx
.
distribute_checkpointed_activations
\
=
distribute_checkpointed_activations
# Copy the rng states.
# Copy the rng states.
ctx
.
fwd_cpu_rng_state
=
torch
.
get_rng_state
()
ctx
.
fwd_cpu_rng_state
=
torch
.
get_rng_state
()
...
@@ -272,16 +241,14 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -272,16 +241,14 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
# the chunk corresponding to the current rank.
if
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
not
N
on
e
:
if
distribute_checkpointed_activati
on
s
:
ctx
.
input_0_shape
=
args
[
0
].
data
.
shape
ctx
.
input_0_shape
=
args
[
0
].
data
.
shape
args
[
0
].
data
=
split_tensor_into_1d_equal_chunks
(
args
[
0
].
data
)
args
[
0
].
data
=
split_tensor_into_1d_equal_chunks
(
args
[
0
].
data
,
args
[
0
].
data
=
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
.
add
(
new_buffer
=
True
)
args
[
0
].
data
)
# Store everything.
# Store everything.
ctx
.
save_for_backward
(
*
args
)
ctx
.
save_for_backward
(
*
args
)
return
outputs
return
outputs
@
staticmethod
@
staticmethod
...
@@ -290,7 +257,7 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -290,7 +257,7 @@ class CheckpointFunction(torch.autograd.Function):
raise
RuntimeError
(
"Checkpointing is not compatible with .grad(), "
raise
RuntimeError
(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
"please use .backward() if possible"
)
inputs
=
ctx
.
saved_tensors
inputs
=
ctx
.
saved_tensors
if
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
not
N
on
e
:
if
ctx
.
distribute_checkpointed_activati
on
s
:
inputs
[
0
].
data
=
gather_split_1d_tensor
(
inputs
[
0
].
data
)
inputs
[
0
].
data
=
gather_split_1d_tensor
(
inputs
[
0
].
data
)
inputs
[
0
].
data
=
inputs
[
0
].
data
.
view
(
ctx
.
input_0_shape
)
inputs
[
0
].
data
=
inputs
[
0
].
data
.
view
(
ctx
.
input_0_shape
)
...
@@ -319,10 +286,11 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -319,10 +286,11 @@ class CheckpointFunction(torch.autograd.Function):
torch
.
autograd
.
backward
(
outputs
,
args
)
torch
.
autograd
.
backward
(
outputs
,
args
)
grads
=
tuple
(
inp
.
grad
if
isinstance
(
inp
,
torch
.
Tensor
)
else
inp
grads
=
tuple
(
inp
.
grad
if
isinstance
(
inp
,
torch
.
Tensor
)
else
inp
for
inp
in
detached_inputs
)
for
inp
in
detached_inputs
)
return
(
None
,)
+
grads
return
(
None
,
None
)
+
grads
def
checkpoint
(
function
,
*
args
):
def
checkpoint
(
function
,
distribute_checkpointed_activations
,
*
args
):
"""Checkpoint a model or part of the model.
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint."""
This has been directly copied from torch.utils.checkpoint."""
return
CheckpointFunction
.
apply
(
function
,
*
args
)
return
CheckpointFunction
.
apply
(
function
,
distribute_checkpointed_activations
,
*
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment