Commit f257d0ee authored by Mohammad's avatar Mohammad
Browse files

added parameters sharing

parent 80f90dcd
......@@ -405,7 +405,6 @@ class ParallelTransformer(MegatronModule):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
layers_ = self.layers[start:end]
x_ = inputs[0]
for index in range(start, end):
layer = self._get_layer(index)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment