Unverified Commit 2d73334d authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

Adding slice to fix failure with multi-devices. (#89)


Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>
parent bc9d57a3
......@@ -138,13 +138,13 @@ def _fp8_dot_fwd(
gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
input_amax = amax[gemm_input_idx]
input_amax = amax[gemm_input_idx, 0:1]
input_scale = scale[gemm_input_idx]
input_scale_inv = scale_inv[gemm_input_idx]
input_cast, input_cast_trans, input_amax = cast_transpose(inputs_, input_amax, input_scale,
input_scale_inv, fwd_dtype)
kernel_amax = amax[gemm_kernel_idx]
kernel_amax = amax[gemm_kernel_idx, 0:1]
kernel_scale = scale[gemm_kernel_idx]
kernel_scale_inv = scale_inv[gemm_kernel_idx]
kernel_cast, kernel_cast_trans, kernel_amax = cast_transpose(kernel_, kernel_amax, kernel_scale,
......@@ -182,7 +182,7 @@ def _fp8_dot_bwd(
gemm_input_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0)
grad_amax = amax[gemm_grad_idx]
grad_amax = amax[gemm_grad_idx, 0:1]
grad_scale = scale[gemm_grad_idx]
grad_scale_inv = scale_inv[gemm_grad_idx]
g = jnp.reshape(g, (input_cast_trans.shape[1], -1))
......
......@@ -285,7 +285,7 @@ def _layernorm_fp8_dot_fwd(
gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
input_amax = amax[gemm_input_idx]
input_amax = amax[gemm_input_idx, 0:1]
input_scale = scale[gemm_input_idx]
input_scale_inv = scale_inv[gemm_input_idx]
if layernorm_type == 'layernorm':
......@@ -309,7 +309,7 @@ def _layernorm_fp8_dot_fwd(
ln_out_ = jnp.reshape(ln_out, (-1, input_contracting_size))
kernel_ = jnp.reshape(kernel, (kernel_contracting_size, -1))
kernel_amax = amax[gemm_kernel_idx]
kernel_amax = amax[gemm_kernel_idx, 0:1]
kernel_scale = scale[gemm_kernel_idx]
kernel_scale_inv = scale_inv[gemm_kernel_idx]
kernel_cast, kernel_cast_trans, kernel_amax = cast_transpose(kernel_, kernel_amax, kernel_scale,
......@@ -352,7 +352,7 @@ def _layernorm_fp8_dot_bwd(
gemm_input_idx, gemm_kernel_idx, gemm_grad_idx = \
FP8Helper.get_fp8_meta_indices(0)
grad_amax = amax[gemm_grad_idx]
grad_amax = amax[gemm_grad_idx, 0:1]
grad_scale = scale[gemm_grad_idx]
grad_scale_inv = scale_inv[gemm_grad_idx]
......
......@@ -266,7 +266,7 @@ def _fp8_mlp_fwd(
gemm1_input_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
input_amax = amax[gemm1_input_idx]
input_amax = amax[gemm1_input_idx, 0:1]
input_scale = scale[gemm1_input_idx]
input_scale_inv = scale_inv[gemm1_input_idx]
if layernorm_type == 'layernorm':
......@@ -286,7 +286,7 @@ def _fp8_mlp_fwd(
epsilon=epsilon)
mu = None
kernel_1_amax = amax[gemm1_kernel_idx]
kernel_1_amax = amax[gemm1_kernel_idx, 0:1]
kernel_1_scale = scale[gemm1_kernel_idx]
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
kernel_1_cast, kernel_1_cast_trans, kernel_1_amax = cast_transpose(
......@@ -297,13 +297,13 @@ def _fp8_mlp_fwd(
gemm2_input_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)
kernel_2_amax = amax[gemm2_kernel_idx]
kernel_2_amax = amax[gemm2_kernel_idx, 0:1]
kernel_2_scale = scale[gemm2_kernel_idx]
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
kernel_2_cast, kernel_2_cast_trans, kernel_2_amax = cast_transpose(
kernel_2_, kernel_2_amax, kernel_2_scale, kernel_2_scale_inv, fwd_dtype)
dense_1_out_amax = amax[gemm2_input_idx]
dense_1_out_amax = amax[gemm2_input_idx, 0:1]
dense_1_out_scale = scale[gemm2_input_idx]
dense_1_out_scale_inv = scale_inv[gemm2_input_idx]
gated_gelu_output_cast, gated_gelu_amax = gated_gelu_fp8(dense_1_output, dense_1_out_amax,
......@@ -354,7 +354,7 @@ def _fp8_mlp_bwd(
gemm2_input_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1)
grad_amax = amax[gemm2_grad_idx]
grad_amax = amax[gemm2_grad_idx, 0:1]
grad_scale = scale[gemm2_grad_idx]
grad_scale_inv = scale_inv[gemm2_grad_idx]
......@@ -372,7 +372,7 @@ def _fp8_mlp_bwd(
gemm1_input_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0)
dgrad_2_amax = amax[gemm1_grad_idx]
dgrad_2_amax = amax[gemm1_grad_idx, 0:1]
dgrad_2_scale = scale[gemm1_grad_idx]
dgrad_2_scale_inv = scale_inv[gemm1_grad_idx]
dgelu, dgelu_trans, dgelu_amax = dgated_gelu_cast_transpose(dgrad_2, dense_1_output,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment