Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
2d73334d
Unverified
Commit
2d73334d
authored
Mar 10, 2023
by
Ming-Xu Huang
Committed by
GitHub
Mar 10, 2023
Browse files
Adding slice to fix failure with multi-devices. (#89)
Signed-off-by:
Ming-Xu Huang
<
mingh@nvidia.com
>
parent
bc9d57a3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
12 deletions
+12
-12
transformer_engine/jax/dot.py
transformer_engine/jax/dot.py
+3
-3
transformer_engine/jax/layernorm.py
transformer_engine/jax/layernorm.py
+3
-3
transformer_engine/jax/mlp.py
transformer_engine/jax/mlp.py
+6
-6
No files found.
transformer_engine/jax/dot.py
View file @
2d73334d
...
@@ -138,13 +138,13 @@ def _fp8_dot_fwd(
...
@@ -138,13 +138,13 @@ def _fp8_dot_fwd(
gemm_input_idx
,
gemm_kernel_idx
,
_
=
FP8Helper
.
get_fp8_meta_indices
(
0
)
gemm_input_idx
,
gemm_kernel_idx
,
_
=
FP8Helper
.
get_fp8_meta_indices
(
0
)
input_amax
=
amax
[
gemm_input_idx
]
input_amax
=
amax
[
gemm_input_idx
,
0
:
1
]
input_scale
=
scale
[
gemm_input_idx
]
input_scale
=
scale
[
gemm_input_idx
]
input_scale_inv
=
scale_inv
[
gemm_input_idx
]
input_scale_inv
=
scale_inv
[
gemm_input_idx
]
input_cast
,
input_cast_trans
,
input_amax
=
cast_transpose
(
inputs_
,
input_amax
,
input_scale
,
input_cast
,
input_cast_trans
,
input_amax
=
cast_transpose
(
inputs_
,
input_amax
,
input_scale
,
input_scale_inv
,
fwd_dtype
)
input_scale_inv
,
fwd_dtype
)
kernel_amax
=
amax
[
gemm_kernel_idx
]
kernel_amax
=
amax
[
gemm_kernel_idx
,
0
:
1
]
kernel_scale
=
scale
[
gemm_kernel_idx
]
kernel_scale
=
scale
[
gemm_kernel_idx
]
kernel_scale_inv
=
scale_inv
[
gemm_kernel_idx
]
kernel_scale_inv
=
scale_inv
[
gemm_kernel_idx
]
kernel_cast
,
kernel_cast_trans
,
kernel_amax
=
cast_transpose
(
kernel_
,
kernel_amax
,
kernel_scale
,
kernel_cast
,
kernel_cast_trans
,
kernel_amax
=
cast_transpose
(
kernel_
,
kernel_amax
,
kernel_scale
,
...
@@ -182,7 +182,7 @@ def _fp8_dot_bwd(
...
@@ -182,7 +182,7 @@ def _fp8_dot_bwd(
gemm_input_idx
,
gemm_kernel_idx
,
gemm_grad_idx
=
FP8Helper
.
get_fp8_meta_indices
(
0
)
gemm_input_idx
,
gemm_kernel_idx
,
gemm_grad_idx
=
FP8Helper
.
get_fp8_meta_indices
(
0
)
grad_amax
=
amax
[
gemm_grad_idx
]
grad_amax
=
amax
[
gemm_grad_idx
,
0
:
1
]
grad_scale
=
scale
[
gemm_grad_idx
]
grad_scale
=
scale
[
gemm_grad_idx
]
grad_scale_inv
=
scale_inv
[
gemm_grad_idx
]
grad_scale_inv
=
scale_inv
[
gemm_grad_idx
]
g
=
jnp
.
reshape
(
g
,
(
input_cast_trans
.
shape
[
1
],
-
1
))
g
=
jnp
.
reshape
(
g
,
(
input_cast_trans
.
shape
[
1
],
-
1
))
...
...
transformer_engine/jax/layernorm.py
View file @
2d73334d
...
@@ -285,7 +285,7 @@ def _layernorm_fp8_dot_fwd(
...
@@ -285,7 +285,7 @@ def _layernorm_fp8_dot_fwd(
gemm_input_idx
,
gemm_kernel_idx
,
_
=
FP8Helper
.
get_fp8_meta_indices
(
0
)
gemm_input_idx
,
gemm_kernel_idx
,
_
=
FP8Helper
.
get_fp8_meta_indices
(
0
)
input_amax
=
amax
[
gemm_input_idx
]
input_amax
=
amax
[
gemm_input_idx
,
0
:
1
]
input_scale
=
scale
[
gemm_input_idx
]
input_scale
=
scale
[
gemm_input_idx
]
input_scale_inv
=
scale_inv
[
gemm_input_idx
]
input_scale_inv
=
scale_inv
[
gemm_input_idx
]
if
layernorm_type
==
'layernorm'
:
if
layernorm_type
==
'layernorm'
:
...
@@ -309,7 +309,7 @@ def _layernorm_fp8_dot_fwd(
...
@@ -309,7 +309,7 @@ def _layernorm_fp8_dot_fwd(
ln_out_
=
jnp
.
reshape
(
ln_out
,
(
-
1
,
input_contracting_size
))
ln_out_
=
jnp
.
reshape
(
ln_out
,
(
-
1
,
input_contracting_size
))
kernel_
=
jnp
.
reshape
(
kernel
,
(
kernel_contracting_size
,
-
1
))
kernel_
=
jnp
.
reshape
(
kernel
,
(
kernel_contracting_size
,
-
1
))
kernel_amax
=
amax
[
gemm_kernel_idx
]
kernel_amax
=
amax
[
gemm_kernel_idx
,
0
:
1
]
kernel_scale
=
scale
[
gemm_kernel_idx
]
kernel_scale
=
scale
[
gemm_kernel_idx
]
kernel_scale_inv
=
scale_inv
[
gemm_kernel_idx
]
kernel_scale_inv
=
scale_inv
[
gemm_kernel_idx
]
kernel_cast
,
kernel_cast_trans
,
kernel_amax
=
cast_transpose
(
kernel_
,
kernel_amax
,
kernel_scale
,
kernel_cast
,
kernel_cast_trans
,
kernel_amax
=
cast_transpose
(
kernel_
,
kernel_amax
,
kernel_scale
,
...
@@ -352,7 +352,7 @@ def _layernorm_fp8_dot_bwd(
...
@@ -352,7 +352,7 @@ def _layernorm_fp8_dot_bwd(
gemm_input_idx
,
gemm_kernel_idx
,
gemm_grad_idx
=
\
gemm_input_idx
,
gemm_kernel_idx
,
gemm_grad_idx
=
\
FP8Helper
.
get_fp8_meta_indices
(
0
)
FP8Helper
.
get_fp8_meta_indices
(
0
)
grad_amax
=
amax
[
gemm_grad_idx
]
grad_amax
=
amax
[
gemm_grad_idx
,
0
:
1
]
grad_scale
=
scale
[
gemm_grad_idx
]
grad_scale
=
scale
[
gemm_grad_idx
]
grad_scale_inv
=
scale_inv
[
gemm_grad_idx
]
grad_scale_inv
=
scale_inv
[
gemm_grad_idx
]
...
...
transformer_engine/jax/mlp.py
View file @
2d73334d
...
@@ -266,7 +266,7 @@ def _fp8_mlp_fwd(
...
@@ -266,7 +266,7 @@ def _fp8_mlp_fwd(
gemm1_input_idx
,
gemm1_kernel_idx
,
_
=
FP8Helper
.
get_fp8_meta_indices
(
0
)
gemm1_input_idx
,
gemm1_kernel_idx
,
_
=
FP8Helper
.
get_fp8_meta_indices
(
0
)
input_amax
=
amax
[
gemm1_input_idx
]
input_amax
=
amax
[
gemm1_input_idx
,
0
:
1
]
input_scale
=
scale
[
gemm1_input_idx
]
input_scale
=
scale
[
gemm1_input_idx
]
input_scale_inv
=
scale_inv
[
gemm1_input_idx
]
input_scale_inv
=
scale_inv
[
gemm1_input_idx
]
if
layernorm_type
==
'layernorm'
:
if
layernorm_type
==
'layernorm'
:
...
@@ -286,7 +286,7 @@ def _fp8_mlp_fwd(
...
@@ -286,7 +286,7 @@ def _fp8_mlp_fwd(
epsilon
=
epsilon
)
epsilon
=
epsilon
)
mu
=
None
mu
=
None
kernel_1_amax
=
amax
[
gemm1_kernel_idx
]
kernel_1_amax
=
amax
[
gemm1_kernel_idx
,
0
:
1
]
kernel_1_scale
=
scale
[
gemm1_kernel_idx
]
kernel_1_scale
=
scale
[
gemm1_kernel_idx
]
kernel_1_scale_inv
=
scale_inv
[
gemm1_kernel_idx
]
kernel_1_scale_inv
=
scale_inv
[
gemm1_kernel_idx
]
kernel_1_cast
,
kernel_1_cast_trans
,
kernel_1_amax
=
cast_transpose
(
kernel_1_cast
,
kernel_1_cast_trans
,
kernel_1_amax
=
cast_transpose
(
...
@@ -297,13 +297,13 @@ def _fp8_mlp_fwd(
...
@@ -297,13 +297,13 @@ def _fp8_mlp_fwd(
gemm2_input_idx
,
gemm2_kernel_idx
,
_
=
FP8Helper
.
get_fp8_meta_indices
(
1
)
gemm2_input_idx
,
gemm2_kernel_idx
,
_
=
FP8Helper
.
get_fp8_meta_indices
(
1
)
kernel_2_amax
=
amax
[
gemm2_kernel_idx
]
kernel_2_amax
=
amax
[
gemm2_kernel_idx
,
0
:
1
]
kernel_2_scale
=
scale
[
gemm2_kernel_idx
]
kernel_2_scale
=
scale
[
gemm2_kernel_idx
]
kernel_2_scale_inv
=
scale_inv
[
gemm2_kernel_idx
]
kernel_2_scale_inv
=
scale_inv
[
gemm2_kernel_idx
]
kernel_2_cast
,
kernel_2_cast_trans
,
kernel_2_amax
=
cast_transpose
(
kernel_2_cast
,
kernel_2_cast_trans
,
kernel_2_amax
=
cast_transpose
(
kernel_2_
,
kernel_2_amax
,
kernel_2_scale
,
kernel_2_scale_inv
,
fwd_dtype
)
kernel_2_
,
kernel_2_amax
,
kernel_2_scale
,
kernel_2_scale_inv
,
fwd_dtype
)
dense_1_out_amax
=
amax
[
gemm2_input_idx
]
dense_1_out_amax
=
amax
[
gemm2_input_idx
,
0
:
1
]
dense_1_out_scale
=
scale
[
gemm2_input_idx
]
dense_1_out_scale
=
scale
[
gemm2_input_idx
]
dense_1_out_scale_inv
=
scale_inv
[
gemm2_input_idx
]
dense_1_out_scale_inv
=
scale_inv
[
gemm2_input_idx
]
gated_gelu_output_cast
,
gated_gelu_amax
=
gated_gelu_fp8
(
dense_1_output
,
dense_1_out_amax
,
gated_gelu_output_cast
,
gated_gelu_amax
=
gated_gelu_fp8
(
dense_1_output
,
dense_1_out_amax
,
...
@@ -354,7 +354,7 @@ def _fp8_mlp_bwd(
...
@@ -354,7 +354,7 @@ def _fp8_mlp_bwd(
gemm2_input_idx
,
gemm2_kernel_idx
,
gemm2_grad_idx
=
FP8Helper
.
get_fp8_meta_indices
(
1
)
gemm2_input_idx
,
gemm2_kernel_idx
,
gemm2_grad_idx
=
FP8Helper
.
get_fp8_meta_indices
(
1
)
grad_amax
=
amax
[
gemm2_grad_idx
]
grad_amax
=
amax
[
gemm2_grad_idx
,
0
:
1
]
grad_scale
=
scale
[
gemm2_grad_idx
]
grad_scale
=
scale
[
gemm2_grad_idx
]
grad_scale_inv
=
scale_inv
[
gemm2_grad_idx
]
grad_scale_inv
=
scale_inv
[
gemm2_grad_idx
]
...
@@ -372,7 +372,7 @@ def _fp8_mlp_bwd(
...
@@ -372,7 +372,7 @@ def _fp8_mlp_bwd(
gemm1_input_idx
,
gemm1_kernel_idx
,
gemm1_grad_idx
=
FP8Helper
.
get_fp8_meta_indices
(
0
)
gemm1_input_idx
,
gemm1_kernel_idx
,
gemm1_grad_idx
=
FP8Helper
.
get_fp8_meta_indices
(
0
)
dgrad_2_amax
=
amax
[
gemm1_grad_idx
]
dgrad_2_amax
=
amax
[
gemm1_grad_idx
,
0
:
1
]
dgrad_2_scale
=
scale
[
gemm1_grad_idx
]
dgrad_2_scale
=
scale
[
gemm1_grad_idx
]
dgrad_2_scale_inv
=
scale_inv
[
gemm1_grad_idx
]
dgrad_2_scale_inv
=
scale_inv
[
gemm1_grad_idx
]
dgelu
,
dgelu_trans
,
dgelu_amax
=
dgated_gelu_cast_transpose
(
dgrad_2
,
dense_1_output
,
dgelu
,
dgelu_trans
,
dgelu_amax
=
dgated_gelu_cast_transpose
(
dgrad_2
,
dense_1_output
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment