Unverified Commit 4478b044 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Replace `int8_t` in Pybind11 extensions with `int64_t` (#882)



Replace int8_t in PyTorch extensions with int64_t
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent e9606077
......@@ -400,7 +400,7 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype,
const int8_t sm_margin,
const int64_t sm_margin,
const bool zero_centered_gamma) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
float eps_float = static_cast<float>(eps);
......@@ -424,7 +424,7 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
double eps,
const int8_t sm_margin,
const int64_t sm_margin,
const bool zero_centered_gamma) {
float eps_float = static_cast<float>(eps);
......@@ -447,7 +447,7 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype,
const int8_t sm_margin,
const int64_t sm_margin,
const bool zero_centered_gamma) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
float eps_float = static_cast<float>(eps);
......@@ -469,7 +469,7 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input,
at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
double eps,
const int8_t sm_margin,
const int64_t sm_margin,
const bool zero_centered_gamma) {
float eps_float = static_cast<float>(eps);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment