"...text-generation-inference.git" did not exist on "9d8f21cace66e8593bc559174de0eed3ecdab6a2"
Commit 589328ff authored by Michael Carilli's avatar Michael Carilli
Browse files

Support for custom batch types

parent 533e88d7
......@@ -27,7 +27,7 @@ different flags to `amp.initialize`.
[DCGAN example coming soon...](https://github.com/NVIDIA/apex/tree/master/examples/dcgan)
[Moving to the new Amp API](https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users) (for users of the deprecated tools formerly called "Amp" and "FP16_Optimizer")
[Moving to the new Amp API](https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users) (for users of the deprecated "Amp" and "FP16_Optimizer" APIs)
## 2. Distributed Training
......
......@@ -12,16 +12,20 @@ from ..parallel import DistributedDataParallel as apex_DDP
def to_type(dtype, t):
if not t.is_cuda:
# This should not be a hard error, since it may be legitimate.
print("Warning: An input tensor was not cuda. ")
if t.requires_grad:
# This should be a hard-ish error.
warn_or_err("input data requires grad. Since input data is not a model parameter,\n"
"its gradients will not be properly allreduced by DDP.")
if t.is_floating_point():
if isinstance(t, torch.Tensor):
if not t.is_cuda:
# This should not be a hard error, since it may be legitimate.
print("Warning: An input tensor was not cuda. ")
if t.requires_grad:
# This should be a hard-ish error.
warn_or_err("input data requires grad. Since input data is not a model parameter,\n"
"its gradients will not be properly allreduced by DDP.")
if t.is_floating_point():
return t.to(dtype)
return t
else:
# Trust the user's custom batch type, that's all I can do here.
return t.to(dtype)
return t
# Modified from torch.optim.optimizer.py. This is a bit more general than casted_args in utils.py.
......@@ -34,7 +38,17 @@ def applier(value, fn):
return {applier(k, fn) : applier(v, fn) for k, v in value.items()}
elif isinstance(value, container_abcs.Iterable):
return type(value)(applier(v, fn) for v in value)
elif hasattr(value, "to"): # Allow handling of custom batch classes
return fn(value)
else:
# Do I want this to fire off even if someone chooses to pass something ordinary like
# an int or float? May be more annoying than it's worth.
# print("Warning: unrecognized type in applier. If your input data is a custom class, "
# "provide it with a .to(dtype) method which converts its floating-point Tensors to dtype. "
# "Amp will check for your custom to() and invoke it to cast the batch's "
# "floating-point Tensors to the appropriate type. "
# "Also, if your data is a custom class, it is your responsibility to ensure that "
# "any Tensors you want to be cuda are already cuda."
return value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment