"docs/git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "72c6bab24f398dbc583a26508dd9ee1f3dbc4fc2"
Commit 589328ff authored by Michael Carilli's avatar Michael Carilli
Browse files

Support for custom batch types

parent 533e88d7
...@@ -27,7 +27,7 @@ different flags to `amp.initialize`. ...@@ -27,7 +27,7 @@ different flags to `amp.initialize`.
[DCGAN example coming soon...](https://github.com/NVIDIA/apex/tree/master/examples/dcgan) [DCGAN example coming soon...](https://github.com/NVIDIA/apex/tree/master/examples/dcgan)
[Moving to the new Amp API](https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users) (for users of the deprecated tools formerly called "Amp" and "FP16_Optimizer") [Moving to the new Amp API](https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users) (for users of the deprecated "Amp" and "FP16_Optimizer" APIs)
## 2. Distributed Training ## 2. Distributed Training
......
...@@ -12,16 +12,20 @@ from ..parallel import DistributedDataParallel as apex_DDP ...@@ -12,16 +12,20 @@ from ..parallel import DistributedDataParallel as apex_DDP
def to_type(dtype, t): def to_type(dtype, t):
if not t.is_cuda: if isinstance(t, torch.Tensor):
# This should not be a hard error, since it may be legitimate. if not t.is_cuda:
print("Warning: An input tensor was not cuda. ") # This should not be a hard error, since it may be legitimate.
if t.requires_grad: print("Warning: An input tensor was not cuda. ")
# This should be a hard-ish error. if t.requires_grad:
warn_or_err("input data requires grad. Since input data is not a model parameter,\n" # This should be a hard-ish error.
"its gradients will not be properly allreduced by DDP.") warn_or_err("input data requires grad. Since input data is not a model parameter,\n"
if t.is_floating_point(): "its gradients will not be properly allreduced by DDP.")
if t.is_floating_point():
return t.to(dtype)
return t
else:
# Trust the user's custom batch type, that's all I can do here.
return t.to(dtype) return t.to(dtype)
return t
# Modified from torch.optim.optimizer.py. This is a bit more general than casted_args in utils.py. # Modified from torch.optim.optimizer.py. This is a bit more general than casted_args in utils.py.
...@@ -34,7 +38,17 @@ def applier(value, fn): ...@@ -34,7 +38,17 @@ def applier(value, fn):
return {applier(k, fn) : applier(v, fn) for k, v in value.items()} return {applier(k, fn) : applier(v, fn) for k, v in value.items()}
elif isinstance(value, container_abcs.Iterable): elif isinstance(value, container_abcs.Iterable):
return type(value)(applier(v, fn) for v in value) return type(value)(applier(v, fn) for v in value)
elif hasattr(value, "to"): # Allow handling of custom batch classes
return fn(value)
else: else:
# Do I want this to fire off even if someone chooses to pass something ordinary like
# an int or float? May be more annoying than it's worth.
# print("Warning: unrecognized type in applier. If your input data is a custom class, "
# "provide it with a .to(dtype) method which converts its floating-point Tensors to dtype. "
# "Amp will check for your custom to() and invoke it to cast the batch's "
# "floating-point Tensors to the appropriate type. "
# "Also, if your data is a custom class, it is your responsibility to ensure that "
# "any Tensors you want to be cuda are already cuda."
return value return value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment