Unverified Commit 607fcc43 authored by buptzyb's avatar buptzyb Committed by GitHub
Browse files

[PyTorch] Support bf16+fp8 cudagraph (#2098)



* support bf16+fp8 model
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

---------
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 4285874d
......@@ -850,7 +850,7 @@ def make_graphed_callables(
num_warmup_iters: int = 3,
allow_unused_input: bool = False,
sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
fp8_enabled: bool = False,
fp8_enabled: SingleOrTuple[bool] = False,
fp8_calibrating: bool = False,
fp8_recipe: Optional[Recipe] = None,
fp8_group: Optional[dist_group_type] = None,
......@@ -896,8 +896,9 @@ def make_graphed_callables(
FP8-related parameters
----------------------
fp8_enabled: bool, default = `True`
whether or not to enable fp8
fp8_enabled: (tuple of) bool, default = `False`
whether or not to enable fp8.
If tuple, the length must match the number of modules.
fp8_calibrating: bool, default = `False`
calibration mode allows collecting statistics such as amax and scale
data of fp8 tensors even when executing without fp8 enabled. This is
......@@ -919,17 +920,25 @@ def make_graphed_callables(
"""
set_capture_start()
if fp8_enabled and fp8_recipe is None:
fp8_recipe = get_default_fp8_recipe()
elif not fp8_enabled:
fp8_recipe = None
# Handle single module.
just_one_callable = False
if not isinstance(modules, tuple):
just_one_callable = True
modules = (modules,)
if not isinstance(fp8_enabled, tuple):
assert isinstance(fp8_enabled, bool), "fp8_enabled must be a bool or a tuple of bools"
fp8_enabled = (fp8_enabled,) * len(modules)
else:
assert len(fp8_enabled) == len(
modules
), f"fp8_enabled length ({len(fp8_enabled)}) must match modules length ({len(modules)})"
if any(fp8_enabled) and fp8_recipe is None:
fp8_recipe = get_default_fp8_recipe()
elif not any(fp8_enabled):
fp8_recipe = None
module_uses_fp8 = dict(zip((id(m) for m in modules), fp8_enabled))
# Store FP8 tensors to reset later.
saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe)
......@@ -944,15 +953,15 @@ def make_graphed_callables(
old_call_funcs[block_cls] = block_cls.__call__
# Wrap the original call function of the module class.
def call_func(*args, **kwargs):
def call_func(self, *args, **kwargs):
with fp8_autocast(
enabled=fp8_enabled,
enabled=module_uses_fp8.get(id(self), False),
calibrating=fp8_calibrating,
fp8_recipe=fp8_recipe,
fp8_group=fp8_group,
_graph=True,
):
outputs = old_call_funcs[block_cls](*args, **kwargs)
outputs = old_call_funcs[block_cls](self, *args, **kwargs)
return outputs
block_cls.__call__ = call_func
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment