"...git@developer.sourcefind.cn:modelzoo/stylegan2_mmcv.git" did not exist on "1401de15d079af4d9d9f995f2d57ddb6d930d7f0"
Unverified Commit 160be219 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Backward compatible Fixes (#1631)



* expose NVTE_FP8_COLLECTION_NAME, update_collections, get_delayed_scaling

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent b0ad8ef0
...@@ -83,7 +83,8 @@ _load_library() ...@@ -83,7 +83,8 @@ _load_library()
from . import flax from . import flax
from . import quantize from . import quantize
from .quantize import fp8_autocast from .quantize import fp8_autocast, update_collections, get_delayed_scaling
from .quantize import NVTE_FP8_COLLECTION_NAME
from .sharding import MeshResource from .sharding import MeshResource
from .sharding import MajorShardingType, ShardingResource, ShardingType from .sharding import MajorShardingType, ShardingResource, ShardingType
...@@ -101,11 +102,14 @@ ShardingResource = deprecate_wrapper( ...@@ -101,11 +102,14 @@ ShardingResource = deprecate_wrapper(
) )
__all__ = [ __all__ = [
"NVTE_FP8_COLLECTION_NAME",
"fp8_autocast", "fp8_autocast",
"update_collections",
"get_delayed_scaling",
"MeshResource", "MeshResource",
"MajorShardingType", "MajorShardingType",
"ShardingResource", "ShardingResource",
"ShardingType", "ShardingType",
"flax", "flax",
"praxis", "quantize",
] ]
...@@ -27,7 +27,14 @@ from transformer_engine.jax.sharding import global_shard_guard, MeshResource ...@@ -27,7 +27,14 @@ from transformer_engine.jax.sharding import global_shard_guard, MeshResource
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode
from .. import cpp_extensions as tex from .. import cpp_extensions as tex
__all__ = ["QuantizeConfig", "fp8_autocast", "is_fp8_available", "update_collections"] __all__ = [
"QuantizeConfig",
"fp8_autocast",
"is_fp8_available",
"update_collections",
"get_delayed_scaling",
"NVTE_FP8_COLLECTION_NAME",
]
_is_fp8_available = None _is_fp8_available = None
_reason_for_no_fp8 = "" _reason_for_no_fp8 = ""
...@@ -178,31 +185,6 @@ def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode: ...@@ -178,31 +185,6 @@ def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode:
raise ValueError("Invalid fp8_recipe!") raise ValueError("Invalid fp8_recipe!")
def update_collections(new: Collection, original: Collection) -> Collection:
"""Update collections with new values while preserving original structure.
Args:
new: New collection of values to add/update
original: Original collection to update
Returns:
Updated collection with new values merged with original
Raises:
AssertionError: If either collection is not a dict or FrozenDict
"""
assert isinstance(original, (dict, FrozenDict))
assert isinstance(new, (dict, FrozenDict))
frozen_original = FrozenDict(original) if not isinstance(original, FrozenDict) else original
for key in new:
if key in frozen_original:
frozen_original, _ = frozen_original.pop(key)
new_coll = FrozenDict({**new, **frozen_original})
if not isinstance(original, FrozenDict):
new_coll = new_coll.unfreeze()
return new_coll
class QuantizeConfig: class QuantizeConfig:
"""Configuration class for quantization settings. """Configuration class for quantization settings.
...@@ -227,7 +209,7 @@ class QuantizeConfig: ...@@ -227,7 +209,7 @@ class QuantizeConfig:
INITIALIZED = False INITIALIZED = False
MARGIN: float = 0.0 MARGIN: float = 0.0
COLLECTION_NAME: str = "quantize_meta" COLLECTION_NAME: str = "fp8_metas"
FP8_FORMAT: recipe.Format = recipe.Format.HYBRID FP8_FORMAT: recipe.Format = recipe.Format.HYBRID
FWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[0] FWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[0]
BWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[1] BWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[1]
...@@ -414,3 +396,56 @@ def fp8_autocast( ...@@ -414,3 +396,56 @@ def fp8_autocast(
yield yield
finally: finally:
Config.finalize() Config.finalize()
def get_delayed_scaling():
r"""
Obtain an instance of DelayedScaling which is set via fp8_autocast.
.. note::
We only store :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len`
, and :attr:`amax_compute_algo` via fp8_autocast. Other parameters in
recipe.DelayedScaling would be returned as the default values.
Returns
-------
delay_scaling : DelayedScaling
an instance of DelayedScaling which is set via fp8_autocast.
"""
amax_compute_algo = (
"max" if QuantizeConfig.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX else "most_recent"
)
return recipe.DelayedScaling(
margin=int(QuantizeConfig.MARGIN),
fp8_format=QuantizeConfig.FP8_FORMAT,
amax_history_len=QuantizeConfig.AMAX_HISTORY_LEN,
amax_compute_algo=amax_compute_algo,
)
def update_collections(new: Collection, original: Collection) -> Collection:
r"""Update collections with new values while preserving original structure.
Args:
new: New collection of values to add/update
original: Original collection to update
Returns:
Updated collection with new values merged with original
Raises:
AssertionError: If either collection is not a dict or FrozenDict
"""
assert isinstance(original, (dict, FrozenDict))
assert isinstance(new, (dict, FrozenDict))
frozen_original = FrozenDict(original) if not isinstance(original, FrozenDict) else original
for key in new:
if key in frozen_original:
frozen_original, _ = frozen_original.pop(key)
new_coll = FrozenDict({**new, **frozen_original})
if not isinstance(original, FrozenDict):
new_coll = new_coll.unfreeze()
return new_coll
NVTE_FP8_COLLECTION_NAME = QuantizeConfig.COLLECTION_NAME
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment