Commit 2c493fb3 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'activation-checkpoint-fix' into 'main'

fix bug in uniform activation recompute

See merge request ADLR/megatron-lm!551
parents 035cae2e fc2c81d3
......@@ -1036,9 +1036,10 @@ class ParallelTransformer(MegatronModule):
"""Forward method with activation checkpointing."""
def custom(start, end, is_transformer_engine=False):
def custom_forward(*args, **kwargs):
x_, *args = args
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(*args, **kwargs)
x_ = layer(x_, *args, **kwargs)
return x_
def custom_forward_transformer_engine(*args, **kwargs):
return custom_forward(*args, is_first_microbatch=is_first_microbatch, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment