Commit 53fd093d authored by Michael Carilli's avatar Michael Carilli
Browse files

Update Wil's code + typo

parent 3c53cf81
......@@ -229,7 +229,7 @@ def post_backward_with_master_weights_fused(self, scaler):
stash = self._amp_stash
stash.scale = scaler.loss_scale()
stash.grads = [[param.grad.data for param in group] for group in self.fp16_groups]
stash.output_params = [[param for param in in group] for group in self.fp16_groups]
stash.output_params = [[param for param in group] for group in self.fp16_groups]
norm_groups = []
skip = False
......
......@@ -65,7 +65,7 @@ struct AdamFunctor
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorList<DEPTH>& tl,
TensorListMetadata<DEPTH>& tl,
const float b1,
const float b2,
const float eps,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment