Unverified Commit 403ade2f authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Adapt to Flax 0.7.1 (#353)



* Cast Flax collections to FrozenDict as WAR to adapt Flax 0.7.1.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding min version of flax to requirements.txt in examples
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix praxis tests and rename compare_frozen_dict to compare_dict
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* [Paddle] Refactor FP8 state (#350)

Refactor fp8 state
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Store FP8 checkpointing data in CPU (#351)
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Make test_layer be able to run on both Flax >=0.7.1 and <=0.7.0
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Update Flax version
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTian Zheng <tizheng@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 85928d08
......@@ -6,6 +6,7 @@ import argparse
import unittest
from functools import partial
import flax
import jax
import jax.numpy as jnp
import nltk
......@@ -13,7 +14,6 @@ import numpy as np
import optax
from datasets import load_dataset
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.linen import partitioning as nn_partitioning
from flax.training import train_state
from jax.experimental import mesh_utils
......@@ -80,12 +80,12 @@ def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8):
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params})
var_collect = {**var_collect, PARAMS_KEY: state.params}
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(var_collect)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
var_collect, grads = grads.pop(PARAMS_KEY)
var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
state = state.apply_gradients(grads=grads)
if use_fp8:
var_collect = te.update_fp8_metas(var_collect)
......@@ -126,7 +126,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params})
var_collect = {**var_collect, PARAMS_KEY: state.params}
loss, logits = loss_fn(var_collect, disable_dropout=True)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return loss, accuracy
......@@ -223,8 +223,9 @@ def get_params_pspec(sharding_rules, abs_var_collect):
params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes))
params_axes_pspec = flax.core.unfreeze(params_axes_pspec)
params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY])
params_pspec = FrozenDict({**params_pspec, **params_axes_pspec})
params_pspec = {**params_pspec, **params_axes_pspec}
return params_pspec
......@@ -232,9 +233,9 @@ def get_state_pspec(state, params_pspec):
"""Refer params_pspec to create state partition spec"""
def replace_params(x):
return params_pspec if isinstance(x, FrozenDict) else None
return params_pspec if isinstance(x, dict) else None
state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, FrozenDict))
state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, dict))
return state_pspec
......@@ -281,13 +282,13 @@ def train_and_evaluate(args):
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
in_shardings = (None, inputs_pspec, masks_pspec)
out_shardings = FrozenDict({key: params_pspec if key is PARAMS_KEY else None \
for key in abs_var_collect})
out_shardings = {key: params_pspec if key is PARAMS_KEY else None \
for key in abs_var_collect}
pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings)
var_collect = pjit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr)
var_collect, params = var_collect.pop(PARAMS_KEY)
var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
state = train_state.TrainState.create(apply_fn=encoder.apply,
params=params,
tx=optimizer)
......
......@@ -6,6 +6,7 @@ import argparse
import unittest
from functools import partial
import flax
import jax
import jax.numpy as jnp
import nltk
......@@ -13,7 +14,6 @@ import numpy as np
import optax
from datasets import load_dataset
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.linen import partitioning as nn_partitioning
from flax.training import train_state
from jax.experimental import mesh_utils
......@@ -71,12 +71,12 @@ def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8):
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params})
var_collect = {**var_collect, PARAMS_KEY: state.params}
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(var_collect)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
var_collect, grads = grads.pop(PARAMS_KEY)
var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
state = state.apply_gradients(grads=grads)
if use_fp8:
var_collect = te.update_fp8_metas(var_collect)
......@@ -117,7 +117,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params})
var_collect = {**var_collect, PARAMS_KEY: state.params}
loss, logits = loss_fn(var_collect, disable_dropout=True)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return loss, accuracy
......@@ -214,8 +214,9 @@ def get_params_pspec(sharding_rules, abs_var_collect):
params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes))
params_axes_pspec = flax.core.unfreeze(params_axes_pspec)
params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY])
params_pspec = FrozenDict({**params_pspec, **params_axes_pspec})
params_pspec = {**params_pspec, **params_axes_pspec}
return params_pspec
......@@ -223,9 +224,9 @@ def get_state_pspec(state, params_pspec):
"""Refer params_pspec to create state partition spec"""
def replace_params(x):
return params_pspec if isinstance(x, FrozenDict) else None
return params_pspec if isinstance(x, dict) else None
state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, FrozenDict))
state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, dict))
return state_pspec
......@@ -263,13 +264,13 @@ def train_and_evaluate(args):
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
in_shardings = (None, inputs_pspec, masks_pspec)
out_shardings = FrozenDict({key: params_pspec if key is PARAMS_KEY else None \
for key in abs_var_collect})
out_shardings = {key: params_pspec if key is PARAMS_KEY else None \
for key in abs_var_collect}
pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings)
var_collect = pjit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr)
var_collect, params = var_collect.pop(PARAMS_KEY)
var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
state = train_state.TrainState.create(apply_fn=encoder.apply,
params=params,
tx=optimizer)
......
......@@ -8,6 +8,7 @@ import os
import unittest
from functools import partial
import flax
import jax
import jax.numpy as jnp
import nltk
......@@ -15,7 +16,6 @@ import numpy as np
import optax
from datasets import load_dataset
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.linen import partitioning as nn_partitioning
from flax.training import train_state
from jax.experimental import mesh_utils
......@@ -115,12 +115,12 @@ def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8):
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params})
var_collect = {**var_collect, PARAMS_KEY: state.params}
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(var_collect)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
var_collect, grads = grads.pop(PARAMS_KEY)
var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
state = state.apply_gradients(grads=grads)
if use_fp8:
var_collect = te.update_fp8_metas(var_collect)
......@@ -182,7 +182,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params})
var_collect = {**var_collect, PARAMS_KEY: state.params}
loss, logits = loss_fn(var_collect, disable_dropout=True)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return loss, accuracy
......@@ -297,8 +297,9 @@ def get_params_pspec(sharding_rules, abs_var_collect):
params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes))
params_axes_pspec = flax.core.unfreeze(params_axes_pspec)
params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY])
params_pspec = FrozenDict({**params_pspec, **params_axes_pspec})
params_pspec = {**params_pspec, **params_axes_pspec}
return params_pspec
......@@ -306,9 +307,9 @@ def get_state_pspec(state, params_pspec):
"""Refer params_pspec to create state partition spec"""
def replace_params(x):
return params_pspec if isinstance(x, FrozenDict) else None
return params_pspec if isinstance(x, dict) else None
state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, FrozenDict))
state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, dict))
return state_pspec
......@@ -362,13 +363,13 @@ def train_and_evaluate(args):
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
in_shardings = (None, inputs_pspec, masks_pspec)
out_shardings = FrozenDict({key: params_pspec if key is PARAMS_KEY else None \
for key in abs_var_collect})
out_shardings = {key: params_pspec if key is PARAMS_KEY else None \
for key in abs_var_collect}
pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings)
var_collect = pjit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr)
var_collect, params = var_collect.pop(PARAMS_KEY)
var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
state = train_state.TrainState.create(apply_fn=encoder.apply,
params=params,
tx=optimizer)
......
......@@ -6,6 +6,7 @@ import argparse
import unittest
from functools import partial
import flax
import jax
import jax.numpy as jnp
import nltk
......@@ -13,7 +14,6 @@ import numpy as np
import optax
from datasets import load_dataset
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.training import train_state
import transformer_engine.jax as te
......@@ -65,12 +65,12 @@ def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8):
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params})
var_collect = {**var_collect, PARAMS_KEY: state.params}
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(var_collect)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
var_collect, grads = grads.pop(PARAMS_KEY)
var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
state = state.apply_gradients(grads=grads)
if use_fp8:
var_collect = te.update_fp8_metas(var_collect)
......@@ -112,7 +112,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params})
var_collect = {**var_collect, PARAMS_KEY: state.params}
loss, logits = loss_fn(var_collect, disable_dropout=True)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return loss, accuracy
......
datasets
flax
flax>=0.7.1
optax
Pillow
......@@ -12,7 +12,6 @@ import numpy as np
import optax
from datasets import load_dataset
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.training import train_state
import transformer_engine.jax as te
......@@ -62,7 +61,7 @@ def apply_model(state, images, labels, var_collect, rngs=None):
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params})
var_collect = {**var_collect, PARAMS_KEY: state.params}
if rngs is not None:
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
......
......@@ -295,7 +295,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
if "jax" in frameworks():
if not found_pybind11():
add_unique(setup_reqs, "pybind11")
add_unique(install_reqs, ["jax", "flax"])
add_unique(install_reqs, ["jax", "flax>=0.7.1"])
add_unique(test_reqs, ["numpy", "praxis"])
if "tensorflow" in frameworks():
if not found_pybind11():
......
......@@ -34,20 +34,24 @@ def generate_test_rngs():
def generate_layer(layer_cls, init_rng, diff_inputs, no_diff_inputs):
layer = layer_cls()
variables = layer.init(init_rng, *diff_inputs, *no_diff_inputs)
others, params = variables.pop('params')
others, params = flax.core.pop(variables, 'params')
del variables
return layer, params, others
def compare_frozen_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
# To be compatible with both Flax>=0.7.1 or <0.7.1
# since Flax 0.7.1 removed FrozenDict.
ref_fd = flax.core.unfreeze(ref_fd)
test_fd = flax.core.unfreeze(test_fd)
for key in ref_fd:
assert key in test_fd, \
f"{key} not found in test FrozenDict {test_fd}"
f"{key} not found in test dict {test_fd}"
assert isinstance(test_fd[key], type(ref_fd[key])), \
f"The data type is not match between ref and test " \
f"FrozenDict on {key=}"
if isinstance(ref_fd[key], flax.core.frozen_dict.FrozenDict):
compare_frozen_dict(ref_fd[key], test_fd[key], rtol, atol)
f"dict on {key=}"
if isinstance(ref_fd[key], dict):
compare_dict(ref_fd[key], test_fd[key], rtol, atol)
else:
assert_allclose(ref_fd[key],
test_fd[key],
......@@ -126,7 +130,7 @@ class TestEncoderLayer:
def sync_params(ref, target, attrs):
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True)
unfreeze_target = target.unfreeze()
unfreeze_target = flax.core.unfreeze(target)
if fuse_qkv:
unfreeze_target['attention']['qkv']['kernel'] = \
jnp.reshape(ref['attention']['qkv']['kernel'],
......@@ -142,7 +146,7 @@ class TestEncoderLayer:
jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape)
unfreeze_target['mlp']['wo_kernel'] = \
ref['mlp']['wo']['kernel']
return ref, flax.core.frozen_dict.FrozenDict(unfreeze_target)
return ref, unfreeze_target
def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS]
......@@ -231,7 +235,7 @@ class TestEncoderLayer:
_, tmp_grad = jax.value_and_grad(loss_fn, argnums=(3,),
has_aux=False)(inputs, test_masks, test_params,
test_others, test_layer, apply_rng)
_, fp8_meta_grad = tmp_grad[0].pop(FP8Helper.FP8_COLLECTION_NAME)
_, fp8_meta_grad = flax.core.pop(tmp_grad[0], FP8Helper.FP8_COLLECTION_NAME)
test_others = FP8Helper.update_collections(
{FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others)
test_others = FP8Helper.update_fp8_metas(test_others)
......@@ -251,7 +255,7 @@ class TestEncoderLayer:
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True)
attn_name = 'attention'
unfreeze_test_wgrad = test_wgrad.unfreeze()
unfreeze_test_wgrad = flax.core.unfreeze(test_wgrad)
if "output_layernorm" not in attrs:
unfreeze_test_wgrad['pre_attention_layer_norm'] = {}
pre_attn_layer_key = 'qkv' if fuse_qkv else 'query'
......@@ -283,12 +287,12 @@ class TestEncoderLayer:
unfreeze_test_wgrad['mlp']['wo']['kernel'] = \
unfreeze_test_wgrad['mlp']['wo_kernel']
del unfreeze_test_wgrad['mlp']['wo_kernel']
return flax.core.frozen_dict.FrozenDict(unfreeze_test_wgrad)
return unfreeze_test_wgrad
compare_frozen_dict(ref_grads[1],
reorganize_test_wgrad(test_grads[1], attrs),
rtol=rtol,
atol=atol) # wgrad
compare_dict(ref_grads[1],
reorganize_test_wgrad(test_grads[1], attrs),
rtol=rtol,
atol=atol) # wgrad
del data_rng, init_rng, apply_rng
......@@ -333,7 +337,7 @@ class TestDecoderLayer:
def sync_params(ref, target, attrs):
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True)
unfreeze_target = target.unfreeze()
unfreeze_target = flax.core.unfreeze(target)
if fuse_qkv:
unfreeze_target['self_attention']['qkv']['kernel'] = \
jnp.reshape(ref['self_attention']['qkv']['kernel'],
......@@ -354,7 +358,7 @@ class TestDecoderLayer:
jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape)
unfreeze_target['mlp']['wo_kernel'] = \
ref['mlp']['wo']['kernel']
return ref, flax.core.frozen_dict.FrozenDict(unfreeze_target)
return ref, unfreeze_target
def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS]
......@@ -444,7 +448,7 @@ class TestDecoderLayer:
_, tmp_grad = jax.value_and_grad(loss_fn, argnums=(3,),
has_aux=False)(inputs, test_masks, test_params,
test_others, test_layer, apply_rng)
_, fp8_meta_grad = tmp_grad[0].pop(FP8Helper.FP8_COLLECTION_NAME)
_, fp8_meta_grad = flax.core.pop(tmp_grad[0], FP8Helper.FP8_COLLECTION_NAME)
test_others = FP8Helper.update_collections(
{FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others)
test_others = FP8Helper.update_fp8_metas(test_others)
......@@ -464,7 +468,7 @@ class TestDecoderLayer:
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True)
attn_name = 'self_attention'
unfreeze_test_wgrad = test_wgrad.unfreeze()
unfreeze_test_wgrad = flax.core.unfreeze(test_wgrad)
if "output_layernorm" not in attrs:
unfreeze_test_wgrad['pre_self_attention_layer_norm'] = {}
pre_attn_layer_key = 'qkv' if fuse_qkv else 'query'
......@@ -510,12 +514,12 @@ class TestDecoderLayer:
unfreeze_test_wgrad['mlp']['wo']['kernel'] = \
unfreeze_test_wgrad['mlp']['wo_kernel']
del unfreeze_test_wgrad['mlp']['wo_kernel']
return flax.core.frozen_dict.FrozenDict(unfreeze_test_wgrad)
return unfreeze_test_wgrad
compare_frozen_dict(ref_grads[1],
reorganize_test_wgrad(test_grads[1], attrs),
rtol=rtol,
atol=atol) # wgrad
compare_dict(ref_grads[1],
reorganize_test_wgrad(test_grads[1], attrs),
rtol=rtol,
atol=atol) # wgrad
del data_rng, init_rng, apply_rng
......
......@@ -5,6 +5,7 @@
from functools import partial
from typing import Dict
import flax
import jax
import jax.numpy as jnp
from praxis import pax_fiddle
......@@ -40,7 +41,7 @@ FP8_FORMATS = [Format.E4M3, Format.HYBRID]
def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
for key in ref_fd:
assert key in test_fd, \
f"{key} not found in test FrozenDict {test_fd}"
f"{key} not found in test dict {test_fd}"
assert isinstance(test_fd[key], type(ref_fd[key])), \
f"The data type is not match between ref and test " \
f" Dict on {key=}"
......@@ -89,7 +90,7 @@ class TestLayer:
lyr_name = self.get_layer_name()
synced_praxis_variables['params'][lyr_name]['cld'] = \
flax_variables['params'].unfreeze()
flax.core.unfreeze(flax_variables['params'])
return synced_praxis_variables, flax_variables
......@@ -105,7 +106,7 @@ class TestLayer:
synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME] = \
synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME][lyr_name]['cld']
return synced_praxis_grads, flax_wgrads.unfreeze()
return synced_praxis_grads, flax.core.unfreeze(flax_wgrads)
def forward_backward_runner(self,
data_shape,
......@@ -127,9 +128,10 @@ class TestLayer:
flax_layer = flax_cls()
flax_variables = flax_layer.init(init_key, *test_inputs)
if "params_axes" in flax_variables:
flax_variables, _ = flax_variables.pop("params_axes")
flax_variables, _ = flax.core.pop(flax_variables, "params_axes")
if FP8Helper.is_fp8_enabled():
flax_variables, _ = flax_variables.pop(FP8Helper.FP8_COLLECTION_NAME + "_axes")
flax_variables, _ = flax.core.pop(flax_variables,
FP8Helper.FP8_COLLECTION_NAME + "_axes")
praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables)
......@@ -143,7 +145,7 @@ class TestLayer:
if FP8Helper.is_fp8_enabled():
praxis_wgrads.pop('params')
praxis_variables = update_collections(praxis_wgrads, praxis_variables)
flax_wgrads, _ = flax_wgrads.pop('params')
flax_wgrads, _ = flax.core.pop(flax_wgrads, 'params')
flax_variables = update_collections(flax_wgrads, flax_variables)
praxis_loss, praxis_wgrads, praxis_dgrad = \
......@@ -642,9 +644,10 @@ class TestRelativePositionBias(TestLayer):
flax_layer = flax_cls()
flax_variables = flax_layer.init(init_key, *test_input)
if "params_axes" in flax_variables:
flax_variables, _ = flax_variables.pop("params_axes")
flax_variables, _ = flax.core.pop(flax_variables, "params_axes")
if FP8Helper.is_fp8_enabled():
flax_variables, _ = flax_variables.pop(FP8Helper.FP8_COLLECTION_NAME + "_axes")
flax_variables, _ = flax.core.pop(flax_variables,
FP8Helper.FP8_COLLECTION_NAME + "_axes")
praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment