Unverified Commit 607fcc43 authored by buptzyb's avatar buptzyb Committed by GitHub
Browse files

[PyTorch] Support bf16+fp8 cudagraph (#2098)



* support bf16+fp8 model
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

---------
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 4285874d
...@@ -850,7 +850,7 @@ def make_graphed_callables( ...@@ -850,7 +850,7 @@ def make_graphed_callables(
num_warmup_iters: int = 3, num_warmup_iters: int = 3,
allow_unused_input: bool = False, allow_unused_input: bool = False,
sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
fp8_enabled: bool = False, fp8_enabled: SingleOrTuple[bool] = False,
fp8_calibrating: bool = False, fp8_calibrating: bool = False,
fp8_recipe: Optional[Recipe] = None, fp8_recipe: Optional[Recipe] = None,
fp8_group: Optional[dist_group_type] = None, fp8_group: Optional[dist_group_type] = None,
...@@ -896,8 +896,9 @@ def make_graphed_callables( ...@@ -896,8 +896,9 @@ def make_graphed_callables(
FP8-related parameters FP8-related parameters
---------------------- ----------------------
fp8_enabled: bool, default = `True` fp8_enabled: (tuple of) bool, default = `False`
whether or not to enable fp8 whether or not to enable fp8.
If tuple, the length must match the number of modules.
fp8_calibrating: bool, default = `False` fp8_calibrating: bool, default = `False`
calibration mode allows collecting statistics such as amax and scale calibration mode allows collecting statistics such as amax and scale
data of fp8 tensors even when executing without fp8 enabled. This is data of fp8 tensors even when executing without fp8 enabled. This is
...@@ -919,17 +920,25 @@ def make_graphed_callables( ...@@ -919,17 +920,25 @@ def make_graphed_callables(
""" """
set_capture_start() set_capture_start()
if fp8_enabled and fp8_recipe is None:
fp8_recipe = get_default_fp8_recipe()
elif not fp8_enabled:
fp8_recipe = None
# Handle single module. # Handle single module.
just_one_callable = False just_one_callable = False
if not isinstance(modules, tuple): if not isinstance(modules, tuple):
just_one_callable = True just_one_callable = True
modules = (modules,) modules = (modules,)
if not isinstance(fp8_enabled, tuple):
assert isinstance(fp8_enabled, bool), "fp8_enabled must be a bool or a tuple of bools"
fp8_enabled = (fp8_enabled,) * len(modules)
else:
assert len(fp8_enabled) == len(
modules
), f"fp8_enabled length ({len(fp8_enabled)}) must match modules length ({len(modules)})"
if any(fp8_enabled) and fp8_recipe is None:
fp8_recipe = get_default_fp8_recipe()
elif not any(fp8_enabled):
fp8_recipe = None
module_uses_fp8 = dict(zip((id(m) for m in modules), fp8_enabled))
# Store FP8 tensors to reset later. # Store FP8 tensors to reset later.
saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe) saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe)
...@@ -944,15 +953,15 @@ def make_graphed_callables( ...@@ -944,15 +953,15 @@ def make_graphed_callables(
old_call_funcs[block_cls] = block_cls.__call__ old_call_funcs[block_cls] = block_cls.__call__
# Wrap the original call function of the module class. # Wrap the original call function of the module class.
def call_func(*args, **kwargs): def call_func(self, *args, **kwargs):
with fp8_autocast( with fp8_autocast(
enabled=fp8_enabled, enabled=module_uses_fp8.get(id(self), False),
calibrating=fp8_calibrating, calibrating=fp8_calibrating,
fp8_recipe=fp8_recipe, fp8_recipe=fp8_recipe,
fp8_group=fp8_group, fp8_group=fp8_group,
_graph=True, _graph=True,
): ):
outputs = old_call_funcs[block_cls](*args, **kwargs) outputs = old_call_funcs[block_cls](self, *args, **kwargs)
return outputs return outputs
block_cls.__call__ = call_func block_cls.__call__ = call_func
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment