Unverified Commit 58f19082 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

remove d2d copies (#64)



* remove d2d copies
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c126396b
...@@ -30,7 +30,7 @@ at::Tensor cast_to_fp8_ts(const at::Tensor &input, ...@@ -30,7 +30,7 @@ at::Tensor cast_to_fp8_ts(const at::Tensor &input,
amax[0][fp8_tensor], amax[0][fp8_tensor],
scale_inv[fp8_tensor], scale_inv[fp8_tensor],
otype_arg); otype_arg);
return output.clone(); return output;
} }
at::Tensor cast_from_fp8_ts(const at::Tensor &input, at::Tensor cast_from_fp8_ts(const at::Tensor &input,
...@@ -44,7 +44,7 @@ at::Tensor cast_from_fp8_ts(const at::Tensor &input, ...@@ -44,7 +44,7 @@ at::Tensor cast_from_fp8_ts(const at::Tensor &input,
scale_inv[fp8_tensor], scale_inv[fp8_tensor],
itype_arg, itype_arg,
otype_arg); otype_arg);
return output.clone(); return output;
} }
at::Tensor fp8_gelu_ts(at::Tensor input, at::Tensor fp8_gelu_ts(at::Tensor input,
...@@ -59,7 +59,7 @@ at::Tensor fp8_gelu_ts(at::Tensor input, ...@@ -59,7 +59,7 @@ at::Tensor fp8_gelu_ts(at::Tensor input,
amax[0][fp8_tensor], amax[0][fp8_tensor],
scale_inv[fp8_tensor], scale_inv[fp8_tensor],
otype_arg); otype_arg);
return output.clone(); return output;
} }
at::Tensor te_gemm_ts(at::Tensor A, at::Tensor te_gemm_ts(at::Tensor A,
...@@ -92,20 +92,18 @@ at::Tensor te_gemm_ts(at::Tensor A, ...@@ -92,20 +92,18 @@ at::Tensor te_gemm_ts(at::Tensor A,
bool accumulate_arg = static_cast<bool>(accumulate); bool accumulate_arg = static_cast<bool>(accumulate);
bool use_split_accumulator_arg = static_cast<bool>(use_split_accumulator); bool use_split_accumulator_arg = static_cast<bool>(use_split_accumulator);
at::Tensor A_scale_inverse_arg = A_scale_inverse.clone();
if (A_scale_inverse.numel()) if (A_scale_inverse.numel())
A_scale_inverse_arg = A_scale_inverse[A_fp8_tensor]; A_scale_inverse = A_scale_inverse[A_fp8_tensor];
at::Tensor B_scale_inverse_arg = B_scale_inverse.clone();
if (B_scale_inverse.numel()) if (B_scale_inverse.numel())
B_scale_inverse_arg = B_scale_inverse[B_fp8_tensor]; B_scale_inverse = B_scale_inverse[B_fp8_tensor];
te_gemm(A, te_gemm(A,
A_scale_inverse_arg, A_scale_inverse,
A_type_arg, A_type_arg,
transa_arg, transa_arg,
B, B,
B_scale_inverse_arg, B_scale_inverse,
B_type_arg, B_type_arg,
transb_arg, transb_arg,
D, D,
...@@ -141,7 +139,7 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, ...@@ -141,7 +139,7 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
scale_inv, scale_inv,
otype_arg); otype_arg);
return output.clone(); return output;
} }
at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input, at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
...@@ -155,7 +153,7 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input, ...@@ -155,7 +153,7 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
bias, bias,
eps_float); eps_float);
return output.clone(); return output;
} }
TORCH_LIBRARY(tex_ts, m) { TORCH_LIBRARY(tex_ts, m) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment