"vscode:/vscode.git/clone" did not exist on "e4b7e5ee25e54939204ae72126260060e2c0784c"
Unverified Commit 58f19082 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

remove d2d copies (#64)



* remove d2d copies
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c126396b
......@@ -30,7 +30,7 @@ at::Tensor cast_to_fp8_ts(const at::Tensor &input,
amax[0][fp8_tensor],
scale_inv[fp8_tensor],
otype_arg);
return output.clone();
return output;
}
at::Tensor cast_from_fp8_ts(const at::Tensor &input,
......@@ -44,7 +44,7 @@ at::Tensor cast_from_fp8_ts(const at::Tensor &input,
scale_inv[fp8_tensor],
itype_arg,
otype_arg);
return output.clone();
return output;
}
at::Tensor fp8_gelu_ts(at::Tensor input,
......@@ -59,7 +59,7 @@ at::Tensor fp8_gelu_ts(at::Tensor input,
amax[0][fp8_tensor],
scale_inv[fp8_tensor],
otype_arg);
return output.clone();
return output;
}
at::Tensor te_gemm_ts(at::Tensor A,
......@@ -92,20 +92,18 @@ at::Tensor te_gemm_ts(at::Tensor A,
bool accumulate_arg = static_cast<bool>(accumulate);
bool use_split_accumulator_arg = static_cast<bool>(use_split_accumulator);
at::Tensor A_scale_inverse_arg = A_scale_inverse.clone();
if (A_scale_inverse.numel())
A_scale_inverse_arg = A_scale_inverse[A_fp8_tensor];
A_scale_inverse = A_scale_inverse[A_fp8_tensor];
at::Tensor B_scale_inverse_arg = B_scale_inverse.clone();
if (B_scale_inverse.numel())
B_scale_inverse_arg = B_scale_inverse[B_fp8_tensor];
B_scale_inverse = B_scale_inverse[B_fp8_tensor];
te_gemm(A,
A_scale_inverse_arg,
A_scale_inverse,
A_type_arg,
transa_arg,
B,
B_scale_inverse_arg,
B_scale_inverse,
B_type_arg,
transb_arg,
D,
......@@ -141,7 +139,7 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
scale_inv,
otype_arg);
return output.clone();
return output;
}
at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
......@@ -155,7 +153,7 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
bias,
eps_float);
return output.clone();
return output;
}
TORCH_LIBRARY(tex_ts, m) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment