Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
a7bc7cf7
Unverified
Commit
a7bc7cf7
authored
Jul 07, 2023
by
Ming-Xu Huang
Committed by
GitHub
Jul 06, 2023
Browse files
[JAX] Support arbitrary dimensinos of fp8 meta. (#309)
Signed-off-by:
Ming Huang
<
mingh@nvidia.com
>
parent
a7a1a070
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
16 deletions
+17
-16
tests/jax/test_helper.py
tests/jax/test_helper.py
+11
-10
transformer_engine/jax/fp8.py
transformer_engine/jax/fp8.py
+6
-6
No files found.
tests/jax/test_helper.py
View file @
a7bc7cf7
...
@@ -64,7 +64,7 @@ class TestFP8Helper(unittest.TestCase):
...
@@ -64,7 +64,7 @@ class TestFP8Helper(unittest.TestCase):
def
select_amax
(
amaxes
):
def
select_amax
(
amaxes
):
if
FP8Helper
.
AMAX_COMPUTE_ALGO
==
AmaxComputeAlgo
.
MAX
:
if
FP8Helper
.
AMAX_COMPUTE_ALGO
==
AmaxComputeAlgo
.
MAX
:
return
jnp
.
max
(
amaxes
,
axis
=
1
,
keepdims
=
True
)
return
jnp
.
max
(
amaxes
,
axis
=
-
1
,
keepdims
=
True
)
return
amaxes
[:,
0
:
1
]
return
amaxes
[:,
0
:
1
]
def
get_fp8_scale
(
fp8_max
,
amax
,
scale
):
def
get_fp8_scale
(
fp8_max
,
amax
,
scale
):
...
@@ -78,15 +78,16 @@ class TestFP8Helper(unittest.TestCase):
...
@@ -78,15 +78,16 @@ class TestFP8Helper(unittest.TestCase):
sf
=
np
.
where
(
np
.
isfinite
(
amax
),
sf
,
scale
)
sf
=
np
.
where
(
np
.
isfinite
(
amax
),
sf
,
scale
)
return
np
.
where
(
exp
<
0
,
1
/
sf
,
sf
)
return
np
.
where
(
exp
<
0
,
1
/
sf
,
sf
)
meta_shape
=
(
num_of_meta
,
FP8Helper
.
AMAX_HISTORY_LEN
)
amax_meta_shape
=
(
num_of_meta
,
FP8Helper
.
AMAX_HISTORY_LEN
)
scale_meta_shape
=
(
num_of_meta
,
1
)
fp8_max_array
=
FP8Helper
.
generate_fp8_max_array
(
num_of_meta
)
fp8_max_array
=
FP8Helper
.
generate_fp8_max_array
(
num_of_meta
)
fp8_amax_array1
=
jax
.
random
.
uniform
(
key1
,
shape
=
meta_shape
)
fp8_amax_array1
=
jax
.
random
.
uniform
(
key1
,
shape
=
amax_
meta_shape
)
fp8_scale_array1
=
get_fp8_scale
(
fp8_max_array
,
select_amax
(
fp8_amax_array1
),
fp8_scale_array1
=
get_fp8_scale
(
fp8_max_array
,
select_amax
(
fp8_amax_array1
),
jnp
.
ones
(
meta_shape
))
jnp
.
ones
(
scale_
meta_shape
))
fp8_scale_inv_array1
=
1
/
fp8_scale_array1
fp8_scale_inv_array1
=
1
/
fp8_scale_array1
fp8_amax_array2
=
jax
.
random
.
uniform
(
key2
,
shape
=
meta_shape
)
fp8_amax_array2
=
jax
.
random
.
uniform
(
key2
,
shape
=
amax_
meta_shape
)
fp8_scale_array2
=
get_fp8_scale
(
fp8_max_array
,
select_amax
(
fp8_amax_array2
),
fp8_scale_array2
=
get_fp8_scale
(
fp8_max_array
,
select_amax
(
fp8_amax_array2
),
jnp
.
ones
(
meta_shape
))
jnp
.
ones
(
scale_
meta_shape
))
fp8_scale_inv_array2
=
1
/
fp8_scale_array2
fp8_scale_inv_array2
=
1
/
fp8_scale_array2
state
=
flax
.
core
.
frozen_dict
.
FrozenDict
({
state
=
flax
.
core
.
frozen_dict
.
FrozenDict
({
...
@@ -94,14 +95,14 @@ class TestFP8Helper(unittest.TestCase):
...
@@ -94,14 +95,14 @@ class TestFP8Helper(unittest.TestCase):
"test_update_fp8_metas1"
:
{
"test_update_fp8_metas1"
:
{
FP8Helper
.
FP8_MAX_NAME
:
fp8_max_array
,
FP8Helper
.
FP8_MAX_NAME
:
fp8_max_array
,
FP8Helper
.
FP8_AMAX_NAME
:
fp8_amax_array1
,
FP8Helper
.
FP8_AMAX_NAME
:
fp8_amax_array1
,
FP8Helper
.
FP8_SCALE_NAME
:
jnp
.
ones
(
meta_shape
),
FP8Helper
.
FP8_SCALE_NAME
:
jnp
.
ones
(
scale_
meta_shape
),
FP8Helper
.
FP8_SCALE_INV_NAME
:
jnp
.
ones
(
meta_shape
)
FP8Helper
.
FP8_SCALE_INV_NAME
:
jnp
.
ones
(
scale_
meta_shape
)
},
},
"test_update_fp8_metas2"
:
{
"test_update_fp8_metas2"
:
{
FP8Helper
.
FP8_MAX_NAME
:
fp8_max_array
,
FP8Helper
.
FP8_MAX_NAME
:
fp8_max_array
,
FP8Helper
.
FP8_AMAX_NAME
:
fp8_amax_array2
,
FP8Helper
.
FP8_AMAX_NAME
:
fp8_amax_array2
,
FP8Helper
.
FP8_SCALE_NAME
:
jnp
.
ones
(
meta_shape
),
FP8Helper
.
FP8_SCALE_NAME
:
jnp
.
ones
(
scale_
meta_shape
),
FP8Helper
.
FP8_SCALE_INV_NAME
:
jnp
.
ones
(
meta_shape
)
FP8Helper
.
FP8_SCALE_INV_NAME
:
jnp
.
ones
(
scale_
meta_shape
)
}
}
}
}
})
})
...
...
transformer_engine/jax/fp8.py
View file @
a7bc7cf7
...
@@ -305,9 +305,9 @@ class FP8Helper:
...
@@ -305,9 +305,9 @@ class FP8Helper:
fp8_max
=
fp8_meta_arrays
[
fp8_max_idx
]
fp8_max
=
fp8_meta_arrays
[
fp8_max_idx
]
if
FP8Helper
.
AMAX_COMPUTE_ALGO
is
AmaxComputeAlgo
.
MAX
:
if
FP8Helper
.
AMAX_COMPUTE_ALGO
is
AmaxComputeAlgo
.
MAX
:
amax
=
jnp
.
max
(
fp8_meta_arrays
[
fp8_amax_idx
],
axis
=
1
,
keepdims
=
True
)
amax
=
jnp
.
max
(
fp8_meta_arrays
[
fp8_amax_idx
],
axis
=
-
1
,
keepdims
=
True
)
else
:
else
:
amax
=
fp8_meta_arrays
[
fp8_amax_idx
][
:
,
0
:
1
]
amax
=
fp8_meta_arrays
[
fp8_amax_idx
][
...
,
0
:
1
]
scale
=
fp8_meta_arrays
[
fp8_scale_idx
]
scale
=
fp8_meta_arrays
[
fp8_scale_idx
]
exp
=
jnp
.
floor
(
jnp
.
log2
(
fp8_max
/
amax
))
-
FP8Helper
.
MARGIN
exp
=
jnp
.
floor
(
jnp
.
log2
(
fp8_max
/
amax
))
-
FP8Helper
.
MARGIN
...
@@ -366,14 +366,14 @@ def fp8_autocast(enabled: bool = False,
...
@@ -366,14 +366,14 @@ def fp8_autocast(enabled: bool = False,
if
fp8_recipe
is
None
:
if
fp8_recipe
is
None
:
fp8_recipe
=
DelayedScaling
()
fp8_recipe
=
DelayedScaling
()
assert
fp8_recipe
.
amax_compute_algo
in
[
"max"
,
"most_recent"
],
(
assert
fp8_recipe
.
amax_compute_algo
in
[
"DelayedScaling amax_compute_algo only supports max and most_recent with TE/JAX."
)
"max"
,
"most_recent"
],
(
"DelayedScaling amax_compute_algo only supports max and most_recent with TE/JAX."
)
assert
fp8_recipe
.
scaling_factor_compute_algo
is
None
,
(
assert
fp8_recipe
.
scaling_factor_compute_algo
is
None
,
(
"DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX."
)
"DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX."
)
assert
fp8_recipe
.
override_linear_precision
==
(
False
,
False
,
False
),
(
assert
fp8_recipe
.
override_linear_precision
==
(
False
,
False
,
False
),
(
"DelayedScaling override_linear_precision isn't supported by TE/JAX."
)
"DelayedScaling override_linear_precision isn't supported by TE/JAX."
)
assert
fp8_recipe
.
reduce_amax
,
(
assert
fp8_recipe
.
reduce_amax
,
(
"DelayedScaling reduce_amax should be enabled for TE/JAX."
)
"DelayedScaling reduce_amax should be enabled for TE/JAX."
)
if
sharding_resource
is
None
:
if
sharding_resource
is
None
:
sharding_resource
=
ShardingResource
()
sharding_resource
=
ShardingResource
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment