Unverified Commit 403ade2f authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Adapt to Flax 0.7.1 (#353)



* Cast Flax collections to FrozenDict as WAR to adapt Flax 0.7.1.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding min version of flax to requirements.txt in examples
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix praxis tests and rename compare_frozen_dict to compare_dict
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* [Paddle] Refactor FP8 state (#350)

Refactor fp8 state
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Store FP8 checkpointing data in CPU (#351)
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Make test_layer be able to run on both Flax >=0.7.1 and <=0.7.0
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Update Flax version
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTian Zheng <tizheng@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 85928d08
datasets datasets
flax flax>=0.7.1
nltk nltk
optax optax
...@@ -6,6 +6,7 @@ import argparse ...@@ -6,6 +6,7 @@ import argparse
import unittest import unittest
from functools import partial from functools import partial
import flax
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import nltk import nltk
...@@ -13,7 +14,6 @@ import numpy as np ...@@ -13,7 +14,6 @@ import numpy as np
import optax import optax
from datasets import load_dataset from datasets import load_dataset
from flax import linen as nn from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.linen import partitioning as nn_partitioning from flax.linen import partitioning as nn_partitioning
from flax.training import train_state from flax.training import train_state
from jax.experimental import mesh_utils from jax.experimental import mesh_utils
...@@ -80,12 +80,12 @@ def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8): ...@@ -80,12 +80,12 @@ def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8):
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params}) var_collect = {**var_collect, PARAMS_KEY: state.params}
grad_fn = jax.value_and_grad(loss_fn, has_aux=True) grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(var_collect) (loss, logits), grads = grad_fn(var_collect)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
var_collect, grads = grads.pop(PARAMS_KEY) var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
state = state.apply_gradients(grads=grads) state = state.apply_gradients(grads=grads)
if use_fp8: if use_fp8:
var_collect = te.update_fp8_metas(var_collect) var_collect = te.update_fp8_metas(var_collect)
...@@ -126,7 +126,7 @@ def eval_step(state, inputs, masks, labels, var_collect): ...@@ -126,7 +126,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params}) var_collect = {**var_collect, PARAMS_KEY: state.params}
loss, logits = loss_fn(var_collect, disable_dropout=True) loss, logits = loss_fn(var_collect, disable_dropout=True)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return loss, accuracy return loss, accuracy
...@@ -223,8 +223,9 @@ def get_params_pspec(sharding_rules, abs_var_collect): ...@@ -223,8 +223,9 @@ def get_params_pspec(sharding_rules, abs_var_collect):
params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {}) params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes)) params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes))
params_axes_pspec = flax.core.unfreeze(params_axes_pspec)
params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY]) params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY])
params_pspec = FrozenDict({**params_pspec, **params_axes_pspec}) params_pspec = {**params_pspec, **params_axes_pspec}
return params_pspec return params_pspec
...@@ -232,9 +233,9 @@ def get_state_pspec(state, params_pspec): ...@@ -232,9 +233,9 @@ def get_state_pspec(state, params_pspec):
"""Refer params_pspec to create state partition spec""" """Refer params_pspec to create state partition spec"""
def replace_params(x): def replace_params(x):
return params_pspec if isinstance(x, FrozenDict) else None return params_pspec if isinstance(x, dict) else None
state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, FrozenDict)) state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, dict))
return state_pspec return state_pspec
...@@ -281,13 +282,13 @@ def train_and_evaluate(args): ...@@ -281,13 +282,13 @@ def train_and_evaluate(args):
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
in_shardings = (None, inputs_pspec, masks_pspec) in_shardings = (None, inputs_pspec, masks_pspec)
out_shardings = FrozenDict({key: params_pspec if key is PARAMS_KEY else None \ out_shardings = {key: params_pspec if key is PARAMS_KEY else None \
for key in abs_var_collect}) for key in abs_var_collect}
pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings) pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings)
var_collect = pjit_encoder_init(init_rngs, inputs, masks) var_collect = pjit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr) optimizer = optax.adamw(args.lr)
var_collect, params = var_collect.pop(PARAMS_KEY) var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
state = train_state.TrainState.create(apply_fn=encoder.apply, state = train_state.TrainState.create(apply_fn=encoder.apply,
params=params, params=params,
tx=optimizer) tx=optimizer)
......
...@@ -6,6 +6,7 @@ import argparse ...@@ -6,6 +6,7 @@ import argparse
import unittest import unittest
from functools import partial from functools import partial
import flax
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import nltk import nltk
...@@ -13,7 +14,6 @@ import numpy as np ...@@ -13,7 +14,6 @@ import numpy as np
import optax import optax
from datasets import load_dataset from datasets import load_dataset
from flax import linen as nn from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.linen import partitioning as nn_partitioning from flax.linen import partitioning as nn_partitioning
from flax.training import train_state from flax.training import train_state
from jax.experimental import mesh_utils from jax.experimental import mesh_utils
...@@ -71,12 +71,12 @@ def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8): ...@@ -71,12 +71,12 @@ def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8):
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params}) var_collect = {**var_collect, PARAMS_KEY: state.params}
grad_fn = jax.value_and_grad(loss_fn, has_aux=True) grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(var_collect) (loss, logits), grads = grad_fn(var_collect)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
var_collect, grads = grads.pop(PARAMS_KEY) var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
state = state.apply_gradients(grads=grads) state = state.apply_gradients(grads=grads)
if use_fp8: if use_fp8:
var_collect = te.update_fp8_metas(var_collect) var_collect = te.update_fp8_metas(var_collect)
...@@ -117,7 +117,7 @@ def eval_step(state, inputs, masks, labels, var_collect): ...@@ -117,7 +117,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params}) var_collect = {**var_collect, PARAMS_KEY: state.params}
loss, logits = loss_fn(var_collect, disable_dropout=True) loss, logits = loss_fn(var_collect, disable_dropout=True)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return loss, accuracy return loss, accuracy
...@@ -214,8 +214,9 @@ def get_params_pspec(sharding_rules, abs_var_collect): ...@@ -214,8 +214,9 @@ def get_params_pspec(sharding_rules, abs_var_collect):
params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {}) params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes)) params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes))
params_axes_pspec = flax.core.unfreeze(params_axes_pspec)
params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY]) params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY])
params_pspec = FrozenDict({**params_pspec, **params_axes_pspec}) params_pspec = {**params_pspec, **params_axes_pspec}
return params_pspec return params_pspec
...@@ -223,9 +224,9 @@ def get_state_pspec(state, params_pspec): ...@@ -223,9 +224,9 @@ def get_state_pspec(state, params_pspec):
"""Refer params_pspec to create state partition spec""" """Refer params_pspec to create state partition spec"""
def replace_params(x): def replace_params(x):
return params_pspec if isinstance(x, FrozenDict) else None return params_pspec if isinstance(x, dict) else None
state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, FrozenDict)) state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, dict))
return state_pspec return state_pspec
...@@ -263,13 +264,13 @@ def train_and_evaluate(args): ...@@ -263,13 +264,13 @@ def train_and_evaluate(args):
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
in_shardings = (None, inputs_pspec, masks_pspec) in_shardings = (None, inputs_pspec, masks_pspec)
out_shardings = FrozenDict({key: params_pspec if key is PARAMS_KEY else None \ out_shardings = {key: params_pspec if key is PARAMS_KEY else None \
for key in abs_var_collect}) for key in abs_var_collect}
pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings) pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings)
var_collect = pjit_encoder_init(init_rngs, inputs, masks) var_collect = pjit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr) optimizer = optax.adamw(args.lr)
var_collect, params = var_collect.pop(PARAMS_KEY) var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
state = train_state.TrainState.create(apply_fn=encoder.apply, state = train_state.TrainState.create(apply_fn=encoder.apply,
params=params, params=params,
tx=optimizer) tx=optimizer)
......
...@@ -8,6 +8,7 @@ import os ...@@ -8,6 +8,7 @@ import os
import unittest import unittest
from functools import partial from functools import partial
import flax
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import nltk import nltk
...@@ -15,7 +16,6 @@ import numpy as np ...@@ -15,7 +16,6 @@ import numpy as np
import optax import optax
from datasets import load_dataset from datasets import load_dataset
from flax import linen as nn from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.linen import partitioning as nn_partitioning from flax.linen import partitioning as nn_partitioning
from flax.training import train_state from flax.training import train_state
from jax.experimental import mesh_utils from jax.experimental import mesh_utils
...@@ -115,12 +115,12 @@ def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8): ...@@ -115,12 +115,12 @@ def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8):
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params}) var_collect = {**var_collect, PARAMS_KEY: state.params}
grad_fn = jax.value_and_grad(loss_fn, has_aux=True) grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(var_collect) (loss, logits), grads = grad_fn(var_collect)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
var_collect, grads = grads.pop(PARAMS_KEY) var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
state = state.apply_gradients(grads=grads) state = state.apply_gradients(grads=grads)
if use_fp8: if use_fp8:
var_collect = te.update_fp8_metas(var_collect) var_collect = te.update_fp8_metas(var_collect)
...@@ -182,7 +182,7 @@ def eval_step(state, inputs, masks, labels, var_collect): ...@@ -182,7 +182,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params}) var_collect = {**var_collect, PARAMS_KEY: state.params}
loss, logits = loss_fn(var_collect, disable_dropout=True) loss, logits = loss_fn(var_collect, disable_dropout=True)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return loss, accuracy return loss, accuracy
...@@ -297,8 +297,9 @@ def get_params_pspec(sharding_rules, abs_var_collect): ...@@ -297,8 +297,9 @@ def get_params_pspec(sharding_rules, abs_var_collect):
params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {}) params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes)) params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes))
params_axes_pspec = flax.core.unfreeze(params_axes_pspec)
params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY]) params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY])
params_pspec = FrozenDict({**params_pspec, **params_axes_pspec}) params_pspec = {**params_pspec, **params_axes_pspec}
return params_pspec return params_pspec
...@@ -306,9 +307,9 @@ def get_state_pspec(state, params_pspec): ...@@ -306,9 +307,9 @@ def get_state_pspec(state, params_pspec):
"""Refer params_pspec to create state partition spec""" """Refer params_pspec to create state partition spec"""
def replace_params(x): def replace_params(x):
return params_pspec if isinstance(x, FrozenDict) else None return params_pspec if isinstance(x, dict) else None
state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, FrozenDict)) state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, dict))
return state_pspec return state_pspec
...@@ -362,13 +363,13 @@ def train_and_evaluate(args): ...@@ -362,13 +363,13 @@ def train_and_evaluate(args):
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
in_shardings = (None, inputs_pspec, masks_pspec) in_shardings = (None, inputs_pspec, masks_pspec)
out_shardings = FrozenDict({key: params_pspec if key is PARAMS_KEY else None \ out_shardings = {key: params_pspec if key is PARAMS_KEY else None \
for key in abs_var_collect}) for key in abs_var_collect}
pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings) pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings)
var_collect = pjit_encoder_init(init_rngs, inputs, masks) var_collect = pjit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr) optimizer = optax.adamw(args.lr)
var_collect, params = var_collect.pop(PARAMS_KEY) var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
state = train_state.TrainState.create(apply_fn=encoder.apply, state = train_state.TrainState.create(apply_fn=encoder.apply,
params=params, params=params,
tx=optimizer) tx=optimizer)
......
...@@ -6,6 +6,7 @@ import argparse ...@@ -6,6 +6,7 @@ import argparse
import unittest import unittest
from functools import partial from functools import partial
import flax
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import nltk import nltk
...@@ -13,7 +14,6 @@ import numpy as np ...@@ -13,7 +14,6 @@ import numpy as np
import optax import optax
from datasets import load_dataset from datasets import load_dataset
from flax import linen as nn from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.training import train_state from flax.training import train_state
import transformer_engine.jax as te import transformer_engine.jax as te
...@@ -65,12 +65,12 @@ def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8): ...@@ -65,12 +65,12 @@ def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8):
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params}) var_collect = {**var_collect, PARAMS_KEY: state.params}
grad_fn = jax.value_and_grad(loss_fn, has_aux=True) grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(var_collect) (loss, logits), grads = grad_fn(var_collect)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
var_collect, grads = grads.pop(PARAMS_KEY) var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
state = state.apply_gradients(grads=grads) state = state.apply_gradients(grads=grads)
if use_fp8: if use_fp8:
var_collect = te.update_fp8_metas(var_collect) var_collect = te.update_fp8_metas(var_collect)
...@@ -112,7 +112,7 @@ def eval_step(state, inputs, masks, labels, var_collect): ...@@ -112,7 +112,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params}) var_collect = {**var_collect, PARAMS_KEY: state.params}
loss, logits = loss_fn(var_collect, disable_dropout=True) loss, logits = loss_fn(var_collect, disable_dropout=True)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return loss, accuracy return loss, accuracy
......
datasets datasets
flax flax>=0.7.1
optax optax
Pillow Pillow
...@@ -12,7 +12,6 @@ import numpy as np ...@@ -12,7 +12,6 @@ import numpy as np
import optax import optax
from datasets import load_dataset from datasets import load_dataset
from flax import linen as nn from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.training import train_state from flax.training import train_state
import transformer_engine.jax as te import transformer_engine.jax as te
...@@ -62,7 +61,7 @@ def apply_model(state, images, labels, var_collect, rngs=None): ...@@ -62,7 +61,7 @@ def apply_model(state, images, labels, var_collect, rngs=None):
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params}) var_collect = {**var_collect, PARAMS_KEY: state.params}
if rngs is not None: if rngs is not None:
grad_fn = jax.value_and_grad(loss_fn, has_aux=True) grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
......
...@@ -295,7 +295,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -295,7 +295,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
if "jax" in frameworks(): if "jax" in frameworks():
if not found_pybind11(): if not found_pybind11():
add_unique(setup_reqs, "pybind11") add_unique(setup_reqs, "pybind11")
add_unique(install_reqs, ["jax", "flax"]) add_unique(install_reqs, ["jax", "flax>=0.7.1"])
add_unique(test_reqs, ["numpy", "praxis"]) add_unique(test_reqs, ["numpy", "praxis"])
if "tensorflow" in frameworks(): if "tensorflow" in frameworks():
if not found_pybind11(): if not found_pybind11():
......
...@@ -34,20 +34,24 @@ def generate_test_rngs(): ...@@ -34,20 +34,24 @@ def generate_test_rngs():
def generate_layer(layer_cls, init_rng, diff_inputs, no_diff_inputs): def generate_layer(layer_cls, init_rng, diff_inputs, no_diff_inputs):
layer = layer_cls() layer = layer_cls()
variables = layer.init(init_rng, *diff_inputs, *no_diff_inputs) variables = layer.init(init_rng, *diff_inputs, *no_diff_inputs)
others, params = variables.pop('params') others, params = flax.core.pop(variables, 'params')
del variables del variables
return layer, params, others return layer, params, others
def compare_frozen_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08): def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
# To be compatible with both Flax>=0.7.1 or <0.7.1
# since Flax 0.7.1 removed FrozenDict.
ref_fd = flax.core.unfreeze(ref_fd)
test_fd = flax.core.unfreeze(test_fd)
for key in ref_fd: for key in ref_fd:
assert key in test_fd, \ assert key in test_fd, \
f"{key} not found in test FrozenDict {test_fd}" f"{key} not found in test dict {test_fd}"
assert isinstance(test_fd[key], type(ref_fd[key])), \ assert isinstance(test_fd[key], type(ref_fd[key])), \
f"The data type is not match between ref and test " \ f"The data type is not match between ref and test " \
f"FrozenDict on {key=}" f"dict on {key=}"
if isinstance(ref_fd[key], flax.core.frozen_dict.FrozenDict): if isinstance(ref_fd[key], dict):
compare_frozen_dict(ref_fd[key], test_fd[key], rtol, atol) compare_dict(ref_fd[key], test_fd[key], rtol, atol)
else: else:
assert_allclose(ref_fd[key], assert_allclose(ref_fd[key],
test_fd[key], test_fd[key],
...@@ -126,7 +130,7 @@ class TestEncoderLayer: ...@@ -126,7 +130,7 @@ class TestEncoderLayer:
def sync_params(ref, target, attrs): def sync_params(ref, target, attrs):
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True) fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True)
unfreeze_target = target.unfreeze() unfreeze_target = flax.core.unfreeze(target)
if fuse_qkv: if fuse_qkv:
unfreeze_target['attention']['qkv']['kernel'] = \ unfreeze_target['attention']['qkv']['kernel'] = \
jnp.reshape(ref['attention']['qkv']['kernel'], jnp.reshape(ref['attention']['qkv']['kernel'],
...@@ -142,7 +146,7 @@ class TestEncoderLayer: ...@@ -142,7 +146,7 @@ class TestEncoderLayer:
jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape) jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape)
unfreeze_target['mlp']['wo_kernel'] = \ unfreeze_target['mlp']['wo_kernel'] = \
ref['mlp']['wo']['kernel'] ref['mlp']['wo']['kernel']
return ref, flax.core.frozen_dict.FrozenDict(unfreeze_target) return ref, unfreeze_target
def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS] transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS]
...@@ -231,7 +235,7 @@ class TestEncoderLayer: ...@@ -231,7 +235,7 @@ class TestEncoderLayer:
_, tmp_grad = jax.value_and_grad(loss_fn, argnums=(3,), _, tmp_grad = jax.value_and_grad(loss_fn, argnums=(3,),
has_aux=False)(inputs, test_masks, test_params, has_aux=False)(inputs, test_masks, test_params,
test_others, test_layer, apply_rng) test_others, test_layer, apply_rng)
_, fp8_meta_grad = tmp_grad[0].pop(FP8Helper.FP8_COLLECTION_NAME) _, fp8_meta_grad = flax.core.pop(tmp_grad[0], FP8Helper.FP8_COLLECTION_NAME)
test_others = FP8Helper.update_collections( test_others = FP8Helper.update_collections(
{FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others) {FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others)
test_others = FP8Helper.update_fp8_metas(test_others) test_others = FP8Helper.update_fp8_metas(test_others)
...@@ -251,7 +255,7 @@ class TestEncoderLayer: ...@@ -251,7 +255,7 @@ class TestEncoderLayer:
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True) fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True)
attn_name = 'attention' attn_name = 'attention'
unfreeze_test_wgrad = test_wgrad.unfreeze() unfreeze_test_wgrad = flax.core.unfreeze(test_wgrad)
if "output_layernorm" not in attrs: if "output_layernorm" not in attrs:
unfreeze_test_wgrad['pre_attention_layer_norm'] = {} unfreeze_test_wgrad['pre_attention_layer_norm'] = {}
pre_attn_layer_key = 'qkv' if fuse_qkv else 'query' pre_attn_layer_key = 'qkv' if fuse_qkv else 'query'
...@@ -283,9 +287,9 @@ class TestEncoderLayer: ...@@ -283,9 +287,9 @@ class TestEncoderLayer:
unfreeze_test_wgrad['mlp']['wo']['kernel'] = \ unfreeze_test_wgrad['mlp']['wo']['kernel'] = \
unfreeze_test_wgrad['mlp']['wo_kernel'] unfreeze_test_wgrad['mlp']['wo_kernel']
del unfreeze_test_wgrad['mlp']['wo_kernel'] del unfreeze_test_wgrad['mlp']['wo_kernel']
return flax.core.frozen_dict.FrozenDict(unfreeze_test_wgrad) return unfreeze_test_wgrad
compare_frozen_dict(ref_grads[1], compare_dict(ref_grads[1],
reorganize_test_wgrad(test_grads[1], attrs), reorganize_test_wgrad(test_grads[1], attrs),
rtol=rtol, rtol=rtol,
atol=atol) # wgrad atol=atol) # wgrad
...@@ -333,7 +337,7 @@ class TestDecoderLayer: ...@@ -333,7 +337,7 @@ class TestDecoderLayer:
def sync_params(ref, target, attrs): def sync_params(ref, target, attrs):
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True) fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True)
unfreeze_target = target.unfreeze() unfreeze_target = flax.core.unfreeze(target)
if fuse_qkv: if fuse_qkv:
unfreeze_target['self_attention']['qkv']['kernel'] = \ unfreeze_target['self_attention']['qkv']['kernel'] = \
jnp.reshape(ref['self_attention']['qkv']['kernel'], jnp.reshape(ref['self_attention']['qkv']['kernel'],
...@@ -354,7 +358,7 @@ class TestDecoderLayer: ...@@ -354,7 +358,7 @@ class TestDecoderLayer:
jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape) jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape)
unfreeze_target['mlp']['wo_kernel'] = \ unfreeze_target['mlp']['wo_kernel'] = \
ref['mlp']['wo']['kernel'] ref['mlp']['wo']['kernel']
return ref, flax.core.frozen_dict.FrozenDict(unfreeze_target) return ref, unfreeze_target
def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS] transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS]
...@@ -444,7 +448,7 @@ class TestDecoderLayer: ...@@ -444,7 +448,7 @@ class TestDecoderLayer:
_, tmp_grad = jax.value_and_grad(loss_fn, argnums=(3,), _, tmp_grad = jax.value_and_grad(loss_fn, argnums=(3,),
has_aux=False)(inputs, test_masks, test_params, has_aux=False)(inputs, test_masks, test_params,
test_others, test_layer, apply_rng) test_others, test_layer, apply_rng)
_, fp8_meta_grad = tmp_grad[0].pop(FP8Helper.FP8_COLLECTION_NAME) _, fp8_meta_grad = flax.core.pop(tmp_grad[0], FP8Helper.FP8_COLLECTION_NAME)
test_others = FP8Helper.update_collections( test_others = FP8Helper.update_collections(
{FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others) {FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others)
test_others = FP8Helper.update_fp8_metas(test_others) test_others = FP8Helper.update_fp8_metas(test_others)
...@@ -464,7 +468,7 @@ class TestDecoderLayer: ...@@ -464,7 +468,7 @@ class TestDecoderLayer:
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True) fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True)
attn_name = 'self_attention' attn_name = 'self_attention'
unfreeze_test_wgrad = test_wgrad.unfreeze() unfreeze_test_wgrad = flax.core.unfreeze(test_wgrad)
if "output_layernorm" not in attrs: if "output_layernorm" not in attrs:
unfreeze_test_wgrad['pre_self_attention_layer_norm'] = {} unfreeze_test_wgrad['pre_self_attention_layer_norm'] = {}
pre_attn_layer_key = 'qkv' if fuse_qkv else 'query' pre_attn_layer_key = 'qkv' if fuse_qkv else 'query'
...@@ -510,9 +514,9 @@ class TestDecoderLayer: ...@@ -510,9 +514,9 @@ class TestDecoderLayer:
unfreeze_test_wgrad['mlp']['wo']['kernel'] = \ unfreeze_test_wgrad['mlp']['wo']['kernel'] = \
unfreeze_test_wgrad['mlp']['wo_kernel'] unfreeze_test_wgrad['mlp']['wo_kernel']
del unfreeze_test_wgrad['mlp']['wo_kernel'] del unfreeze_test_wgrad['mlp']['wo_kernel']
return flax.core.frozen_dict.FrozenDict(unfreeze_test_wgrad) return unfreeze_test_wgrad
compare_frozen_dict(ref_grads[1], compare_dict(ref_grads[1],
reorganize_test_wgrad(test_grads[1], attrs), reorganize_test_wgrad(test_grads[1], attrs),
rtol=rtol, rtol=rtol,
atol=atol) # wgrad atol=atol) # wgrad
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from functools import partial from functools import partial
from typing import Dict from typing import Dict
import flax
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from praxis import pax_fiddle from praxis import pax_fiddle
...@@ -40,7 +41,7 @@ FP8_FORMATS = [Format.E4M3, Format.HYBRID] ...@@ -40,7 +41,7 @@ FP8_FORMATS = [Format.E4M3, Format.HYBRID]
def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08): def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
for key in ref_fd: for key in ref_fd:
assert key in test_fd, \ assert key in test_fd, \
f"{key} not found in test FrozenDict {test_fd}" f"{key} not found in test dict {test_fd}"
assert isinstance(test_fd[key], type(ref_fd[key])), \ assert isinstance(test_fd[key], type(ref_fd[key])), \
f"The data type is not match between ref and test " \ f"The data type is not match between ref and test " \
f" Dict on {key=}" f" Dict on {key=}"
...@@ -89,7 +90,7 @@ class TestLayer: ...@@ -89,7 +90,7 @@ class TestLayer:
lyr_name = self.get_layer_name() lyr_name = self.get_layer_name()
synced_praxis_variables['params'][lyr_name]['cld'] = \ synced_praxis_variables['params'][lyr_name]['cld'] = \
flax_variables['params'].unfreeze() flax.core.unfreeze(flax_variables['params'])
return synced_praxis_variables, flax_variables return synced_praxis_variables, flax_variables
...@@ -105,7 +106,7 @@ class TestLayer: ...@@ -105,7 +106,7 @@ class TestLayer:
synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME] = \ synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME] = \
synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME][lyr_name]['cld'] synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME][lyr_name]['cld']
return synced_praxis_grads, flax_wgrads.unfreeze() return synced_praxis_grads, flax.core.unfreeze(flax_wgrads)
def forward_backward_runner(self, def forward_backward_runner(self,
data_shape, data_shape,
...@@ -127,9 +128,10 @@ class TestLayer: ...@@ -127,9 +128,10 @@ class TestLayer:
flax_layer = flax_cls() flax_layer = flax_cls()
flax_variables = flax_layer.init(init_key, *test_inputs) flax_variables = flax_layer.init(init_key, *test_inputs)
if "params_axes" in flax_variables: if "params_axes" in flax_variables:
flax_variables, _ = flax_variables.pop("params_axes") flax_variables, _ = flax.core.pop(flax_variables, "params_axes")
if FP8Helper.is_fp8_enabled(): if FP8Helper.is_fp8_enabled():
flax_variables, _ = flax_variables.pop(FP8Helper.FP8_COLLECTION_NAME + "_axes") flax_variables, _ = flax.core.pop(flax_variables,
FP8Helper.FP8_COLLECTION_NAME + "_axes")
praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables) praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables)
...@@ -143,7 +145,7 @@ class TestLayer: ...@@ -143,7 +145,7 @@ class TestLayer:
if FP8Helper.is_fp8_enabled(): if FP8Helper.is_fp8_enabled():
praxis_wgrads.pop('params') praxis_wgrads.pop('params')
praxis_variables = update_collections(praxis_wgrads, praxis_variables) praxis_variables = update_collections(praxis_wgrads, praxis_variables)
flax_wgrads, _ = flax_wgrads.pop('params') flax_wgrads, _ = flax.core.pop(flax_wgrads, 'params')
flax_variables = update_collections(flax_wgrads, flax_variables) flax_variables = update_collections(flax_wgrads, flax_variables)
praxis_loss, praxis_wgrads, praxis_dgrad = \ praxis_loss, praxis_wgrads, praxis_dgrad = \
...@@ -642,9 +644,10 @@ class TestRelativePositionBias(TestLayer): ...@@ -642,9 +644,10 @@ class TestRelativePositionBias(TestLayer):
flax_layer = flax_cls() flax_layer = flax_cls()
flax_variables = flax_layer.init(init_key, *test_input) flax_variables = flax_layer.init(init_key, *test_input)
if "params_axes" in flax_variables: if "params_axes" in flax_variables:
flax_variables, _ = flax_variables.pop("params_axes") flax_variables, _ = flax.core.pop(flax_variables, "params_axes")
if FP8Helper.is_fp8_enabled(): if FP8Helper.is_fp8_enabled():
flax_variables, _ = flax_variables.pop(FP8Helper.FP8_COLLECTION_NAME + "_axes") flax_variables, _ = flax.core.pop(flax_variables,
FP8Helper.FP8_COLLECTION_NAME + "_axes")
praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables) praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment