Commit 56ea6d78 authored by Michael Carilli's avatar Michael Carilli
Browse files

saving for carl to review

parent 3c7a0e44
...@@ -82,17 +82,20 @@ def casted_args(cast_fn, args, kwargs): ...@@ -82,17 +82,20 @@ def casted_args(cast_fn, args, kwargs):
return new_args return new_args
def cached_cast(cast_fn, x, cache): def cached_cast(cast_fn, x, cache):
print("Calling cached_cast")
if is_nested(x): if is_nested(x):
return type(x)([cached_cast(y) for y in x]) return type(x)([cached_cast(y) for y in x])
if x in cache: if x in cache:
cached_x = cache[x] cached_x = cache[x]
# During eval, it's possible to end up caching casted weights if x.requires_grad and cached_x.requires_grad:
# with requires_grad == False. This is then a problem when they # Check to make sure x is actually cached_x's autograd parent.
# get reused on the next train iter. So we ensure that cached if cached_x.grad_fn.next_functions[1][0].variable is not x:
# weights have same requires_grad flag of most recent request. raise RuntimeError("x and cache[x] both require grad, but x is not "
"cache[x]'s parent. This is likely an error.")
if x.requires_grad != cached_x.requires_grad: if x.requires_grad != cached_x.requires_grad:
cached_x.requires_grad_(x.requires_grad) del cache[x]
return cache[x] else:
return cached_x
casted_x = cast_fn(x) casted_x = cast_fn(x)
cache[x] = casted_x cache[x] = casted_x
......
...@@ -292,7 +292,8 @@ class DistributedDataParallel(Module): ...@@ -292,7 +292,8 @@ class DistributedDataParallel(Module):
# Sanity checks that all the buckets were kicked off # Sanity checks that all the buckets were kicked off
if self.next_bucket != self.num_buckets: if self.next_bucket != self.num_buckets:
raise RuntimeError("In epilogue, next_bucket != num_buckets. " raise RuntimeError("In epilogue, next_bucket ({}) != num_buckets ({}). ".format(
self.next_bucket, self.num_buckets),
"This probably indicates some buckets were not allreduced.") "This probably indicates some buckets were not allreduced.")
for actual, expected in zip(self.buckets_ready_size, self.bucket_sizes): for actual, expected in zip(self.buckets_ready_size, self.bucket_sizes):
...@@ -389,6 +390,8 @@ class DistributedDataParallel(Module): ...@@ -389,6 +390,8 @@ class DistributedDataParallel(Module):
def allreduce_fallback(self): def allreduce_fallback(self):
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None] grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
print("In allreduce_fallback: {}".format(len(grads)))
split_buckets = split_half_float_double(grads) split_buckets = split_half_float_double(grads)
# If retain_allreduce_buffers is True and delay_allreduce is False, # If retain_allreduce_buffers is True and delay_allreduce is False,
...@@ -413,6 +416,7 @@ class DistributedDataParallel(Module): ...@@ -413,6 +416,7 @@ class DistributedDataParallel(Module):
self.buckets[bucket_idx][bucket_loc] = param.grad.data self.buckets[bucket_idx][bucket_loc] = param.grad.data
self.buckets_ready_size[bucket_idx] += 1 self.buckets_ready_size[bucket_idx] += 1
print(self.buckets_ready_size)
if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]: if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]:
if bucket_idx == self.next_bucket: if bucket_idx == self.next_bucket:
...@@ -472,6 +476,9 @@ class DistributedDataParallel(Module): ...@@ -472,6 +476,9 @@ class DistributedDataParallel(Module):
self.allreduce_buffers = [None for _ in range(self.num_buckets)] self.allreduce_buffers = [None for _ in range(self.num_buckets)]
self.next_bucket = 0 self.next_bucket = 0
self.ready_buckets_not_reduced = set() self.ready_buckets_not_reduced = set()
print(len(param_list), len(self.active_params), [len(b) for b in self.buckets],
self.needs_refresh)
self.active_params = param_list self.active_params = param_list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment