Unverified Commit 4a1efe89 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Fix the wrong shape of bias when fusing GEMMs. (#152)



* Allow update_collections and update_fp8_metas to return both Dict and FrozenDict.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix the wrong shape issue of bias when fused QKV or KV.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Reuse tuplized features for bias creating.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Replace get_args to be more readable.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
parent a41bf711
...@@ -199,24 +199,32 @@ class FP8Helper: ...@@ -199,24 +199,32 @@ class FP8Helper:
""" """
Update the collections Update the collections
""" """
if not isinstance(original, FrozenDict): assert isinstance(original, (dict, FrozenDict))
original = FrozenDict(original) assert isinstance(new, (dict, FrozenDict))
frozen_original = FrozenDict(original) if not isinstance(original, FrozenDict) else original
for key in new: for key in new:
if key in original: if key in frozen_original:
original, _ = original.pop(key) frozen_original, _ = frozen_original.pop(key)
return FrozenDict({**new, **original}) new_coll = FrozenDict({**new, **frozen_original})
if not isinstance(original, FrozenDict):
new_coll = new_coll.unfreeze()
return new_coll
@staticmethod @staticmethod
def update_fp8_metas(state: Collection) -> Collection: def update_fp8_metas(state: Collection) -> Collection:
""" """
Update the FP8 metas Update the FP8 metas
""" """
assert isinstance(state, (dict, FrozenDict))
if FP8Helper.FP8_COLLECTION_NAME in state: if FP8Helper.FP8_COLLECTION_NAME in state:
if not isinstance(state, FrozenDict): frozen_state = FrozenDict(state) if not isinstance(state, FrozenDict) else state
state = FrozenDict(state) others, fp8_metas = frozen_state.pop(FP8Helper.FP8_COLLECTION_NAME)
others, fp8_metas = state.pop(FP8Helper.FP8_COLLECTION_NAME)
fp8_metas = FP8Helper._update_fp8_metas_impl(fp8_metas) fp8_metas = FP8Helper._update_fp8_metas_impl(fp8_metas)
return FrozenDict({**others, FP8Helper.FP8_COLLECTION_NAME: fp8_metas}) new_state = FrozenDict({**others, FP8Helper.FP8_COLLECTION_NAME: fp8_metas})
if not isinstance(state, FrozenDict):
new_state = new_state.unfreeze()
return new_state
return state return state
@staticmethod @staticmethod
......
...@@ -425,7 +425,8 @@ class DenseGeneral(TransformerEngineBase): ...@@ -425,7 +425,8 @@ class DenseGeneral(TransformerEngineBase):
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes('bias', bias = nn_partitioning.param_with_axes('bias',
self.bias_init, (self.features,), self.bias_init,
features,
self.dtype, self.dtype,
axes=self.bias_axes) axes=self.bias_axes)
else: else:
...@@ -446,7 +447,8 @@ class DenseGeneral(TransformerEngineBase): ...@@ -446,7 +447,8 @@ class DenseGeneral(TransformerEngineBase):
y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ()))) y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))
if bias is not None: if bias is not None:
y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) bais_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
y += jnp.reshape(bias, bais_shape)
return y return y
...@@ -651,12 +653,14 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -651,12 +653,14 @@ class LayerNormDenseGeneral(TransformerEngineBase):
bias = None bias = None
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes('bias', bias = nn_partitioning.param_with_axes('bias',
self.bias_init, (self.features,), self.bias_init,
features,
self.dtype, self.dtype,
axes=self.bias_axes) axes=self.bias_axes)
if bias is not None: if bias is not None:
z += jnp.reshape(bias, (1,) * (z.ndim - 1) + (-1,)) bais_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
z += jnp.reshape(bias, bais_shape)
if self.depth_scaling is not None: if self.depth_scaling is not None:
z = z / self.depth_scaling z = z / self.depth_scaling
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment