Commit 53fd093d authored by Michael Carilli's avatar Michael Carilli
Browse files

Update Wil's code + typo

parent 3c53cf81
...@@ -229,7 +229,7 @@ def post_backward_with_master_weights_fused(self, scaler): ...@@ -229,7 +229,7 @@ def post_backward_with_master_weights_fused(self, scaler):
stash = self._amp_stash stash = self._amp_stash
stash.scale = scaler.loss_scale() stash.scale = scaler.loss_scale()
stash.grads = [[param.grad.data for param in group] for group in self.fp16_groups] stash.grads = [[param.grad.data for param in group] for group in self.fp16_groups]
stash.output_params = [[param for param in in group] for group in self.fp16_groups] stash.output_params = [[param for param in group] for group in self.fp16_groups]
norm_groups = [] norm_groups = []
skip = False skip = False
......
...@@ -65,7 +65,7 @@ struct AdamFunctor ...@@ -65,7 +65,7 @@ struct AdamFunctor
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int chunk_size,
volatile int* noop_gmem, volatile int* noop_gmem,
TensorList<DEPTH>& tl, TensorListMetadata<DEPTH>& tl,
const float b1, const float b1,
const float b2, const float b2,
const float eps, const float eps,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment